Skip to content

Commit 3e54515

Browse files
committed
working
1 parent 200c35b commit 3e54515

File tree

3 files changed

+68
-26
lines changed

3 files changed

+68
-26
lines changed

main.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,21 @@
55

66
setup_logging('stdout')
77

8-
conn_str = os.getenv("DB_CONNECTION_STRING")
8+
# conn_str = os.getenv("DB_CONNECTION_STRING")
9+
conn_str = "Server=Saumya;DATABASE=master;UID=sa;PWD=HappyPass1234;Trust_Connection=yes;TrustServerCertificate=yes;"
10+
911
conn = connect(conn_str)
1012

1113
# conn.autocommit = True
1214

1315
cursor = conn.cursor()
14-
cursor.execute("SELECT database_id, name from sys.databases;")
15-
rows = cursor.fetchall()
16-
17-
for row in rows:
18-
print(f"Database ID: {row[0]}, Name: {row[1]}")
19-
20-
cursor.close()
21-
conn.close()
16+
cursor.execute("DROP TABLE IF EXISTS test_decimal")
17+
cursor.execute("CREATE TABLE test_decimal (val DECIMAL(38, 10))")
18+
cursor.execute("INSERT INTO test_decimal (val) VALUES (?)", (decimal.Decimal('1234567890.1234567890'),))
19+
cursor.commit()
20+
print("Inserted value")
21+
cursor.execute("SELECT val FROM test_decimal")
22+
row = cursor.fetchone()
23+
print(f"Fetched value: {row[0]}")
24+
print(f"Type of fetched value: {type(row[0])}")
25+
assert row[0] == decimal.Decimal('1234567890.1234567890')

