Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 55 additions & 8 deletions mssql_python/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,59 @@
UTF16_ENCODINGS: frozenset[str] = frozenset(["utf-16le", "utf-16be"])


def _validate_utf16_wchar_compatibility(
encoding: str, wchar_type: int, context: str = "SQL_WCHAR"
) -> None:
"""
Validates UTF-16 encoding compatibility with SQL_WCHAR.

Centralizes the validation logic to eliminate duplication across setencoding/setdecoding.

Args:
encoding: The encoding string (already normalized to lowercase)
wchar_type: The SQL_WCHAR constant value to check against
context: Context string for error messages ('SQL_WCHAR', 'SQL_WCHAR ctype', etc.)

Raises:
ProgrammingError: If encoding is incompatible with SQL_WCHAR
"""
if encoding == "utf-16":
# UTF-16 with BOM is rejected due to byte order ambiguity
logger.warning("utf-16 with BOM rejected for %s", context)
raise ProgrammingError(
driver_error="UTF-16 with Byte Order Mark not supported for SQL_WCHAR",
ddbc_error=(
"Cannot use 'utf-16' encoding with SQL_WCHAR due to Byte Order Mark ambiguity. "
"Use 'utf-16le' or 'utf-16be' instead for explicit byte order."
),
)
elif encoding not in UTF16_ENCODINGS:
# Non-UTF-16 encodings are not supported with SQL_WCHAR
logger.warning(
"Non-UTF-16 encoding %s attempted with %s", sanitize_user_input(encoding), context
)

# Generate context-appropriate error messages
if "ctype" in context:
driver_error = f"SQL_WCHAR ctype only supports UTF-16 encodings"
ddbc_context = "SQL_WCHAR ctype"
else:
driver_error = f"SQL_WCHAR only supports UTF-16 encodings"
ddbc_context = "SQL_WCHAR"

raise ProgrammingError(
driver_error=driver_error,
ddbc_error=(
f"Cannot use encoding '{encoding}' with {ddbc_context}. "
f"SQL_WCHAR requires UTF-16 encodings (utf-16le, utf-16be)"
),
)


# Note: "utf-16" with BOM is NOT included as it's problematic for SQL_WCHAR
UTF16_ENCODINGS: frozenset[str] = frozenset(["utf-16le", "utf-16be"])


