Skip to content

Commit 4ad47ca

Browse files
committed
Resolving comments
1 parent 6583708 commit 4ad47ca

File tree

4 files changed

+958
-592
lines changed

4 files changed

+958
-592
lines changed

mssql_python/connection.py

Lines changed: 63 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,54 @@
5757
# Note: "utf-16" with BOM is NOT included as it's problematic for SQL_WCHAR
5858
UTF16_ENCODINGS: frozenset[str] = frozenset(["utf-16le", "utf-16be"])
5959

60-
# Valid encoding characters (alphanumeric, dash, underscore only)
61-
import string
6260

63-
VALID_ENCODING_CHARS: frozenset[str] = frozenset(string.ascii_letters + string.digits + "-_")
61+
def _validate_utf16_wchar_compatibility(
62+
encoding: str, wchar_type: int, context: str = "SQL_WCHAR"
63+
) -> None:
64+
"""
65+
Validates UTF-16 encoding compatibility with SQL_WCHAR.
66+
67+
Centralizes the validation logic to eliminate duplication across setencoding/setdecoding.
68+
69+
Args:
70+
encoding: The encoding string (already normalized to lowercase)
71+
wchar_type: The SQL_WCHAR constant value to check against
72+
context: Context string for error messages ('SQL_WCHAR', 'SQL_WCHAR ctype', etc.)
73+
74+
Raises:
75+
ProgrammingError: If encoding is incompatible with SQL_WCHAR
76+
"""
77+
if encoding == "utf-16":
78+
# UTF-16 with BOM is rejected due to byte order ambiguity
79+
logger.warning("utf-16 with BOM rejected for %s", context)
80+
raise ProgrammingError(
81+
driver_error="UTF-16 with Byte Order Mark not supported for SQL_WCHAR",
82+
ddbc_error=(
83+
"Cannot use 'utf-16' encoding with SQL_WCHAR due to Byte Order Mark ambiguity. "
84+
"Use 'utf-16le' or 'utf-16be' instead for explicit byte order."
85+
),
86+
)
87+
elif encoding not in UTF16_ENCODINGS:
88+
# Non-UTF-16 encodings are not supported with SQL_WCHAR
89+
logger.warning(
90+
"Non-UTF-16 encoding %s attempted with %s", sanitize_user_input(encoding), context
91+
)
92+
93+
# Generate context-appropriate error messages
94+
if "ctype" in context:
95+
driver_error = f"SQL_WCHAR ctype only supports UTF-16 encodings"
96+
ddbc_context = "SQL_WCHAR ctype"
97+
else:
98+
driver_error = f"SQL_WCHAR only supports UTF-16 encodings"
99+
ddbc_context = "SQL_WCHAR"
100+
101+
raise ProgrammingError(
102+
driver_error=driver_error,
103+
ddbc_error=(
104+
f"Cannot use encoding '{encoding}' with {ddbc_context}. "
105+
f"SQL_WCHAR requires UTF-16 encodings (utf-16le, utf-16be)"
106+
),
107+
)
64108

65109

