Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 35eab64

Browse files
dbogunowiczbogunowicz@arrival.comBenjamin
authored andcommitted
Ability to load recipe stages from YAML into dictionary of name -> List[modifiers] (#556)
* first proposal of the solution * add docstring * stages support for Manager lifecycle * load staged recipe yaml string into list of dicts * modifiers properly read sample staged recipe * modifiers properly read sample staged recipe * edit after Ben's comments * tiny docstring fix * Removed stage_recipy.md from repository * Add fixes according to Ben's commments * Nit: remove print functions * Give priority to local variables over global when conflicting * Allow information from main_container flow to the evaluation of stage_containers Co-authored-by: bogunowicz@arrival.com <bogunowicz@arrival.com> Co-authored-by: Benjamin <ben@neuralmagic.com>
1 parent a87f0b7 commit 35eab64

File tree

2 files changed

+316
-22
lines changed

2 files changed

+316
-22
lines changed

src/sparseml/optim/helpers.py

Lines changed: 101 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
"""
1616
Helper functions for base Modifier and Manger utilities
1717
"""
18-
1918
import json
2019
import re
2120
from contextlib import suppress
@@ -35,6 +34,7 @@
3534
"update_recipe_variables",
3635
"evaluate_recipe_yaml_str_equations",
3736
"parse_recipe_variables",
37+
"check_if_staged_recipe",
3838
]
3939

4040

@@ -230,18 +230,100 @@ def evaluate_recipe_yaml_str_equations(recipe_yaml_str: str) -> str:
230230
# yaml string does not create a dict, return original string
231231
return recipe_yaml_str
232232

233-
# validate and load remaining variables
234-
container, variables, non_val_variables = _evaluate_recipe_variables(container)
233+
# check whether the recipe is a stage recipe of not
234+
if check_if_staged_recipe(container):
235+
container = _evaluate_staged_recipe_yaml_str_equations(container)
236+
237+
else:
238+
container, variables, non_val_variables = _evaluate_container_variables(
239+
container
240+
)
235241

236-
# update values nested in modifier lists based on the variables
237-
for key, val in container.items():
238-
if "modifiers" not in key:
239-
continue
240-
container[key] = _maybe_evaluate_yaml_object(val, variables, non_val_variables)
242+
# update values nested in modifier lists based on the variables
243+
for key, val in container.items():
244+
if "modifiers" not in key:
245+
continue
246+
container[key] = _maybe_evaluate_yaml_object(
247+
val, variables, non_val_variables
248+
)
241249

242250
return rewrite_recipe_yaml_string_with_classes(container)
243251

244252

253+
def check_if_staged_recipe(container: dict) -> bool:
254+
"""
255+
Check whether container pertains to a staged recipe.
256+
Such a "staged container" fulfills two conditions:
257+
- no top level key in container contains "modifiers" in its name
258+
- a stage should map to a dict that has at least one key with
259+
"modifiers" in its name
260+
:param container: a container generated from a YAML string of SparseML recipe
261+
:return: True if stage recipe, False if normal recipe
262+
"""
263+
for k, v in container.items():
264+
if isinstance(v, dict):
265+
if any([key for key in v.keys() if "modifiers" in key]):
266+
return True
267+
return False
268+
269+
270+
def _evaluate_staged_recipe_yaml_str_equations(container: dict) -> dict:
271+
"""
272+
Consumes a staged container and transforms it into a valid
273+
container for the manager and modifiers to consume further.
274+
275+
:param container: a staged container generated from a staged recipe.
276+
:return: transformed container containing evaluated
277+
variables, operations and objects.
278+
"""
279+
main_container = {}
280+
for k, v in container.items():
281+
if isinstance(v, dict):
282+
if any([key for key in v.keys() if "modifiers" in key]):
283+
continue
284+
main_container.update({k: v})
285+
286+
stages = {k: container[k] for k in set(container) - set(main_container)}
287+
288+
(
289+
main_container,
290+
global_variables,
291+
global_non_val_variables,
292+
) = _evaluate_container_variables(main_container)
293+
294+
for stage_name, staged_container in stages.items():
295+
stage_container, variables, non_val_variables = _evaluate_container_variables(
296+
staged_container, main_container
297+
)
298+
299+
"""
300+
if same variable is both in global_variables and variables, the
301+
global_variable will get overwritten.
302+
"""
303+
_global_variables = {
304+
k: v for k, v in global_variables.items() if k not in variables.keys()
305+
}
306+
variables = {**variables, **_global_variables}
307+
308+
_global_non_val_variables = {
309+
k: v
310+
for k, v in global_non_val_variables.items()
311+
if k not in non_val_variables.keys()
312+
}
313+
non_val_variables = {**non_val_variables, **_global_non_val_variables}
314+
315+
for key, val in staged_container.items():
316+
if "modifiers" not in key:
317+
continue
318+
stage_container[key] = _maybe_evaluate_yaml_object(
319+
val, variables, non_val_variables
320+
)
321+
322+
container[stage_name] = staged_container
323+
324+
return container
325+
326+
245327
def is_eval_string(val: str) -> bool:
246328
return val.startswith("eval(") and val.endswith(")")
247329

