Skip to content

Commit 600c113

Browse files
committed
Resolving comments
1 parent 751b0b8 commit 600c113

File tree

3 files changed

+292
-93
lines changed

3 files changed

+292
-93
lines changed

mssql_python/connection.py

Lines changed: 62 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,49 @@
1313
import weakref
1414
import re
1515
import codecs
16+
from functools import lru_cache
1617
from mssql_python.cursor import Cursor
17-
from mssql_python.helpers import add_driver_to_connection_str, sanitize_connection_string, log
18+
from mssql_python.helpers import add_driver_to_connection_str, sanitize_connection_string, sanitize_user_input, log
1819
from mssql_python import ddbc_bindings
1920
from mssql_python.pooling import PoolingManager
2021
from mssql_python.exceptions import InterfaceError, ProgrammingError
2122
from mssql_python.auth import process_connection_string
2223
from mssql_python.constants import ConstantsDDBC
2324

25+
# UTF-16 encoding variants that should use SQL_WCHAR by default
26+
UTF16_ENCODINGS = frozenset([
27+
'utf-16',
28+
'utf-16le',
29+
'utf-16be'
30+
])
31+
32+
# Cache for encoding validation to improve performance
33+
# Using a simple dict instead of lru_cache for module-level caching
34+
_ENCODING_VALIDATION_CACHE = {}
35+
_CACHE_MAX_SIZE = 100 # Limit cache size to prevent memory bloat
36+
37+
38+
@lru_cache(maxsize=128)
39+
def _validate_encoding(encoding: str) -> bool:
40+
"""
41+
Cached encoding validation using codecs.lookup().
42+
43+
Args:
44+
encoding (str): The encoding name to validate.
45+
46+
Returns:
47+
bool: True if encoding is valid, False otherwise.
48+
49+
Note:
50+
Uses LRU cache to avoid repeated expensive codecs.lookup() calls.
51+
Cache size is limited to 128 entries which should cover most use cases.
52+
"""
53+
try:
54+
codecs.lookup(encoding)
55+
return True
56+
except LookupError:
57+
return False
58+
2459

2560
class Connection:
2661
"""
@@ -181,7 +216,7 @@ def setencoding(self, encoding=None, ctype=None):
181216
encoding that converts text to bytes. If None, defaults to 'utf-16le'.
182217
ctype (int, optional): The C data type to use when passing data:
183218
SQL_CHAR or SQL_WCHAR. If not provided, SQL_WCHAR is used for
184-
"utf-16", "utf-16le", and "utf-16be". SQL_CHAR is used for all other encodings.
219+
UTF-16 variants (see UTF16_ENCODINGS constant). SQL_CHAR is used for all other encodings.
185220
186221
Returns:
187222
None
@@ -199,33 +234,38 @@ def setencoding(self, encoding=None, ctype=None):
199234
"""
200235
if self._closed:
201236
raise InterfaceError(
202-
driver_error="Cannot set encoding on closed connection",
203-
ddbc_error="Cannot set encoding on closed connection",
237+
driver_error="Connection is closed",
238+
ddbc_error="Connection is closed",
204239
)
205240

206241
# Set default encoding if not provided
207242
if encoding is None:
208243
encoding = 'utf-16le'
209244

210-
# Validate encoding
211-
try:
212-
codecs.lookup(encoding)
213-
except LookupError:
245+
# Validate encoding using cached validation for better performance
246+
if not _validate_encoding(encoding):
247+
# Log the sanitized encoding for security
248+
log('warning', "Invalid encoding attempted: %s", sanitize_user_input(str(encoding)))
214249
raise ProgrammingError(
215-
driver_error=f"Unknown encoding: {encoding}",
250+
driver_error=f"Unsupported encoding: {encoding}",
216251
ddbc_error=f"The encoding '{encoding}' is not supported by Python",
217252
)
218253

254+
# Normalize encoding to lowercase for consistency
255+
encoding = encoding.lower()
256+
219257
# Set default ctype based on encoding if not provided
220258
if ctype is None:
221-
if encoding.lower() in ('utf-16', 'utf-16le', 'utf-16be'):
259+
if encoding in UTF16_ENCODINGS:
222260
ctype = ConstantsDDBC.SQL_WCHAR.value
223261
else:
224262
ctype = ConstantsDDBC.SQL_CHAR.value
225263

226264
# Validate ctype
227265
valid_ctypes = [ConstantsDDBC.SQL_CHAR.value, ConstantsDDBC.SQL_WCHAR.value]
228266
if ctype not in valid_ctypes:
267+
# Log the sanitized ctype for security
268+
log('warning', "Invalid ctype attempted: %s", sanitize_user_input(str(ctype)))
229269
raise ProgrammingError(
230270
driver_error=f"Invalid ctype: {ctype}",
231271
ddbc_error=f"ctype must be SQL_CHAR ({ConstantsDDBC.SQL_CHAR.value}) or SQL_WCHAR ({ConstantsDDBC.SQL_WCHAR.value})",
@@ -237,7 +277,9 @@ def setencoding(self, encoding=None, ctype=None):
237277
'ctype': ctype
238278
}
239279

240-
log('info', "Text encoding set to %s with ctype %s", encoding, ctype)
280+
# Log with sanitized values for security
281+
log('info', "Text encoding set to %s with ctype %s",
282+
sanitize_user_input(encoding), sanitize_user_input(str(ctype)))
241283

242284
def getencoding(self):
243285
"""
@@ -246,11 +288,20 @@ def getencoding(self):
246288
Returns:
247289
dict: A dictionary containing 'encoding' and 'ctype' keys.
248290
291+
Raises:
292+
InterfaceError: If the connection is closed.
293+
249294
Example:
250295
settings = cnxn.getencoding()
251296
print(f"Current encoding: {settings['encoding']}")
252297
print(f"Current ctype: {settings['ctype']}")
253298
"""
299+
if self._closed:
300+
raise InterfaceError(
301+
driver_error="Connection is closed",
302+
ddbc_error="Connection is closed",
303+
)
304+
254305
return self._encoding_settings.copy()
255306

256307
def cursor(self) -> Cursor:

mssql_python/helpers.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,34 @@ def sanitize_connection_string(conn_str: str) -> str:
128128
return re.sub(r"(Pwd\s*=\s*)[^;]*", r"\1***", conn_str, flags=re.IGNORECASE)
129129

130130

131+
def sanitize_user_input(user_input: str, max_length: int = 50) -> str:
132+
"""
133+
Sanitize user input for safe logging by removing control characters,
134+
limiting length, and ensuring safe characters only.
135+
136+
Args:
137+
user_input (str): The user input to sanitize.
138+
max_length (int): Maximum length of the sanitized output.
139+
140+
Returns:
141+
str: The sanitized string safe for logging.
142+
"""
143+
if not isinstance(user_input, str):
144+
return "<non-string>"
145+
146+
# Remove control characters and non-printable characters
147+
import re
148+
# Allow alphanumeric, dash, underscore, and dot (common in encoding names)
149+
sanitized = re.sub(r'[^\w\-\.]', '', user_input)
150+
151+
# Limit length to prevent log flooding
152+
if len(sanitized) > max_length:
153+
sanitized = sanitized[:max_length] + "..."
154+
155+
# Return placeholder if nothing remains after sanitization
156+
return sanitized if sanitized else "<invalid>"
157+
158+
131159
def log(level: str, message: str, *args) -> None:
132160
"""
133161
Universal logging helper that gets a fresh logger instance.

0 commit comments

Comments
 (0)