Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 24 additions & 10 deletions mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,8 @@ def _get_numeric_data(self, param):
the numeric data.
"""
decimal_as_tuple = param.as_tuple()
num_digits = len(decimal_as_tuple.digits)
digits_tuple = decimal_as_tuple.digits
num_digits = len(digits_tuple)
exponent = decimal_as_tuple.exponent

# Calculate the SQL precision & scale
Expand All @@ -215,12 +216,11 @@ def _get_numeric_data(self, param):
precision = exponent * -1
scale = exponent * -1

# TODO: Revisit this check, do we want this restriction?
if precision > 15:
if precision > 38:
raise ValueError(
"Precision of the numeric value is too high - "
+ str(param)
+ ". Should be less than or equal to 15"
+ ". Should be less than or equal to 38"
)
Numeric_Data = ddbc_bindings.NumericData
numeric_data = Numeric_Data()
Expand All @@ -229,12 +229,26 @@ def _get_numeric_data(self, param):
numeric_data.sign = 1 if decimal_as_tuple.sign == 0 else 0
# strip decimal point from param & convert the significant digits to integer
# Ex: 12.34 ---> 1234
val = str(param)
if "." in val or "-" in val:
val = val.replace(".", "")
val = val.replace("-", "")
val = int(val)
numeric_data.val = val
int_str = ''.join(str(d) for d in digits_tuple)
if exponent > 0:
int_str = int_str + ('0' * exponent)
elif exponent < 0:
if -exponent > num_digits:
int_str = ('0' * (-exponent - num_digits)) + int_str

if int_str == '':
int_str = '0'

# Convert decimal base-10 string to python int, then to 16 little-endian bytes
big_int = int(int_str)
byte_array = bytearray(16) # SQL_MAX_NUMERIC_LEN
for i in range(16):
byte_array[i] = big_int & 0xFF
big_int >>= 8
if big_int == 0:
break

numeric_data.val = bytes(byte_array)
return numeric_data

def _map_sql_type(self, param, parameters_list, i, min_val=None, max_val=None):
Expand Down
38 changes: 21 additions & 17 deletions mssql_python/pybind/ddbc_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#define SQL_SS_TIMESTAMPOFFSET (-155)
#define SQL_C_SS_TIMESTAMPOFFSET (0x4001)
#define MAX_DIGITS_IN_NUMERIC 64
#define SQL_MAX_NUMERIC_LEN 16

#define STRINGIFY_FOR_CASE(x) \
case x: \
Expand Down Expand Up @@ -56,12 +57,12 @@
SQLCHAR precision;
SQLSCHAR scale;
SQLCHAR sign; // 1=pos, 0=neg
std::uint64_t val; // 123.45 -> 12345
std::string val; // 123.45 -> 12345

NumericData() : precision(0), scale(0), sign(0), val(0) {}
NumericData() : precision(0), scale(0), sign(0), val(SQL_MAX_NUMERIC_LEN, '\0') {}

NumericData(SQLCHAR precision, SQLSCHAR scale, SQLCHAR sign, std::uint64_t value)
: precision(precision), scale(scale), sign(sign), val(value) {}
NumericData(SQLCHAR precision, SQLSCHAR scale, SQLCHAR sign, const std::string& valueBytes)
: precision(precision), scale(scale), sign(sign), val(valueBytes) {}
};

// Struct to hold the DateTimeOffset structure
Expand Down Expand Up @@ -557,9 +558,10 @@
decimalPtr->sign = decimalParam.sign;
// Convert the integer decimalParam.val to char array
std::memset(static_cast<void*>(decimalPtr->val), 0, sizeof(decimalPtr->val));
std::memcpy(static_cast<void*>(decimalPtr->val),
reinterpret_cast<char*>(&decimalParam.val),
sizeof(decimalParam.val));
size_t copyLen = std::min(decimalParam.val.size(), sizeof(decimalPtr->val));
if (copyLen > 0) {
std::memcpy(decimalPtr->val, decimalParam.val.data(), copyLen);
}
dataPtr = static_cast<void*>(decimalPtr);
break;
}
Expand Down Expand Up @@ -2050,15 +2052,17 @@
throw std::runtime_error(MakeParamMismatchErrorStr(info.paramCType, paramIndex));
}
NumericData decimalParam = element.cast<NumericData>();
LOG("Received numeric parameter at [%zu]: precision=%d, scale=%d, sign=%d, val=%lld",
i, decimalParam.precision, decimalParam.scale, decimalParam.sign, decimalParam.val);
numericArray[i].precision = decimalParam.precision;
numericArray[i].scale = decimalParam.scale;
numericArray[i].sign = decimalParam.sign;
std::memset(numericArray[i].val, 0, sizeof(numericArray[i].val));
std::memcpy(numericArray[i].val,
reinterpret_cast<const char*>(&decimalParam.val),
std::min(sizeof(decimalParam.val), sizeof(numericArray[i].val)));
LOG("Received numeric parameter at [%zu]: precision=%d, scale=%d, sign=%d, val=%s",
i, decimalParam.precision, decimalParam.scale, decimalParam.sign, decimalParam.val.c_str());
SQL_NUMERIC_STRUCT& target = numericArray[i];
std::memset(&target, 0, sizeof(SQL_NUMERIC_STRUCT));
target.precision = decimalParam.precision;
target.scale = decimalParam.scale;
target.sign = decimalParam.sign;
size_t copyLen = std::min(decimalParam.val.size(), sizeof(target.val));
if (copyLen > 0) {
std::memcpy(target.val, decimalParam.val.data(), copyLen);
}
strLenOrIndArray[i] = sizeof(SQL_NUMERIC_STRUCT);
}
dataPtr = numericArray;
Expand Down Expand Up @@ -3794,7 +3798,7 @@
// Define numeric data class
py::class_<NumericData>(m, "NumericData")
.def(py::init<>())
.def(py::init<SQLCHAR, SQLSCHAR, SQLCHAR, std::uint64_t>())
.def(py::init<SQLCHAR, SQLSCHAR, SQLCHAR, const std::string&>())
.def_readwrite("precision", &NumericData::precision)
.def_readwrite("scale", &NumericData::scale)
.def_readwrite("sign", &NumericData::sign)
Expand Down
Loading