Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from .types import FORMAT_TYPE_MAP, DecodedMPLog, FeatureInfo, Format
from .utils import format_dataframe_floats, get_format_name, unpack_metadata_byte

__version__ = "0.3.9"
__version__ = "0.3.10"

# Maximum supported schema version (4 bits = 0-15)
_MAX_SCHEMA_VERSION = 15
Expand Down Expand Up @@ -163,6 +163,7 @@ def decode_mplog(
decompress: bool = True,
schema: Optional[list] = None,
needed_columns: Optional[Collection[str]] = None,
go_string: bool = True,
) -> "SparkDataFrame":
"""
Main function to decode MPLog bytes to a Spark DataFrame.
Expand Down Expand Up @@ -234,10 +235,13 @@ def decode_mplog(
if schema is None:
schema = get_feature_schema(model_proxy_id, version, inference_host)

# go_string=True (default) yields exact go-core BytesToString output for
# every format: proto threads the flag; arrow/parquet go through
# decode_feature_value which defaults to the same go-core port.
# Decode based on format
if detected_format == Format.PROTO:
entity_ids, decoded_rows = decode_proto_format(
working_data, schema, needed_columns=needed_columns
working_data, schema, needed_columns=needed_columns, go_string=go_string
)
elif detected_format == Format.ARROW:
entity_ids, decoded_rows = decode_arrow_format(
Expand Down
157 changes: 148 additions & 9 deletions py-sdk/inference_logging_client/inference_logging_client/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import struct
import json
import base64
from typing import Any

from .utils import (
Expand All @@ -12,6 +13,7 @@
SCALAR_TYPE_SIZES,
SIZED_TYPES,
)
from .go_datatypeconverter import bytes_to_string, go_format_float


class ByteReader:
Expand Down Expand Up @@ -159,6 +161,67 @@ def decode_ieee754_fp16(value_bytes: bytes) -> float:
return format_float(result)


def decode_bfloat16(value_bytes: bytes) -> float:
"""
Decode a Go-encoded FP16 *scalar* to float.

The Go encoder writes FP16 scalars as the top 16 bits of a float32
(bfloat16): uint16(math.Float32bits(v) >> 16), little-endian. This is NOT
IEEE-754 half precision. To invert, place these 2 bytes as the high half of
a little-endian float32 (low half = 0) and read it back as float32.

