Skip to content

Commit 5df6dfd

Browse files
committed
Add support for multiple dataclasses
1 parent 7594b69 commit 5df6dfd

File tree

4 files changed

+89
-33
lines changed

4 files changed

+89
-33
lines changed

test/test_dataclass.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,24 @@ class Test:
1818
param: Literal['a', 'b']
1919
items: list[str]
2020

21-
@with_dataclass(dataclass=Test)
21+
@with_dataclass(args=Test)
2222
def func(args: Test):
2323
return args.param
2424

2525
with sys_args(param="a", item=["a"]):
2626
self.assertEqual("a", func())
27+
28+
def test_multi_dataclass(self):
29+
@dataclass
30+
class A:
31+
param: Literal['a', 'b']
32+
@dataclass
33+
class B:
34+
number: int
35+
36+
@with_dataclass(A, B)
37+
def func(arg1: A, arg2: B):
38+
return len(arg1.param) + arg2.number
39+
40+
with sys_args(param="a", number=1):
41+
self.assertEqual(2, func())

with_argparse/configure_argparse.py

Lines changed: 55 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
Any, Set, List, get_origin, get_args, Union, Literal, Optional, Sequence, TypeVar, Iterable,
1313
Callable, MutableMapping, Mapping
1414
)
15+
16+
from with_argparse.types import DataclassInstance
1517
from with_argparse.utils import flatten, glob_to_paths
1618

1719
SET_TYPES = {set, Set}
@@ -43,6 +45,13 @@ class _Argument:
4345
choices: Optional[Sequence[Any]] = None
4446
action: Optional[str] = None
4547

48+
49+
@dataclass
50+
class DataclassConfig:
51+
func: Callable
52+
positional_dataclasses: tuple[type[DataclassInstance], ...]
53+
keyword_dataclasses: dict[str, type[DataclassInstance]]
54+
4655
class WithArgparse:
4756
ignore_rename_sequences: set[str]
4857
ignore_arg_keys: set[str]
@@ -54,11 +63,11 @@ class WithArgparse:
5463
allow_dispatch_custom: bool
5564

5665
func: Callable
57-
dataclass: Optional[type]
66+
dataclass: Optional[DataclassConfig]
5867

