Skip to content

Commit 909a653

Browse files
committed
fixed test crash
1 parent fa7d718 commit 909a653

File tree

2 files changed

+121
-24
lines changed

2 files changed

+121
-24
lines changed

mssql_python/cursor.py

Lines changed: 99 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,60 @@ def _parse_time(self, param):
185185
continue
186186
return None
187187

188+
# def _get_numeric_data(self, param):
189+
# """
190+
# Get the data for a numeric parameter.
191+
192+
# Args:
193+
# param: The numeric parameter.
194+
195+
# Returns:
196+
# numeric_data: A NumericData struct containing
197+
# the numeric data.
198+
# """
199+
# decimal_as_tuple = param.as_tuple()
200+
# num_digits = len(decimal_as_tuple.digits)
201+
# exponent = decimal_as_tuple.exponent
202+
203+
# # Calculate the SQL precision & scale
204+
# # precision = no. of significant digits
205+
# # scale = no. digits after decimal point
206+
# if exponent >= 0:
207+
# # digits=314, exp=2 ---> '31400' --> precision=5, scale=0
208+
# precision = num_digits + exponent
209+
# scale = 0
210+
# elif (-1 * exponent) <= num_digits:
211+
# # digits=3140, exp=-3 ---> '3.140' --> precision=4, scale=3
212+
# precision = num_digits
213+
# scale = exponent * -1
214+
# else:
215+
# # digits=3140, exp=-5 ---> '0.03140' --> precision=5, scale=5
216+
# # TODO: double check the precision calculation here with SQL documentation
217+
# precision = exponent * -1
218+
# scale = exponent * -1
219+
220+
# # TODO: Revisit this check, do we want this restriction?
221+
# if precision > 15:
222+
# raise ValueError(
223+
# "Precision of the numeric value is too high - "
224+
# + str(param)
225+
# + ". Should be less than or equal to 15"
226+
# )
227+
# Numeric_Data = ddbc_bindings.NumericData
228+
# numeric_data = Numeric_Data()
229+
# numeric_data.scale = scale
230+
# numeric_data.precision = precision
231+
# numeric_data.sign = 1 if decimal_as_tuple.sign == 0 else 0
232+
# # strip decimal point from param & convert the significant digits to integer
233+
# # Ex: 12.34 ---> 1234
234+
# val = str(param)
235+
# if "." in val or "-" in val:
236+
# val = val.replace(".", "")
237+
# val = val.replace("-", "")
238+
# val = int(val)
239+
# numeric_data.val = val
240+
# return numeric_data
241+
188242
def _get_numeric_data(self, param):
189243
"""
190244
Get the data for a numeric parameter.
@@ -197,7 +251,8 @@ def _get_numeric_data(self, param):
197251
the numeric data.
198252
"""
199253
decimal_as_tuple = param.as_tuple()
200-
num_digits = len(decimal_as_tuple.digits)
254+
digits_tuple = decimal_as_tuple.digits
255+
num_digits = len(digits_tuple)
201256
exponent = decimal_as_tuple.exponent
202257

203258
# Calculate the SQL precision & scale
@@ -217,12 +272,11 @@ def _get_numeric_data(self, param):
217272
precision = exponent * -1
218273
scale = exponent * -1
219274

220-
# TODO: Revisit this check, do we want this restriction?
221-
if precision > 15:
275+
if precision > 38:
222276
raise ValueError(
223277
"Precision of the numeric value is too high - "
224278
+ str(param)
225-
+ ". Should be less than or equal to 15"
279+
+ ". Should be less than or equal to 38"
226280
)
227281
Numeric_Data = ddbc_bindings.NumericData
228282
numeric_data = Numeric_Data()
@@ -231,12 +285,26 @@ def _get_numeric_data(self, param):
231285
numeric_data.sign = 1 if decimal_as_tuple.sign == 0 else 0
232286
# strip decimal point from param & convert the significant digits to integer
233287
# Ex: 12.34 ---> 1234
234-
val = str(param)
235-
if "." in val or "-" in val:
236-
val = val.replace(".", "")
237-
val = val.replace("-", "")
238-
val = int(val)
239-
numeric_data.val = val
288+
int_str = ''.join(str(d) for d in digits_tuple)
289+
if exponent > 0:
290+
int_str = int_str + ('0' * exponent)
291+
elif exponent < 0:
292+
if -exponent > num_digits:
293+
int_str = ('0' * (-exponent - num_digits)) + int_str
294+
295+
if int_str == '':
296+
int_str = '0'
297+
298+
# Convert decimal base-10 string to python int, then to 16 little-endian bytes
299+
big_int = int(int_str)
300+
byte_array = bytearray(16) # SQL_MAX_NUMERIC_LEN
301+
for i in range(16):
302+
byte_array[i] = big_int & 0xFF
303+
big_int >>= 8
304+
if big_int == 0:
305+
break
306+
307+
numeric_data.val = bytes(byte_array)
240308
return numeric_data
241309

242310
def _map_sql_type(self, param, parameters_list, i, min_val=None, max_val=None):
@@ -309,7 +377,27 @@ def _map_sql_type(self, param, parameters_list, i, min_val=None, max_val=None):
309377
)
310378

