1515"""
1616Helper functions for base Modifier and Manger utilities
1717"""
18-
1918import json
2019import re
2120from contextlib import suppress
3534 "update_recipe_variables" ,
3635 "evaluate_recipe_yaml_str_equations" ,
3736 "parse_recipe_variables" ,
37+ "check_if_staged_recipe" ,
3838]
3939
4040
@@ -230,18 +230,100 @@ def evaluate_recipe_yaml_str_equations(recipe_yaml_str: str) -> str:
230230 # yaml string does not create a dict, return original string
231231 return recipe_yaml_str
232232
233- # validate and load remaining variables
234- container , variables , non_val_variables = _evaluate_recipe_variables (container )
233+ # check whether the recipe is a stage recipe of not
234+ if check_if_staged_recipe (container ):
235+ container = _evaluate_staged_recipe_yaml_str_equations (container )
236+
237+ else :
238+ container , variables , non_val_variables = _evaluate_container_variables (
239+ container
240+ )
235241
236- # update values nested in modifier lists based on the variables
237- for key , val in container .items ():
238- if "modifiers" not in key :
239- continue
240- container [key ] = _maybe_evaluate_yaml_object (val , variables , non_val_variables )
242+ # update values nested in modifier lists based on the variables
243+ for key , val in container .items ():
244+ if "modifiers" not in key :
245+ continue
246+ container [key ] = _maybe_evaluate_yaml_object (
247+ val , variables , non_val_variables
248+ )
241249
242250 return rewrite_recipe_yaml_string_with_classes (container )
243251
244252
253+ def check_if_staged_recipe (container : dict ) -> bool :
254+ """
255+ Check whether container pertains to a staged recipe.
256+ Such a "staged container" fulfills two conditions:
257+ - no top level key in container contains "modifiers" in its name
258+ - a stage should map to a dict that has at least one key with
259+ "modifiers" in its name
260+ :param container: a container generated from a YAML string of SparseML recipe
261+ :return: True if stage recipe, False if normal recipe
262+ """
263+ for k , v in container .items ():
264+ if isinstance (v , dict ):
265+ if any ([key for key in v .keys () if "modifiers" in key ]):
266+ return True
267+ return False
268+
269+
270+ def _evaluate_staged_recipe_yaml_str_equations (container : dict ) -> dict :
271+ """
272+ Consumes a staged container and transforms it into a valid
273+ container for the manager and modifiers to consume further.
274+
275+ :param container: a staged container generated from a staged recipe.
276+ :return: transformed container containing evaluated
277+ variables, operations and objects.
278+ """
279+ main_container = {}
280+ for k , v in container .items ():
281+ if isinstance (v , dict ):
282+ if any ([key for key in v .keys () if "modifiers" in key ]):
283+ continue
284+ main_container .update ({k : v })
285+
286+ stages = {k : container [k ] for k in set (container ) - set (main_container )}
287+
288+ (
289+ main_container ,
290+ global_variables ,
291+ global_non_val_variables ,
292+ ) = _evaluate_container_variables (main_container )
293+
294+ for stage_name , staged_container in stages .items ():
295+ stage_container , variables , non_val_variables = _evaluate_container_variables (
296+ staged_container , main_container
297+ )
298+
299+ """
300+ if same variable is both in global_variables and variables, the
301+ global_variable will get overwritten.
302+ """
303+ _global_variables = {
304+ k : v for k , v in global_variables .items () if k not in variables .keys ()
305+ }
306+ variables = {** variables , ** _global_variables }
307+
308+ _global_non_val_variables = {
309+ k : v
310+ for k , v in global_non_val_variables .items ()
311+ if k not in non_val_variables .keys ()
312+ }
313+ non_val_variables = {** non_val_variables , ** _global_non_val_variables }
314+
315+ for key , val in staged_container .items ():
316+ if "modifiers" not in key :
317+ continue
318+ stage_container [key ] = _maybe_evaluate_yaml_object (
319+ val , variables , non_val_variables
320+ )
321+
322+ container [stage_name ] = staged_container
323+
324+ return container
325+
326+
245327def is_eval_string (val : str ) -> bool :
246328 return val .startswith ("eval(" ) and val .endswith (")" )
247329
@@ -256,6 +338,7 @@ def _maybe_evaluate_recipe_equation(
256338 val : str ,
257339 variables : Dict [str , Union [int , float ]],
258340 non_eval_variables : Dict [str , Any ],
341+ global_container : Optional [Dict [str , Any ]] = {},
259342) -> Union [str , float , int ]:
260343 if is_eval_string (val ):
261344 is_eval_str = True
@@ -266,6 +349,9 @@ def _maybe_evaluate_recipe_equation(
266349 if val in non_eval_variables :
267350 return non_eval_variables [val ]
268351
352+ if val in global_container :
353+ return global_container [val ]
354+
269355 evaluated_val = restricted_eval (val , variables )
270356
271357 if is_eval_str and not isinstance (evaluated_val , (int , float )):
@@ -276,8 +362,8 @@ def _maybe_evaluate_recipe_equation(
276362 return evaluated_val
277363
278364
279- def _evaluate_recipe_variables (
280- recipe_dict : Dict [str , Any ],
365+ def _evaluate_container_variables (
366+ recipe_container : Dict [str , Any ], global_container : Optional [ Dict [ str , Any ]] = {}
281367) -> Tuple [Dict [str , Any ], Dict [str , Union [int , float ]]]:
282368 valid_variables = {}
283369 non_evaluatable_variables = {}
@@ -286,7 +372,7 @@ def _evaluate_recipe_variables(
286372 while prev_num_variables != len (valid_variables ):
287373 prev_num_variables = len (valid_variables )
288374
289- for name , val in recipe_dict .items ():
375+ for name , val in recipe_container .items ():
290376 if name in valid_variables :
291377 continue
292378
@@ -301,26 +387,26 @@ def _evaluate_recipe_variables(
301387
302388 try :
303389 val = _maybe_evaluate_recipe_equation (
304- val , valid_variables , non_evaluatable_variables
390+ val , valid_variables , non_evaluatable_variables , global_container
305391 )
306392 except UnknownVariableException :
307393 # dependant variables maybe not evaluated yet
308394 continue
309395
310396 if isinstance (val , (int , float )):
311397 # update variable value and add to valid vars
312- recipe_dict [name ] = val
398+ recipe_container [name ] = val
313399 valid_variables [name ] = val
314400
315401 # check that all eval statements have been evaluated
316- for name , val in recipe_dict .items ():
402+ for name , val in recipe_container .items ():
317403 if isinstance (val , str ) and is_eval_string (val ):
318404 raise RuntimeError (
319405 f"Unable to evaluate expression: { val } . Check if any dependent "
320406 "variables form a cycle or are not defined"
321407 )
322408
323- return recipe_dict , valid_variables , non_evaluatable_variables
409+ return recipe_container , valid_variables , non_evaluatable_variables
324410
325411
326412def _maybe_evaluate_yaml_object (
0 commit comments