5968
def __init__(
6069
self,
61-
func_or_dataclass: Union[Callable, tuple[type, Callable]],
70+
func_or_dataclass: Union[Callable, DataclassConfig],
6271
aliases: Optional[Mapping[str, Sequence[str]]] = None,
6372
ignore_rename: Optional[set[str]] = None,
6473
ignore_keys: Optional[set[str]] = None,
@@ -75,13 +84,9 @@ def __init__(
7584
self.allow_custom = allow_custom or dict()
7685
self.allow_dispatch_custom = True
7786

78-
if isinstance(func_or_dataclass, tuple):
79-
if not inspect.isclass(func_or_dataclass[0]):
80-
raise ValueError("First argument must be a type")
81-
if not dataclasses.is_dataclass(func_or_dataclass[0]):
82-
raise ValueError("First argument must be a dataclass")
83-
self.dataclass = func_or_dataclass[0]
84-
self.func = func_or_dataclass[1]
87+
if isinstance(func_or_dataclass, DataclassConfig):
88+
self.dataclass = func_or_dataclass
89+
self.func = func_or_dataclass.func
8590
else:
8691
self.func = func_or_dataclass
8792
self.dataclass = None
@@ -108,29 +113,51 @@ def _call_dataclass(self, args: Sequence[Any], kwargs: Mapping[str, Any]):
108113
if self.dataclass is None:
109114
raise ValueError("self.dataclass cannot be None")
110115

111-
field_hints = typing.get_type_hints(self.dataclass)
112-
for field in dataclasses.fields(self.dataclass):
113-
field_required = field.default is MISSING
114-
field_default = field.default if not field_required else None
115-
field_type = field.type
116-
if isinstance(field_type, str):
117-
field_type = typing.cast(type, field_hints.get(field.name))
118-
119-
# known_types = {type, Literal, GenericAlias, UnionType}
120-
# if type(field_type) not in known_types and typing.get_origin(field_type) not in known_types:
121-
# raise ValueError(f"Cannot determine type of {field.name}, got {field_type} {type(field_type)}")
122-
# raises on typing.Optional[typing.Literal['epsilon', 'v_prediction']]
123-
self._setup_argument(
124-
field.name,
125-
field_type,
126-
field_default,
127-
field_required,
128-
)
116+
positional_dataclasses = self.dataclass.positional_dataclasses or tuple()
117+
keyword_dataclasses = self.dataclass.keyword_dataclasses or dict()
118+
dataclasses_to_process = [*positional_dataclasses, *keyword_dataclasses.values()]
119+
120+
for klass in dataclasses_to_process:
121+
field_hints = typing.get_type_hints(klass)
122+
for field in dataclasses.fields(klass):
123+
field_required = field.default is MISSING
124+
field_default = field.default if not field_required else None
125+
field_type = field.type
126+
if isinstance(field_type, str):
127+
field_type = typing.cast(type, field_hints.get(field.name))
128+
129+
# known_types = {type, Literal, GenericAlias, UnionType}
130+
# if type(field_type) not in known_types and typing.get_origin(field_type) not in known_types:
131+
# raise ValueError(f"Cannot determine type of {field.name}, got {field_type} {type(field_type)}")
132+
# raises on typing.Optional[typing.Literal['epsilon', 'v_prediction']]
133+
self._setup_argument(
134+
field.name,
135+
field_type,
136+
field_default,
137+
field_required,
138+
)
129139

130140
parsed_args = self.argparse.parse_args()
131141
args_dict = self._apply_name_mapping(parsed_args.__dict__, None)
132142
args_dict = self._apply_post_parse_conversions(args_dict, dict())
133-
return self.func(self.dataclass(**args_dict), **kwargs)
143+
144+
pos = tuple()
145+
keywords = dict()
146+
for i, klass in enumerate(positional_dataclasses):
147+
klass_args = dict()
148+
for field in dataclasses.fields(klass):
149+
klass_args[field.name] = args_dict[field.name]
150+
151+
pos = pos + (klass(**klass_args),)
152+
153+
for name, klass in keyword_dataclasses.items():
154+
klass_args = dict()
155+
for field in dataclasses.fields(klass):
156+
klass_args[field.name] = args_dict[field.name]
157+
158+
keywords[name] = klass(**klass_args)
159+
160+
return self.func(*pos, **keywords, **kwargs)
134161

135162
def call(self, args: Sequence[Any], kwargs: Mapping[str, Any]):
136163
if self.dataclass is not None:

with_argparse/impl.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from typing import Callable, Union, ParamSpec, TypeVar, overload, Optional, Mapping, _SpecialForm, Any
55
import warnings
66

7-
from with_argparse.configure_argparse import WithArgparse
7+
from with_argparse.configure_argparse import WithArgparse, DataclassConfig
8+
from with_argparse.types import DataclassInstance
89

910
try:
1011
from pyrootutils import setup_root as setup_root_fn
@@ -50,18 +51,25 @@ def set_config(key: str, state: bool):
5051

5152

5253
def with_dataclass(
53-
*,
54-
dataclass=None,
54+
*pos: type[DataclassInstance],
5555
allow_glob: Optional[set[str]] = None,
56+
**kwargs: type[DataclassInstance],
5657
):
58+
if not pos:
59+
pos = tuple()
60+
5761
def wrapper(fn):
5862
@functools.wraps(fn)
5963
def inner(*inner_args, **inner_kwargs):
6064
if not is_enabled():
6165
return fn(*inner_args, **inner_kwargs)
6266

6367
parser = WithArgparse(
64-
(dataclass, fn),
68+
DataclassConfig(
69+
fn,
70+
pos,
71+
kwargs,
72+
),
6573
allow_glob=allow_glob,
6674
)
6775
return parser.call(inner_args, inner_kwargs)

with_argparse/types.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from dataclasses import Field
2+
from typing import Protocol, ClassVar, Any
3+
4+
5+
class DataclassInstance(Protocol):
6+
__dataclass_fields__: ClassVar[dict[str, Field[Any]]]

0 commit comments

Comments
 (0)