66110
def _validate_encoding(encoding: str) -> bool:
@@ -78,14 +122,18 @@ def _validate_encoding(encoding: str) -> bool:
78122
Cache size is limited to 128 entries which should cover most use cases.
79123
Also validates that encoding name only contains safe characters.
80124
"""
81-
# First check for dangerous characters (security validation)
82-
if not all(c in VALID_ENCODING_CHARS for c in encoding):
125+
# Basic security checks - prevent obvious attacks
126+
if not encoding or not isinstance(encoding, str):
83127
return False
84128

85129
# Check length limit (prevent DOS)
86130
if len(encoding) > 100:
87131
return False
88132

133+
# Prevent null bytes and control characters that could cause issues
134+
if "\x00" in encoding or any(ord(c) < 32 and c not in "\t\n\r" for c in encoding):
135+
return False
136+
89137
# Then check if it's a valid Python codec
90138
try:
91139
codecs.lookup(encoding)
@@ -450,18 +498,9 @@ def setencoding(self, encoding: Optional[str] = None, ctype: Optional[int] = Non
450498
encoding = encoding.casefold()
451499
logger.debug("setencoding: Encoding normalized to %s", encoding)
452500

453-
# Reject 'utf-16' with BOM for SQL_WCHAR (ambiguous byte order)
454-
if encoding == "utf-16" and ctype == ConstantsDDBC.SQL_WCHAR.value:
455-
logger.warning(
456-
"utf-16 with BOM rejected for SQL_WCHAR",
457-
)
458-
raise ProgrammingError(
459-
driver_error="UTF-16 with Byte Order Mark not supported for SQL_WCHAR",
460-
ddbc_error=(
461-
"Cannot use 'utf-16' encoding with SQL_WCHAR due to Byte Order Mark ambiguity. "
462-
"Use 'utf-16le' or 'utf-16be' instead for explicit byte order."
463-
),
464-
)
501+
# Early validation if ctype is already specified as SQL_WCHAR
502+
if ctype == ConstantsDDBC.SQL_WCHAR.value:
503+
_validate_utf16_wchar_compatibility(encoding, ctype, "SQL_WCHAR")
465504

466505
# Set default ctype based on encoding if not provided
467506
if ctype is None:
@@ -488,28 +527,9 @@ def setencoding(self, encoding: Optional[str] = None, ctype: Optional[int] = Non
488527
),
489528
)
490529

491-
# Validate that SQL_WCHAR ctype only used with UTF-16 encodings (not utf-16 with BOM)
530+
# Final validation: SQL_WCHAR ctype only supports UTF-16 encodings (without BOM)
492531
if ctype == ConstantsDDBC.SQL_WCHAR.value:
493-
if encoding == "utf-16":
494-
raise ProgrammingError(
495-
driver_error="UTF-16 with Byte Order Mark not supported for SQL_WCHAR",
496-
ddbc_error=(
497-
"Cannot use 'utf-16' encoding with SQL_WCHAR due to Byte Order Mark ambiguity. "
498-
"Use 'utf-16le' or 'utf-16be' instead for explicit byte order."
499-
),
500-
)
501-
elif encoding not in UTF16_ENCODINGS:
502-
logger.warning(
503-
"Non-UTF-16 encoding %s attempted with SQL_WCHAR ctype",
504-
sanitize_user_input(encoding),
505-
)
506-
raise ProgrammingError(
507-
driver_error=f"SQL_WCHAR only supports UTF-16 encodings",
508-
ddbc_error=(
509-
f"Cannot use encoding '{encoding}' with SQL_WCHAR. "
510-
f"SQL_WCHAR requires UTF-16 encodings (utf-16le, utf-16be)"
511-
),
512-
)
532+
_validate_utf16_wchar_compatibility(encoding, ctype, "SQL_WCHAR")
513533

514534
# Store the encoding settings (thread-safe with lock)
515535
with self._encoding_lock:
@@ -633,32 +653,9 @@ def setdecoding(
633653
# Normalize encoding to lowercase for consistency
634654
encoding = encoding.lower()
635655

636-
# Reject 'utf-16' with BOM for SQL_WCHAR (ambiguous byte order)
637-
if sqltype == ConstantsDDBC.SQL_WCHAR.value and encoding == "utf-16":
638-
logger.warning(
639-
"utf-16 with BOM rejected for SQL_WCHAR",
640-
)
641-
raise ProgrammingError(
642-
driver_error="UTF-16 with Byte Order Mark not supported for SQL_WCHAR",
643-
ddbc_error=(
644-
"Cannot use 'utf-16' encoding with SQL_WCHAR due to Byte Order Mark ambiguity. "
645-
"Use 'utf-16le' or 'utf-16be' instead for explicit byte order."
646-
),
647-
)
648-
649-
# Validate SQL_WCHAR only supports UTF-16 encodings (SQL_WMETADATA is more flexible)
650-
if sqltype == ConstantsDDBC.SQL_WCHAR.value and encoding not in UTF16_ENCODINGS:
651-
logger.warning(
652-
"Non-UTF-16 encoding %s attempted with SQL_WCHAR sqltype",
653-
sanitize_user_input(encoding),
654-
)
655-
raise ProgrammingError(
656-
driver_error=f"SQL_WCHAR only supports UTF-16 encodings",
657-
ddbc_error=(
658-
f"Cannot use encoding '{encoding}' with SQL_WCHAR. "
659-
f"SQL_WCHAR requires UTF-16 encodings (utf-16le, utf-16be)"
660-
),
661-
)
656+
# Validate SQL_WCHAR encoding compatibility
657+
if sqltype == ConstantsDDBC.SQL_WCHAR.value:
658+
_validate_utf16_wchar_compatibility(encoding, sqltype, "SQL_WCHAR sqltype")
662659

663660
# SQL_WMETADATA can use any valid encoding (UTF-8, UTF-16, etc.)
664661
# No restriction needed here - let users configure as needed
@@ -685,28 +682,9 @@ def setdecoding(
685682
),
686683
)
687684

688-
# Validate that SQL_WCHAR ctype only used with UTF-16 encodings (not utf-16 with BOM)
685+
# Validate SQL_WCHAR ctype encoding compatibility
689686
if ctype == ConstantsDDBC.SQL_WCHAR.value:
690-
if encoding == "utf-16":
691-
raise ProgrammingError(
692-
driver_error="UTF-16 with Byte Order Mark not supported for SQL_WCHAR",
693-
ddbc_error=(
694-
"Cannot use 'utf-16' encoding with SQL_WCHAR due to Byte Order Mark ambiguity. "
695-
"Use 'utf-16le' or 'utf-16be' instead for explicit byte order."
696-
),
697-
)
698-
elif encoding not in UTF16_ENCODINGS:
699-
logger.warning(
700-
"Non-UTF-16 encoding %s attempted with SQL_WCHAR ctype",
701-
sanitize_user_input(encoding),
702-
)
703-
raise ProgrammingError(
704-
driver_error=f"SQL_WCHAR ctype only supports UTF-16 encodings",
705-
ddbc_error=(
706-
f"Cannot use encoding '{encoding}' with SQL_WCHAR ctype. "
707-
f"SQL_WCHAR requires UTF-16 encodings (utf-16le, utf-16be)"
708-
),
709-
)
687+
_validate_utf16_wchar_compatibility(encoding, ctype, "SQL_WCHAR ctype")
710688

711689
# Store the decoding settings for the specified sqltype (thread-safe with lock)
712690
with self._encoding_lock:

mssql_python/cursor.py

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -297,21 +297,33 @@ def _get_encoding_settings(self):
297297
298298
Returns:
299299
dict: A dictionary with 'encoding' and 'ctype' keys, or default settings if not available
300+
301+
Raises:
302+
OperationalError, DatabaseError: If there are unexpected database connection issues
303+
that indicate a broken connection state. These should not be silently ignored
304+
as they can lead to data corruption or inconsistent behavior.
300305
"""
301306
if hasattr(self._connection, "getencoding"):
302307
try:
303308
return self._connection.getencoding()
304309
except (OperationalError, DatabaseError) as db_error:
305-
# Only catch database-related errors, not programming errors
306-
from mssql_python.helpers import log
307-
308-
log(
309-
"warning",
310-
f"Failed to get encoding settings from connection due to database error: {db_error}",
310+
# Log the error for debugging but re-raise for fail-fast behavior
311+
# Silently returning defaults can lead to data corruption and hard-to-debug issues
312+
logger.error(
313+
"Failed to get encoding settings from connection due to database error: %s. "
314+
"This indicates a broken connection state that should not be ignored.",
315+
db_error,
311316
)
312-
return {"encoding": "utf-16le", "ctype": ddbc_sql_const.SQL_WCHAR.value}
317+
# Re-raise to fail fast - users should know their connection is broken
318+
raise
319+
except Exception as unexpected_error:
320+
# Handle other unexpected errors (connection closed, programming errors, etc.)
321+
logger.error("Unexpected error getting encoding settings: %s", unexpected_error)
322+
# Re-raise unexpected errors as well
323+
raise
313324

314325
# Return default encoding settings if getencoding is not available
326+
# This is the only case where defaults are appropriate (method doesn't exist)
315327
return {"encoding": "utf-16le", "ctype": ddbc_sql_const.SQL_WCHAR.value}
316328

317329
def _get_decoding_settings(self, sql_type):
@@ -323,22 +335,35 @@ def _get_decoding_settings(self, sql_type):
323335
324336
Returns:
325337
Dictionary containing the decoding settings.
338+
339+
Raises:
340+
OperationalError, DatabaseError: If there are unexpected database connection issues
341+
that indicate a broken connection state. These should not be silently ignored
342+
as they can lead to data corruption or inconsistent behavior.
326343
"""
327344
try:
328345
# Get decoding settings from connection for this SQL type
329346
return self._connection.getdecoding(sql_type)
330347
except (OperationalError, DatabaseError) as db_error:
331-
# Only handle expected database-related errors
332-
from mssql_python.helpers import log
333-
334-
log(
335-
"warning",
336-
f"Failed to get decoding settings for SQL type {sql_type} due to database error: {db_error}",
348+
# Log the error for debugging but re-raise for fail-fast behavior
349+
# Silently returning defaults can lead to data corruption and hard-to-debug issues
350+
logger.error(
351+
"Failed to get decoding settings for SQL type %s due to database error: %s. "
352+
"This indicates a broken connection state that should not be ignored.",
353+
sql_type,
354+
db_error,
337355
)
338-
if sql_type == ddbc_sql_const.SQL_WCHAR.value:
339-
return {"encoding": "utf-16le", "ctype": ddbc_sql_const.SQL_WCHAR.value}
340-
else:
341-
return {"encoding": "utf-8", "ctype": ddbc_sql_const.SQL_CHAR.value}
356+
# Re-raise to fail fast - users should know their connection is broken
357+
raise
358+
except Exception as unexpected_error:
359+
# Handle other unexpected errors (connection closed, programming errors, etc.)
360+
logger.error(
361+
"Unexpected error getting decoding settings for SQL type %s: %s",
362+
sql_type,
363+
unexpected_error,
364+
)
365+
# Re-raise unexpected errors as well
366+
raise
342367

343368
def _map_sql_type( # pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-locals,too-many-return-statements,too-many-branches
344369
self,

mssql_python/pybind/ddbc_bindings.cpp

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1811,7 +1811,8 @@ SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle,
18111811

18121812
SQLRETURN BindParameterArray(SQLHANDLE hStmt, const py::list& columnwise_params,
18131813
const std::vector<ParamInfo>& paramInfos, size_t paramSetSize,
1814-
std::vector<std::shared_ptr<void>>& paramBuffers) {
1814+
std::vector<std::shared_ptr<void>>& paramBuffers,
1815+
const std::string& charEncoding = "utf-8") {
18151816
LOG("BindParameterArray: Starting column-wise array binding - "
18161817
"param_count=%zu, param_set_size=%zu",
18171818
columnwise_params.size(), paramSetSize);
@@ -2013,8 +2014,8 @@ SQLRETURN BindParameterArray(SQLHANDLE hStmt, const py::list& columnwise_params,
20132014
case SQL_C_CHAR:
20142015
case SQL_C_BINARY: {
20152016
LOG("BindParameterArray: Binding SQL_C_CHAR/BINARY array - "
2016-
"param_index=%d, count=%zu, column_size=%zu",
2017-
paramIndex, paramSetSize, info.columnSize);
2017+
"param_index=%d, count=%zu, column_size=%zu, encoding='%s'",
2018+
paramIndex, paramSetSize, info.columnSize, charEncoding.c_str());
20182019
char* charArray = AllocateParamBufferArray<char>(
20192020
tempBuffers, paramSetSize * (info.columnSize + 1));
20202021
strLenOrIndArray = AllocateParamBufferArray<SQLLEN>(tempBuffers, paramSetSize);
@@ -2024,18 +2025,45 @@ SQLRETURN BindParameterArray(SQLHANDLE hStmt, const py::list& columnwise_params,
20242025
std::memset(charArray + i * (info.columnSize + 1), 0,
20252026
info.columnSize + 1);
20262027
} else {
2027-
std::string str = columnValues[i].cast<std::string>();
2028-
if (str.size() > info.columnSize) {
2028+
std::string encodedStr;
2029+
2030+
if (py::isinstance<py::str>(columnValues[i])) {
2031+
// Use Python's codec system to encode the string with specified
2032+
// encoding (like pyodbc does)
2033+
try {
2034+
py::object encoded =
2035+
columnValues[i].attr("encode")(charEncoding, "strict");
2036+
encodedStr = encoded.cast<std::string>();
2037+
LOG("BindParameterArray: param[%d] row[%zu] SQL_C_CHAR - "
2038+
"Encoded with '%s', "
2039+
"size=%zu bytes",
2040+
paramIndex, i, charEncoding.c_str(), encodedStr.size());
2041+
} catch (const py::error_already_set& e) {
2042+
LOG_ERROR("BindParameterArray: param[%d] row[%zu] SQL_C_CHAR - "
2043+
"Failed to encode "
2044+
"with '%s': %s",
2045+
paramIndex, i, charEncoding.c_str(), e.what());
2046+
throw std::runtime_error(
2047+
std::string("Failed to encode parameter ") +
2048+
std::to_string(paramIndex) + " row " + std::to_string(i) +
2049+
" with encoding '" + charEncoding + "': " + e.what());
2050+
}
2051+
} else {
2052+
// bytes/bytearray - use as-is (already encoded)
2053+
encodedStr = columnValues[i].cast<std::string>();
2054+
}
2055+
2056+
if (encodedStr.size() > info.columnSize) {
20292057
LOG("BindParameterArray: String/binary too "
20302058
"long - param_index=%d, row=%zu, size=%zu, "
20312059
"max=%zu",
2032-
paramIndex, i, str.size(), info.columnSize);
2060+
paramIndex, i, encodedStr.size(), info.columnSize);
20332061
ThrowStdException("Input exceeds column size at index " +
20342062
std::to_string(i));
20352063
}
2036-
std::memcpy(charArray + i * (info.columnSize + 1), str.c_str(),
2037-
str.size());
2038-
strLenOrIndArray[i] = static_cast<SQLLEN>(str.size());
2064+
std::memcpy(charArray + i * (info.columnSize + 1), encodedStr.c_str(),
2065+
encodedStr.size());
2066+
strLenOrIndArray[i] = static_cast<SQLLEN>(encodedStr.size());
20392067
}
20402068
}
20412069
LOG("BindParameterArray: SQL_C_CHAR/BINARY bound - "
@@ -2471,10 +2499,11 @@ SQLRETURN SQLExecuteMany_wrap(const SqlHandlePtr statementHandle, const std::wst
24712499

24722500
if (!hasDAE) {
24732501
LOG("SQLExecuteMany: Using array binding (non-DAE) - calling "
2474-
"BindParameterArray");
2502+
"BindParameterArray with encoding '%s'",
2503+
charEncoding.c_str());
24752504
std::vector<std::shared_ptr<void>> paramBuffers;
2476-
// TODO: Pass charEncoding to BindParameterArray when it's updated to support encoding
2477-
rc = BindParameterArray(hStmt, columnwise_params, paramInfos, paramSetSize, paramBuffers);
2505+
rc = BindParameterArray(hStmt, columnwise_params, paramInfos, paramSetSize, paramBuffers,
2506+
charEncoding);
24782507
if (!SQL_SUCCEEDED(rc)) {
24792508
LOG("SQLExecuteMany: BindParameterArray failed - rc=%d", rc);
24802509
return rc;
@@ -2500,7 +2529,7 @@ SQLRETURN SQLExecuteMany_wrap(const SqlHandlePtr statementHandle, const std::wst
25002529

25012530
std::vector<std::shared_ptr<void>> paramBuffers;
25022531
rc = BindParameters(hStmt, rowParams, const_cast<std::vector<ParamInfo>&>(paramInfos),
2503-
paramBuffers);
2532+
paramBuffers, charEncoding);
25042533
if (!SQL_SUCCEEDED(rc)) {
25052534
LOG("SQLExecuteMany: BindParameters failed for row %zu - rc=%d", rowIndex, rc);
25062535
return rc;

0 commit comments

Comments
 (0)