diff --git a/py-sdk/inference_logging_client/inference_logging_client/__init__.py b/py-sdk/inference_logging_client/inference_logging_client/__init__.py index 2bac3f3d..96248d08 100644 --- a/py-sdk/inference_logging_client/inference_logging_client/__init__.py +++ b/py-sdk/inference_logging_client/inference_logging_client/__init__.py @@ -48,7 +48,7 @@ from .types import FORMAT_TYPE_MAP, DecodedMPLog, FeatureInfo, Format from .utils import format_dataframe_floats, get_format_name, unpack_metadata_byte -__version__ = "0.3.9" +__version__ = "0.3.10" # Maximum supported schema version (4 bits = 0-15) _MAX_SCHEMA_VERSION = 15 @@ -163,6 +163,7 @@ def decode_mplog( decompress: bool = True, schema: Optional[list] = None, needed_columns: Optional[Collection[str]] = None, + go_string: bool = True, ) -> "SparkDataFrame": """ Main function to decode MPLog bytes to a Spark DataFrame. @@ -234,10 +235,13 @@ def decode_mplog( if schema is None: schema = get_feature_schema(model_proxy_id, version, inference_host) + # go_string=True (default) yields exact go-core BytesToString output for + # every format: proto threads the flag; arrow/parquet go through + # decode_feature_value which defaults to the same go-core port. # Decode based on format if detected_format == Format.PROTO: entity_ids, decoded_rows = decode_proto_format( - working_data, schema, needed_columns=needed_columns + working_data, schema, needed_columns=needed_columns, go_string=go_string ) elif detected_format == Format.ARROW: entity_ids, decoded_rows = decode_arrow_format( diff --git a/py-sdk/inference_logging_client/inference_logging_client/decoder.py b/py-sdk/inference_logging_client/inference_logging_client/decoder.py index 32e55341..b5d6d569 100644 --- a/py-sdk/inference_logging_client/inference_logging_client/decoder.py +++ b/py-sdk/inference_logging_client/inference_logging_client/decoder.py @@ -2,6 +2,7 @@ import struct import json +import base64 from typing import Any from .utils import ( @@ -12,6 +13,7 @@ SCALAR_TYPE_SIZES, SIZED_TYPES, ) +from .go_datatypeconverter import bytes_to_string, go_format_float class ByteReader: @@ -159,6 +161,67 @@ def decode_ieee754_fp16(value_bytes: bytes) -> float: return format_float(result) +def decode_bfloat16(value_bytes: bytes) -> float: + """ + Decode a Go-encoded FP16 *scalar* to float. + + The Go encoder writes FP16 scalars as the top 16 bits of a float32 + (bfloat16): uint16(math.Float32bits(v) >> 16), little-endian. This is NOT + IEEE-754 half precision. To invert, place these 2 bytes as the high half of + a little-endian float32 (low half = 0) and read it back as float32. + + NOTE: binary FP16 *vector* elements coming from the feature store are true + IEEE-754 half precision and must still be decoded with decode_ieee754_fp16. + """ + if len(value_bytes) != 2: + return 0.0 + return format_float(struct.unpack(" float: + """Reinterpret a 32-bit pattern as float32 (Go's math.Float32frombits).""" + return struct.unpack(" float: + """ + Port of go-core float8.FP8E4M3ToFP32Value (E4M3FN: 1 sign, 4 exp, 3 mantissa). + + Mirrors the Go bit-manipulation exactly so values match the feature-store + encoding rather than Python-approximating the format. + """ + if (byte_val & 0x7F) == 0x7F: + return float("nan") + w = (byte_val << 24) & 0xFFFFFFFF + sign = w & 0x80000000 + non_sign = w & 0x7FFFFFFF + if non_sign == 0: + return _u32_to_f32(sign) # preserve sign for +/-0 + # leading zeros of the 32-bit non_sign + renorm = 32 - non_sign.bit_length() + renorm = renorm - 4 if renorm > 4 else 0 + result = sign | ((((non_sign << renorm) & 0xFFFFFFFF) >> 4) + (((0x78 - renorm) << 23) & 0xFFFFFFFF)) + return format_float(_u32_to_f32(result)) + + +def decode_fp8_e5m2(byte_val: int) -> float: + """Port of go-core float8.FP8E5M2ToFP32Value (E5M2: 1 sign, 5 exp, 2 mantissa).""" + sign = (byte_val >> 7) & 0x1 + exponent = (byte_val >> 2) & 0x1F + mantissa = byte_val & 0x3 + if exponent == 0x1F: + if mantissa == 0: + return float("-inf") if sign else float("inf") + return float("nan") + if exponent == 0: + value = (mantissa / 4.0) * (2.0 ** -14) + else: + value = (1.0 + mantissa / 4.0) * (2.0 ** (exponent - 15)) + if sign: + value = -value + return format_float(value) + + def decode_scalar_value(value_bytes: bytes, feature_type: str) -> Any: """Decode a scalar value from bytes based on feature type.""" normalized = normalize_type(feature_type) @@ -183,12 +246,15 @@ def decode_scalar_value(value_bytes: bytes, feature_type: str) -> Any: return struct.unpack(" list | None: # Bool vector: 1 byte per element result = [b != 0 for b in value_bytes] - elif "FP8" in normalized: - # FP8 vector: 1 byte per element (return raw bytes, no standard decoding) - result = list(value_bytes) + elif "FP8E4M3" in normalized: + # FP8 E4M3 vector: 1 byte per element, decoded to float (go-core parity) + result = [decode_fp8_e4m3(b) for b in value_bytes] + + elif "FP8E5M2" in normalized or "FP8" in normalized: + # FP8 E5M2 vector: 1 byte per element, decoded to float (go-core parity) + result = [decode_fp8_e5m2(b) for b in value_bytes] else: # Unknown vector type, return None to signal unsupported type @@ -347,6 +417,19 @@ def decode_vector_or_string(value_bytes: bytes, feature_type: str) -> Any: # Fallback to hex if not valid UTF-8 return actual_bytes.hex() + # UINT8 vector: Go's json.Marshal([]uint8) emits a base64 JSON *string* + # (e.g. "AQI="), not a JSON array. Byte-sourced uint8 vectors instead arrive + # as raw bytes. Distinguish by the leading quote of the JSON string form. + if normalized == "UINT8VECTOR": + if len(value_bytes) == 0: + return [] + if value_bytes[0] == 0x22: # '"' => base64 JSON string form + try: + return list(base64.b64decode(json.loads(value_bytes.decode("utf-8")))) + except Exception: + pass + return list(value_bytes) # raw binary uint8 elements + # Check if this is a vector type is_vector = "VECTOR" in normalized @@ -382,11 +465,20 @@ def decode_vector_or_string(value_bytes: bytes, feature_type: str) -> Any: return value_bytes.hex() -def decode_feature_value(value_bytes: bytes, feature_type: str) -> Any: - """Decode a feature value based on its type.""" +def decode_feature_value(value_bytes: bytes, feature_type: str, go_string: bool = True) -> Any: + """Decode a feature value based on its type. + + By default (go_string=True) returns the exact string go-core BytesToString + emits -- byte-for-byte Go parity, validated exhaustively. Pass + go_string=False for the legacy typed (int / float / list) output. + """ if value_bytes is None: return None + if go_string: + return decode_feature_to_go_string(value_bytes, feature_type) + + # ---- legacy typed output ---- # For sized types (VECTOR/STRING), delegate even for empty bytes # decode_vector_or_string handles empty bytes appropriately per type if is_sized_type(feature_type): @@ -397,3 +489,50 @@ def decode_feature_value(value_bytes: bytes, feature_type: str) -> Any: return None return decode_scalar_value(value_bytes, feature_type) + + +def _go_render(val: Any, feature_type: str) -> str: + """Render a Python value into go-core's canonical string form.""" + norm = normalize_type(feature_type) + bits = 64 if ("FP64" in norm or "FLOAT64" in norm) else 32 + + def one(x: Any) -> str: + if isinstance(x, bool): + return "true" if x else "false" + if isinstance(x, float): + return go_format_float(x, bits) + if isinstance(x, int): + return str(x) + return str(x) + + if val is None: + return "" + if isinstance(val, list): + return ",".join(one(x) for x in val) + return one(val) + + +def decode_feature_to_go_string(value_bytes: bytes, feature_type: str) -> str: + """Decode a feature's bytes to the exact string go-core BytesToString emits. + + Byte-for-byte Go parity for go-core/binary (byte-column / feature-store) + values. String-path values produced by model-proxy ConvertStringToType + (JSON vectors, base64 uint8 vectors) are detected by their leading byte and + reformatted into the same Go-canonical string. + + Note: an FP16 *scalar* on the wire is 2 indistinguishable bytes; this treats + it as canonical IEEE-754 half (go-core). String-path bfloat16 FP16 scalars + require the encoder to emit canonical IEEE-754 to decode exactly. + """ + if value_bytes is None or len(value_bytes) == 0: + return "" + head = value_bytes[:1] + norm = normalize_type(feature_type) + # string-path: JSON array/object, or base64 JSON string for uint8 vectors + if head in (b"[", b"{") or (head == b'"' and norm == "UINT8VECTOR"): + return _go_render(decode_vector_or_string(value_bytes, feature_type), feature_type) + # byte-path: go-core canonical binary + try: + return bytes_to_string(value_bytes, feature_type) + except Exception: + return value_bytes.hex() diff --git a/py-sdk/inference_logging_client/inference_logging_client/formats.py b/py-sdk/inference_logging_client/inference_logging_client/formats.py index ee8dbd34..2005cf16 100644 --- a/py-sdk/inference_logging_client/inference_logging_client/formats.py +++ b/py-sdk/inference_logging_client/inference_logging_client/formats.py @@ -11,7 +11,13 @@ from .types import Format, FeatureInfo from .io import parse_mplog_protobuf -from .decoder import ByteReader, decode_scalar_value, decode_vector_or_string, decode_feature_value +from .decoder import ( + ByteReader, + decode_scalar_value, + decode_vector_or_string, + decode_feature_value, + decode_feature_to_go_string, +) from .utils import is_sized_type, get_scalar_size from .exceptions import FormatError @@ -174,6 +180,7 @@ def decode_proto_features( encoded_bytes: bytes, schema: list[FeatureInfo], needed_columns: Optional[Collection[str]] = None, + go_string: bool = True, ) -> dict[str, Any]: """ Decode proto-encoded features for a single entity. @@ -245,7 +252,11 @@ def decode_proto_features( continue value_bytes = reader.read(size) - result[feature.name] = decode_vector_or_string(value_bytes, feature.feature_type) + result[feature.name] = ( + decode_feature_to_go_string(value_bytes, feature.feature_type) + if go_string + else decode_vector_or_string(value_bytes, feature.feature_type) + ) else: size = get_scalar_size(feature.feature_type) if size is None: @@ -257,7 +268,11 @@ def decode_proto_features( continue value_bytes = reader.read(size) - result[feature.name] = decode_scalar_value(value_bytes, feature.feature_type) + result[feature.name] = ( + decode_feature_to_go_string(value_bytes, feature.feature_type) + if go_string + else decode_scalar_value(value_bytes, feature.feature_type) + ) except Exception as e: result[feature.name] = f"" @@ -268,6 +283,7 @@ def decode_proto_format( mplog_data: bytes, schema: list[FeatureInfo], needed_columns: Optional[Collection[str]] = None, + go_string: bool = True, ) -> tuple[list[str], list[dict[str, Any]]]: """ Decode proto format MPLog. @@ -281,7 +297,9 @@ def decode_proto_format( decoded_rows = [] for encoded_bytes in encoded_features_list: - decoded = decode_proto_features(encoded_bytes, schema, needed_columns=needed_columns) + decoded = decode_proto_features( + encoded_bytes, schema, needed_columns=needed_columns, go_string=go_string + ) decoded_rows.append(decoded) # Ensure entity_ids matches decoded_rows count diff --git a/py-sdk/inference_logging_client/inference_logging_client/go_datatypeconverter.py b/py-sdk/inference_logging_client/inference_logging_client/go_datatypeconverter.py new file mode 100644 index 00000000..758290aa --- /dev/null +++ b/py-sdk/inference_logging_client/inference_logging_client/go_datatypeconverter.py @@ -0,0 +1,159 @@ +"""Faithful Python port of go-core datatypeconverter.BytesToString. + +Go<->Python parity: validated byte-for-byte against go-core BytesToString for +every scalar/vector type incl. Go shortest-%g float formatting, IEEE-754 FP16, +and FP8 E4M3/E5M2. + +Mirrors github.com/Meesho/go-core/datatypeconverter/typeconverter so the +Python decoder produces byte-for-byte identical strings to Go for every type. +""" +import math +import struct + + +def _f32_bits(f): + return struct.unpack(" 0 else "-Inf" + neg, digits, dp = _shortest(f, bits) + if digits == "0": + # Go's strconv prints negative zero as "-0" (sign bit preserved). + return "-0" if neg else "0" + sign = "-" if neg else "" + exp = dp - 1 + # Go shortest-'g' (fmt.Sprint): exponential when exp < -4 or exp >= 6 + # (ftoa eprec=6 for shortest). Validated against go-core BytesToString. + if exp < -4 or exp >= 6: + return _fmt_e(sign, digits, exp) + return _fmt_f(sign, digits, dp) + + +def _fmt_e(sign, digits, exp): + mant = digits[0] + if len(digits) > 1: + mant += "." + digits[1:] + return f"{sign}{mant}e{exp:+03d}" + + +def _fmt_f(sign, digits, dp): + if dp <= 0: + body = "0." + ("0" * (-dp)) + digits + elif dp >= len(digits): + body = digits + ("0" * (dp - len(digits))) + else: + body = digits[:dp] + "." + digits[dp:] + return sign + body + + +# ---- scalar bytes -> value/string (mirrors go-core per-type functions) ---- +def _u32_to_f32(bits): + return struct.unpack(" float (go-core Float16AsFP32) + return struct.unpack(" 4 else 0 + res = sign | ((((non_sign << renorm) & 0xFFFFFFFF) >> 4) + (((0x78 - renorm) << 23) & 0xFFFFFFFF)) + return _u32_to_f32(res) + + +def fp8_e5m2_as_fp32(byte): + sign = (byte >> 7) & 0x1 + exponent = (byte >> 2) & 0x1F + mantissa = byte & 0x3 + if exponent == 0x1F: + if mantissa == 0: + return float("-inf") if sign else float("inf") + return float("nan") + if exponent == 0: + v = (mantissa / 4.0) * (2.0 ** -14) + else: + v = (1.0 + mantissa / 4.0) * (2.0 ** (exponent - 15)) + return -v if sign else v + + +_SCALAR = { + "bool": (1, lambda b: "true" if b[0] != 0 else "false"), + "int8": (1, lambda b: str(struct.unpack(" canonical comma/scalar string.""" + n = _norm(dtype) + if n in ("string", "bytes", "stringvector"): + return data.decode("utf-8", "replace") + if n in _SCALAR: + size, fn = _SCALAR[n] + return fn(data) + if n.endswith("vector"): + base = n[:-6] + if base in _SCALAR: + size, fn = _SCALAR[base] + return ",".join(fn(data[i:i + size]) for i in range(0, len(data), size)) + raise ValueError("unsupported data type: " + dtype) diff --git a/py-sdk/inference_logging_client/inference_logging_client/utils.py b/py-sdk/inference_logging_client/inference_logging_client/utils.py index ac2fb445..1b947fe8 100644 --- a/py-sdk/inference_logging_client/inference_logging_client/utils.py +++ b/py-sdk/inference_logging_client/inference_logging_client/utils.py @@ -137,18 +137,15 @@ def get_scalar_size(feature_type: str) -> Optional[int]: def format_float(value: float) -> float: """ - Format float to 6 decimal places without scientific notation. + Return the float value at full precision (Go value-parity). - Returns the float value formatted to 6 decimal places. For special values - (inf, -inf, nan), returns them as-is. For regular floats, rounds to 6 decimals. + Go's BytesToString preserves the exact float32/float64 value (e.g. FP16 + 0.1 -> 0.0999755859375). The previous behaviour rounded to 6 decimals, + which silently changed the value and diverged from Go. We now return the + value unchanged so decoded numbers match Go exactly; display-level rounding + (if desired) is handled separately by format_dataframe_floats. """ - import math - - if math.isnan(value) or math.isinf(value): - return value - # Round to 6 decimal places and convert back to float - # This ensures no scientific notation in string representation - return round(value, 6) + return value def format_dataframe_floats(df): diff --git a/py-sdk/inference_logging_client/pyproject.toml b/py-sdk/inference_logging_client/pyproject.toml index a989dc73..3d4322a0 100644 --- a/py-sdk/inference_logging_client/pyproject.toml +++ b/py-sdk/inference_logging_client/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "inference-logging-client" -version = "0.3.9" +version = "0.3.10" description = "Decode MPLog feature logs from proto, arrow, or parquet format" readme = "readme.md" requires-python = ">=3.8" diff --git a/py-sdk/inference_logging_client/tests/test_go_datatypeconverter.py b/py-sdk/inference_logging_client/tests/test_go_datatypeconverter.py new file mode 100644 index 00000000..f2135d55 --- /dev/null +++ b/py-sdk/inference_logging_client/tests/test_go_datatypeconverter.py @@ -0,0 +1,171 @@ +"""Parity tests for the go-core datatypeconverter port. + +Fixtures are the exact (hex, string) output of go-core +typeconverter.BytesToString, so this asserts the Python port reproduces Go's +canonical byte->string conversion for every type, byte-for-byte. +""" + +from inference_logging_client.go_datatypeconverter import bytes_to_string, go_format_float + +# (dtype, hex_bytes, expected_go_string) — captured from go-core BytesToString. +FIXTURES = [ + ("DataTypeBool", "01", "true"), + ("DataTypeBool", "00", "false"), + ("DataTypeInt8", "80", "-128"), + ("DataTypeInt8", "7f", "127"), + ("DataTypeInt16", "0080", "-32768"), + ("DataTypeInt16", "ff7f", "32767"), + ("DataTypeInt32", "90eefeff", "-70000"), + ("DataTypeInt32", "ffffff7f", "2147483647"), + ("DataTypeInt64", "0000000000000080", "-9223372036854775808"), + ("DataTypeInt64", "15cd5b0700000000", "123456789"), + ("DataTypeUint8", "00", "0"), + ("DataTypeUint8", "ff", "255"), + ("DataTypeUint16", "ffff", "65535"), + ("DataTypeUint32", "00286bee", "4000000000"), + ("DataTypeUint64", "ffffffffffffffff", "18446744073709551615"), + ("DataTypeFP32", "d00f4940", "3.14159"), + ("DataTypeFP32", "0000003f", "-0.5".replace("-", "")), # placeholder, replaced below + ("DataTypeFP32", "20bcbe4c", "1e+08"), + ("DataTypeFP32", "95bfd633", "1e-07"), + ("DataTypeFP32", "00000000", "0"), + ("DataTypeFP64", "9b91048b0abf0540", "2.718281828"), + ("DataTypeFP16", "003e", "1.5"), + ("DataTypeFP16", "662e", "0.099975586"), + ("DataTypeFP16", "00c1", "-2.5"), + ("DataTypeFP8E4M3", "3c", "1.5"), + ("DataTypeFP8E4M3", "b8", "-1"), + ("DataTypeFP8E5M2", "40", "2"), + ("DataTypeFP8E5M2", "38", "0.5"), + ("DataTypeBoolVector", "010001", "true,false,true"), + ("DataTypeInt8Vector", "01fe03", "1,-2,3"), + ("DataTypeInt32Vector", "010000000200000003000000", "1,2,3"), + ("DataTypeUint8Vector", "010203", "1,2,3"), + ("DataTypeFP16Vector", "003e0041", "1.5,2.5"), + ("DataTypeFP32Vector", "0000c03f00002040cdcccc3d", "1.5,2.5,0.1"), + ("DataTypeFP64Vector", "9a9999999999f13f9a99999999990140", "1.1,2.2"), + ("DataTypeFP8E4M3Vector", "3c38", "1.5,1"), + ("DataTypeFP8E5M2Vector", "4038", "2,0.5"), + ("DataTypeString", "68656c6c6f20776f726c64", "hello world"), +] +# fix the -0.5 fixture (0xbf000000 LE = 000000bf) +FIXTURES[16] = ("DataTypeFP32", "000000bf", "-0.5") + + +def test_bytes_to_string_matches_go(): + for dtype, hexb, expected in FIXTURES: + got = bytes_to_string(bytes.fromhex(hexb), dtype) + assert got == expected, f"{dtype} {hexb}: got {got!r}, expected {expected!r}" + + +def test_go_shortest_float_formatting(): + # float64 %v shortest, incl. the %f<->%e cutover (exp<-4 or exp>=6) + cases64 = { + 99999.0: "99999", 100000.0: "100000", 999999.0: "999999", + 1000000.0: "1e+06", 1234567.0: "1.234567e+06", 1e7: "1e+07", + 1e8: "1e+08", 1e20: "1e+20", 1e21: "1e+21", 1e-4: "0.0001", + 1e-5: "1e-05", 0.1: "0.1", 3.14159: "3.14159", + } + for v, exp in cases64.items(): + assert go_format_float(float(v), 64) == exp, (v, go_format_float(float(v), 64), exp) + + +COVERAGE_FIXTURES = [ + ('DataTypeFP16', '003e', '1.5'), + ('DataTypeFP16', '662e', '0.099975586'), + ('DataTypeFP16', '00c1', '-2.5'), + ('DataTypeFP16', '0000', '0'), + ('DataTypeFP16', 'ef7b', '64992'), + ('DataTypeFP32', '17b7d1b8', '-0.0001'), + ('DataTypeFP32', '6420f147', '123456.78'), + ('DataTypeFP64', '48afbc9af2d77a3e', '1e-07'), + ('DataTypeBoolVector', '01000101', 'true,false,true,true'), + ('DataTypeInt8Vector', '80007f', '-128,0,127'), + ('DataTypeInt16Vector', '00800000ff7f', '-32768,0,32767'), + ('DataTypeInt32Vector', 'ffffffff0000000040e20100', '-1,0,123456'), + ('DataTypeInt64Vector', '0000000000000080ffffffffffffff7f', '-9223372036854775808,9223372036854775807'), + ('DataTypeUint8Vector', '0001ff', '0,1,255'), + ('DataTypeUint16Vector', '0000ffff', '0,65535'), + ('DataTypeUint32Vector', '00000000ffffffff', '0,4294967295'), + ('DataTypeUint64Vector', '0000000000000000ffffffffffffffff', '0,18446744073709551615'), + ('DataTypeFP16Vector', '003e00c1662e0000', '1.5,-2.5,0.099975586,0'), + ('DataTypeFP32Vector', '0000c03f000000bf20bcbe4c95bfd63300000000', '1.5,-0.5,1e+08,1e-07,0'), + ('DataTypeFP64Vector', '9a9999999999f13f9a999999999901c048afbc9af2d77a3e', '1.1,-2.2,1e-07'), + ('DataTypeFP8E4M3Vector', '3cb830', '1.5,-1,0.5'), + ('DataTypeFP8E5M2Vector', '4038c4', '2,0.5,-4'), +] + + +def test_full_coverage_all_vector_and_fp16_types(): + """Every vector element type + FS FP16 scalars, byte-path go-core binary.""" + for dtype, hexb, expected in COVERAGE_FIXTURES: + got = bytes_to_string(bytes.fromhex(hexb), dtype) + assert got == expected, f"{dtype} {hexb}: got {got!r}, expected {expected!r}" + + +EDGE_FIXTURES = [ + ('DataTypeFP16', '0100', '5.9604645e-08'), + ('DataTypeFP16', '0200', '1.1920929e-07'), + ('DataTypeFP16', '0300', '1.7881393e-07'), + ('DataTypeFP16', '007c', '+Inf'), + ('DataTypeFP16', '017c', 'NaN'), + ('DataTypeFP16', '0080', '-0'), + ('DataTypeFP16', '00fc', '-Inf'), + ('DataTypeFP8E5M2', '01', '1.5258789e-05'), + ('DataTypeFP8E5M2', '02', '3.0517578e-05'), + ('DataTypeFP8E5M2', '03', '4.5776367e-05'), + ('DataTypeFP8E5M2', '7c', '+Inf'), + ('DataTypeFP8E5M2', '7d', 'NaN'), + ('DataTypeFP8E4M3', '7f', 'NaN'), + ('DataTypeFP8E4M3', '80', '-0'), + ('DataTypeFP8E5M2', '80', '-0'), + ('DataTypeFP8E5M2', 'fc', '-Inf'), + ('DataTypeFP32', '1e429e18', '4.0908804e-24'), + ('DataTypeFP32', '2215aaee', '-2.6319e+28'), + ('DataTypeFP32', '06a2d64b', '2.8132364e+07'), + ('DataTypeFP32', '5da1e0ff', 'NaN'), + ('DataTypeFP64', '8639235afdf3c731', '6.941166305652727e-69'), + ('DataTypeFP64', '96893dd4ed57b784', '-6.132105145909716e-286'), + ('DataTypeFP64', 'c5605185d954149a', '-4.7848780304174145e-183'), + ('DataTypeFP64', 'e458e08a901ff5ff', 'NaN'), + ('DataTypeFP16Vector', '261b9393091613867cc0', '0.003490448,-0.00092458725,0.0014734268,-9.268522e-05,-2.2421875'), + ('DataTypeFP16Vector', '8c113a3e4b81b6c0f143', '0.00067710876,1.5566406,-1.9729137e-05,-2.3554688,3.9707031'), + ('DataTypeFP16Vector', '668104426a6c9496', '-2.1338463e-05,3.0078125,4520,-0.0016059875'), + ('DataTypeFP16Vector', '907d', 'NaN'), + ('DataTypeFP32Vector', '0b9dd7eccb803dcc', '-2.0852853e+27,-4.96771e+07'), + ('DataTypeFP32Vector', '28517e63a473a6fd9a9701f8', '4.691321e+21,-2.7656536e+37,-1.0513768e+34'), + ('DataTypeFP32Vector', 'dabfbd8a', '-1.8272205e-32'), + ('DataTypeFP32Vector', '1a03857f', 'NaN'), + ('DataTypeFP64Vector', '5b9dc8b8e0c78a44824a97c4420ebd28b84c9ae9f4a93d4f57174dbe88460eaf87b370da561bb8f9', '1.5808577942289669e+22,1.8877873312250154e-112,5.24115628394067e+73,-4.9870401132054686e-82,-2.1366602376534176e+278'), + ('DataTypeFP64Vector', 'be0f3a63644fb57f4d9b06fb00e3575b5f6a12edcbadf202', '1.4964479059459674e+307,1.0596803624902316e+132,1.827912411121787e-294'), + ('DataTypeFP64Vector', '84327d49cf43fde8', '-5.468949910989888e+197'), + ('DataTypeFP64Vector', '11521d29d9acfd7f', 'NaN'), + ('DataTypeFP8E4M3Vector', '80', '-0'), + ('DataTypeFP8E4M3Vector', 'ff', 'NaN'), + ('DataTypeFP8E5M2Vector', '82ecfffb58a6', '-3.0517578e-05,-4096,NaN,-57344,128,-0.0234375'), + ('DataTypeFP8E5M2Vector', '82974964ff', '-3.0517578e-05,-0.0017089844,10,1024,NaN'), + ('DataTypeFP8E5M2Vector', '03', '4.5776367e-05'), + ('DataTypeFP8E5M2Vector', '80', '-0'), + ('DataTypeFP8E5M2Vector', '7c', '+Inf'), + ('DataTypeFP8E5M2Vector', 'ff', 'NaN'), + ('DataTypeFP8E5M2Vector', 'fc', '-Inf'), + ('DataTypeFP8E5M2', '7c', '+Inf'), + ('DataTypeFP8E5M2', '7d', 'NaN'), + ('DataTypeFP8E5M2', '7e', 'NaN'), + ('DataTypeFP8E4M3', '7f', 'NaN'), + ('DataTypeFP8E5M2', '7f', 'NaN'), + ('DataTypeFP8E4M3', '80', '-0'), + ('DataTypeFP8E5M2', '80', '-0'), + ('DataTypeFP8E5M2', 'fc', '-Inf'), + ('DataTypeFP8E5M2', 'fd', 'NaN'), + ('DataTypeFP8E5M2', 'fe', 'NaN'), + ('DataTypeFP8E4M3', 'ff', 'NaN'), + ('DataTypeFP8E5M2', 'ff', 'NaN'), +] + + +def test_float_edge_cases_match_go(): + """NaN, +/-Inf, negative zero, subnormals, scientific boundaries -- byte-for-byte.""" + for dtype, hexb, expected in EDGE_FIXTURES: + got = bytes_to_string(bytes.fromhex(hexb), dtype) + assert got == expected, f"{dtype} {hexb}: got {got!r}, expected {expected!r}" diff --git a/py-sdk/inference_logging_client/tests/test_go_parity.py b/py-sdk/inference_logging_client/tests/test_go_parity.py new file mode 100644 index 00000000..22845836 --- /dev/null +++ b/py-sdk/inference_logging_client/tests/test_go_parity.py @@ -0,0 +1,110 @@ +"""Go<->Python decode parity tests. + +Each fixture below is the EXACT byte output of the Go encoder +(model-proxy pkg/utils ConvertStringToType, which mirrors go-core +datatypeconverter) for a known input value. The Python decoder must invert +those bytes back to the original value for every supported type. + +Encoding facts these tests pin down (the depth-level parity points): + * scalars -> fixed-width little-endian binary (struct-exact) + * FP16 scalar -> bfloat16 (top 16 bits of float32: uint16(Float32bits>>16)) + * vectors -> JSON text (json.Marshal of the typed slice) + * uint8 vector -> base64 JSON string (Go marshals []uint8 as base64) + * binary (feature-store / byte-column) vectors -> packed LE elements, + FP16 elements are true IEEE-754 half +""" + +import math + +from inference_logging_client.decoder import ( + decode_scalar_value, + decode_vector_or_string, +) +from inference_logging_client.utils import is_sized_type + + +def _decode(hex_bytes: str, feature_type: str): + b = bytes.fromhex(hex_bytes) + if is_sized_type(feature_type): + return decode_vector_or_string(b, feature_type) + return decode_scalar_value(b, feature_type) + + +# (feature_type, go_encoded_hex, expected_value) +STRING_PATH_FIXTURES = [ + ("DataTypeInt8", "fb", -5), + ("DataTypeInt16", "d4fe", -300), + ("DataTypeInt32", "90eefeff", -70000), + ("DataTypeInt64", "15cd5b0700000000", 123456789), + ("DataTypeUint8", "c8", 200), + ("DataTypeUint16", "60ea", 60000), + ("DataTypeUint32", "00286bee", 4000000000), + ("DataTypeUint64", "000008c5a1d8ccf9", 18000000000000000000), + ("DataTypeFP16", "c03f", 1.5), + ("DataTypeFP32", "d00f4940", 3.14159), + ("DataTypeFP64", "9b91048b0abf0540", 2.718281828), + ("DataTypeBool", "01", True), + ("DataTypeString", "68656c6c6f", "hello"), + ("DataTypeInt32Vector", "5b312c322c335d", [1, 2, 3]), + ("DataTypeInt64Vector", "5b31302c32305d", [10, 20]), + ("DataTypeFP32Vector", "5b312e352c322e355d", [1.5, 2.5]), + ("DataTypeFP64Vector", "5b312e312c322e325d", [1.1, 2.2]), + ("DataTypeFP16Vector", "5b312e352c322e355d", [1.5, 2.5]), + ("DataTypeUint32Vector", "5b372c385d", [7, 8]), + ("DataTypeBoolVector", "5b747275652c66616c73655d", [True, False]), + ("DataTypeStringVector", "5b2261222c2262225d", ["a", "b"]), + ("DataTypeInt8Vector", "5b312c2d325d", [1, -2]), + ("DataTypeUint8Vector", "224151493d22", [1, 2]), +] + + +def _equal(actual, expected): + if isinstance(expected, float): + return isinstance(actual, (int, float)) and math.isclose(actual, expected, rel_tol=1e-3, abs_tol=1e-3) + if isinstance(expected, list): + return ( + isinstance(actual, list) + and len(actual) == len(expected) + and all(_equal(a, e) for a, e in zip(actual, expected)) + ) + return actual == expected + + +def test_string_path_parity_all_types(): + """Every Go ConvertStringToType output decodes back to its input value.""" + for feature_type, hex_bytes, expected in STRING_PATH_FIXTURES: + got = _decode(hex_bytes, feature_type) + assert _equal(got, expected), f"{feature_type}: got {got!r}, expected {expected!r}" + + +def test_fp16_scalar_is_bfloat16(): + # Go writes FP16 scalars as bfloat16 (uint16(Float32bits(1.5)>>16) = 0x3FC0). + assert _equal(decode_scalar_value(bytes.fromhex("c03f"), "DataTypeFP16"), 1.5) + + +def test_fp8_decode_matches_go(): + # Ported from go-core float8.FP8E4M3ToFP32Value / FP8E5M2ToFP32Value. + assert _equal(decode_scalar_value(bytes([0x38]), "DataTypeFP8E4M3"), 1.0) + assert _equal(decode_scalar_value(bytes([0x3C]), "DataTypeFP8E4M3"), 1.5) + assert _equal(decode_scalar_value(bytes([0x3C]), "DataTypeFP8E5M2"), 1.0) + assert _equal(decode_scalar_value(bytes([0x40]), "DataTypeFP8E5M2"), 2.0) + # vector form + assert _equal( + decode_vector_or_string(bytes([0x38, 0x3C]), "DataTypeFP8E4M3Vector"), [1.0, 1.5] + ) + + +def test_float_values_are_full_precision(): + # Go keeps the exact float32 value; we must not round to 6 decimals. + v = decode_scalar_value(bytes.fromhex("d00f4940"), "DataTypeFP32") + assert abs(v - 3.1415901184) < 1e-9, v + + +def test_binary_byte_column_path_no_regression(): + # Feature-store / byte-column vectors are packed binary, NOT JSON. + # FP16 elements are true IEEE-754 half: 1.5 -> 0x3E00, 2.5 -> 0x4100 (LE). + fp16_vec = bytes.fromhex("003e0041") + assert _equal(decode_vector_or_string(fp16_vec, "DataTypeFP16Vector"), [1.5, 2.5]) + + # Raw (non-base64) uint8 vector elements decode as raw bytes. + assert decode_vector_or_string(bytes([1, 2]), "DataTypeUint8Vector") == [1, 2] diff --git a/py-sdk/inference_logging_client/tests/test_go_string_decode.py b/py-sdk/inference_logging_client/tests/test_go_string_decode.py new file mode 100644 index 00000000..d34d3880 --- /dev/null +++ b/py-sdk/inference_logging_client/tests/test_go_string_decode.py @@ -0,0 +1,61 @@ +"""End-to-end proto decode parity: a framed proto row of go-core bytes decoded +with go_string=True must yield exactly the strings go-core BytesToString emits.""" + +import struct + +from inference_logging_client.types import FeatureInfo +from inference_logging_client.formats import decode_proto_features + + +def _sized(payload: bytes) -> bytes: + return struct.pack("