1212 Any , Set , List , get_origin , get_args , Union , Literal , Optional , Sequence , TypeVar , Iterable ,
1313 Callable , MutableMapping , Mapping
1414)
15+
16+ from with_argparse .types import DataclassInstance
1517from with_argparse .utils import flatten , glob_to_paths
1618
1719SET_TYPES = {set , Set }
@@ -43,6 +45,13 @@ class _Argument:
4345 choices : Optional [Sequence [Any ]] = None
4446 action : Optional [str ] = None
4547
48+
49+ @dataclass
50+ class DataclassConfig :
51+ func : Callable
52+ positional_dataclasses : tuple [type [DataclassInstance ], ...]
53+ keyword_dataclasses : dict [str , type [DataclassInstance ]]
54+
4655class WithArgparse :
4756 ignore_rename_sequences : set [str ]
4857 ignore_arg_keys : set [str ]
@@ -54,11 +63,11 @@ class WithArgparse:
5463 allow_dispatch_custom : bool
5564
5665 func : Callable
57- dataclass : Optional [type ]
66+ dataclass : Optional [DataclassConfig ]
5867
5968 def __init__ (
6069 self ,
61- func_or_dataclass : Union [Callable , tuple [ type , Callable ] ],
70+ func_or_dataclass : Union [Callable , DataclassConfig ],
6271 aliases : Optional [Mapping [str , Sequence [str ]]] = None ,
6372 ignore_rename : Optional [set [str ]] = None ,
6473 ignore_keys : Optional [set [str ]] = None ,
@@ -75,13 +84,9 @@ def __init__(
7584 self .allow_custom = allow_custom or dict ()
7685 self .allow_dispatch_custom = True
7786
78- if isinstance (func_or_dataclass , tuple ):
79- if not inspect .isclass (func_or_dataclass [0 ]):
80- raise ValueError ("First argument must be a type" )
81- if not dataclasses .is_dataclass (func_or_dataclass [0 ]):
82- raise ValueError ("First argument must be a dataclass" )
83- self .dataclass = func_or_dataclass [0 ]
84- self .func = func_or_dataclass [1 ]
87+ if isinstance (func_or_dataclass , DataclassConfig ):
88+ self .dataclass = func_or_dataclass
89+ self .func = func_or_dataclass .func
8590 else :
8691 self .func = func_or_dataclass
8792 self .dataclass = None
@@ -108,29 +113,51 @@ def _call_dataclass(self, args: Sequence[Any], kwargs: Mapping[str, Any]):
108113 if self .dataclass is None :
109114 raise ValueError ("self.dataclass cannot be None" )
110115
111- field_hints = typing .get_type_hints (self .dataclass )
112- for field in dataclasses .fields (self .dataclass ):
113- field_required = field .default is MISSING
114- field_default = field .default if not field_required else None
115- field_type = field .type
116- if isinstance (field_type , str ):
117- field_type = typing .cast (type , field_hints .get (field .name ))
118-
119- # known_types = {type, Literal, GenericAlias, UnionType}
120- # if type(field_type) not in known_types and typing.get_origin(field_type) not in known_types:
121- # raise ValueError(f"Cannot determine type of {field.name}, got {field_type} {type(field_type)}")
122- # raises on typing.Optional[typing.Literal['epsilon', 'v_prediction']]
123- self ._setup_argument (
124- field .name ,
125- field_type ,
126- field_default ,
127- field_required ,
128- )
116+ positional_dataclasses = self .dataclass .positional_dataclasses or tuple ()
117+ keyword_dataclasses = self .dataclass .keyword_dataclasses or dict ()
118+ dataclasses_to_process = [* positional_dataclasses , * keyword_dataclasses .values ()]
119+
120+ for klass in dataclasses_to_process :
121+ field_hints = typing .get_type_hints (klass )
122+ for field in dataclasses .fields (klass ):
123+ field_required = field .default is MISSING
124+ field_default = field .default if not field_required else None
125+ field_type = field .type
126+ if isinstance (field_type , str ):
127+ field_type = typing .cast (type , field_hints .get (field .name ))
128+
129+ # known_types = {type, Literal, GenericAlias, UnionType}
130+ # if type(field_type) not in known_types and typing.get_origin(field_type) not in known_types:
131+ # raise ValueError(f"Cannot determine type of {field.name}, got {field_type} {type(field_type)}")
132+ # raises on typing.Optional[typing.Literal['epsilon', 'v_prediction']]
133+ self ._setup_argument (
134+ field .name ,
135+ field_type ,
136+ field_default ,
137+ field_required ,
138+ )
129139
130140 parsed_args = self .argparse .parse_args ()
131141 args_dict = self ._apply_name_mapping (parsed_args .__dict__ , None )
132142 args_dict = self ._apply_post_parse_conversions (args_dict , dict ())
133- return self .func (self .dataclass (** args_dict ), ** kwargs )
143+
144+ pos = tuple ()
145+ keywords = dict ()
146+ for i , klass in enumerate (positional_dataclasses ):
147+ klass_args = dict ()
148+ for field in dataclasses .fields (klass ):
149+ klass_args [field .name ] = args_dict [field .name ]
150+
151+ pos = pos + (klass (** klass_args ),)
152+
153+ for name , klass in keyword_dataclasses .items ():
154+ klass_args = dict ()
155+ for field in dataclasses .fields (klass ):
156+ klass_args [field .name ] = args_dict [field .name ]
157+
158+ keywords [name ] = klass (** klass_args )
159+
160+ return self .func (* pos , ** keywords , ** kwargs )
134161
135162 def call (self , args : Sequence [Any ], kwargs : Mapping [str , Any ]):
136163 if self .dataclass is not None :
0 commit comments