From f514cc29a460e31c02b439860bba5dadef19def0 Mon Sep 17 00:00:00 2001 From: Axel Huebl Date: Tue, 12 May 2026 10:59:47 -0700 Subject: [PATCH 1/5] Test: NumPy Multipole Serialization --- pyproject.toml | 2 +- tests/test_parameters.py | 50 +++++++++++++++++++++++++ tests/test_serialization.py | 74 +++++++++++++++++++++++++++++++++++++ 3 files changed, 125 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 269a045..79bcd5f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ classifiers = [ ] [project.optional-dependencies] -test = ["pytest"] +test = ["pytest", "numpy"] [project.urls] Documentation = "https://pals-project.readthedocs.io" diff --git a/tests/test_parameters.py b/tests/test_parameters.py index 09416bc..a78eacf 100644 --- a/tests/test_parameters.py +++ b/tests/test_parameters.py @@ -134,3 +134,53 @@ def test_ParameterClasses(): # Test BeamBeamParameters beambeam = BeamBeamParameters() assert beambeam is not None + + +def test_multipole_numpy_coercion(): + """Regression test for issue #67: numpy scalars passed to multipole parameter + classes must be coerced to Python-native numeric types at construction time, + so YAML/JSON serialization produces clean output regardless of input type.""" + np = pytest.importorskip("numpy") + + # MagneticMultipoleParameters: cover all prefixes and several numpy dtypes + mmp = MagneticMultipoleParameters( + tilt1=np.float64(0.1), + Bn1=np.float64(1.5), + Bn2=np.float32(2.5), + Bs1=np.int64(3), + Kn0=np.int32(-1), + Ks1=np.float64(0.25), + ) + assert type(mmp.tilt1) is float and mmp.tilt1 == 0.1 + assert type(mmp.Bn1) is float and mmp.Bn1 == 1.5 + assert type(mmp.Bn2) is float and mmp.Bn2 == 2.5 + assert type(mmp.Bs1) is int and mmp.Bs1 == 3 + assert type(mmp.Kn0) is int and mmp.Kn0 == -1 + assert type(mmp.Ks1) is float and mmp.Ks1 == 0.25 + + # 0-d numpy array also works + mmp_arr = MagneticMultipoleParameters(Bn1=np.array(4.2)) + assert type(mmp_arr.Bn1) is float and mmp_arr.Bn1 == 4.2 + + # Length-integrated variants + mmp_L = MagneticMultipoleParameters(Bn1L=np.float64(7.0), Ks1L=np.float64(8.0)) + assert type(mmp_L.Bn1L) is float and mmp_L.Bn1L == 7.0 + assert type(mmp_L.Ks1L) is float and mmp_L.Ks1L == 8.0 + + # ElectricMultipoleParameters: cover all prefixes + emp = ElectricMultipoleParameters( + tilt1=np.float64(0.2), + En1=np.float64(0.5), + Es1=np.int64(2), + ) + assert type(emp.tilt1) is float and emp.tilt1 == 0.2 + assert type(emp.En1) is float and emp.En1 == 0.5 + assert type(emp.Es1) is int and emp.Es1 == 2 + + emp_L = ElectricMultipoleParameters(En1L=np.float64(1.0), Es1L=np.float64(0.5)) + assert type(emp_L.En1L) is float and emp_L.En1L == 1.0 + assert type(emp_L.Es1L) is float and emp_L.Es1L == 0.5 + + # Plain Python values must still pass through unchanged + mmp_plain = MagneticMultipoleParameters(Bn1=1.5) + assert type(mmp_plain.Bn1) is float and mmp_plain.Bn1 == 1.5 diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 7e05b67..75627ac 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -1,5 +1,7 @@ import os +import pytest + import pals @@ -332,3 +334,75 @@ def test_comprehensive_lattice(): # Clean up temporary files os.remove(yaml_file) os.remove(json_file) + + +def _build_numpy_lattice(np): + """Build a small lattice using numpy-typed scalar values throughout.""" + quad = pals.Quadrupole( + name="q_np", + length=np.float64(0.061), + MagneticMultipoleP=pals.MagneticMultipoleParameters( + Bn1=np.float64(-26.0), Bs1=np.float32(0.5), Kn0=np.int64(-1) + ), + ) + oct_ = pals.Octupole( + name="o_np", + length=np.float64(0.25), + ElectricMultipoleP=pals.ElectricMultipoleParameters( + En3=np.float64(0.75), Es3=np.float32(0.125) + ), + ) + return pals.BeamLine(name="line_np", line=[quad, oct_]) + + +def test_yaml_roundtrip_with_numpy(): + """Regression test for issue #67: writing YAML with numpy-typed values + must not produce !!python/object tags, and round-tripping must yield + Python-native floats with the correct numeric values.""" + np = pytest.importorskip("numpy") + + line = _build_numpy_lattice(np) + yaml_file = "numpy_roundtrip.pals.yaml" + line.to_file(yaml_file) + try: + with open(yaml_file, "r") as f: + text = f.read() + + # The bug symptom: YAML contains opaque numpy object tags. + assert "!!python/object" not in text, ( + f"YAML output still contains unsafe numpy object tags:\n{text}" + ) + assert "numpy" not in text, f"YAML output still references numpy:\n{text}" + + loaded = pals.BeamLine.from_file(yaml_file) + loaded_quad = loaded.line[0] + assert loaded_quad.MagneticMultipoleP.Bn1 == -26.0 + assert type(loaded_quad.MagneticMultipoleP.Bn1) is float + assert loaded_quad.MagneticMultipoleP.Bs1 == 0.5 + assert loaded_quad.MagneticMultipoleP.Kn0 == -1 + + loaded_oct = loaded.line[1] + assert loaded_oct.ElectricMultipoleP.En3 == 0.75 + assert type(loaded_oct.ElectricMultipoleP.En3) is float + finally: + if os.path.exists(yaml_file): + os.remove(yaml_file) + + +def test_json_roundtrip_with_numpy(): + """JSON path also needs to handle numpy values cleanly (defense-in-depth).""" + np = pytest.importorskip("numpy") + + line = _build_numpy_lattice(np) + json_file = "numpy_roundtrip.pals.json" + line.to_file(json_file) + try: + loaded = pals.BeamLine.from_file(json_file) + loaded_quad = loaded.line[0] + assert loaded_quad.MagneticMultipoleP.Bn1 == -26.0 + assert type(loaded_quad.MagneticMultipoleP.Bn1) is float + loaded_oct = loaded.line[1] + assert loaded_oct.ElectricMultipoleP.En3 == 0.75 + finally: + if os.path.exists(json_file): + os.remove(json_file) From 970614486d4b85262b86109338708a317b12c836 Mon Sep 17 00:00:00 2001 From: Axel Huebl Date: Tue, 12 May 2026 12:41:34 -0700 Subject: [PATCH 2/5] Fix: NumPy Serialization --- src/pals/functions.py | 67 ++++++++++- .../parameters/ElectricMultipoleParameters.py | 58 ++-------- .../parameters/MagneticMultipoleParameters.py | 62 ++-------- src/pals/parameters/_multipole_base.py | 108 ++++++++++++++++++ 4 files changed, 193 insertions(+), 102 deletions(-) create mode 100644 src/pals/parameters/_multipole_base.py diff --git a/src/pals/functions.py b/src/pals/functions.py index cc6ad9a..531e9e1 100644 --- a/src/pals/functions.py +++ b/src/pals/functions.py @@ -54,6 +54,25 @@ def load_file_to_dict(filename: str) -> dict: return pals_data +def _numpy_to_native(obj): + """Convert a numpy scalar/array to its Python-native equivalent. + + Returns ``None`` when the object is not a numpy type or when numpy is not + installed; callers use that to decide whether to fall back to the default + serializer behavior. numpy is an optional dependency. + """ + try: + import numpy as np + except ImportError: + return None + + if isinstance(obj, np.ndarray): + return obj.tolist() + if isinstance(obj, np.generic): + return obj.item() + return None + + def store_dict_to_file(filename: str, pals_dict: dict): file_noext, extension, file_noext_noext, extension_inner = inspect_file_extensions( filename @@ -63,14 +82,58 @@ def store_dict_to_file(filename: str, pals_dict: dict): if extension == ".json": import json - json_data = json.dumps(pals_dict, sort_keys=False, indent=2) + def _json_default(obj): + native = _numpy_to_native(obj) + if native is not None: + return native + raise TypeError( + f"Object of type {type(obj).__name__} is not JSON serializable" + ) + + json_data = json.dumps( + pals_dict, sort_keys=False, indent=2, default=_json_default + ) with open(filename, "w") as file: file.write(json_data) elif extension == ".yaml": import yaml - yaml_data = yaml.dump(pals_dict, default_flow_style=False, sort_keys=False) + # Subclass the safe dumper so numpy representers are scoped to PALS + # serialization and do not leak into the global pyyaml state used by + # other code in the same process. + class _PALSDumper(yaml.SafeDumper): + pass + + try: + import numpy as np + except ImportError: + np = None + + if np is not None: + + def _represent_numpy_scalar(dumper, value): + native = value.item() + if isinstance(native, bool): + return dumper.represent_bool(native) + if isinstance(native, int): + return dumper.represent_int(native) + if isinstance(native, float): + return dumper.represent_float(native) + return dumper.represent_data(native) + + def _represent_numpy_array(dumper, value): + return dumper.represent_list(value.tolist()) + + _PALSDumper.add_multi_representer(np.generic, _represent_numpy_scalar) + _PALSDumper.add_representer(np.ndarray, _represent_numpy_array) + + yaml_data = yaml.dump( + pals_dict, + Dumper=_PALSDumper, + default_flow_style=False, + sort_keys=False, + ) with open(filename, "w") as file: file.write(yaml_data) diff --git a/src/pals/parameters/ElectricMultipoleParameters.py b/src/pals/parameters/ElectricMultipoleParameters.py index af16624..c0f8d7d 100644 --- a/src/pals/parameters/ElectricMultipoleParameters.py +++ b/src/pals/parameters/ElectricMultipoleParameters.py @@ -1,27 +1,9 @@ -from pydantic import BaseModel, ConfigDict, model_validator -from typing import Any +from typing import ClassVar -# Valid parameter prefixes, their expected format and description -_PARAMETER_PREFIXES = { - "tilt": ("tiltN", "Tilt"), - "En": ("EnN", "Normal component"), - "Es": ("EsN", "Skew component"), -} +from pals.parameters._multipole_base import _MultipoleBase -def _validate_order( - key_num: str, parameter_name: str, prefix: str, expected_format: str -) -> None: - """Validate that the order number is a non-negative integer without leading zeros.""" - error_msg = ( - f"Invalid {parameter_name}: '{prefix}{key_num}'. " - f"Parameter must be of the form '{expected_format}', where 'N' is a non-negative integer without leading zeros." - ) - if not key_num.isdigit() or (key_num.startswith("0") and key_num != "0"): - raise ValueError(error_msg) - - -class ElectricMultipoleParameters(BaseModel): +class ElectricMultipoleParameters(_MultipoleBase): """Electric multipole parameters Valid parameter formats: @@ -33,31 +15,9 @@ class ElectricMultipoleParameters(BaseModel): Where N is a positive integer without leading zeros (except "0" itself). """ - model_config = ConfigDict(extra="allow") - - @model_validator(mode="before") - @classmethod - def validate(cls, values: dict[str, Any]) -> dict[str, Any]: - """Validate all parameter names match the expected multipole format.""" - for key in values: - # Check if key ends with 'L' for length-integrated values - is_length_integrated = key.endswith("L") - base_key = key[:-1] if is_length_integrated else key - - # No length-integrated values allowed for tilt parameter - if is_length_integrated and base_key.startswith("tilt"): - raise ValueError(f"Invalid electric multipole parameter: '{key}'. ") - - # Find matching prefix - for prefix, (expected_format, description) in _PARAMETER_PREFIXES.items(): - if base_key.startswith(prefix): - key_num = base_key[len(prefix) :] - _validate_order(key_num, description, prefix, expected_format) - break - else: - raise ValueError( - f"Invalid electric multipole parameter: '{key}'. " - f"Parameters must be of the form 'tiltN', 'EnN', or 'EsN' " - f"(with optional 'L' suffix for length-integrated), where 'N' is a non-negative integer." - ) - return values + _KIND_NAME: ClassVar[str] = "electric" + _PARAMETER_PREFIXES: ClassVar[dict[str, tuple[str, str]]] = { + "tilt": ("tiltN", "Tilt"), + "En": ("EnN", "Normal component"), + "Es": ("EsN", "Skew component"), + } diff --git a/src/pals/parameters/MagneticMultipoleParameters.py b/src/pals/parameters/MagneticMultipoleParameters.py index 7c4b176..47d38ad 100644 --- a/src/pals/parameters/MagneticMultipoleParameters.py +++ b/src/pals/parameters/MagneticMultipoleParameters.py @@ -1,29 +1,9 @@ -from pydantic import BaseModel, ConfigDict, model_validator -from typing import Any +from typing import ClassVar -# Valid parameter prefixes, their expected format and description -_PARAMETER_PREFIXES = { - "tilt": ("tiltN", "Tilt"), - "Bn": ("BnN", "Normal component"), - "Bs": ("BsN", "Skew component"), - "Kn": ("KnN", "Normalized normal component"), - "Ks": ("KsN", "Normalized skew component"), -} +from pals.parameters._multipole_base import _MultipoleBase -def _validate_order( - key_num: str, parameter_name: str, prefix: str, expected_format: str -) -> None: - """Validate that the order number is a non-negative integer without leading zeros.""" - error_msg = ( - f"Invalid {parameter_name}: '{prefix}{key_num}'. " - f"Parameter must be of the form '{expected_format}', where 'N' is a non-negative integer without leading zeros." - ) - if not key_num.isdigit() or (key_num.startswith("0") and key_num != "0"): - raise ValueError(error_msg) - - -class MagneticMultipoleParameters(BaseModel): +class MagneticMultipoleParameters(_MultipoleBase): """Magnetic multipole parameters Valid parameter formats: @@ -37,31 +17,11 @@ class MagneticMultipoleParameters(BaseModel): Where N is a positive integer without leading zeros (except "0" itself). """ - model_config = ConfigDict(extra="allow") - - @model_validator(mode="before") - @classmethod - def validate(cls, values: dict[str, Any]) -> dict[str, Any]: - """Validate all parameter names match the expected multipole format.""" - for key in values: - # Check if key ends with 'L' for length-integrated values - is_length_integrated = key.endswith("L") - base_key = key[:-1] if is_length_integrated else key - - # No length-integrated values allowed for tilt parameter - if is_length_integrated and base_key.startswith("tilt"): - raise ValueError(f"Invalid magnetic multipole parameter: '{key}'. ") - - # Find matching prefix - for prefix, (expected_format, description) in _PARAMETER_PREFIXES.items(): - if base_key.startswith(prefix): - key_num = base_key[len(prefix) :] - _validate_order(key_num, description, prefix, expected_format) - break - else: - raise ValueError( - f"Invalid magnetic multipole parameter: '{key}'. " - f"Parameters must be of the form 'tiltN', 'BnN', 'BsN', 'KnN', or 'KsN' " - f"(with optional 'L' suffix for length-integrated), where 'N' is a non-negative integer." - ) - return values + _KIND_NAME: ClassVar[str] = "magnetic" + _PARAMETER_PREFIXES: ClassVar[dict[str, tuple[str, str]]] = { + "tilt": ("tiltN", "Tilt"), + "Bn": ("BnN", "Normal component"), + "Bs": ("BsN", "Skew component"), + "Kn": ("KnN", "Normalized normal component"), + "Ks": ("KsN", "Normalized skew component"), + } diff --git a/src/pals/parameters/_multipole_base.py b/src/pals/parameters/_multipole_base.py new file mode 100644 index 0000000..3f1170b --- /dev/null +++ b/src/pals/parameters/_multipole_base.py @@ -0,0 +1,108 @@ +"""Private shared base class for multipole parameter groups. + +Both :class:`MagneticMultipoleParameters` and :class:`ElectricMultipoleParameters` +allow arbitrary order-indexed extra fields (e.g. ``Bn1``, ``Es3``, ``Kn0L``). +Because these fields are not declared with a type, Pydantic would otherwise +store them as-is, preserving non-native numeric inputs like ``numpy.float64``. +That breaks downstream YAML serialization (PyYAML emits unsafe Python-object +tags for numpy scalars). See pals-project/pals-python#67. + +This module centralizes the name-validation logic and adds numpy-to-native +coercion at construction time. +""" + +from typing import Any, ClassVar + +from pydantic import BaseModel, ConfigDict, model_validator + + +def _coerce_numpy_value(value: Any) -> Any: + """Convert numpy scalars/arrays to Python-native equivalents. + + Recurses through ``list``/``tuple``/``dict`` containers so nested + structures are also cleaned. Returns ``value`` unchanged when numpy is + not installed or the value is not a numpy type. numpy remains an optional + dependency of this project. + """ + try: + import numpy as np + except ImportError: + return value + + if isinstance(value, np.ndarray): + if value.ndim == 0: + return value.item() + return _coerce_numpy_value(value.tolist()) + if isinstance(value, np.generic): + return value.item() + if isinstance(value, list): + return [_coerce_numpy_value(v) for v in value] + if isinstance(value, tuple): + return tuple(_coerce_numpy_value(v) for v in value) + if isinstance(value, dict): + return {k: _coerce_numpy_value(v) for k, v in value.items()} + return value + + +def _validate_order( + key_num: str, parameter_name: str, prefix: str, expected_format: str +) -> None: + """Validate that the order number is a non-negative integer without leading zeros.""" + error_msg = ( + f"Invalid {parameter_name}: '{prefix}{key_num}'. " + f"Parameter must be of the form '{expected_format}', " + f"where 'N' is a non-negative integer without leading zeros." + ) + if not key_num.isdigit() or (key_num.startswith("0") and key_num != "0"): + raise ValueError(error_msg) + + +class _MultipoleBase(BaseModel): + """Private shared base for multipole parameter groups. + + Subclasses must set :attr:`_PARAMETER_PREFIXES` and :attr:`_KIND_NAME`. + Both are ``ClassVar`` and are not exposed as Pydantic fields. + """ + + # Subclasses override these: + _PARAMETER_PREFIXES: ClassVar[dict[str, tuple[str, str]]] = {} + _KIND_NAME: ClassVar[str] = "multipole" + + model_config = ConfigDict(extra="allow") + + @model_validator(mode="before") + @classmethod + def _validate_and_coerce(cls, values: dict[str, Any]) -> dict[str, Any]: + """Validate parameter names and coerce numpy values to Python natives.""" + coerced: dict[str, Any] = {} + for key, value in values.items(): + is_length_integrated = key.endswith("L") + base_key = key[:-1] if is_length_integrated else key + + if is_length_integrated and base_key.startswith("tilt"): + raise ValueError( + f"Invalid {cls._KIND_NAME} multipole parameter: '{key}'. " + ) + + for prefix, ( + expected_format, + description, + ) in cls._PARAMETER_PREFIXES.items(): + if base_key.startswith(prefix): + key_num = base_key[len(prefix) :] + _validate_order(key_num, description, prefix, expected_format) + break + else: + prefix_list = ", ".join( + f"'{p}N'" for p in cls._PARAMETER_PREFIXES if p != "tilt" + ) + raise ValueError( + f"Invalid {cls._KIND_NAME} multipole parameter: '{key}'. " + f"Parameters must be of the form 'tiltN', {prefix_list} " + f"(with optional 'L' suffix for length-integrated), " + f"where 'N' is a non-negative integer." + ) + + coerced[key] = _coerce_numpy_value(value) + + return coerced From 3792a3410a4a4f8b9be0b14177e02e34238296c5 Mon Sep 17 00:00:00 2001 From: Axel Huebl Date: Tue, 12 May 2026 12:55:53 -0700 Subject: [PATCH 3/5] Generalize NumPy Serialization --- src/pals/parameters/_multipole_base.py | 52 +++++--------------------- 1 file changed, 10 insertions(+), 42 deletions(-) diff --git a/src/pals/parameters/_multipole_base.py b/src/pals/parameters/_multipole_base.py index 3f1170b..c37faed 100644 --- a/src/pals/parameters/_multipole_base.py +++ b/src/pals/parameters/_multipole_base.py @@ -1,14 +1,13 @@ """Private shared base class for multipole parameter groups. Both :class:`MagneticMultipoleParameters` and :class:`ElectricMultipoleParameters` -allow arbitrary order-indexed extra fields (e.g. ``Bn1``, ``Es3``, ``Kn0L``). -Because these fields are not declared with a type, Pydantic would otherwise -store them as-is, preserving non-native numeric inputs like ``numpy.float64``. -That breaks downstream YAML serialization (PyYAML emits unsafe Python-object -tags for numpy scalars). See pals-project/pals-python#67. +allow arbitrary order-indexed extra fields (e.g. ``Bn1``, ``Es3``, ``Kn0L``) and +share the same name-validation logic. This module centralizes that logic. -This module centralizes the name-validation logic and adds numpy-to-native -coercion at construction time. +numpy interoperability (see pals-project/pals-python#67) is handled at the +serialization boundary in :mod:`pals.functions`, which keeps the fix general: +any numpy scalar reaching ``yaml.dump`` or ``json.dumps`` is converted to a +Python-native equivalent regardless of which model produced it. """ from typing import Any, ClassVar @@ -16,34 +15,6 @@ from pydantic import BaseModel, ConfigDict, model_validator -def _coerce_numpy_value(value: Any) -> Any: - """Convert numpy scalars/arrays to Python-native equivalents. - - Recurses through ``list``/``tuple``/``dict`` containers so nested - structures are also cleaned. Returns ``value`` unchanged when numpy is - not installed or the value is not a numpy type. numpy remains an optional - dependency of this project. - """ - try: - import numpy as np - except ImportError: - return value - - if isinstance(value, np.ndarray): - if value.ndim == 0: - return value.item() - return _coerce_numpy_value(value.tolist()) - if isinstance(value, np.generic): - return value.item() - if isinstance(value, list): - return [_coerce_numpy_value(v) for v in value] - if isinstance(value, tuple): - return tuple(_coerce_numpy_value(v) for v in value) - if isinstance(value, dict): - return {k: _coerce_numpy_value(v) for k, v in value.items()} - return value - - def _validate_order( key_num: str, parameter_name: str, prefix: str, expected_format: str ) -> None: @@ -72,10 +43,9 @@ class _MultipoleBase(BaseModel): @model_validator(mode="before") @classmethod - def _validate_and_coerce(cls, values: dict[str, Any]) -> dict[str, Any]: - """Validate parameter names and coerce numpy values to Python natives.""" - coerced: dict[str, Any] = {} - for key, value in values.items(): + def _validate(cls, values: dict[str, Any]) -> dict[str, Any]: + """Validate that all parameter names match the expected multipole format.""" + for key in values: is_length_integrated = key.endswith("L") base_key = key[:-1] if is_length_integrated else key @@ -103,6 +73,4 @@ def _validate_and_coerce(cls, values: dict[str, Any]) -> dict[str, Any]: f"where 'N' is a non-negative integer." ) - coerced[key] = _coerce_numpy_value(value) - - return coerced + return values From 23f93a9476bb26a491a8dcce5d643a0ec06840b1 Mon Sep 17 00:00:00 2001 From: Axel Huebl Date: Tue, 12 May 2026 12:56:30 -0700 Subject: [PATCH 4/5] Simplify Test --- tests/test_parameters.py | 50 ---------------------------------------- 1 file changed, 50 deletions(-) diff --git a/tests/test_parameters.py b/tests/test_parameters.py index a78eacf..09416bc 100644 --- a/tests/test_parameters.py +++ b/tests/test_parameters.py @@ -134,53 +134,3 @@ def test_ParameterClasses(): # Test BeamBeamParameters beambeam = BeamBeamParameters() assert beambeam is not None - - -def test_multipole_numpy_coercion(): - """Regression test for issue #67: numpy scalars passed to multipole parameter - classes must be coerced to Python-native numeric types at construction time, - so YAML/JSON serialization produces clean output regardless of input type.""" - np = pytest.importorskip("numpy") - - # MagneticMultipoleParameters: cover all prefixes and several numpy dtypes - mmp = MagneticMultipoleParameters( - tilt1=np.float64(0.1), - Bn1=np.float64(1.5), - Bn2=np.float32(2.5), - Bs1=np.int64(3), - Kn0=np.int32(-1), - Ks1=np.float64(0.25), - ) - assert type(mmp.tilt1) is float and mmp.tilt1 == 0.1 - assert type(mmp.Bn1) is float and mmp.Bn1 == 1.5 - assert type(mmp.Bn2) is float and mmp.Bn2 == 2.5 - assert type(mmp.Bs1) is int and mmp.Bs1 == 3 - assert type(mmp.Kn0) is int and mmp.Kn0 == -1 - assert type(mmp.Ks1) is float and mmp.Ks1 == 0.25 - - # 0-d numpy array also works - mmp_arr = MagneticMultipoleParameters(Bn1=np.array(4.2)) - assert type(mmp_arr.Bn1) is float and mmp_arr.Bn1 == 4.2 - - # Length-integrated variants - mmp_L = MagneticMultipoleParameters(Bn1L=np.float64(7.0), Ks1L=np.float64(8.0)) - assert type(mmp_L.Bn1L) is float and mmp_L.Bn1L == 7.0 - assert type(mmp_L.Ks1L) is float and mmp_L.Ks1L == 8.0 - - # ElectricMultipoleParameters: cover all prefixes - emp = ElectricMultipoleParameters( - tilt1=np.float64(0.2), - En1=np.float64(0.5), - Es1=np.int64(2), - ) - assert type(emp.tilt1) is float and emp.tilt1 == 0.2 - assert type(emp.En1) is float and emp.En1 == 0.5 - assert type(emp.Es1) is int and emp.Es1 == 2 - - emp_L = ElectricMultipoleParameters(En1L=np.float64(1.0), Es1L=np.float64(0.5)) - assert type(emp_L.En1L) is float and emp_L.En1L == 1.0 - assert type(emp_L.Es1L) is float and emp_L.Es1L == 0.5 - - # Plain Python values must still pass through unchanged - mmp_plain = MagneticMultipoleParameters(Bn1=1.5) - assert type(mmp_plain.Bn1) is float and mmp_plain.Bn1 == 1.5 From 7e75d61818ab81fd7d4571e16695ec961b80c618 Mon Sep 17 00:00:00 2001 From: Axel Huebl Date: Tue, 12 May 2026 15:28:22 -0700 Subject: [PATCH 5/5] Simplify Implementation Again --- .../parameters/ElectricMultipoleParameters.py | 58 +++++++++++--- .../parameters/MagneticMultipoleParameters.py | 62 ++++++++++++--- src/pals/parameters/_multipole_base.py | 76 ------------------- 3 files changed, 100 insertions(+), 96 deletions(-) delete mode 100644 src/pals/parameters/_multipole_base.py diff --git a/src/pals/parameters/ElectricMultipoleParameters.py b/src/pals/parameters/ElectricMultipoleParameters.py index c0f8d7d..af16624 100644 --- a/src/pals/parameters/ElectricMultipoleParameters.py +++ b/src/pals/parameters/ElectricMultipoleParameters.py @@ -1,9 +1,27 @@ -from typing import ClassVar +from pydantic import BaseModel, ConfigDict, model_validator +from typing import Any -from pals.parameters._multipole_base import _MultipoleBase +# Valid parameter prefixes, their expected format and description +_PARAMETER_PREFIXES = { + "tilt": ("tiltN", "Tilt"), + "En": ("EnN", "Normal component"), + "Es": ("EsN", "Skew component"), +} -class ElectricMultipoleParameters(_MultipoleBase): +def _validate_order( + key_num: str, parameter_name: str, prefix: str, expected_format: str +) -> None: + """Validate that the order number is a non-negative integer without leading zeros.""" + error_msg = ( + f"Invalid {parameter_name}: '{prefix}{key_num}'. " + f"Parameter must be of the form '{expected_format}', where 'N' is a non-negative integer without leading zeros." + ) + if not key_num.isdigit() or (key_num.startswith("0") and key_num != "0"): + raise ValueError(error_msg) + + +class ElectricMultipoleParameters(BaseModel): """Electric multipole parameters Valid parameter formats: @@ -15,9 +33,31 @@ class ElectricMultipoleParameters(_MultipoleBase): Where N is a positive integer without leading zeros (except "0" itself). """ - _KIND_NAME: ClassVar[str] = "electric" - _PARAMETER_PREFIXES: ClassVar[dict[str, tuple[str, str]]] = { - "tilt": ("tiltN", "Tilt"), - "En": ("EnN", "Normal component"), - "Es": ("EsN", "Skew component"), - } + model_config = ConfigDict(extra="allow") + + @model_validator(mode="before") + @classmethod + def validate(cls, values: dict[str, Any]) -> dict[str, Any]: + """Validate all parameter names match the expected multipole format.""" + for key in values: + # Check if key ends with 'L' for length-integrated values + is_length_integrated = key.endswith("L") + base_key = key[:-1] if is_length_integrated else key + + # No length-integrated values allowed for tilt parameter + if is_length_integrated and base_key.startswith("tilt"): + raise ValueError(f"Invalid electric multipole parameter: '{key}'. ") + + # Find matching prefix + for prefix, (expected_format, description) in _PARAMETER_PREFIXES.items(): + if base_key.startswith(prefix): + key_num = base_key[len(prefix) :] + _validate_order(key_num, description, prefix, expected_format) + break + else: + raise ValueError( + f"Invalid electric multipole parameter: '{key}'. " + f"Parameters must be of the form 'tiltN', 'EnN', or 'EsN' " + f"(with optional 'L' suffix for length-integrated), where 'N' is a non-negative integer." + ) + return values diff --git a/src/pals/parameters/MagneticMultipoleParameters.py b/src/pals/parameters/MagneticMultipoleParameters.py index 47d38ad..7c4b176 100644 --- a/src/pals/parameters/MagneticMultipoleParameters.py +++ b/src/pals/parameters/MagneticMultipoleParameters.py @@ -1,9 +1,29 @@ -from typing import ClassVar +from pydantic import BaseModel, ConfigDict, model_validator +from typing import Any -from pals.parameters._multipole_base import _MultipoleBase +# Valid parameter prefixes, their expected format and description +_PARAMETER_PREFIXES = { + "tilt": ("tiltN", "Tilt"), + "Bn": ("BnN", "Normal component"), + "Bs": ("BsN", "Skew component"), + "Kn": ("KnN", "Normalized normal component"), + "Ks": ("KsN", "Normalized skew component"), +} -class MagneticMultipoleParameters(_MultipoleBase): +def _validate_order( + key_num: str, parameter_name: str, prefix: str, expected_format: str +) -> None: + """Validate that the order number is a non-negative integer without leading zeros.""" + error_msg = ( + f"Invalid {parameter_name}: '{prefix}{key_num}'. " + f"Parameter must be of the form '{expected_format}', where 'N' is a non-negative integer without leading zeros." + ) + if not key_num.isdigit() or (key_num.startswith("0") and key_num != "0"): + raise ValueError(error_msg) + + +class MagneticMultipoleParameters(BaseModel): """Magnetic multipole parameters Valid parameter formats: @@ -17,11 +37,31 @@ class MagneticMultipoleParameters(_MultipoleBase): Where N is a positive integer without leading zeros (except "0" itself). """ - _KIND_NAME: ClassVar[str] = "magnetic" - _PARAMETER_PREFIXES: ClassVar[dict[str, tuple[str, str]]] = { - "tilt": ("tiltN", "Tilt"), - "Bn": ("BnN", "Normal component"), - "Bs": ("BsN", "Skew component"), - "Kn": ("KnN", "Normalized normal component"), - "Ks": ("KsN", "Normalized skew component"), - } + model_config = ConfigDict(extra="allow") + + @model_validator(mode="before") + @classmethod + def validate(cls, values: dict[str, Any]) -> dict[str, Any]: + """Validate all parameter names match the expected multipole format.""" + for key in values: + # Check if key ends with 'L' for length-integrated values + is_length_integrated = key.endswith("L") + base_key = key[:-1] if is_length_integrated else key + + # No length-integrated values allowed for tilt parameter + if is_length_integrated and base_key.startswith("tilt"): + raise ValueError(f"Invalid magnetic multipole parameter: '{key}'. ") + + # Find matching prefix + for prefix, (expected_format, description) in _PARAMETER_PREFIXES.items(): + if base_key.startswith(prefix): + key_num = base_key[len(prefix) :] + _validate_order(key_num, description, prefix, expected_format) + break + else: + raise ValueError( + f"Invalid magnetic multipole parameter: '{key}'. " + f"Parameters must be of the form 'tiltN', 'BnN', 'BsN', 'KnN', or 'KsN' " + f"(with optional 'L' suffix for length-integrated), where 'N' is a non-negative integer." + ) + return values diff --git a/src/pals/parameters/_multipole_base.py b/src/pals/parameters/_multipole_base.py deleted file mode 100644 index c37faed..0000000 --- a/src/pals/parameters/_multipole_base.py +++ /dev/null @@ -1,76 +0,0 @@ -"""Private shared base class for multipole parameter groups. - -Both :class:`MagneticMultipoleParameters` and :class:`ElectricMultipoleParameters` -allow arbitrary order-indexed extra fields (e.g. ``Bn1``, ``Es3``, ``Kn0L``) and -share the same name-validation logic. This module centralizes that logic. - -numpy interoperability (see pals-project/pals-python#67) is handled at the -serialization boundary in :mod:`pals.functions`, which keeps the fix general: -any numpy scalar reaching ``yaml.dump`` or ``json.dumps`` is converted to a -Python-native equivalent regardless of which model produced it. -""" - -from typing import Any, ClassVar - -from pydantic import BaseModel, ConfigDict, model_validator - - -def _validate_order( - key_num: str, parameter_name: str, prefix: str, expected_format: str -) -> None: - """Validate that the order number is a non-negative integer without leading zeros.""" - error_msg = ( - f"Invalid {parameter_name}: '{prefix}{key_num}'. " - f"Parameter must be of the form '{expected_format}', " - f"where 'N' is a non-negative integer without leading zeros." - ) - if not key_num.isdigit() or (key_num.startswith("0") and key_num != "0"): - raise ValueError(error_msg) - - -class _MultipoleBase(BaseModel): - """Private shared base for multipole parameter groups. - - Subclasses must set :attr:`_PARAMETER_PREFIXES` and :attr:`_KIND_NAME`. - Both are ``ClassVar`` and are not exposed as Pydantic fields. - """ - - # Subclasses override these: - _PARAMETER_PREFIXES: ClassVar[dict[str, tuple[str, str]]] = {} - _KIND_NAME: ClassVar[str] = "multipole" - - model_config = ConfigDict(extra="allow") - - @model_validator(mode="before") - @classmethod - def _validate(cls, values: dict[str, Any]) -> dict[str, Any]: - """Validate that all parameter names match the expected multipole format.""" - for key in values: - is_length_integrated = key.endswith("L") - base_key = key[:-1] if is_length_integrated else key - - if is_length_integrated and base_key.startswith("tilt"): - raise ValueError( - f"Invalid {cls._KIND_NAME} multipole parameter: '{key}'. " - ) - - for prefix, ( - expected_format, - description, - ) in cls._PARAMETER_PREFIXES.items(): - if base_key.startswith(prefix): - key_num = base_key[len(prefix) :] - _validate_order(key_num, description, prefix, expected_format) - break - else: - prefix_list = ", ".join( - f"'{p}N'" for p in cls._PARAMETER_PREFIXES if p != "tilt" - ) - raise ValueError( - f"Invalid {cls._KIND_NAME} multipole parameter: '{key}'. " - f"Parameters must be of the form 'tiltN', {prefix_list} " - f"(with optional 'L' suffix for length-integrated), " - f"where 'N' is a non-negative integer." - ) - - return values