Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 004ab81

Browse files
authored
Support for simple arithmetic in recipes (#361)
* Support for metadata and simple arithmatic in recipes * allow variables to be defined anywhere on top layer of recipe, hard check for equations wrapped by * cleanup * restricted eval function, clean up recipe eval, tests * remove formal metadata
1 parent 07379a6 commit 004ab81

File tree

7 files changed

+514
-0
lines changed

7 files changed

+514
-0
lines changed

src/sparseml/optim/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
# flake8: noqa
2121

2222
from .analyzer import *
23+
from .helpers import *
2324
from .learning_rate import *
2425
from .manager import *
2526
from .modifier import *

src/sparseml/optim/helpers.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
Helper functions for base Modifier and Manger utilities
17+
"""
18+
19+
20+
import re
21+
from typing import Any, Dict, Tuple, Union
22+
23+
import yaml
24+
25+
from sparseml.utils import UnknownVariableException, restricted_eval
26+
27+
28+
__all__ = [
29+
"load_recipe_yaml_str_no_classes",
30+
"rewrite_recipe_yaml_string_with_classes",
31+
"evaluate_recipe_yaml_str_equations",
32+
]
33+
34+
35+
def load_recipe_yaml_str_no_classes(recipe_yaml_str: str) -> str:
36+
"""
37+
:param recipe_yaml_str: YAML string of a SparseML recipe
38+
:return: recipe loaded into YAML with all objects replaced
39+
as a dictionary of their parameters
40+
"""
41+
pattern = re.compile(r"!(?P<class_name>(?!.*\.)[a-zA-Z_][a-zA-Z^._0-9]+)")
42+
classless_yaml_str = pattern.sub(r"OBJECT.\g<class_name>:", recipe_yaml_str)
43+
return yaml.safe_load(classless_yaml_str)
44+
45+
46+
def rewrite_recipe_yaml_string_with_classes(recipe_contianer: Any) -> str:
47+
"""
48+
:param recipe_contianer: recipe loaded as yaml with load_recipe_yaml_str_no_classes
49+
:return: recipe serialized into YAML with original class values re-added
50+
"""
51+
updated_yaml_str = yaml.dump(recipe_contianer)
52+
53+
# convert object dicts back to object declarations and return
54+
pattern = re.compile(r"OBJECT\.(?P<class_name>(?!.*\.)[a-zA-Z_][a-zA-Z^._0-9]+):")
55+
return pattern.sub(r"!\g<class_name>", updated_yaml_str)
56+
57+
58+
def evaluate_recipe_yaml_str_equations(recipe_yaml_str: str) -> str:
59+
"""
60+
:param recipe_yaml_str: YAML string of a SparseML recipe
61+
:return: the YAML string with any expressions based on valid
62+
metadata and recipe variables and operations
63+
"""
64+
container = load_recipe_yaml_str_no_classes(recipe_yaml_str)
65+
if not isinstance(container, dict):
66+
# yaml string does not create a dict, return original string
67+
return recipe_yaml_str
68+
69+
# validate and load remaining variables
70+
container, variables = _evaluate_recipe_variables(container)
71+
72+
# update values nested in modifier lists based on the variables
73+
for key, val in container.items():
74+
if "modifiers" not in key:
75+
continue
76+
container[key] = _maybe_evaluate_yaml_object(val, variables)
77+
78+
return rewrite_recipe_yaml_string_with_classes(container)
79+
80+
81+
def is_eval_string(val: str) -> bool:
82+
return val.startswith("eval(") and val.endswith(")")
83+
84+
85+
def _maybe_evaluate_recipe_equation(
86+
val: str,
87+
variables: Dict[str, Union[int, float]],
88+
) -> Union[str, float, int]:
89+
if is_eval_string(val):
90+
is_eval_str = True
91+
val = val[5:-1]
92+
else:
93+
return val
94+
95+
evaluated_val = restricted_eval(val, variables)
96+
97+
if is_eval_str and not isinstance(evaluated_val, (int, float)):
98+
raise RuntimeError(
99+
"eval expressions in recipes must evaluate to a float or int"
100+
)
101+
102+
return evaluated_val
103+
104+
105+
def _evaluate_recipe_variables(
106+
recipe_dict: Dict[str, Any],
107+
) -> Tuple[Dict[str, Any], Dict[str, Union[int, float]]]:
108+
valid_variables = {}
109+
prev_num_variables = -1
110+
111+
while prev_num_variables != len(valid_variables):
112+
prev_num_variables = len(valid_variables)
113+
114+
for name, val in recipe_dict.items():
115+
if name in valid_variables:
116+
continue
117+
118+
if isinstance(val, (int, float)):
119+
valid_variables[name] = val
120+
121+
if not isinstance(val, str):
122+
# only parse string values
123+
continue
124+
125+
try:
126+
val = _maybe_evaluate_recipe_equation(val, valid_variables)
127+
except UnknownVariableException:
128+
# dependant variables maybe not evaluated yet
129+
continue
130+
131+
if isinstance(val, (int, float)):
132+
# update variable value and add to valid vars
133+
recipe_dict[name] = val
134+
valid_variables[name] = val
135+
136+
# check that all eval statements have been evaluated
137+
for name, val in recipe_dict.items():
138+
if isinstance(val, str) and is_eval_string(val):
139+
raise RuntimeError(
140+
f"Unable to evaluate expression: {val}. Check if any dependent "
141+
"variables form a cycle or are not defined"
142+
)
143+
144+
return recipe_dict, valid_variables
145+
146+
147+
def _maybe_evaluate_yaml_object(
148+
obj: Any, variables: Dict[str, Union[int, float]]
149+
) -> Any:
150+
151+
if isinstance(obj, str):
152+
return _maybe_evaluate_recipe_equation(obj, variables)
153+
elif isinstance(obj, list):
154+
return [_maybe_evaluate_yaml_object(val, variables) for val in obj]
155+
elif isinstance(obj, dict):
156+
return {
157+
key: _maybe_evaluate_yaml_object(val, variables) for key, val in obj.items()
158+
}
159+
else:
160+
return obj
161+
162+
163+
def _maybe_parse_number(val: str) -> Union[str, float, int]:
164+
try:
165+
return int(val)
166+
except Exception:
167+
try:
168+
return float(val)
169+
except Exception:
170+
return val

src/sparseml/optim/modifier.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import yaml
2626

27+
from sparseml.optim.helpers import evaluate_recipe_yaml_str_equations
2728
from sparseml.utils import validate_str_iterable
2829

2930

@@ -288,6 +289,7 @@ def load_framework_list(yaml_str: str, framework: str):
288289
:param framework: the framework to load the modifiers for
289290
:return: the loaded modifiers list
290291
"""
292+
yaml_str = evaluate_recipe_yaml_str_equations(yaml_str)
291293
yaml_str = BaseModifier._convert_to_framework_modifiers(yaml_str, framework)
292294
container = yaml.safe_load(yaml_str)
293295

src/sparseml/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from .frameworks import *
2222
from .helpers import *
23+
from .restricted_eval import *
2324
from .singleton import *
2425
from .worker import *
2526
from .wrapper import *
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
Restricted eval function for safely evaluating equations in recipes
17+
"""
18+
19+
20+
import ast
21+
import operator
22+
from typing import Any, Dict, Optional
23+
24+
25+
__all__ = [
26+
"restricted_eval",
27+
"UnknownVariableException",
28+
]
29+
30+
31+
class UnknownVariableException(Exception):
32+
"""
33+
Exception raised for known variable names in restricted eval
34+
35+
:param var_name: name of unknown variable
36+
"""
37+
38+
def __init__(self, var_name: str):
39+
self.var_name = var_name
40+
super().__init__(f"Unknown variable name in eval: {var_name}")
41+
42+
43+
def restricted_eval(
44+
expression: str,
45+
variables: Optional[Dict[str, float]] = None,
46+
) -> float:
47+
"""
48+
:param expression: expression to evaluate
49+
:param variables: dictionary of string variables to float values that may be
50+
included in the expression
51+
:return: evaluated expression. Only supported operations, numbers, and float
52+
variables named in the variables dict may be included
53+
:raises: RuntimeError if any unsupported operations are included,
54+
UnknownVariableException if any variables not included in the variables dict
55+
are given
56+
"""
57+
variables = variables or {}
58+
return _restricted_eval_node(ast.parse(expression.strip()).body[0], variables)
59+
60+
61+
_VALID_BINOPS_TO_EVAL = {
62+
ast.Add: operator.add,
63+
ast.Sub: operator.sub,
64+
ast.Mult: operator.mul,
65+
ast.Div: operator.truediv,
66+
ast.FloorDiv: operator.floordiv,
67+
ast.Pow: operator.pow,
68+
ast.Mod: operator.mod,
69+
}
70+
71+
_VALID_UOPS_TO_EVAL = {ast.USub: operator.neg}
72+
73+
_VALID_FUNCTIONS_TO_EVAL = {
74+
"abs": abs,
75+
"float": float,
76+
"int": int,
77+
"min": min,
78+
"max": max,
79+
"round": round,
80+
}
81+
82+
83+
def _restricted_eval_node(node: Any, variables: Dict[str, float]) -> float:
84+
if isinstance(node, ast.Expr):
85+
return _restricted_eval_node(node.value, variables)
86+
if isinstance(node, ast.Num):
87+
return node.n
88+
if isinstance(node, ast.Name):
89+
if node.id in variables:
90+
return variables[node.id]
91+
else:
92+
raise UnknownVariableException(node.id)
93+
if isinstance(node, ast.BinOp):
94+
op_type = type(node.op)
95+
if op_type in _VALID_BINOPS_TO_EVAL:
96+
return _VALID_BINOPS_TO_EVAL[op_type](
97+
_restricted_eval_node(node.left, variables),
98+
_restricted_eval_node(node.right, variables),
99+
)
100+
else:
101+
raise RuntimeError(f"Unsupported binary operator type {op_type}")
102+
if isinstance(node, ast.UnaryOp):
103+
op_type = type(node.op)
104+
if op_type in _VALID_UOPS_TO_EVAL:
105+
return _VALID_UOPS_TO_EVAL[op_type](
106+
_restricted_eval_node(node.left, variables),
107+
)
108+
else:
109+
raise RuntimeError(f"Unsupported binary operator type {op_type}")
110+
if isinstance(node, ast.Call):
111+
func_name = node.func.id
112+
if func_name in _VALID_FUNCTIONS_TO_EVAL:
113+
args = [_restricted_eval_node(arg, variables) for arg in node.args]
114+
kwargs = {
115+
kwarg.arg: _restricted_eval_node(kwarg.value, variables)
116+
for kwarg in node.keywords
117+
}
118+
return _VALID_FUNCTIONS_TO_EVAL[func_name](*args, **kwargs)
119+
else:
120+
raise RuntimeError(f"Unsupported function name {func_name}")
121+
122+
raise RuntimeError(f"Unsupported AST node type {type(node)}")

0 commit comments

Comments
 (0)