Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit a87f0b7

Browse files
committed
re-adding deprecated IC examples (#574)
1 parent 28ec07f commit a87f0b7

File tree

7 files changed

+2863
-0
lines changed

7 files changed

+2863
-0
lines changed
Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import dataclasses
16+
import json
17+
import re
18+
import sys
19+
from argparse import ArgumentParser, ArgumentTypeError
20+
from copy import copy
21+
from enum import Enum
22+
from pathlib import Path
23+
from typing import Any, Iterable, List, NewType, Optional, Tuple, Union
24+
25+
26+
DataClass = NewType("DataClass", Any)
27+
DataClassType = NewType("DataClassType", Any)
28+
29+
30+
# From https://stackoverflow.com/questions/15008758
31+
# /parsing-boolean-values-with-argparse
32+
def string_to_bool(v):
33+
if isinstance(v, bool):
34+
return v
35+
if v.lower() in ("yes", "true", "t", "y", "1"):
36+
return True
37+
elif v.lower() in ("no", "false", "f", "n", "0"):
38+
return False
39+
else:
40+
raise ArgumentTypeError(
41+
f"Truthy value expected: got {v} but expected one of yes/no,"
42+
f"true/false, t/f, y/n, 1/0 (case insensitive)."
43+
)
44+
45+
46+
# Inspired from https://huggingface.co/transformers/_modules
47+
# /transformers/hf_argparser.html
48+
class NmArgumentParser(ArgumentParser):
49+
"""
50+
This subclass of `argparse.ArgumentParser` uses type hints on dataclasses
51+
to generate arguments.
52+
53+
The class is designed to play well with the native argparse. In particular,
54+
you can add more (non-dataclass backed) arguments to the parser after
55+
initialization and you'll get the output back after parsing as an additional
56+
namespace. Optional: To create sub argument groups use the
57+
`_argument_group_name` attribute in the dataclass.
58+
59+
Note: __post_init__(...) for specific dataclasses passed is executed only
60+
when parse_args_into_dataclasses(...) function is called because it needs
61+
actual instantiation of the dataclass.
62+
"""
63+
64+
dataclass_types: Iterable[DataClassType]
65+
66+
def __init__(
67+
self, dataclass_types: Union[DataClassType, Iterable[DataClassType]], **kwargs
68+
):
69+
"""
70+
:param dataclass_types: Dataclass type, or list of dataclass types
71+
for which we will "fill" instances with the parsed args.
72+
:param kwargs: (Optional) Passed to `argparse.ArgumentParser()` in
73+
the regular way.
74+
"""
75+
super().__init__(**kwargs)
76+
if dataclasses.is_dataclass(dataclass_types):
77+
dataclass_types = [dataclass_types]
78+
self.dataclass_types = dataclass_types
79+
for dataclass_ in self.dataclass_types:
80+
self._add_dataclass_arguments(dataclass_)
81+
82+
def _add_dataclass_arguments(self, dataclass_: DataClassType):
83+
if hasattr(dataclass_, "_argument_group_name"):
84+
parser = self.add_argument_group(dataclass_._argument_group_name)
85+
else:
86+
parser = self
87+
for field in dataclasses.fields(dataclass_):
88+
if not field.init:
89+
continue
90+
91+
name, kwargs = field.name, field.metadata.copy()
92+
93+
keep_underscores_key = "keep_underscores"
94+
keep_underscores = kwargs.get(keep_underscores_key)
95+
_field_name = name if keep_underscores else name.replace("_", "-")
96+
97+
# cleanup
98+
if keep_underscores_key in kwargs:
99+
del kwargs[keep_underscores_key]
100+
101+
# field.metadata is not used at all by Data Classes,
102+
# it is provided as a third-party extension mechanism.
103+
if isinstance(field.type, str):
104+
raise ImportError(
105+
"This implementation is not compatible with Postponed "
106+
"Evaluation of Annotations (PEP 563),"
107+
"which can be opted in from Python 3.7 with "
108+
"`from __future__ import annotations`."
109+
)
110+
typestring = str(field.type)
111+
for prim_type in (int, float, str):
112+
for collection in (List,):
113+
if (
114+
typestring == f"typing.Union["
115+
f"{collection[prim_type]}, NoneType]"
116+
or typestring == f"typing.Optional" f"[{collection[prim_type]}]"
117+
):
118+
field.type = collection[prim_type]
119+
if (
120+
typestring == f"typing.Union[" f"{prim_type.__name__}, NoneType]"
121+
or typestring == f"typing.Optional[" f"{prim_type.__name__}]"
122+
):
123+
field.type = prim_type
124+
125+
if isinstance(field.type, type) and issubclass(field.type, Enum):
126+
kwargs["choices"] = [x.value for x in field.type]
127+
kwargs["type"] = type(kwargs["choices"][0])
128+
if field.default is not dataclasses.MISSING:
129+
kwargs["default"] = field.default
130+
else:
131+
kwargs["required"] = True
132+
elif field.type is bool or field.type == Optional[bool]:
133+
if field.default is True:
134+
kwargs_copy = copy(kwargs)
135+
if "help" in kwargs_copy:
136+
kwargs_copy["help"] = f"Do not {kwargs_copy['help'].lower()}"
137+
parser.add_argument(
138+
f"--no-{_field_name}",
139+
action="store_false",
140+
dest=_field_name,
141+
**kwargs_copy,
142+
)
143+
144+
# Hack because type=bool in argparse does not behave as we want.
145+
kwargs["type"] = string_to_bool
146+
if field.type is bool or (
147+
field.default is not None
148+
and field.default is not dataclasses.MISSING
149+
):
150+
# Default value is False if we have no default
151+
# when of type bool.
152+
if field.default is dataclasses.MISSING:
153+
default = False
154+
else:
155+
default = field.default
156+
# This is the value that will get picked if
157+
# we don't include --field_name in any way
158+
kwargs["default"] = default
159+
160+
# This tells argparse we accept 0 or 1
161+
# value after --field_name
162+
kwargs["nargs"] = "?"
163+
# This is the value that will get picked
164+
# if we do --field_name (without value)
165+
kwargs["const"] = True
166+
elif (
167+
hasattr(field.type, "__origin__")
168+
and re.search(r"^typing\.List\[(.*)\]$", str(field.type)) is not None
169+
):
170+
kwargs["nargs"] = "+"
171+
kwargs["type"] = field.type.__args__[0]
172+
assert all(
173+
x == kwargs["type"] for x in field.type.__args__
174+
), f"{field.name} cannot be a List of mixed types"
175+
if field.default_factory is not dataclasses.MISSING:
176+
kwargs["default"] = field.default_factory()
177+
elif field.default is dataclasses.MISSING:
178+
kwargs["required"] = True
179+
else:
180+
kwargs["type"] = field.type
181+
if field.default is not dataclasses.MISSING:
182+
kwargs["default"] = field.default
183+
elif field.default_factory is not dataclasses.MISSING:
184+
kwargs["default"] = field.default_factory()
185+
else:
186+
kwargs["required"] = True
187+
parser.add_argument(f"--{_field_name}", **kwargs)
188+
189+
def parse_args_into_dataclasses(
190+
self,
191+
args=None,
192+
return_remaining_strings=False,
193+
look_for_args_file=True,
194+
args_filename=None,
195+
) -> Tuple[DataClass, ...]:
196+
"""
197+
Parse command-line args into instances of the specified dataclass types.
198+
199+
This relies on argparse's `ArgumentParser.parse_known_args`.
200+
See the doc at:
201+
docs.python.org/3.7/library/argparse.html#argparse.ArgumentParser.
202+
parse_args
203+
204+
205+
:param args: List of strings to parse. The default is taken from
206+
sys.argv. (same as argparse.ArgumentParser)
207+
:param return_remaining_strings: If true, also return a list of
208+
remaining argument strings.
209+
:param look_for_args_file: If true, will look for a ".args" file with
210+
the same base name as the entry point script for this process,
211+
and will append its potential content to the command line args.
212+
:param args_filename: If not None, will uses this file instead of the
213+
".args" file specified in the previous argument.
214+
:returns: Tuple consisting of:
215+
- the dataclass instances in the same order as they were
216+
passed to the initializer.abspath
217+
- if applicable, an additional namespace for more
218+
(non-dataclass backed) arguments added to the parser
219+
after initialization.
220+
- The potential list of remaining argument strings.
221+
(same as argparse.ArgumentParser.parse_known_args)
222+
"""
223+
if args_filename or (look_for_args_file and len(sys.argv)):
224+
if args_filename:
225+
args_file = Path(args_filename)
226+
else:
227+
args_file = Path(sys.argv[0]).with_suffix(".args")
228+
229+
if args_file.exists():
230+
fargs = args_file.read_text().split()
231+
args = fargs + args if args is not None else fargs + sys.argv[1:]
232+
# in case of duplicate arguments the first one has precedence
233+
# so we append rather than prepend.
234+
namespace, remaining_args = self.parse_known_args(args=args)
235+
outputs = []
236+
for dtype in self.dataclass_types:
237+
keys = {f.name for f in dataclasses.fields(dtype) if f.init}
238+
inputs = {k: v for k, v in vars(namespace).items() if k in keys}
239+
for k in keys:
240+
delattr(namespace, k)
241+
obj = dtype(**inputs)
242+
outputs.append(obj)
243+
if len(namespace.__dict__) > 0:
244+
# additional namespace.
245+
outputs.append(namespace)
246+
if return_remaining_strings:
247+
return (*outputs, remaining_args)
248+
else:
249+
if remaining_args:
250+
raise ValueError(
251+
f"Some specified arguments are not used by the "
252+
f"NmArgumentParser: {remaining_args}"
253+
)
254+
255+
return (*outputs,)
256+
257+
def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]:
258+
"""
259+
Alternative helper method that does not use `argparse` at all,
260+
instead loading a json file and populating the dataclass types.
261+
"""
262+
data = json.loads(Path(json_file).read_text())
263+
outputs = []
264+
for dtype in self.dataclass_types:
265+
keys = {f.name for f in dataclasses.fields(dtype) if f.init}
266+
inputs = {k: v for k, v in data.items() if k in keys}
267+
obj = dtype(**inputs)
268+
outputs.append(obj)
269+
return (*outputs,)
270+
271+
def parse_dict(self, args: dict) -> Tuple[DataClass, ...]:
272+
"""
273+
Alternative helper method that does not use `argparse` at all,
274+
instead uses a dict and populating the dataclass types.
275+
"""
276+
outputs = []
277+
for dtype in self.dataclass_types:
278+
keys = {f.name for f in dataclasses.fields(dtype) if f.init}
279+
inputs = {k: v for k, v in args.items() if k in keys}
280+
obj = dtype(**inputs)
281+
outputs.append(obj)
282+
return (*outputs,)

0 commit comments

Comments
 (0)