@@ -256,6 +338,7 @@ def _maybe_evaluate_recipe_equation(
256338
val: str,
257339
variables: Dict[str, Union[int, float]],
258340
non_eval_variables: Dict[str, Any],
341+
global_container: Optional[Dict[str, Any]] = {},
259342
) -> Union[str, float, int]:
260343
if is_eval_string(val):
261344
is_eval_str = True
@@ -266,6 +349,9 @@ def _maybe_evaluate_recipe_equation(
266349
if val in non_eval_variables:
267350
return non_eval_variables[val]
268351

352+
if val in global_container:
353+
return global_container[val]
354+
269355
evaluated_val = restricted_eval(val, variables)
270356

271357
if is_eval_str and not isinstance(evaluated_val, (int, float)):
@@ -276,8 +362,8 @@ def _maybe_evaluate_recipe_equation(
276362
return evaluated_val
277363

278364

279-
def _evaluate_recipe_variables(
280-
recipe_dict: Dict[str, Any],
365+
def _evaluate_container_variables(
366+
recipe_container: Dict[str, Any], global_container: Optional[Dict[str, Any]] = {}
281367
) -> Tuple[Dict[str, Any], Dict[str, Union[int, float]]]:
282368
valid_variables = {}
283369
non_evaluatable_variables = {}
@@ -286,7 +372,7 @@ def _evaluate_recipe_variables(
286372
while prev_num_variables != len(valid_variables):
287373
prev_num_variables = len(valid_variables)
288374

289-
for name, val in recipe_dict.items():
375+
for name, val in recipe_container.items():
290376
if name in valid_variables:
291377
continue
292378

@@ -301,26 +387,26 @@ def _evaluate_recipe_variables(
301387

302388
try:
303389
val = _maybe_evaluate_recipe_equation(
304-
val, valid_variables, non_evaluatable_variables
390+
val, valid_variables, non_evaluatable_variables, global_container
305391
)
306392
except UnknownVariableException:
307393
# dependant variables maybe not evaluated yet
308394
continue
309395

310396
if isinstance(val, (int, float)):
311397
# update variable value and add to valid vars
312-
recipe_dict[name] = val
398+
recipe_container[name] = val
313399
valid_variables[name] = val
314400

315401
# check that all eval statements have been evaluated
316-
for name, val in recipe_dict.items():
402+
for name, val in recipe_container.items():
317403
if isinstance(val, str) and is_eval_string(val):
318404
raise RuntimeError(
319405
f"Unable to evaluate expression: {val}. Check if any dependent "
320406
"variables form a cycle or are not defined"
321407
)
322408

323-
return recipe_dict, valid_variables, non_evaluatable_variables
409+
return recipe_container, valid_variables, non_evaluatable_variables
324410

325411

326412
def _maybe_evaluate_yaml_object(

0 commit comments

Comments
 (0)