diff --git a/pyproject.toml b/pyproject.toml index 269a045..79bcd5f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ classifiers = [ ] [project.optional-dependencies] -test = ["pytest"] +test = ["pytest", "numpy"] [project.urls] Documentation = "https://pals-project.readthedocs.io" diff --git a/src/pals/functions.py b/src/pals/functions.py index cc6ad9a..531e9e1 100644 --- a/src/pals/functions.py +++ b/src/pals/functions.py @@ -54,6 +54,25 @@ def load_file_to_dict(filename: str) -> dict: return pals_data +def _numpy_to_native(obj): + """Convert a numpy scalar/array to its Python-native equivalent. + + Returns ``None`` when the object is not a numpy type or when numpy is not + installed; callers use that to decide whether to fall back to the default + serializer behavior. numpy is an optional dependency. + """ + try: + import numpy as np + except ImportError: + return None + + if isinstance(obj, np.ndarray): + return obj.tolist() + if isinstance(obj, np.generic): + return obj.item() + return None + + def store_dict_to_file(filename: str, pals_dict: dict): file_noext, extension, file_noext_noext, extension_inner = inspect_file_extensions( filename @@ -63,14 +82,58 @@ def store_dict_to_file(filename: str, pals_dict: dict): if extension == ".json": import json - json_data = json.dumps(pals_dict, sort_keys=False, indent=2) + def _json_default(obj): + native = _numpy_to_native(obj) + if native is not None: + return native + raise TypeError( + f"Object of type {type(obj).__name__} is not JSON serializable" + ) + + json_data = json.dumps( + pals_dict, sort_keys=False, indent=2, default=_json_default + ) with open(filename, "w") as file: file.write(json_data) elif extension == ".yaml": import yaml - yaml_data = yaml.dump(pals_dict, default_flow_style=False, sort_keys=False) + # Subclass the safe dumper so numpy representers are scoped to PALS + # serialization and do not leak into the global pyyaml state used by + # other code in the same process. + class _PALSDumper(yaml.SafeDumper): + pass + + try: + import numpy as np + except ImportError: + np = None + + if np is not None: + + def _represent_numpy_scalar(dumper, value): + native = value.item() + if isinstance(native, bool): + return dumper.represent_bool(native) + if isinstance(native, int): + return dumper.represent_int(native) + if isinstance(native, float): + return dumper.represent_float(native) + return dumper.represent_data(native) + + def _represent_numpy_array(dumper, value): + return dumper.represent_list(value.tolist()) + + _PALSDumper.add_multi_representer(np.generic, _represent_numpy_scalar) + _PALSDumper.add_representer(np.ndarray, _represent_numpy_array) + + yaml_data = yaml.dump( + pals_dict, + Dumper=_PALSDumper, + default_flow_style=False, + sort_keys=False, + ) with open(filename, "w") as file: file.write(yaml_data) diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 7e05b67..75627ac 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -1,5 +1,7 @@ import os +import pytest + import pals @@ -332,3 +334,75 @@ def test_comprehensive_lattice(): # Clean up temporary files os.remove(yaml_file) os.remove(json_file) + + +def _build_numpy_lattice(np): + """Build a small lattice using numpy-typed scalar values throughout.""" + quad = pals.Quadrupole( + name="q_np", + length=np.float64(0.061), + MagneticMultipoleP=pals.MagneticMultipoleParameters( + Bn1=np.float64(-26.0), Bs1=np.float32(0.5), Kn0=np.int64(-1) + ), + ) + oct_ = pals.Octupole( + name="o_np", + length=np.float64(0.25), + ElectricMultipoleP=pals.ElectricMultipoleParameters( + En3=np.float64(0.75), Es3=np.float32(0.125) + ), + ) + return pals.BeamLine(name="line_np", line=[quad, oct_]) + + +def test_yaml_roundtrip_with_numpy(): + """Regression test for issue #67: writing YAML with numpy-typed values + must not produce !!python/object tags, and round-tripping must yield + Python-native floats with the correct numeric values.""" + np = pytest.importorskip("numpy") + + line = _build_numpy_lattice(np) + yaml_file = "numpy_roundtrip.pals.yaml" + line.to_file(yaml_file) + try: + with open(yaml_file, "r") as f: + text = f.read() + + # The bug symptom: YAML contains opaque numpy object tags. + assert "!!python/object" not in text, ( + f"YAML output still contains unsafe numpy object tags:\n{text}" + ) + assert "numpy" not in text, f"YAML output still references numpy:\n{text}" + + loaded = pals.BeamLine.from_file(yaml_file) + loaded_quad = loaded.line[0] + assert loaded_quad.MagneticMultipoleP.Bn1 == -26.0 + assert type(loaded_quad.MagneticMultipoleP.Bn1) is float + assert loaded_quad.MagneticMultipoleP.Bs1 == 0.5 + assert loaded_quad.MagneticMultipoleP.Kn0 == -1 + + loaded_oct = loaded.line[1] + assert loaded_oct.ElectricMultipoleP.En3 == 0.75 + assert type(loaded_oct.ElectricMultipoleP.En3) is float + finally: + if os.path.exists(yaml_file): + os.remove(yaml_file) + + +def test_json_roundtrip_with_numpy(): + """JSON path also needs to handle numpy values cleanly (defense-in-depth).""" + np = pytest.importorskip("numpy") + + line = _build_numpy_lattice(np) + json_file = "numpy_roundtrip.pals.json" + line.to_file(json_file) + try: + loaded = pals.BeamLine.from_file(json_file) + loaded_quad = loaded.line[0] + assert loaded_quad.MagneticMultipoleP.Bn1 == -26.0 + assert type(loaded_quad.MagneticMultipoleP.Bn1) is float + loaded_oct = loaded.line[1] + assert loaded_oct.ElectricMultipoleP.En3 == 0.75 + finally: + if os.path.exists(json_file): + os.remove(json_file)