Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ classifiers = [
]

[project.optional-dependencies]
test = ["pytest"]
test = ["pytest", "numpy"]

[project.urls]
Documentation = "https://pals-project.readthedocs.io"
Expand Down
67 changes: 65 additions & 2 deletions src/pals/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,25 @@ def load_file_to_dict(filename: str) -> dict:
return pals_data


def _numpy_to_native(obj):
"""Convert a numpy scalar/array to its Python-native equivalent.

Returns ``None`` when the object is not a numpy type or when numpy is not
installed; callers use that to decide whether to fall back to the default
serializer behavior. numpy is an optional dependency.
"""
try:
import numpy as np
except ImportError:
return None

if isinstance(obj, np.ndarray):
return obj.tolist()
if isinstance(obj, np.generic):
return obj.item()
return None


def store_dict_to_file(filename: str, pals_dict: dict):
file_noext, extension, file_noext_noext, extension_inner = inspect_file_extensions(
filename
Expand All @@ -63,14 +82,58 @@ def store_dict_to_file(filename: str, pals_dict: dict):
if extension == ".json":
import json

json_data = json.dumps(pals_dict, sort_keys=False, indent=2)
def _json_default(obj):
native = _numpy_to_native(obj)
if native is not None:
return native
raise TypeError(
f"Object of type {type(obj).__name__} is not JSON serializable"
)

json_data = json.dumps(
pals_dict, sort_keys=False, indent=2, default=_json_default
)
with open(filename, "w") as file:
file.write(json_data)

elif extension == ".yaml":
import yaml

yaml_data = yaml.dump(pals_dict, default_flow_style=False, sort_keys=False)
# Subclass the safe dumper so numpy representers are scoped to PALS
# serialization and do not leak into the global pyyaml state used by
# other code in the same process.
class _PALSDumper(yaml.SafeDumper):
pass

try:
import numpy as np
except ImportError:
np = None

if np is not None:

def _represent_numpy_scalar(dumper, value):
native = value.item()
if isinstance(native, bool):
return dumper.represent_bool(native)
if isinstance(native, int):
return dumper.represent_int(native)
if isinstance(native, float):
return dumper.represent_float(native)
return dumper.represent_data(native)

def _represent_numpy_array(dumper, value):
return dumper.represent_list(value.tolist())

_PALSDumper.add_multi_representer(np.generic, _represent_numpy_scalar)
_PALSDumper.add_representer(np.ndarray, _represent_numpy_array)

yaml_data = yaml.dump(
pals_dict,
Dumper=_PALSDumper,
default_flow_style=False,
sort_keys=False,
)
with open(filename, "w") as file:
file.write(yaml_data)

Expand Down
74 changes: 74 additions & 0 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os

import pytest

import pals


Expand Down Expand Up @@ -332,3 +334,75 @@ def test_comprehensive_lattice():
# Clean up temporary files
os.remove(yaml_file)
os.remove(json_file)


def _build_numpy_lattice(np):
"""Build a small lattice using numpy-typed scalar values throughout."""
quad = pals.Quadrupole(
name="q_np",
length=np.float64(0.061),
MagneticMultipoleP=pals.MagneticMultipoleParameters(
Bn1=np.float64(-26.0), Bs1=np.float32(0.5), Kn0=np.int64(-1)
),
)
oct_ = pals.Octupole(
name="o_np",
length=np.float64(0.25),
ElectricMultipoleP=pals.ElectricMultipoleParameters(
En3=np.float64(0.75), Es3=np.float32(0.125)
),
)
return pals.BeamLine(name="line_np", line=[quad, oct_])


def test_yaml_roundtrip_with_numpy():
"""Regression test for issue #67: writing YAML with numpy-typed values
must not produce !!python/object tags, and round-tripping must yield
Python-native floats with the correct numeric values."""
np = pytest.importorskip("numpy")

line = _build_numpy_lattice(np)
yaml_file = "numpy_roundtrip.pals.yaml"
line.to_file(yaml_file)
try:
with open(yaml_file, "r") as f:
text = f.read()

# The bug symptom: YAML contains opaque numpy object tags.
assert "!!python/object" not in text, (
f"YAML output still contains unsafe numpy object tags:\n{text}"
)
assert "numpy" not in text, f"YAML output still references numpy:\n{text}"

loaded = pals.BeamLine.from_file(yaml_file)
loaded_quad = loaded.line[0]
assert loaded_quad.MagneticMultipoleP.Bn1 == -26.0
assert type(loaded_quad.MagneticMultipoleP.Bn1) is float
assert loaded_quad.MagneticMultipoleP.Bs1 == 0.5
assert loaded_quad.MagneticMultipoleP.Kn0 == -1

loaded_oct = loaded.line[1]
assert loaded_oct.ElectricMultipoleP.En3 == 0.75
assert type(loaded_oct.ElectricMultipoleP.En3) is float
finally:
if os.path.exists(yaml_file):
os.remove(yaml_file)
Comment on lines +358 to +389
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could use pytest's tmp_path fixture to avoid polluting and cleaning the root directory (no try/finally, no os.remove, etc.):

Suggested change
def test_yaml_roundtrip_with_numpy():
"""Regression test for issue #67: writing YAML with numpy-typed values
must not produce !!python/object tags, and round-tripping must yield
Python-native floats with the correct numeric values."""
np = pytest.importorskip("numpy")
line = _build_numpy_lattice(np)
yaml_file = "numpy_roundtrip.pals.yaml"
line.to_file(yaml_file)
try:
with open(yaml_file, "r") as f:
text = f.read()
# The bug symptom: YAML contains opaque numpy object tags.
assert "!!python/object" not in text, (
f"YAML output still contains unsafe numpy object tags:\n{text}"
)
assert "numpy" not in text, f"YAML output still references numpy:\n{text}"
loaded = pals.BeamLine.from_file(yaml_file)
loaded_quad = loaded.line[0]
assert loaded_quad.MagneticMultipoleP.Bn1 == -26.0
assert type(loaded_quad.MagneticMultipoleP.Bn1) is float
assert loaded_quad.MagneticMultipoleP.Bs1 == 0.5
assert loaded_quad.MagneticMultipoleP.Kn0 == -1
loaded_oct = loaded.line[1]
assert loaded_oct.ElectricMultipoleP.En3 == 0.75
assert type(loaded_oct.ElectricMultipoleP.En3) is float
finally:
if os.path.exists(yaml_file):
os.remove(yaml_file)
def test_yaml_roundtrip_with_numpy(tmp_path):
"""Regression test for issue #67: writing YAML with numpy-typed values
must not produce !!python/object tags, and round-tripping must yield
Python-native floats with the correct numeric values."""
np = pytest.importorskip("numpy")
line = _build_numpy_lattice(np)
yaml_file = str(tmp_path / "numpy_roundtrip.pals.yaml")
line.to_file(yaml_file)
with open(yaml_file, "r") as f:
text = f.read()
# The bug symptom: YAML contains opaque numpy object tags.
assert "!!python/object" not in text, (
f"YAML output still contains unsafe numpy object tags:\n{text}"
)
assert "numpy" not in text, f"YAML output still references numpy:\n{text}"
loaded = pals.BeamLine.from_file(yaml_file)
loaded_quad = loaded.line[0]
assert loaded_quad.MagneticMultipoleP.Bn1 == -26.0
assert type(loaded_quad.MagneticMultipoleP.Bn1) is float
assert loaded_quad.MagneticMultipoleP.Bs1 == 0.5
assert loaded_quad.MagneticMultipoleP.Kn0 == -1
loaded_oct = loaded.line[1]
assert loaded_oct.ElectricMultipoleP.En3 == 0.75
assert type(loaded_oct.ElectricMultipoleP.En3) is float

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For convenience, the diff without whitespace would be:

-def test_yaml_roundtrip_with_numpy():
+def test_yaml_roundtrip_with_numpy(tmp_path):
     """Regression test for issue #67: writing YAML with numpy-typed values
     must not produce !!python/object tags, and round-tripping must yield
     Python-native floats with the correct numeric values."""
     np = pytest.importorskip("numpy")
 
     line = _build_numpy_lattice(np)
-    yaml_file = "numpy_roundtrip.pals.yaml"
+    yaml_file = str(tmp_path / "numpy_roundtrip.pals.yaml")
     line.to_file(yaml_file)
-    try:
+
     with open(yaml_file, "r") as f:
         text = f.read()
 
@@ -384,25 +384,19 @@ def test_yaml_roundtrip_with_numpy():
     loaded_oct = loaded.line[1]
     assert loaded_oct.ElectricMultipoleP.En3 == 0.75
     assert type(loaded_oct.ElectricMultipoleP.En3) is float
-    finally:
-        if os.path.exists(yaml_file):
-            os.remove(yaml_file)



def test_json_roundtrip_with_numpy():
"""JSON path also needs to handle numpy values cleanly (defense-in-depth)."""
np = pytest.importorskip("numpy")

line = _build_numpy_lattice(np)
json_file = "numpy_roundtrip.pals.json"
line.to_file(json_file)
try:
loaded = pals.BeamLine.from_file(json_file)
loaded_quad = loaded.line[0]
assert loaded_quad.MagneticMultipoleP.Bn1 == -26.0
assert type(loaded_quad.MagneticMultipoleP.Bn1) is float
loaded_oct = loaded.line[1]
assert loaded_oct.ElectricMultipoleP.En3 == 0.75
finally:
if os.path.exists(json_file):
os.remove(json_file)
Comment on lines +392 to +408
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could use pytest's tmp_path fixture to avoid polluting and cleaning the root directory (no try/finally, no os.remove, etc.):

Suggested change
def test_json_roundtrip_with_numpy():
"""JSON path also needs to handle numpy values cleanly (defense-in-depth)."""
np = pytest.importorskip("numpy")
line = _build_numpy_lattice(np)
json_file = "numpy_roundtrip.pals.json"
line.to_file(json_file)
try:
loaded = pals.BeamLine.from_file(json_file)
loaded_quad = loaded.line[0]
assert loaded_quad.MagneticMultipoleP.Bn1 == -26.0
assert type(loaded_quad.MagneticMultipoleP.Bn1) is float
loaded_oct = loaded.line[1]
assert loaded_oct.ElectricMultipoleP.En3 == 0.75
finally:
if os.path.exists(json_file):
os.remove(json_file)
def test_json_roundtrip_with_numpy(tmp_path):
"""JSON path also needs to handle numpy values cleanly (defense-in-depth)."""
np = pytest.importorskip("numpy")
line = _build_numpy_lattice(np)
json_file = str(tmp_path / "numpy_roundtrip.pals.json")
line.to_file(json_file)
loaded = pals.BeamLine.from_file(json_file)
loaded_quad = loaded.line[0]
assert loaded_quad.MagneticMultipoleP.Bn1 == -26.0
assert type(loaded_quad.MagneticMultipoleP.Bn1) is float
loaded_oct = loaded.line[1]
assert loaded_oct.ElectricMultipoleP.En3 == 0.75

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For convenience, the diff without whitespace would be:

-def test_json_roundtrip_with_numpy():
+def test_json_roundtrip_with_numpy(tmp_path):
     """JSON path also needs to handle numpy values cleanly (defense-in-depth)."""
     np = pytest.importorskip("numpy")
 
     line = _build_numpy_lattice(np)
-    json_file = "numpy_roundtrip.pals.json"
+    json_file = str(tmp_path / "numpy_roundtrip.pals.json")
     line.to_file(json_file)
-    try:
+
     loaded = pals.BeamLine.from_file(json_file)
     loaded_quad = loaded.line[0]
     assert loaded_quad.MagneticMultipoleP.Bn1 == -26.0
     assert type(loaded_quad.MagneticMultipoleP.Bn1) is float
     loaded_oct = loaded.line[1]
     assert loaded_oct.ElectricMultipoleP.En3 == 0.75
-    finally:
-        if os.path.exists(json_file):
-            os.remove(json_file)

Loading