Skip to content

Commit b446f8d

Browse files
authored
Merge pull request #87 from labthings/pydantic-2-10
Work with pydantic 2.10
2 parents 7ac72fa + 307845b commit b446f8d

File tree

4 files changed

+29
-10
lines changed

4 files changed

+29
-10
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "labthings-fastapi"
3-
version = "0.0.6"
3+
version = "0.0.7"
44
authors = [
55
{ name="Richard Bowman", email="richard.bowman@cantab.net" },
66
]

src/labthings_fastapi/thing_description/__init__.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515

1616
from __future__ import annotations
1717
from collections.abc import Mapping, Sequence
18-
from typing import Any, Optional, Union
18+
from typing import Any, Optional
1919
import json
2020

21-
from pydantic import TypeAdapter, ValidationError, BaseModel
21+
from pydantic import TypeAdapter, ValidationError
2222
from .model import DataSchema
2323

2424

@@ -192,7 +192,7 @@ def jsonschema_to_dataschema(
192192
return output
193193

194194

195-
def type_to_dataschema(t: Union[type, BaseModel], **kwargs) -> DataSchema:
195+
def type_to_dataschema(t: type, **kwargs) -> DataSchema:
196196
"""Convert a Python type to a Thing Description DataSchema
197197
198198
This makes use of pydantic's `schema_of` function to create a
@@ -205,9 +205,14 @@ def type_to_dataschema(t: Union[type, BaseModel], **kwargs) -> DataSchema:
205205
is passed in. Typically you'll want to use this for the
206206
`title` field.
207207
"""
208-
if isinstance(t, BaseModel):
208+
if hasattr(t, "model_json_schema"):
209+
# The input should be a `BaseModel` subclass, in which case this works:
209210
json_schema = t.model_json_schema()
210211
else:
212+
# In principle, the below should work for any type, though some
213+
# deferred annotations can go wrong.
214+
# Some attempt at looking up the environment of functions might help
215+
# here.
211216
json_schema = TypeAdapter(t).json_schema()
212217
schema_dict = jsonschema_to_dataschema(json_schema)
213218
# Definitions of referenced ($ref) schemas are put in a

src/labthings_fastapi/types/numpy.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,15 @@ def double(arr: NDArray) -> NDArray:
3030
WrapSerializer,
3131
)
3232
from typing import Annotated, Any, List, Union
33+
from typing_extensions import TypeAlias
3334
from collections.abc import Mapping, Sequence
3435

3536

3637
# Define a nested list of floats with 0-6 dimensions
3738
# This would be most elegantly defined as a recursive type
3839
# but the below gets the job done for now.
39-
Number = Union[int, float]
40-
NestedListOfNumbers = Union[
40+
Number: TypeAlias = Union[int, float]
41+
NestedListOfNumbers: TypeAlias = Union[
4142
Number,
4243
List[Number],
4344
List[List[Number]],
@@ -68,10 +69,12 @@ def listoflists_to_np(lol: Union[NestedListOfNumbers, np.ndarray]) -> np.ndarray
6869

6970

7071
# Define an annotated type so Pydantic can cope with numpy
71-
NDArray = Annotated[
72+
NDArray: TypeAlias = Annotated[
7273
np.ndarray,
7374
PlainValidator(listoflists_to_np),
74-
PlainSerializer(np_to_listoflists, when_used="json-unless-none"),
75+
PlainSerializer(
76+
np_to_listoflists, when_used="json-unless-none", return_type=NestedListOfNumbers
77+
),
7578
WithJsonSchema(NestedListOfNumbersModel.model_json_schema(), mode="validation"),
7679
]
7780

tests/test_numpy_type.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
from __future__ import annotations
22

3-
from pydantic import BaseModel
3+
from pydantic import BaseModel, RootModel
44
import numpy as np
55

66
from labthings_fastapi.types.numpy import NDArray, DenumpifyingDict
77
from labthings_fastapi.thing import Thing
88
from labthings_fastapi.decorators import thing_action
99

1010

11+
class ArrayModel(RootModel):
12+
root: NDArray
13+
14+
1115
def check_field_works_with_list(data):
1216
class Model(BaseModel):
1317
a: NDArray
@@ -86,3 +90,10 @@ def test_denumpifying_dict():
8690
assert dump["e"] is None
8791
assert dump["f"] == 1
8892
d.model_dump_json()
93+
94+
95+
def test_rootmodel():
96+
for input in [[0, 1, 2], np.arange(3)]:
97+
m = ArrayModel(root=input)
98+
assert isinstance(m.root, np.ndarray)
99+
assert (m.model_dump() == [0, 1, 2]).all()

0 commit comments

Comments
 (0)