def _validate_utf16_wchar_compatibility(
encoding: str, wchar_type: int, context: str = "SQL_WCHAR"
) -> None:
Expand Down Expand Up @@ -293,12 +346,8 @@ def __init__(

# Initialize encoding/decoding settings lock for thread safety
# This lock protects both _encoding_settings and _decoding_settings dictionaries
# from concurrent modification. We use a simple Lock (not RLock) because:
# - Write operations (setencoding/setdecoding) replace the entire dict atomically
# - Read operations (getencoding/getdecoding) return a copy, so they're safe
# - No recursive locking is needed in our usage pattern
# This is more performant than RLock for the multiple-readers-single-writer pattern
self._encoding_lock = threading.Lock()
# to prevent race conditions when multiple threads are reading/writing encoding settings
self._encoding_lock = threading.RLock() # RLock allows recursive locking

# Initialize search escape character
self._searchescape = None
Expand Down Expand Up @@ -563,7 +612,6 @@ def getencoding(self) -> Dict[str, Union[str, int]]:

Note:
This method is thread-safe and can be called from multiple threads concurrently.
Returns a copy of the settings to prevent external modification.
"""
if self._closed:
raise InterfaceError(
Expand Down Expand Up @@ -730,7 +778,6 @@ def getdecoding(self, sqltype: int) -> Dict[str, Union[str, int]]:

Note:
This method is thread-safe and can be called from multiple threads concurrently.
Returns a copy of the settings to prevent external modification.
"""
if self._closed:
raise InterfaceError(
Expand Down
3 changes: 3 additions & 0 deletions mssql_python/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,9 @@ class ConstantsDDBC(Enum):
# Reset Connection Constants
SQL_RESET_CONNECTION_YES = 1

# Query Timeout Constants
SQL_ATTR_QUERY_TIMEOUT = 0


class GetInfoConstants(Enum):
"""
Expand Down
35 changes: 21 additions & 14 deletions mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,13 +681,34 @@ def _initialize_cursor(self) -> None:
Initialize the DDBC statement handle.
"""
self._allocate_statement_handle()
self._set_timeout()

def _allocate_statement_handle(self) -> None:
"""
Allocate the DDBC statement handle.
"""
self.hstmt = self._connection._conn.alloc_statement_handle()

def _set_timeout(self) -> None:
"""
Set the query timeout attribute on the statement handle.
This is called once when the cursor is created and after any handle reallocation.
Following pyodbc's approach for better performance.
"""
if self._timeout > 0:
logger.debug("_set_timeout: Setting query timeout=%d seconds", self._timeout)
try:
timeout_value = int(self._timeout)
ret = ddbc_bindings.DDBCSQLSetStmtAttr(
self.hstmt,
ddbc_sql_const.SQL_ATTR_QUERY_TIMEOUT.value,
timeout_value,
)
check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret)
logger.debug("Query timeout set to %d seconds", timeout_value)
except Exception as e: # pylint: disable=broad-exception-caught
logger.warning("Failed to set query timeout: %s", str(e))

def _reset_cursor(self) -> None:
"""
Reset the DDBC statement handle.
Expand Down Expand Up @@ -1216,20 +1237,6 @@ def execute( # pylint: disable=too-many-locals,too-many-branches,too-many-state
encoding_settings = self._get_encoding_settings()

# Apply timeout if set (non-zero)
if self._timeout > 0:
logger.debug("execute: Setting query timeout=%d seconds", self._timeout)
try:
timeout_value = int(self._timeout)
ret = ddbc_bindings.DDBCSQLSetStmtAttr(
self.hstmt,
ddbc_sql_const.SQL_ATTR_QUERY_TIMEOUT.value,
timeout_value,
)
check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret)
logger.debug("Set query timeout to %d seconds", timeout_value)
except Exception as e: # pylint: disable=broad-exception-caught
logger.warning("Failed to set query timeout: %s", str(e))

logger.debug("execute: Creating parameter type list")
param_info = ddbc_bindings.ParamInfo
parameters_type = []
Expand Down
34 changes: 22 additions & 12 deletions mssql_python/pybind/ddbc_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params,
std::string encodedStr;

if (py::isinstance<py::str>(param)) {
// Encode Unicode string using the specified encoding
// Encode Unicode string using the specified encoding (like pyodbc does)
try {
py::object encoded = param.attr("encode")(charEncoding, "strict");
encodedStr = encoded.cast<std::string>();
Expand Down Expand Up @@ -1741,7 +1741,7 @@ SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle,
offset += len;
}
} else if (matchedInfo->paramCType == SQL_C_CHAR) {
// Encode the string using the specified encoding
// Encode the string using the specified encoding (like pyodbc does)
std::string encodedStr;
try {
if (py::isinstance<py::str>(pyObj)) {
Expand Down Expand Up @@ -2043,7 +2043,7 @@ SQLRETURN BindParameterArray(SQLHANDLE hStmt, const py::list& columnwise_params,

if (py::isinstance<py::str>(columnValues[i])) {
// Use Python's codec system to encode the string with specified
// encoding
// encoding (like pyodbc does)
try {
py::object encoded =
columnValues[i].attr("encode")(charEncoding, "strict");
Expand Down Expand Up @@ -2836,10 +2836,9 @@ py::object FetchLobColumnData(SQLHSTMT hStmt, SQLUSMALLINT colIndex, SQLSMALLINT
return py::bytes(buffer.data(), buffer.size());
}

// For SQL_C_CHAR data, decode using the specified encoding
// Create py::bytes once to avoid double allocation
py::bytes raw_bytes(buffer.data(), buffer.size());
// For SQL_C_CHAR data, decode using the specified encoding (like pyodbc does)
try {
py::bytes raw_bytes(buffer.data(), buffer.size());
py::object decoded = raw_bytes.attr("decode")(charEncoding, "strict");
LOG("FetchLobColumnData: Decoded narrow string with '%s' - %zu bytes -> %zu chars for "
"column %d",
Expand All @@ -2849,7 +2848,7 @@ py::object FetchLobColumnData(SQLHSTMT hStmt, SQLUSMALLINT colIndex, SQLSMALLINT
LOG_ERROR("FetchLobColumnData: Failed to decode with '%s' for column %d: %s",
charEncoding.c_str(), colIndex, e.what());
// Return raw bytes as fallback
return raw_bytes;
return py::bytes(buffer.data(), buffer.size());
}
}

Expand Down Expand Up @@ -2916,10 +2915,10 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p
if (numCharsInData < dataBuffer.size()) {
// SQLGetData will null-terminate the data
// Use Python's codec system to decode bytes with specified encoding
// Create py::bytes once to avoid double allocation
py::bytes raw_bytes(reinterpret_cast<char*>(dataBuffer.data()),
static_cast<size_t>(dataLen));
// (like pyodbc does)
try {
py::bytes raw_bytes(reinterpret_cast<char*>(dataBuffer.data()),
static_cast<size_t>(dataLen));
py::object decoded =
raw_bytes.attr("decode")(charEncoding, "strict");
row.append(decoded);
Expand All @@ -2931,6 +2930,8 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p
"SQLGetData: Failed to decode CHAR column %d with '%s': %s",
i, charEncoding.c_str(), e.what());
// Return raw bytes as fallback
py::bytes raw_bytes(reinterpret_cast<char*>(dataBuffer.data()),
static_cast<size_t>(dataLen));
row.append(raw_bytes);
}
} else {
Expand Down Expand Up @@ -4395,8 +4396,17 @@ PYBIND11_MODULE(ddbc_bindings, m) {
"Set the decimal separator character");
m.def(
"DDBCSQLSetStmtAttr",
[](SqlHandlePtr stmt, SQLINTEGER attr, SQLPOINTER value) {
return SQLSetStmtAttr_ptr(stmt->get(), attr, value, 0);
[](SqlHandlePtr stmt, SQLINTEGER attr, py::object value) {
SQLPOINTER ptr_value;
if (py::isinstance<py::int_>(value)) {
// For integer attributes like SQL_ATTR_QUERY_TIMEOUT
ptr_value =
reinterpret_cast<SQLPOINTER>(static_cast<SQLULEN>(value.cast<int64_t>()));
} else {
// For pointer attributes
ptr_value = value.cast<SQLPOINTER>();
}
return SQLSetStmtAttr_ptr(stmt->get(), attr, ptr_value, 0);
},
"Set statement attributes");
m.def("DDBCSQLGetTypeInfo", &SQLGetTypeInfo_Wrapper,
Expand Down
Loading
Loading