Skip to content

Commit 6111b7b

Browse files
fix: Prevent silent conversion of float array to int (#96)
#### Relevant issue or PR N/A #### Description of changes Tesseract `Arrays` are intended to prevent conversion from float to int due to the following line in `runtime.array_encoding.coerce_shape_dtype`: ```python if not np.can_cast(arr.dtype, expected_dtype, casting=**"same_kind"**): ``` However, this does not work as intended due to casting happening at an earlier stage in `runtime.array_encoding.python_to_array` ```python arr = np.asarray(val, dtype=**expected_dtype**, order="C") ``` This PR removes the initial casting by changing the `dtype` kwarg to `None` (note that this is its default so we could remove entirely if preferred). Unfortunately, `asarray` does not offer a `casting` kwarg so we can't pass `"same_kind"` any earlier. #### Testing done - [x] Manually tested validation errors raised when trying to read json with float arrays into an int array: ``` ╭─ Error ───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮ │ Invalid value for ApplyInputSchema: 4 validation errors for ApplyInputSchema │ │ inputs.a.`chain[EncodedArrayModel__any__int32__DIFFERENTIABLE,function-plain[functools.partial(<function decode_array at │ │ 0x10542a980>, expected_shape=(None,), expected_dtype='int32')()]]` │ │ Input should be a valid dictionary or instance of EncodedArrayModel__any__int32__DIFFERENTIABLE [type=model_type, input_value=[1.1, │ │ 2.5, 3.9], input_type=list] │ │ For further information visit https://errors.pydantic.dev/2.10/v/model_type │ │ inputs.a.`function-plain[functools.partial(<function python_to_array at 0x10542aca0>, expected_shape=(None,), │ │ expected_dtype='int32')()]` │ │ Value error, Dtype mismatch: float64 cannot be cast to int32 [type=value_error, input_value=[1.1, 2.5, 3.9], input_type=list] │ │ For further information visit https://errors.pydantic.dev/2.10/v/value_error │ │ inputs.b.`chain[EncodedArrayModel__any__int32__DIFFERENTIABLE,function-plain[functools.partial(<function decode_array at │ │ 0x10542a980>, expected_shape=(None,), expected_dtype='int32')()]]` │ │ Input should be a valid dictionary or instance of EncodedArrayModel__any__int32__DIFFERENTIABLE [type=model_type, input_value=[4.1, │ │ 5.5, 6.9], input_type=list] │ │ For further information visit https://errors.pydantic.dev/2.10/v/model_type │ │ inputs.b.`function-plain[functools.partial(<function python_to_array at 0x10542aca0>, expected_shape=(None,), │ │ expected_dtype='int32')()]` │ │ Value error, Dtype mismatch: float64 cannot be cast to int32 [type=value_error, input_value=[4.1, 5.5, 6.9], input_type=list] │ │ For further information visit https://errors.pydantic.dev/2.10/v/value_error │ ╰───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ ``` - [x] Manually tested that validation errors are now raised when converting from float to int ```python >>> import numpy as np >>> from tesseract_core.runtime import Array, Int32 >>> from pydantic import TypeAdapter >>> TypeAdapter(Array[..., 'int32']).validate_python(2.5*np.ones((3,3))) Traceback (most recent call last): File "<python-input-3>", line 1, in <module> TypeAdapter(Array[..., 'int32']).validate_python(2.5*np.ones((3,3))) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^ File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/pydantic/type_adapter.py", line 412, in validate_python return self.validator.validate_python( ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^ object, ^^^^^^^ ...<3 lines>... allow_partial=experimental_allow_partial, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ) ^ pydantic_core._pydantic_core.ValidationError: 2 validation errors for json-or-python[json=chain[EncodedArrayModel__anyrank__int32__noflags,function-plain[functools.partial(<function decode_array at 0x105d6a340>, expected_shape=Ellipsis, expected_dtype='int32')()]],python=union[chain[EncodedArrayModel__anyrank__int32__noflags,function-plain[functools.partial(<function decode_array at 0x105d6a340>, expected_shape=Ellipsis, expected_dtype='int32')()]],function-plain[functools.partial(<function python_to_array at 0x105d6a660>, expected_shape=Ellipsis, expected_dtype='int32')()]]] `chain[EncodedArrayModel__anyrank__int32__noflags,function-plain[functools.partial(<function decode_array at 0x105d6a340>, expected_shape=Ellipsis, expected_dtype='int32')()]]` Input should be a valid dictionary or instance of EncodedArrayModel__anyrank__int32__noflags [type=model_type, input_value=array([[2.5, 2.5, 2.5], ... [2.5, 2.5, 2.5]]), input_type=ndarray] For further information visit https://errors.pydantic.dev/2.10/v/model_type `function-plain[functools.partial(<function python_to_array at 0x105d6a660>, expected_shape=Ellipsis, expected_dtype='int32')()]` Value error, Dtype mismatch: float64 cannot be cast to int32 [type=value_error, input_value=array([[2.5, 2.5, 2.5], ... [2.5, 2.5, 2.5]]), input_type=ndarray] For further information visit https://errors.pydantic.dev/2.10/v/value_error ``` - [x] CI passes #### License - [x] By submitting this pull request, I confirm that my contribution is made under the terms of the [Apache 2.0 license](https://pasteurlabs.github.io/tesseract/LICENSE). - [x] I sign the Developer Certificate of Origin below by adding my name and email address to the `Signed-off-by` line. <details> <summary><b>Developer Certificate of Origin</b></summary> ```text Developer Certificate of Origin Version 1.1 Copyright (C) 2004, 2006 The Linux Foundation and its contributors. Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed. Developer's Certificate of Origin 1.1 By making a contribution to this project, I certify that: (a) The contribution was created in whole or in part by me and I have the right to submit it under the open source license indicated in the file; or (b) The contribution is based upon previous work that, to the best of my knowledge, is covered under an appropriate open source license and I have the right under that license to submit that work with modifications, whether created in whole or in part by me, under the same open source license (unless I am permitted to submit under a different license), as indicated in the file; or (c) The contribution was provided directly to me by some other person who certified (a), (b) or (c) and I have not modified it. (d) I understand and agree that this project and the contribution are public and that a record of the contribution (including all personal information I submit with it, including my sign-off) is maintained indefinitely and may be redistributed consistent with this project or the open source license(s) involved. ``` </details> Signed-off-by: Jonathan Brodrick <jonathan.brodrick@simulation.science> --------- Co-authored-by: Dion Häfner <dion.haefner@simulation.science>
1 parent 1578ddb commit 6111b7b

File tree

3 files changed

+92
-11
lines changed

3 files changed

+92
-11
lines changed

tesseract_core/runtime/array_encoding.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -356,11 +356,14 @@ def python_to_array(
356356
val: Any, expected_shape: ShapeType, expected_dtype: Optional[str]
357357
) -> ArrayLike:
358358
"""Convert a Python object to a NumPy array."""
359-
try:
360-
arr = np.asarray(val, dtype=expected_dtype, order="C")
361-
except TypeError as exc:
362-
raise ValueError(f"Could not convert {val} to NumPy array") from exc
363-
return _coerce_shape_dtype(arr, expected_shape, expected_dtype)
359+
val = np.asarray(val, order="C")
360+
if not np.issubdtype(val.dtype, np.number) and not np.issubdtype(
361+
val.dtype, np.bool_
362+
):
363+
raise ValueError(
364+
f"Could not convert object to numeric NumPy array (got dtype: {val.dtype})"
365+
)
366+
return _coerce_shape_dtype(val, expected_shape, expected_dtype)
364367

365368

366369
def decode_array(
@@ -381,7 +384,15 @@ def decode_array(
381384

382385
# keep checking for "raw" for backwards compat
383386
elif val.data.encoding in {"json", "raw"}:
384-
data = np.asarray(val.data.buffer, dtype=val.dtype).reshape(val.shape)
387+
data = np.asarray(val.data.buffer).reshape(val.shape)
388+
if np.issubdtype(data.dtype, np.floating) and np.issubdtype(
389+
val.dtype, np.integer
390+
):
391+
if np.any(data % 1):
392+
raise ValueError(
393+
f"Expected integer data, but got floating point data: {data}"
394+
)
395+
data = data.astype(val.dtype, casting="unsafe")
385396

386397
else:
387398
# Unreachable

tests/runtime_tests/test_core.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ def test_ad_endpoint_bad_tangent(testmodule, endpoint_name, failure_mode):
408408
msg = "String should match pattern"
409409
elif failure_mode == "invalid":
410410
tangent_vector = {k: "ahoy" for k in ad_inp}
411-
msg = "Got invalid dtype"
411+
msg = "Could not convert object"
412412

413413
inputs = {
414414
"inputs": test_input,
@@ -427,7 +427,7 @@ def test_ad_endpoint_bad_tangent(testmodule, endpoint_name, failure_mode):
427427
msg = "String should match pattern"
428428
elif failure_mode == "invalid":
429429
cotangent_vector = {k: "ahoy" for k in ad_out}
430-
msg = "Got invalid dtype"
430+
msg = "Could not convert object"
431431

432432
inputs = {
433433
"inputs": test_input,
@@ -436,8 +436,6 @@ def test_ad_endpoint_bad_tangent(testmodule, endpoint_name, failure_mode):
436436
"cotangent_vector": cotangent_vector,
437437
}
438438

439-
with pytest.raises(ValidationError) as excinfo:
439+
with pytest.raises(ValidationError, match=msg):
440440
inputs = EndpointSchema.model_validate(inputs)
441441
endpoint_func(inputs)
442-
443-
assert msg in str(excinfo.value)

tests/runtime_tests/test_schema_types.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,3 +447,75 @@ class MyLazyModel(BaseModel):
447447
assert len(allowed_types) == 2
448448
assert allowed_types[0]["type"] == "array"
449449
assert allowed_types[1]["type"] == "string"
450+
451+
452+
def test_dtype_casting():
453+
json_payload_str = MyModel(
454+
array_int=arr_int,
455+
array_float=arr_float,
456+
array_bool=arr_bool,
457+
scalar_int=scalar_int,
458+
).model_dump_json()
459+
460+
# Base case: proper int data (should work fine)
461+
json_payload = json.loads(json_payload_str)
462+
json_payload["array_int"]["data"] = {
463+
"buffer": arr_int.flatten().tolist(),
464+
"encoding": "json",
465+
}
466+
res = MyModel.model_validate(json_payload)
467+
assert np.array_equal(res.array_int, arr_int)
468+
469+
# Case 1: floats in JSON array w/o fractional parts (should work fine)
470+
json_payload = json.loads(json_payload_str)
471+
json_payload["array_int"]["data"] = {
472+
"buffer": arr_int.astype(float).flatten().tolist(),
473+
"encoding": "json",
474+
}
475+
res = MyModel.model_validate(json_payload)
476+
assert np.array_equal(res.array_int, arr_int)
477+
478+
# Case 2: floats in JSON array w/ fractional parts (should raise)
479+
json_payload = json.loads(json_payload_str)
480+
json_payload["array_int"]["data"] = {
481+
"buffer": (arr_int.astype(float) + 1e-6).flatten().tolist(),
482+
"encoding": "json",
483+
}
484+
with pytest.raises(ValidationError, match="Expected integer data"):
485+
MyModel.model_validate(json_payload)
486+
487+
# Case 3: pass NumPy array directly (should work fine)
488+
json_payload = json.loads(json_payload_str)
489+
json_payload["array_int"] = arr_int
490+
res = MyModel.model_validate(json_payload)
491+
assert np.array_equal(res.array_int, arr_int)
492+
493+
# Case 4: pass NumPy array with incompatible dtype (should raise)
494+
json_payload = json.loads(json_payload_str)
495+
json_payload["array_int"] = arr_int.astype(np.float32)
496+
with pytest.raises(ValidationError, match="cannot be cast"):
497+
MyModel.model_validate(json_payload)
498+
499+
# Case 5: pass JSON data directly (should work fine)
500+
json_payload = json.loads(json_payload_str)
501+
json_payload["array_int"] = arr_int.tolist()
502+
res = MyModel.model_validate(json_payload)
503+
assert np.array_equal(res.array_int, arr_int)
504+
505+
# Case 6: pass JSON data with incompatible dtype (should raise)
506+
json_payload = json.loads(json_payload_str)
507+
json_payload["array_int"] = arr_int.astype(np.float32).tolist()
508+
with pytest.raises(ValidationError, match="cannot be cast"):
509+
MyModel.model_validate(json_payload)
510+
511+
# Case 7: Pass non-numeric data (should raise)
512+
json_payload = json.loads(json_payload_str)
513+
json_payload["array_int"] = ["a", "b", "c"]
514+
with pytest.raises(ValidationError, match="Could not convert object"):
515+
MyModel.model_validate(json_payload)
516+
517+
# Case 8: Pass non-numeric Python object (should raise)
518+
json_payload = json.loads(json_payload_str)
519+
json_payload["array_int"] = object()
520+
with pytest.raises(ValidationError, match="Could not convert object"):
521+
MyModel.model_validate(json_payload)

0 commit comments

Comments
 (0)