NOTE: binary FP16 *vector* elements coming from the feature store are true
IEEE-754 half precision and must still be decoded with decode_ieee754_fp16.
"""
if len(value_bytes) != 2:
return 0.0
return format_float(struct.unpack("<f", b"\x00\x00" + value_bytes)[0])


def _u32_to_f32(bits: int) -> float:
"""Reinterpret a 32-bit pattern as float32 (Go's math.Float32frombits)."""
return struct.unpack("<f", struct.pack("<I", bits & 0xFFFFFFFF))[0]


def decode_fp8_e4m3(byte_val: int) -> float:
"""
Port of go-core float8.FP8E4M3ToFP32Value (E4M3FN: 1 sign, 4 exp, 3 mantissa).

Mirrors the Go bit-manipulation exactly so values match the feature-store
encoding rather than Python-approximating the format.
"""
if (byte_val & 0x7F) == 0x7F:
return float("nan")
w = (byte_val << 24) & 0xFFFFFFFF
sign = w & 0x80000000
non_sign = w & 0x7FFFFFFF
if non_sign == 0:
return _u32_to_f32(sign) # preserve sign for +/-0
# leading zeros of the 32-bit non_sign
renorm = 32 - non_sign.bit_length()
renorm = renorm - 4 if renorm > 4 else 0
result = sign | ((((non_sign << renorm) & 0xFFFFFFFF) >> 4) + (((0x78 - renorm) << 23) & 0xFFFFFFFF))
return format_float(_u32_to_f32(result))


def decode_fp8_e5m2(byte_val: int) -> float:
"""Port of go-core float8.FP8E5M2ToFP32Value (E5M2: 1 sign, 5 exp, 2 mantissa)."""
sign = (byte_val >> 7) & 0x1
exponent = (byte_val >> 2) & 0x1F
mantissa = byte_val & 0x3
if exponent == 0x1F:
if mantissa == 0:
return float("-inf") if sign else float("inf")
return float("nan")
if exponent == 0:
value = (mantissa / 4.0) * (2.0 ** -14)
else:
value = (1.0 + mantissa / 4.0) * (2.0 ** (exponent - 15))
if sign:
value = -value
return format_float(value)


def decode_scalar_value(value_bytes: bytes, feature_type: str) -> Any:
"""Decode a scalar value from bytes based on feature type."""
normalized = normalize_type(feature_type)
Expand All @@ -183,12 +246,15 @@ def decode_scalar_value(value_bytes: bytes, feature_type: str) -> Any:
return struct.unpack("<I", value_bytes)[0]
elif normalized in {"UINT64", "U64"}:
return struct.unpack("<Q", value_bytes)[0]
elif normalized in {"FP8E5M2", "FP8E4M3"}:
return value_bytes[0] # Return raw byte
elif normalized in {"FP8E5M2"}:
return decode_fp8_e5m2(value_bytes[0])
elif normalized in {"FP8E4M3"}:
return decode_fp8_e4m3(value_bytes[0])
elif normalized in {"FP16", "FLOAT16", "F16"}:
# IEEE 754 half-precision (FP16)
# Go encodes FP16 scalars as bfloat16 (top 16 bits of float32),
# not IEEE-754 half precision. See decode_bfloat16.
if len(value_bytes) == 2:
return decode_ieee754_fp16(value_bytes)
return decode_bfloat16(value_bytes)
return None
elif normalized in {"FP32", "FLOAT32", "F32", "FLOAT"}:
result = struct.unpack("<f", value_bytes)[0]
Expand Down Expand Up @@ -299,9 +365,13 @@ def decode_binary_vector(value_bytes: bytes, feature_type: str) -> list | None:
# Bool vector: 1 byte per element
result = [b != 0 for b in value_bytes]

elif "FP8" in normalized:
# FP8 vector: 1 byte per element (return raw bytes, no standard decoding)
result = list(value_bytes)
elif "FP8E4M3" in normalized:
# FP8 E4M3 vector: 1 byte per element, decoded to float (go-core parity)
result = [decode_fp8_e4m3(b) for b in value_bytes]

elif "FP8E5M2" in normalized or "FP8" in normalized:
# FP8 E5M2 vector: 1 byte per element, decoded to float (go-core parity)
result = [decode_fp8_e5m2(b) for b in value_bytes]

else:
# Unknown vector type, return None to signal unsupported type
Expand Down Expand Up @@ -347,6 +417,19 @@ def decode_vector_or_string(value_bytes: bytes, feature_type: str) -> Any:
# Fallback to hex if not valid UTF-8
return actual_bytes.hex()

# UINT8 vector: Go's json.Marshal([]uint8) emits a base64 JSON *string*
# (e.g. "AQI="), not a JSON array. Byte-sourced uint8 vectors instead arrive
# as raw bytes. Distinguish by the leading quote of the JSON string form.
if normalized == "UINT8VECTOR":
if len(value_bytes) == 0:
return []
if value_bytes[0] == 0x22: # '"' => base64 JSON string form
try:
return list(base64.b64decode(json.loads(value_bytes.decode("utf-8"))))
except Exception:
pass
return list(value_bytes) # raw binary uint8 elements

# Check if this is a vector type
is_vector = "VECTOR" in normalized

Expand Down Expand Up @@ -382,11 +465,20 @@ def decode_vector_or_string(value_bytes: bytes, feature_type: str) -> Any:
return value_bytes.hex()


def decode_feature_value(value_bytes: bytes, feature_type: str) -> Any:
"""Decode a feature value based on its type."""
def decode_feature_value(value_bytes: bytes, feature_type: str, go_string: bool = True) -> Any:
"""Decode a feature value based on its type.

By default (go_string=True) returns the exact string go-core BytesToString
emits -- byte-for-byte Go parity, validated exhaustively. Pass
go_string=False for the legacy typed (int / float / list) output.
"""
if value_bytes is None:
return None

if go_string:
return decode_feature_to_go_string(value_bytes, feature_type)

# ---- legacy typed output ----
# For sized types (VECTOR/STRING), delegate even for empty bytes
# decode_vector_or_string handles empty bytes appropriately per type
if is_sized_type(feature_type):
Expand All @@ -397,3 +489,50 @@ def decode_feature_value(value_bytes: bytes, feature_type: str) -> Any:
return None

return decode_scalar_value(value_bytes, feature_type)


def _go_render(val: Any, feature_type: str) -> str:
"""Render a Python value into go-core's canonical string form."""
norm = normalize_type(feature_type)
bits = 64 if ("FP64" in norm or "FLOAT64" in norm) else 32

def one(x: Any) -> str:
if isinstance(x, bool):
return "true" if x else "false"
if isinstance(x, float):
return go_format_float(x, bits)
if isinstance(x, int):
return str(x)
return str(x)

if val is None:
return ""
if isinstance(val, list):
return ",".join(one(x) for x in val)
return one(val)


def decode_feature_to_go_string(value_bytes: bytes, feature_type: str) -> str:
"""Decode a feature's bytes to the exact string go-core BytesToString emits.

Byte-for-byte Go parity for go-core/binary (byte-column / feature-store)
values. String-path values produced by model-proxy ConvertStringToType
(JSON vectors, base64 uint8 vectors) are detected by their leading byte and
reformatted into the same Go-canonical string.

Note: an FP16 *scalar* on the wire is 2 indistinguishable bytes; this treats
it as canonical IEEE-754 half (go-core). String-path bfloat16 FP16 scalars
require the encoder to emit canonical IEEE-754 to decode exactly.
"""
if value_bytes is None or len(value_bytes) == 0:
return ""
head = value_bytes[:1]
norm = normalize_type(feature_type)
# string-path: JSON array/object, or base64 JSON string for uint8 vectors
if head in (b"[", b"{") or (head == b'"' and norm == "UINT8VECTOR"):
return _go_render(decode_vector_or_string(value_bytes, feature_type), feature_type)
# byte-path: go-core canonical binary
try:
return bytes_to_string(value_bytes, feature_type)
except Exception:
return value_bytes.hex()
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@

from .types import Format, FeatureInfo
from .io import parse_mplog_protobuf
from .decoder import ByteReader, decode_scalar_value, decode_vector_or_string, decode_feature_value
from .decoder import (
ByteReader,
decode_scalar_value,
decode_vector_or_string,
decode_feature_value,
decode_feature_to_go_string,
)
from .utils import is_sized_type, get_scalar_size
from .exceptions import FormatError

Expand Down Expand Up @@ -174,6 +180,7 @@ def decode_proto_features(
encoded_bytes: bytes,
schema: list[FeatureInfo],
needed_columns: Optional[Collection[str]] = None,
go_string: bool = True,
) -> dict[str, Any]:
"""
Decode proto-encoded features for a single entity.
Expand Down Expand Up @@ -245,7 +252,11 @@ def decode_proto_features(
continue

value_bytes = reader.read(size)
result[feature.name] = decode_vector_or_string(value_bytes, feature.feature_type)
result[feature.name] = (
decode_feature_to_go_string(value_bytes, feature.feature_type)
if go_string
else decode_vector_or_string(value_bytes, feature.feature_type)
)
else:
size = get_scalar_size(feature.feature_type)
if size is None:
Expand All @@ -257,7 +268,11 @@ def decode_proto_features(
continue

value_bytes = reader.read(size)
result[feature.name] = decode_scalar_value(value_bytes, feature.feature_type)
result[feature.name] = (
decode_feature_to_go_string(value_bytes, feature.feature_type)
if go_string
else decode_scalar_value(value_bytes, feature.feature_type)
)
except Exception as e:
result[feature.name] = f"<decode_error: {e}>"

Expand All @@ -268,6 +283,7 @@ def decode_proto_format(
mplog_data: bytes,
schema: list[FeatureInfo],
needed_columns: Optional[Collection[str]] = None,
go_string: bool = True,
) -> tuple[list[str], list[dict[str, Any]]]:
"""
Decode proto format MPLog.
Expand All @@ -281,7 +297,9 @@ def decode_proto_format(

decoded_rows = []
for encoded_bytes in encoded_features_list:
decoded = decode_proto_features(encoded_bytes, schema, needed_columns=needed_columns)
decoded = decode_proto_features(
encoded_bytes, schema, needed_columns=needed_columns, go_string=go_string
)
decoded_rows.append(decoded)

# Ensure entity_ids matches decoded_rows count
Expand Down
Loading