mssql_python/cursor.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,8 @@ def _get_numeric_data(self, param):
195195
the numeric data.
196196
"""
197197
decimal_as_tuple = param.as_tuple()
198-
num_digits = len(decimal_as_tuple.digits)
198+
digits_tuple = decimal_as_tuple.digits
199+
num_digits = len(digits_tuple)
199200
exponent = decimal_as_tuple.exponent
200201

201202
# Calculate the SQL precision & scale
@@ -216,11 +217,11 @@ def _get_numeric_data(self, param):
216217
scale = exponent * -1
217218

218219
# TODO: Revisit this check, do we want this restriction?
219-
if precision > 15:
220+
if precision > 38:
220221
raise ValueError(
221222
"Precision of the numeric value is too high - "
222223
+ str(param)
223-
+ ". Should be less than or equal to 15"
224+
+ ". Should be less than or equal to 38"
224225
)
225226
Numeric_Data = ddbc_bindings.NumericData
226227
numeric_data = Numeric_Data()
@@ -229,12 +230,32 @@ def _get_numeric_data(self, param):
229230
numeric_data.sign = 1 if decimal_as_tuple.sign == 0 else 0
230231
# strip decimal point from param & convert the significant digits to integer
231232
# Ex: 12.34 ---> 1234
232-
val = str(param)
233-
if "." in val or "-" in val:
234-
val = val.replace(".", "")
235-
val = val.replace("-", "")
236-
val = int(val)
237-
numeric_data.val = val
233+
int_str = ''.join(str(d) for d in digits_tuple)
234+
235+
# Apply exponent to get the unscaled integer string
236+
if exponent > 0:
237+
int_str = int_str + ('0' * exponent)
238+
elif exponent < 0:
239+
# if exponent negative and abs(exponent) > num_digits we padded precision above
240+
# for the integer representation we pad leading zeros
241+
if -exponent > num_digits:
242+
int_str = ('0' * (-exponent - num_digits)) + int_str
243+
244+
# Edge: if int_str becomes empty (Decimal('0')), make "0"
245+
if int_str == '':
246+
int_str = '0'
247+
248+
# Convert decimal base-10 string -> python int, then to 16 little-endian bytes
249+
big_int = int(int_str) # Python big int is arbitrary precision
250+
byte_array = bytearray(16) # SQL_MAX_NUMERIC_LEN
251+
for i in range(16):
252+
byte_array[i] = big_int & 0xFF
253+
big_int >>= 8
254+
if big_int == 0:
255+
break
256+
257+
# numeric_data.val should be bytes (pybindable). Ensure a bytes object of length 16.
258+
numeric_data.val = bytes(byte_array)
238259
return numeric_data
239260

240261
def _map_sql_type(self, param, parameters_list, i, min_val=None, max_val=None):

mssql_python/pybind/ddbc_bindings.cpp

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#define SQL_SS_TIMESTAMPOFFSET (-155)
2222
#define SQL_C_SS_TIMESTAMPOFFSET (0x4001)
2323
#define MAX_DIGITS_IN_NUMERIC 64
24+
#define SQL_MAX_NUMERIC_LEN 16
2425

2526
#define STRINGIFY_FOR_CASE(x) \
2627
case x: \
@@ -56,12 +57,16 @@ struct NumericData {
5657
SQLCHAR precision;
5758
SQLSCHAR scale;
5859
SQLCHAR sign; // 1=pos, 0=neg
59-
std::uint64_t val; // 123.45 -> 12345
60+
std::string val; // 123.45 -> 12345
6061

61-
NumericData() : precision(0), scale(0), sign(0), val(0) {}
62+
NumericData() : precision(0), scale(0), sign(0), val(SQL_MAX_NUMERIC_LEN, '\0') {}
6263

63-
NumericData(SQLCHAR precision, SQLSCHAR scale, SQLCHAR sign, std::uint64_t value)
64-
: precision(precision), scale(scale), sign(sign), val(value) {}
64+
NumericData(SQLCHAR precision, SQLSCHAR scale, SQLCHAR sign, const std::string& valueBytes)
65+
: precision(precision), scale(scale), sign(sign) {
66+
val = valueBytes;
67+
// Ensure val is always exactly SQL_MAX_NUMERIC_LEN bytes
68+
val.resize(SQL_MAX_NUMERIC_LEN, '\0');
69+
}
6570
};
6671

6772
// Struct to hold the DateTimeOffset structure
@@ -557,9 +562,21 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params,
557562
decimalPtr->sign = decimalParam.sign;
558563
// Convert the integer decimalParam.val to char array
559564
std::memset(static_cast<void*>(decimalPtr->val), 0, sizeof(decimalPtr->val));
560-
std::memcpy(static_cast<void*>(decimalPtr->val),
561-
reinterpret_cast<char*>(&decimalParam.val),
562-
sizeof(decimalParam.val));
565+
// std::memcpy(static_cast<void*>(decimalPtr->val),
566+
// reinterpret_cast<char*>(&decimalParam.val),
567+
// sizeof(decimalParam.val));
568+
size_t src_len = decimalParam.val.size();
569+
if (src_len > sizeof(decimalPtr->val)) {
570+
// Defensive: should never happen if Python side ensures 16 bytes; but guard anyway
571+
ThrowStdException("Numeric value byte buffer too large for SQL_NUMERIC_STRUCT (paramIndex " + std::to_string(paramIndex) + ")");
572+
}
573+
if (src_len > 0) {
574+
std::memcpy(static_cast<void*>(decimalPtr->val),
575+
static_cast<const void*>(decimalParam.val.data()),
576+
src_len);
577+
}
578+
//print the data received from python
579+
LOG("Numeric parameter val bytes: {}", decimalPtr->val);
563580
dataPtr = static_cast<void*>(decimalPtr);
564581
break;
565582
}
@@ -3794,7 +3811,7 @@ PYBIND11_MODULE(ddbc_bindings, m) {
37943811
// Define numeric data class
37953812
py::class_<NumericData>(m, "NumericData")
37963813
.def(py::init<>())
3797-
.def(py::init<SQLCHAR, SQLSCHAR, SQLCHAR, std::uint64_t>())
3814+
.def(py::init<SQLCHAR, SQLSCHAR, SQLCHAR, const std::string&>())
37983815
.def_readwrite("precision", &NumericData::precision)
37993816
.def_readwrite("scale", &NumericData::scale)
38003817
.def_readwrite("sign", &NumericData::sign)

0 commit comments

Comments
 (0)