|
22 | 22 | from mssql_python.auth import process_connection_string |
23 | 23 | from mssql_python.constants import ConstantsDDBC |
24 | 24 |
|
| 25 | +# Add SQL_WMETADATA constant for metadata decoding configuration |
| 26 | +SQL_WMETADATA = -99 # Special flag for column name decoding |
| 27 | + |
25 | 28 | # UTF-16 encoding variants that should use SQL_WCHAR by default |
26 | 29 | UTF16_ENCODINGS = frozenset([ |
27 | 30 | 'utf-16', |
@@ -74,6 +77,8 @@ class Connection: |
74 | 77 | rollback() -> None: |
75 | 78 | close() -> None: |
76 | 79 | setencoding(encoding=None, ctype=None) -> None: |
| 80 | + setdecoding(sqltype, encoding=None, ctype=None) -> None: |
| 81 | + getdecoding(sqltype) -> dict: |
77 | 82 | """ |
78 | 83 |
|
79 | 84 | def __init__(self, connection_str: str = "", autocommit: bool = False, attrs_before: dict = None, **kwargs) -> None: |
@@ -108,6 +113,22 @@ def __init__(self, connection_str: str = "", autocommit: bool = False, attrs_bef |
108 | 113 | 'ctype': ConstantsDDBC.SQL_WCHAR.value |
109 | 114 | } |
110 | 115 |
|
| 116 | + # Initialize decoding settings with Python 3 defaults |
| 117 | + self._decoding_settings = { |
| 118 | + ConstantsDDBC.SQL_CHAR.value: { |
| 119 | + 'encoding': 'utf-8', |
| 120 | + 'ctype': ConstantsDDBC.SQL_CHAR.value |
| 121 | + }, |
| 122 | + ConstantsDDBC.SQL_WCHAR.value: { |
| 123 | + 'encoding': 'utf-16le', |
| 124 | + 'ctype': ConstantsDDBC.SQL_WCHAR.value |
| 125 | + }, |
| 126 | + SQL_WMETADATA: { |
| 127 | + 'encoding': 'utf-16le', |
| 128 | + 'ctype': ConstantsDDBC.SQL_WCHAR.value |
| 129 | + } |
| 130 | + } |
| 131 | + |
111 | 132 | # Check if the connection string contains authentication parameters |
112 | 133 | # This is important for processing the connection string correctly. |
113 | 134 | # If authentication is specified, it will be processed to handle |
@@ -304,6 +325,147 @@ def getencoding(self): |
304 | 325 |
|
305 | 326 | return self._encoding_settings.copy() |
306 | 327 |
|
| 328 | + def setdecoding(self, sqltype, encoding=None, ctype=None): |
| 329 | + """ |
| 330 | + Sets the text decoding used when reading SQL_CHAR and SQL_WCHAR from the database. |
| 331 | + |
| 332 | + This method configures how text data is decoded when reading from the database. |
| 333 | + In Python 3, all text is Unicode (str), so this primarily affects the encoding |
| 334 | + used to decode bytes from the database. |
| 335 | + |
| 336 | + Args: |
| 337 | + sqltype (int): The SQL type being configured: SQL_CHAR, SQL_WCHAR, or SQL_WMETADATA. |
| 338 | + SQL_WMETADATA is a special flag for configuring column name decoding. |
| 339 | + encoding (str, optional): The Python encoding to use when decoding the data. |
| 340 | + If None, uses default encoding based on sqltype. |
| 341 | + ctype (int, optional): The C data type to request from SQLGetData: |
| 342 | + SQL_CHAR or SQL_WCHAR. If None, uses default based on encoding. |
| 343 | + |
| 344 | + Returns: |
| 345 | + None |
| 346 | + |
| 347 | + Raises: |
| 348 | + ProgrammingError: If the sqltype, encoding, or ctype is invalid. |
| 349 | + InterfaceError: If the connection is closed. |
| 350 | + |
| 351 | + Example: |
| 352 | + # Configure SQL_CHAR to use UTF-8 decoding |
| 353 | + cnxn.setdecoding(mssql_python.SQL_CHAR, encoding='utf-8') |
| 354 | + |
| 355 | + # Configure column metadata decoding |
| 356 | + cnxn.setdecoding(mssql_python.SQL_WMETADATA, encoding='utf-16le') |
| 357 | + |
| 358 | + # Use explicit ctype |
| 359 | + cnxn.setdecoding(mssql_python.SQL_WCHAR, encoding='utf-16le', ctype=mssql_python.SQL_WCHAR) |
| 360 | + """ |
| 361 | + if self._closed: |
| 362 | + raise InterfaceError( |
| 363 | + driver_error="Connection is closed", |
| 364 | + ddbc_error="Connection is closed", |
| 365 | + ) |
| 366 | + |
| 367 | + # Validate sqltype |
| 368 | + valid_sqltypes = [ |
| 369 | + ConstantsDDBC.SQL_CHAR.value, |
| 370 | + ConstantsDDBC.SQL_WCHAR.value, |
| 371 | + SQL_WMETADATA |
| 372 | + ] |
| 373 | + if sqltype not in valid_sqltypes: |
| 374 | + log('warning', "Invalid sqltype attempted: %s", sanitize_user_input(str(sqltype))) |
| 375 | + raise ProgrammingError( |
| 376 | + driver_error=f"Invalid sqltype: {sqltype}", |
| 377 | + ddbc_error=f"sqltype must be SQL_CHAR ({ConstantsDDBC.SQL_CHAR.value}), SQL_WCHAR ({ConstantsDDBC.SQL_WCHAR.value}), or SQL_WMETADATA ({SQL_WMETADATA})", |
| 378 | + ) |
| 379 | + |
| 380 | + # Set default encoding based on sqltype if not provided |
| 381 | + if encoding is None: |
| 382 | + if sqltype == ConstantsDDBC.SQL_CHAR.value: |
| 383 | + encoding = 'utf-8' # Default for SQL_CHAR in Python 3 |
| 384 | + else: # SQL_WCHAR or SQL_WMETADATA |
| 385 | + encoding = 'utf-16le' # Default for SQL_WCHAR in Python 3 |
| 386 | + |
| 387 | + # Validate encoding using cached validation for better performance |
| 388 | + if not _validate_encoding(encoding): |
| 389 | + log('warning', "Invalid encoding attempted: %s", sanitize_user_input(str(encoding))) |
| 390 | + raise ProgrammingError( |
| 391 | + driver_error=f"Unsupported encoding: {encoding}", |
| 392 | + ddbc_error=f"The encoding '{encoding}' is not supported by Python", |
| 393 | + ) |
| 394 | + |
| 395 | + # Normalize encoding to lowercase for consistency |
| 396 | + encoding = encoding.lower() |
| 397 | + |
| 398 | + # Set default ctype based on encoding if not provided |
| 399 | + if ctype is None: |
| 400 | + if encoding in UTF16_ENCODINGS: |
| 401 | + ctype = ConstantsDDBC.SQL_WCHAR.value |
| 402 | + else: |
| 403 | + ctype = ConstantsDDBC.SQL_CHAR.value |
| 404 | + |
| 405 | + # Validate ctype |
| 406 | + valid_ctypes = [ConstantsDDBC.SQL_CHAR.value, ConstantsDDBC.SQL_WCHAR.value] |
| 407 | + if ctype not in valid_ctypes: |
| 408 | + log('warning', "Invalid ctype attempted: %s", sanitize_user_input(str(ctype))) |
| 409 | + raise ProgrammingError( |
| 410 | + driver_error=f"Invalid ctype: {ctype}", |
| 411 | + ddbc_error=f"ctype must be SQL_CHAR ({ConstantsDDBC.SQL_CHAR.value}) or SQL_WCHAR ({ConstantsDDBC.SQL_WCHAR.value})", |
| 412 | + ) |
| 413 | + |
| 414 | + # Store the decoding settings for the specified sqltype |
| 415 | + self._decoding_settings[sqltype] = { |
| 416 | + 'encoding': encoding, |
| 417 | + 'ctype': ctype |
| 418 | + } |
| 419 | + |
| 420 | + # Log with sanitized values for security |
| 421 | + sqltype_name = { |
| 422 | + ConstantsDDBC.SQL_CHAR.value: "SQL_CHAR", |
| 423 | + ConstantsDDBC.SQL_WCHAR.value: "SQL_WCHAR", |
| 424 | + SQL_WMETADATA: "SQL_WMETADATA" |
| 425 | + }.get(sqltype, str(sqltype)) |
| 426 | + |
| 427 | + log('info', "Text decoding set for %s to %s with ctype %s", |
| 428 | + sqltype_name, sanitize_user_input(encoding), sanitize_user_input(str(ctype))) |
| 429 | + |
| 430 | + def getdecoding(self, sqltype): |
| 431 | + """ |
| 432 | + Gets the current text decoding settings for the specified SQL type. |
| 433 | + |
| 434 | + Args: |
| 435 | + sqltype (int): The SQL type to get settings for: SQL_CHAR, SQL_WCHAR, or SQL_WMETADATA. |
| 436 | + |
| 437 | + Returns: |
| 438 | + dict: A dictionary containing 'encoding' and 'ctype' keys for the specified sqltype. |
| 439 | + |
| 440 | + Raises: |
| 441 | + ProgrammingError: If the sqltype is invalid. |
| 442 | + InterfaceError: If the connection is closed. |
| 443 | + |
| 444 | + Example: |
| 445 | + settings = cnxn.getdecoding(mssql_python.SQL_CHAR) |
| 446 | + print(f"SQL_CHAR encoding: {settings['encoding']}") |
| 447 | + print(f"SQL_CHAR ctype: {settings['ctype']}") |
| 448 | + """ |
| 449 | + if self._closed: |
| 450 | + raise InterfaceError( |
| 451 | + driver_error="Connection is closed", |
| 452 | + ddbc_error="Connection is closed", |
| 453 | + ) |
| 454 | + |
| 455 | + # Validate sqltype |
| 456 | + valid_sqltypes = [ |
| 457 | + ConstantsDDBC.SQL_CHAR.value, |
| 458 | + ConstantsDDBC.SQL_WCHAR.value, |
| 459 | + SQL_WMETADATA |
| 460 | + ] |
| 461 | + if sqltype not in valid_sqltypes: |
| 462 | + raise ProgrammingError( |
| 463 | + driver_error=f"Invalid sqltype: {sqltype}", |
| 464 | + ddbc_error=f"sqltype must be SQL_CHAR ({ConstantsDDBC.SQL_CHAR.value}), SQL_WCHAR ({ConstantsDDBC.SQL_WCHAR.value}), or SQL_WMETADATA ({SQL_WMETADATA})", |
| 465 | + ) |
| 466 | + |
| 467 | + return self._decoding_settings[sqltype].copy() |
| 468 | + |
307 | 469 | def cursor(self) -> Cursor: |
308 | 470 | """ |
309 | 471 | Return a new Cursor object using the connection. |
|
0 commit comments