311379
if isinstance(param, decimal.Decimal):
312-
# Detect MONEY / SMALLMONEY range
380+
# First check precision limit for all decimal values
381+
decimal_as_tuple = param.as_tuple()
382+
digits_tuple = decimal_as_tuple.digits
383+
num_digits = len(digits_tuple)
384+
exponent = decimal_as_tuple.exponent
385+
386+
# Calculate the SQL precision (same logic as _get_numeric_data)
387+
if exponent >= 0:
388+
precision = num_digits + exponent
389+
elif (-1 * exponent) <= num_digits:
390+
precision = num_digits
391+
else:
392+
precision = exponent * -1
393+
394+
if precision > 38:
395+
raise ValueError(
396+
f"Precision of the numeric value is too high. "
397+
f"The maximum precision supported by SQL Server is 38, but got {precision}."
398+
)
399+
400+
# Detect MONEY / SMALLMONEY range
313401
if SMALLMONEY_MIN <= param <= SMALLMONEY_MAX:
314402
# smallmoney
315403
parameters_list[i] = str(param)

mssql_python/pybind/ddbc_bindings.cpp

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#define SQL_SS_TIMESTAMPOFFSET (-155)
2222
#define SQL_C_SS_TIMESTAMPOFFSET (0x4001)
2323
#define MAX_DIGITS_IN_NUMERIC 64
24+
#define SQL_MAX_NUMERIC_LEN 16
2425

2526
#define STRINGIFY_FOR_CASE(x) \
2627
case x: \
@@ -118,12 +119,18 @@ struct NumericData {
118119
SQLCHAR precision;
119120
SQLSCHAR scale;
120121
SQLCHAR sign; // 1=pos, 0=neg
121-
std::uint64_t val; // 123.45 -> 12345
122+
std::string val; // 123.45 -> 12345
122123

123-
NumericData() : precision(0), scale(0), sign(0), val(0) {}
124+
NumericData() : precision(0), scale(0), sign(0), val(SQL_MAX_NUMERIC_LEN, '\0') {}
124125

125-
NumericData(SQLCHAR precision, SQLSCHAR scale, SQLCHAR sign, std::uint64_t value)
126-
: precision(precision), scale(scale), sign(sign), val(value) {}
126+
NumericData(SQLCHAR precision, SQLSCHAR scale, SQLCHAR sign, const std::string& valueBytes)
127+
: precision(precision), scale(scale), sign(sign), val(SQL_MAX_NUMERIC_LEN, '\0') {
128+
if (valueBytes.size() > SQL_MAX_NUMERIC_LEN) {
129+
throw std::runtime_error("NumericData valueBytes size exceeds SQL_MAX_NUMERIC_LEN (16)");
130+
}
131+
// Copy binary data to buffer, remaining bytes stay zero-padded
132+
std::memcpy(&val[0], valueBytes.data(), valueBytes.size());
133+
}
127134
};
128135

129136
// Struct to hold the DateTimeOffset structure
@@ -619,9 +626,10 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params,
619626
decimalPtr->sign = decimalParam.sign;
620627
// Convert the integer decimalParam.val to char array
621628
std::memset(static_cast<void*>(decimalPtr->val), 0, sizeof(decimalPtr->val));
622-
std::memcpy(static_cast<void*>(decimalPtr->val),
623-
reinterpret_cast<char*>(&decimalParam.val),
624-
sizeof(decimalParam.val));
629+
size_t copyLen = std::min(decimalParam.val.size(), sizeof(decimalPtr->val));
630+
if (copyLen > 0) {
631+
std::memcpy(decimalPtr->val, decimalParam.val.data(), copyLen);
632+
}
625633
dataPtr = static_cast<void*>(decimalPtr);
626634
break;
627635
}
@@ -2112,15 +2120,16 @@ SQLRETURN BindParameterArray(SQLHANDLE hStmt,
21122120
throw std::runtime_error(MakeParamMismatchErrorStr(info.paramCType, paramIndex));
21132121
}
21142122
NumericData decimalParam = element.cast<NumericData>();
2115-
LOG("Received numeric parameter at [%zu]: precision=%d, scale=%d, sign=%d, val=%lld",
2116-
i, decimalParam.precision, decimalParam.scale, decimalParam.sign, decimalParam.val);
2123+
LOG("Received numeric parameter at [%zu]: precision=%d, scale=%d, sign=%d, val=%s",
2124+
i, decimalParam.precision, decimalParam.scale, decimalParam.sign, decimalParam.val.c_str());
21172125
numericArray[i].precision = decimalParam.precision;
21182126
numericArray[i].scale = decimalParam.scale;
21192127
numericArray[i].sign = decimalParam.sign;
21202128
std::memset(numericArray[i].val, 0, sizeof(numericArray[i].val));
2121-
std::memcpy(numericArray[i].val,
2122-
reinterpret_cast<const char*>(&decimalParam.val),
2123-
std::min(sizeof(decimalParam.val), sizeof(numericArray[i].val)));
2129+
size_t copyLen = std::min(decimalParam.val.size(), sizeof(numericArray[i].val));
2130+
if (copyLen > 0) {
2131+
std::memcpy(numericArray[i].val, decimalParam.val.data(), copyLen);
2132+
}
21242133
strLenOrIndArray[i] = sizeof(SQL_NUMERIC_STRUCT);
21252134
}
21262135
dataPtr = numericArray;
@@ -3869,7 +3878,7 @@ PYBIND11_MODULE(ddbc_bindings, m) {
38693878
// Define numeric data class
38703879
py::class_<NumericData>(m, "NumericData")
38713880
.def(py::init<>())
3872-
.def(py::init<SQLCHAR, SQLSCHAR, SQLCHAR, std::uint64_t>())
3881+
.def(py::init<SQLCHAR, SQLSCHAR, SQLCHAR, const std::string&>())
38733882
.def_readwrite("precision", &NumericData::precision)
38743883
.def_readwrite("scale", &NumericData::scale)
38753884
.def_readwrite("sign", &NumericData::sign)

0 commit comments

Comments
 (0)