Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit dbad7de

Browse files
authored
recipe metadata updates (#394)
* metadata update path * updates from review
1 parent 69a7bd2 commit dbad7de

File tree

7 files changed

+163
-104
lines changed

7 files changed

+163
-104
lines changed

src/sparseml/keras/optim/manager.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@
2525
from sparseml.keras.optim.modifier import Modifier, ScheduledModifier
2626
from sparseml.keras.utils.compat import keras
2727
from sparseml.keras.utils.logger import KerasLogger
28-
from sparseml.optim import BaseManager
29-
from sparseml.utils import load_recipe_yaml_str
28+
from sparseml.optim import BaseManager, load_recipe_yaml_str
3029
from sparsezoo.objects import Recipe
3130

3231

@@ -42,6 +41,7 @@ class ScheduledModifierManager(BaseManager, Modifier):
4241
def from_yaml(
4342
file_path: Union[str, Recipe],
4443
add_modifiers: List[Modifier] = None,
44+
**recipe_variables,
4545
):
4646
"""
4747
Convenience function used to create the manager of multiple modifiers from a
@@ -55,9 +55,11 @@ def from_yaml(
5555
'zoo:model/stub/path?recipe_type=transfer'
5656
:param add_modifiers: additional modifiers that should be added to the
5757
returned manager alongside the ones loaded from the recipe file
58+
:param recipe_variables: additional variable values to override the recipe
59+
with (i.e. num_epochs, init_lr)
5860
:return: ScheduledModifierManager() created from the recipe file
5961
"""
60-
yaml_str = load_recipe_yaml_str(file_path)
62+
yaml_str = load_recipe_yaml_str(file_path, **recipe_variables)
6163
modifiers = Modifier.load_list(yaml_str)
6264
if add_modifiers:
6365
modifiers.extend(add_modifiers)

src/sparseml/optim/helpers.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,103 @@
1717
"""
1818

1919

20+
import logging
2021
import re
2122
from typing import Any, Dict, Tuple, Union
2223

2324
import yaml
2425

2526
from sparseml.utils import UnknownVariableException, restricted_eval
27+
from sparsezoo import Zoo
28+
from sparsezoo.objects import Recipe
2629

2730

2831
__all__ = [
32+
"load_recipe_yaml_str",
2933
"load_recipe_yaml_str_no_classes",
3034
"rewrite_recipe_yaml_string_with_classes",
35+
"update_recipe_variables",
3136
"evaluate_recipe_yaml_str_equations",
3237
]
3338

39+
_LOGGER = logging.getLogger(__name__)
40+
41+
42+
def load_recipe_yaml_str(
43+
file_path: Union[str, Recipe],
44+
**variable_overrides,
45+
) -> str:
46+
"""
47+
Loads a YAML recipe file to a string or
48+
extracts recipe from YAML front matter in a sparsezoo markdown recipe card.
49+
Recipes can also be provided as SparseZoo model stubs or Recipe
50+
objects.
51+
52+
YAML front matter: https://jekyllrb.com/docs/front-matter/
53+
54+
:param file_path: file path to recipe YAML file or markdown recipe card or
55+
stub to a SparseZoo model whose recipe will be downloaded and loaded.
56+
SparseZoo stubs should be preceded by 'zoo:', and can contain an optional
57+
'?recipe_type=<type>' parameter or include a `/<type>` subpath. Can also
58+
be a SparseZoo Recipe object. i.e. '/path/to/local/recipe.yaml',
59+
'zoo:model/stub/path', 'zoo:model/stub/path?recipe_type=transfer_learn',
60+
'zoo:model/stub/path/transfer_learn'
61+
:param variable_overrides: dict of variable values to replace
62+
in the loaded yaml string. Default is None
63+
:return: the recipe YAML configuration loaded as a string
64+
"""
65+
if isinstance(file_path, Recipe):
66+
# download and unwrap Recipe object
67+
file_path = file_path.downloaded_path()
68+
69+
if not isinstance(file_path, str):
70+
raise ValueError(f"file_path must be a str, given {type(file_path)}")
71+
72+
if file_path.startswith("zoo:"):
73+
# download from zoo stub
74+
recipe = Zoo.download_recipe_from_stub(file_path)
75+
file_path = recipe.downloaded_path()
76+
77+
# load the yaml string
78+
if "\n" in file_path or "\r" in file_path:
79+
# treat as raw yaml passed in
80+
yaml_str = file_path
81+
extension = "unknown"
82+
else:
83+
# load yaml from file_path
84+
extension = file_path.lower().split(".")[-1]
85+
if extension not in ["md", "yaml"]:
86+
raise ValueError(
87+
"Unsupported file extension for recipe. Excepted '.md' or '.yaml'. "
88+
f"Received {file_path}"
89+
)
90+
with open(file_path, "r") as yaml_file:
91+
yaml_str = yaml_file.read()
92+
93+
if extension == "md" or extension == "unknown":
94+
# extract YAML front matter from markdown recipe card
95+
# adapted from
96+
# https://github.com/jonbeebe/frontmatter/blob/master/frontmatter
97+
yaml_delim = r"(?:---|\+\+\+)"
98+
yaml = r"(.*?)"
99+
re_pattern = r"^\s*" + yaml_delim + yaml + yaml_delim
100+
regex = re.compile(re_pattern, re.S | re.M)
101+
result = regex.search(yaml_str)
102+
103+
if result:
104+
yaml_str = result.group(1)
105+
elif extension == "md":
106+
# fail if we know whe should have extracted front matter out
107+
raise RuntimeError(
108+
"Could not extract YAML front matter from recipe card:"
109+
" {}".format(file_path)
110+
)
111+
112+
if variable_overrides:
113+
update_recipe_variables(yaml_str, variable_overrides)
114+
115+
return yaml_str
116+
34117

35118
def load_recipe_yaml_str_no_classes(recipe_yaml_str: str) -> str:
36119
"""
@@ -57,6 +140,30 @@ def rewrite_recipe_yaml_string_with_classes(recipe_contianer: Any) -> str:
57140
return pattern.sub(r"!\g<class_name>", updated_yaml_str)
58141

59142

143+
def update_recipe_variables(recipe_yaml_str: str, variables: Dict[str, Any]) -> str:
144+
"""
145+
:param recipe_yaml_str: YAML string of a SparseML recipe
146+
:param variables: variables dictionary to update recipe top level variables with
147+
:return: given recipe with variables updated
148+
"""
149+
150+
container = load_recipe_yaml_str_no_classes(recipe_yaml_str)
151+
if not isinstance(container, dict):
152+
# yaml string does not create a dict, return original string
153+
return recipe_yaml_str
154+
155+
for key in variables:
156+
if key not in container:
157+
_LOGGER.warning(
158+
f"updating recipe variable {key} but {key} is not currently "
159+
"set in existing recipe. This change may have no impact on the recipe "
160+
"modifiers"
161+
)
162+
163+
container.update(variables)
164+
return rewrite_recipe_yaml_string_with_classes(container)
165+
166+
60167
def evaluate_recipe_yaml_str_equations(recipe_yaml_str: str) -> str:
61168
"""
62169
:param recipe_yaml_str: YAML string of a SparseML recipe

src/sparseml/pytorch/optim/manager.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,9 @@
2424
from torch.nn import Module
2525
from torch.optim.optimizer import Optimizer
2626

27-
from sparseml.optim import BaseManager
27+
from sparseml.optim import BaseManager, load_recipe_yaml_str
2828
from sparseml.pytorch.optim.modifier import Modifier, ScheduledModifier
2929
from sparseml.pytorch.utils import BaseLogger
30-
from sparseml.utils import load_recipe_yaml_str
3130
from sparsezoo.objects import Recipe
3231

3332

@@ -248,6 +247,7 @@ class ScheduledModifierManager(BaseManager, Modifier):
248247
def from_yaml(
249248
file_path: Union[str, Recipe],
250249
add_modifiers: List[Modifier] = None,
250+
**recipe_variables,
251251
):
252252
"""
253253
Convenience function used to create the manager of multiple modifiers from a
@@ -261,9 +261,11 @@ def from_yaml(
261261
'zoo:model/stub/path?recipe_type=transfer'
262262
:param add_modifiers: additional modifiers that should be added to the
263263
returned manager alongside the ones loaded from the recipe file
264+
:param recipe_variables: additional variable values to override the recipe
265+
with (i.e. num_epochs, init_lr)
264266
:return: ScheduledModifierManager() created from the recipe file
265267
"""
266-
yaml_str = load_recipe_yaml_str(file_path)
268+
yaml_str = load_recipe_yaml_str(file_path, **recipe_variables)
267269
modifiers = Modifier.load_list(yaml_str)
268270

269271
if add_modifiers:

src/sparseml/tensorflow_v1/optim/manager.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,9 @@
2121
import itertools
2222
from typing import Any, Callable, Dict, List, Tuple, Union
2323

24-
from sparseml.optim import BaseManager, BaseScheduled
24+
from sparseml.optim import BaseManager, BaseScheduled, load_recipe_yaml_str
2525
from sparseml.tensorflow_v1.optim.modifier import NM_RECAL, Modifier, ScheduledModifier
2626
from sparseml.tensorflow_v1.utils import tf_compat
27-
from sparseml.utils import load_recipe_yaml_str
2827
from sparsezoo.objects import Recipe
2928

3029

@@ -78,6 +77,7 @@ class ScheduledModifierManager(BaseManager, Modifier):
7877
def from_yaml(
7978
file_path: Union[str, Recipe],
8079
add_modifiers: List[Modifier] = None,
80+
**recipe_variables,
8181
):
8282
"""
8383
Convenience function used to create the manager of multiple modifiers from a
@@ -91,9 +91,11 @@ def from_yaml(
9191
'zoo:model/stub/path?recipe_type=transfer'
9292
:param add_modifiers: additional modifiers that should be added to the
9393
returned manager alongside the ones loaded from the recipe file
94+
:param recipe_variables: additional variable values to override the recipe
95+
with (i.e. num_epochs, init_lr)
9496
:return: ScheduledModifierManager() created from the recipe file
9597
"""
96-
yaml_str = load_recipe_yaml_str(file_path)
98+
yaml_str = load_recipe_yaml_str(file_path, **recipe_variables)
9799
modifiers = Modifier.load_list(yaml_str)
98100
if add_modifiers:
99101
modifiers.extend(add_modifiers)

src/sparseml/utils/helpers.py

Lines changed: 0 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,13 @@
2121
import fnmatch
2222
import logging
2323
import os
24-
import re
2524
import sys
2625
from collections import OrderedDict
2726
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
2827
from urllib.parse import urlparse
2928

3029
import numpy
3130

32-
from sparsezoo import Zoo
33-
from sparsezoo.objects import Recipe
3431
from sparsezoo.utils import load_numpy_list
3532

3633

@@ -59,7 +56,6 @@
5956
"NumpyArrayBatcher",
6057
"tensor_export",
6158
"tensors_export",
62-
"load_recipe_yaml_str",
6359
"parse_optimization_str",
6460
]
6561

@@ -768,74 +764,6 @@ def _tensors_export_batch(
768764
)
769765

770766

771-
def load_recipe_yaml_str(file_path: Union[str, Recipe]) -> str:
772-
"""
773-
Loads a YAML recipe file to a string or
774-
extracts recipe from YAML front matter in a sparsezoo markdown recipe card.
775-
Recipes can also be provided as SparseZoo model stubs or Recipe
776-
objects.
777-
778-
YAML front matter: https://jekyllrb.com/docs/front-matter/
779-
780-
:param file_path: file path to recipe YAML file or markdown recipe card or
781-
stub to a SparseZoo model whose recipe will be downloaded and loaded.
782-
SparseZoo stubs should be preceded by 'zoo:', and can contain an optional
783-
'?recipe_type=<type>' parameter or include a `/<type>` subpath. Can also
784-
be a SparseZoo Recipe object. i.e. '/path/to/local/recipe.yaml',
785-
'zoo:model/stub/path', 'zoo:model/stub/path?recipe_type=transfer_learn',
786-
'zoo:model/stub/path/transfer_learn'
787-
:return: the recipe YAML configuration loaded as a string
788-
"""
789-
if isinstance(file_path, Recipe):
790-
# download and unwrap Recipe object
791-
file_path = file_path.downloaded_path()
792-
793-
if not isinstance(file_path, str):
794-
raise ValueError(f"file_path must be a str, given {type(file_path)}")
795-
796-
if file_path.startswith("zoo:"):
797-
# download from zoo stub
798-
recipe = Zoo.download_recipe_from_stub(file_path)
799-
file_path = recipe.downloaded_path()
800-
801-
# load the yaml string
802-
if "\n" in file_path or "\r" in file_path:
803-
# treat as raw yaml passed in
804-
yaml_str = file_path
805-
extension = "unknown"
806-
else:
807-
# load yaml from file_path
808-
extension = file_path.lower().split(".")[-1]
809-
if extension not in ["md", "yaml"]:
810-
raise ValueError(
811-
"Unsupported file extension for recipe. Excepted '.md' or '.yaml'. "
812-
"Received {}".format(file_path)
813-
)
814-
with open(file_path, "r") as yaml_file:
815-
yaml_str = yaml_file.read()
816-
817-
if extension == "md" or extension == "unknown":
818-
# extract YAML front matter from markdown recipe card
819-
# adapted from
820-
# https://github.com/jonbeebe/frontmatter/blob/master/frontmatter
821-
yaml_delim = r"(?:---|\+\+\+)"
822-
yaml = r"(.*?)"
823-
re_pattern = r"^\s*" + yaml_delim + yaml + yaml_delim
824-
regex = re.compile(re_pattern, re.S | re.M)
825-
result = regex.search(yaml_str)
826-
827-
if result:
828-
yaml_str = result.group(1)
829-
elif extension == "md":
830-
# fail if we know whe should have extracted front matter out
831-
raise RuntimeError(
832-
"Could not extract YAML front matter from recipe card:"
833-
" {}".format(file_path)
834-
)
835-
836-
return yaml_str
837-
838-
839767
def parse_optimization_str(optim_full_name: str) -> Tuple[str, str, Any]:
840768
"""
841769
:param optim_full_name: A name of a pretrained model optimization. i.e.

0 commit comments

Comments
 (0)