Skip to content

Commit 8f1618a

Browse files
committed
FEAT: Adding setdecoding()
1 parent 600c113 commit 8f1618a

File tree

4 files changed

+594
-3
lines changed

4 files changed

+594
-3
lines changed

mssql_python/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
# Export specific constants for setencoding()
5151
SQL_CHAR = ConstantsDDBC.SQL_CHAR.value
5252
SQL_WCHAR = ConstantsDDBC.SQL_WCHAR.value
53+
SQL_WMETADATA = -99
5354

5455
# GLOBALS
5556
# Read-Only

mssql_python/connection.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
from mssql_python.auth import process_connection_string
2323
from mssql_python.constants import ConstantsDDBC
2424

25+
# Add SQL_WMETADATA constant for metadata decoding configuration
26+
SQL_WMETADATA = -99 # Special flag for column name decoding
27+
2528
# UTF-16 encoding variants that should use SQL_WCHAR by default
2629
UTF16_ENCODINGS = frozenset([
2730
'utf-16',
@@ -74,6 +77,8 @@ class Connection:
7477
rollback() -> None:
7578
close() -> None:
7679
setencoding(encoding=None, ctype=None) -> None:
80+
setdecoding(sqltype, encoding=None, ctype=None) -> None:
81+
getdecoding(sqltype) -> dict:
7782
"""
7883

7984
def __init__(self, connection_str: str = "", autocommit: bool = False, attrs_before: dict = None, **kwargs) -> None:
@@ -108,6 +113,22 @@ def __init__(self, connection_str: str = "", autocommit: bool = False, attrs_bef
108113
'ctype': ConstantsDDBC.SQL_WCHAR.value
109114
}
110115

116+
# Initialize decoding settings with Python 3 defaults
117+
self._decoding_settings = {
118+
ConstantsDDBC.SQL_CHAR.value: {
119+
'encoding': 'utf-8',
120+
'ctype': ConstantsDDBC.SQL_CHAR.value
121+
},
122+
ConstantsDDBC.SQL_WCHAR.value: {
123+
'encoding': 'utf-16le',
124+
'ctype': ConstantsDDBC.SQL_WCHAR.value
125+
},
126+
SQL_WMETADATA: {
127+
'encoding': 'utf-16le',
128+
'ctype': ConstantsDDBC.SQL_WCHAR.value
129+
}
130+
}
131+
111132
# Check if the connection string contains authentication parameters
112133
# This is important for processing the connection string correctly.
113134
# If authentication is specified, it will be processed to handle
@@ -304,6 +325,147 @@ def getencoding(self):
304325

305326
return self._encoding_settings.copy()
306327

328+
def setdecoding(self, sqltype, encoding=None, ctype=None):
329+
"""
330+
Sets the text decoding used when reading SQL_CHAR and SQL_WCHAR from the database.
331+
332+
This method configures how text data is decoded when reading from the database.
333+
In Python 3, all text is Unicode (str), so this primarily affects the encoding
334+
used to decode bytes from the database.
335+
336+
Args:
337+
sqltype (int): The SQL type being configured: SQL_CHAR, SQL_WCHAR, or SQL_WMETADATA.
338+
SQL_WMETADATA is a special flag for configuring column name decoding.
339+
encoding (str, optional): The Python encoding to use when decoding the data.
340+
If None, uses default encoding based on sqltype.
341+
ctype (int, optional): The C data type to request from SQLGetData:
342+
SQL_CHAR or SQL_WCHAR. If None, uses default based on encoding.
343+
344+
Returns:
345+
None
346+
347+
Raises:
348+
ProgrammingError: If the sqltype, encoding, or ctype is invalid.
349+
InterfaceError: If the connection is closed.
350+
351+
Example:
352+
# Configure SQL_CHAR to use UTF-8 decoding
353+
cnxn.setdecoding(mssql_python.SQL_CHAR, encoding='utf-8')
354+
355+
# Configure column metadata decoding
356+
cnxn.setdecoding(mssql_python.SQL_WMETADATA, encoding='utf-16le')
357+
358+
# Use explicit ctype
359+
cnxn.setdecoding(mssql_python.SQL_WCHAR, encoding='utf-16le', ctype=mssql_python.SQL_WCHAR)
360+
"""
361+
if self._closed:
362+
raise InterfaceError(
363+
driver_error="Connection is closed",
364+
ddbc_error="Connection is closed",
365+
)
366+
367+
# Validate sqltype
368+
valid_sqltypes = [
369+
ConstantsDDBC.SQL_CHAR.value,
370+
ConstantsDDBC.SQL_WCHAR.value,
371+
SQL_WMETADATA
372+
]
373+
if sqltype not in valid_sqltypes:
374+
log('warning', "Invalid sqltype attempted: %s", sanitize_user_input(str(sqltype)))
375+
raise ProgrammingError(
376+
driver_error=f"Invalid sqltype: {sqltype}",
377+
ddbc_error=f"sqltype must be SQL_CHAR ({ConstantsDDBC.SQL_CHAR.value}), SQL_WCHAR ({ConstantsDDBC.SQL_WCHAR.value}), or SQL_WMETADATA ({SQL_WMETADATA})",
378+
)
379+
380+
# Set default encoding based on sqltype if not provided
381+
if encoding is None:
382+
if sqltype == ConstantsDDBC.SQL_CHAR.value:
383+
encoding = 'utf-8' # Default for SQL_CHAR in Python 3
384+
else: # SQL_WCHAR or SQL_WMETADATA
385+
encoding = 'utf-16le' # Default for SQL_WCHAR in Python 3
386+
387+
# Validate encoding using cached validation for better performance
388+
if not _validate_encoding(encoding):
389+
log('warning', "Invalid encoding attempted: %s", sanitize_user_input(str(encoding)))
390+
raise ProgrammingError(
391+
driver_error=f"Unsupported encoding: {encoding}",
392+
ddbc_error=f"The encoding '{encoding}' is not supported by Python",
393+
)
394+
395+
# Normalize encoding to lowercase for consistency
396+
encoding = encoding.lower()
397+
398+
# Set default ctype based on encoding if not provided
399+
if ctype is None:
400+
if encoding in UTF16_ENCODINGS:
401+
ctype = ConstantsDDBC.SQL_WCHAR.value
402+
else:
403+
ctype = ConstantsDDBC.SQL_CHAR.value
404+
405+
# Validate ctype
406+
valid_ctypes = [ConstantsDDBC.SQL_CHAR.value, ConstantsDDBC.SQL_WCHAR.value]
407+
if ctype not in valid_ctypes:
408+
log('warning', "Invalid ctype attempted: %s", sanitize_user_input(str(ctype)))
409+
raise ProgrammingError(
410+
driver_error=f"Invalid ctype: {ctype}",
411+
ddbc_error=f"ctype must be SQL_CHAR ({ConstantsDDBC.SQL_CHAR.value}) or SQL_WCHAR ({ConstantsDDBC.SQL_WCHAR.value})",
412+
)
413+
414+
# Store the decoding settings for the specified sqltype
415+
self._decoding_settings[sqltype] = {
416+
'encoding': encoding,
417+
'ctype': ctype
418+
}
419+
420+
# Log with sanitized values for security
421+
sqltype_name = {
422+
ConstantsDDBC.SQL_CHAR.value: "SQL_CHAR",
423+
ConstantsDDBC.SQL_WCHAR.value: "SQL_WCHAR",
424+
SQL_WMETADATA: "SQL_WMETADATA"
425+
}.get(sqltype, str(sqltype))
426+
427+
log('info', "Text decoding set for %s to %s with ctype %s",
428+
sqltype_name, sanitize_user_input(encoding), sanitize_user_input(str(ctype)))
429+
430+
def getdecoding(self, sqltype):
431+
"""
432+
Gets the current text decoding settings for the specified SQL type.
433+
434+
Args:
435+
sqltype (int): The SQL type to get settings for: SQL_CHAR, SQL_WCHAR, or SQL_WMETADATA.
436+
437+
Returns:
438+
dict: A dictionary containing 'encoding' and 'ctype' keys for the specified sqltype.
439+
440+
Raises:
441+
ProgrammingError: If the sqltype is invalid.
442+
InterfaceError: If the connection is closed.
443+
444+
Example:
445+
settings = cnxn.getdecoding(mssql_python.SQL_CHAR)
446+
print(f"SQL_CHAR encoding: {settings['encoding']}")
447+
print(f"SQL_CHAR ctype: {settings['ctype']}")
448+
"""
449+
if self._closed:
450+
raise InterfaceError(
451+
driver_error="Connection is closed",
452+
ddbc_error="Connection is closed",
453+
)
454+
455+
# Validate sqltype
456+
valid_sqltypes = [
457+
ConstantsDDBC.SQL_CHAR.value,
458+
ConstantsDDBC.SQL_WCHAR.value,
459+
SQL_WMETADATA
460+
]
461+
if sqltype not in valid_sqltypes:
462+
raise ProgrammingError(
463+
driver_error=f"Invalid sqltype: {sqltype}",
464+
ddbc_error=f"sqltype must be SQL_CHAR ({ConstantsDDBC.SQL_CHAR.value}), SQL_WCHAR ({ConstantsDDBC.SQL_WCHAR.value}), or SQL_WMETADATA ({SQL_WMETADATA})",
465+
)
466+
467+
return self._decoding_settings[sqltype].copy()
468+
307469
def cursor(self) -> Cursor:
308470
"""
309471
Return a new Cursor object using the connection.

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def pytest_configure(config):
1818

1919
@pytest.fixture(scope='session')
2020
def conn_str():
21-
conn_str = os.getenv('DB_CONNECTION_STRING')
21+
conn_str = "Server=tcp:DESKTOP-1A982SC,1433;Database=master;TrustServerCertificate=yes;Trusted_Connection=yes;"
2222
return conn_str
2323

2424
@pytest.fixture(scope="module")

0 commit comments

Comments
 (0)