Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 2b24f6f

Browse files
authored
Base framework and sparsification implementation for phase 2 (#168)
* add in base methods for framework and sparsification interfaces * add unit tests and minor fixes/functionality * fix missing imports, add console scripts, fix quality checks * add indent to pretty print the json output
1 parent f8201c6 commit 2b24f6f

File tree

16 files changed

+1378
-12
lines changed

16 files changed

+1378
-12
lines changed

docs/conf.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
"sphinx_copybutton",
5353
"sphinx_markdown_tables",
5454
"sphinx_multiversion",
55+
"sphinx-pydantic",
5556
"sphinx_rtd_theme",
5657
"recommonmark",
5758
]
@@ -60,19 +61,19 @@
6061
templates_path = ["_templates"]
6162

6263
# Whitelist pattern for tags (set to None to ignore all tags)
63-
smv_tag_whitelist = r'^v.*$'
64+
smv_tag_whitelist = r"^v.*$"
6465

6566
# Whitelist pattern for branches (set to None to ignore all branches)
66-
smv_branch_whitelist = r'^main$'
67+
smv_branch_whitelist = r"^main$"
6768

6869
# Whitelist pattern for remotes (set to None to use local branches only)
69-
smv_remote_whitelist = r'^.*$'
70+
smv_remote_whitelist = r"^.*$"
7071

7172
# Pattern for released versions
72-
smv_released_pattern = r'^tags/v.*$'
73+
smv_released_pattern = r"^tags/v.*$"
7374

7475
# Format for versioned output directories inside the build directory
75-
smv_outputdir_format = '{ref.name}'
76+
smv_outputdir_format = "{ref.name}"
7677

7778
# Determines whether remote or local git branches/tags are preferred if their output dirs conflict
7879
smv_prefer_remote_refs = False
@@ -111,8 +112,8 @@
111112
html_logo = "source/icon-sparseml.png"
112113

113114
html_theme_options = {
114-
'analytics_id': 'UA-128364174-1', # Provided by Google in your dashboard
115-
'analytics_anonymize_ip': False,
115+
"analytics_id": "UA-128364174-1", # Provided by Google in your dashboard
116+
"analytics_anonymize_ip": False,
116117
}
117118

118119
# Add any paths that contain custom static files (such as style sheets) here,
@@ -153,7 +154,13 @@
153154
# (source start file, target name, title,
154155
# author, documentclass [howto, manual, or own class]).
155156
latex_documents = [
156-
(master_doc, "sparseml.tex", "SparseML Documentation", [author], "manual",),
157+
(
158+
master_doc,
159+
"sparseml.tex",
160+
"SparseML Documentation",
161+
[author],
162+
"manual",
163+
),
157164
]
158165

159166
# -- Options for manual page output ------------------------------------------

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ ensure_newline_before_comments = True
55
force_grid_wrap = 0
66
include_trailing_comma = True
77
known_first_party = sparseml,sparsezoo,tests
8-
known_third_party = bs4,requests,packaging,yaml,tqdm,numpy,onnx,onnxruntime,pandas,PIL,psutil,scipy,toposort,pytest,torch,torchvision,keras,tensorflow,merge-args,cv2
8+
known_third_party = bs4,requests,packaging,yaml,pydantic,tqdm,numpy,onnx,onnxruntime,pandas,PIL,psutil,scipy,toposort,pytest,torch,torchvision,keras,tensorflow,merge-args,cv2
99
sections = FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER
1010

1111
line_length = 88

setup.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@
4646
"onnx>=1.5.0,<1.8.0",
4747
"onnxruntime>=1.0.0",
4848
"pandas<1.0.0",
49+
"packaging>=20.0",
4950
"psutil>=5.0.0",
51+
"pydantic>=1.0.0",
5052
"requests>=2.0.0",
5153
"scikit-image>=0.15.0",
5254
"scipy>=1.0.0",
@@ -80,6 +82,8 @@
8082
"sphinx-copybutton>=0.3.0",
8183
"sphinx-markdown-tables>=0.0.15",
8284
"sphinx-multiversion==0.2.4",
85+
"sphinx-pydantic>=0.1.0",
86+
"sphinx-rtd-theme>=0.5.0",
8387
"wheel>=0.36.2",
8488
"pytest>=6.0.0",
8589
"flaky>=3.0.0",
@@ -114,7 +118,12 @@ def _setup_extras() -> Dict:
114118

115119

116120
def _setup_entry_points() -> Dict:
117-
return {}
121+
return {
122+
"console_scripts": [
123+
"sparseml.framework=sparseml.framework.info:_main",
124+
"sparseml.sparsification=sparseml.sparsification.info:_main",
125+
]
126+
}
118127

119128

120129
def _setup_long_description() -> Tuple[str, str]:

src/sparseml/__init__.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,27 @@
1919
# flake8: noqa
2020
# isort: skip_file
2121

22-
from .version import *
23-
2422
# be sure to import all logging first and at the root
2523
# this keeps other loggers in nested files creating from the root logger setups
2624
from .log import *
25+
from .version import *
26+
27+
from .base import (
28+
Framework,
29+
check_version,
30+
detect_framework,
31+
execute_in_sparseml_framework,
32+
)
33+
from .framework import (
34+
FrameworkInferenceProviderInfo,
35+
FrameworkInfo,
36+
framework_info,
37+
save_framework_info,
38+
load_framework_info,
39+
)
40+
from .sparsification import (
41+
SparsificationInfo,
42+
sparsification_info,
43+
save_sparsification_info,
44+
load_sparsification_info,
45+
)

src/sparseml/base.py

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import importlib
17+
import logging
18+
from enum import Enum
19+
from typing import Any, Optional
20+
21+
from packaging import version
22+
23+
import pkg_resources
24+
25+
26+
__all__ = [
27+
"Framework",
28+
"detect_framework",
29+
"execute_in_sparseml_framework",
30+
"get_version",
31+
"check_version",
32+
]
33+
34+
35+
_LOGGER = logging.getLogger(__name__)
36+
37+
38+
class Framework(Enum):
39+
"""
40+
Framework types known of/supported within the sparseml/deepsparse ecosystem
41+
"""
42+
43+
unknown = "unknown"
44+
deepsparse = "deepsparse"
45+
onnx = "onnx"
46+
keras = "keras"
47+
pytorch = "pytorch"
48+
tensorflow_v1 = "tensorflow_v1"
49+
50+
51+
def detect_framework(item: Any) -> Framework:
52+
"""
53+
Detect the supported ML framework for a given item.
54+
Supported input types are the following:
55+
- A Framework enum
56+
- A string of any case representing the name of the framework
57+
(deepsparse, onnx, keras, pytorch, tensorflow_v1)
58+
- A supported file type within the framework such as model files:
59+
(onnx, pth, h5, pb)
60+
- An object from a supported ML framework such as a model instance
61+
If the framework cannot be determined, will return Framework.unknown
62+
:param item: The item to detect the ML framework for
63+
:type item: Any
64+
:return: The detected framework from the given item
65+
:rtype: Framework
66+
"""
67+
_LOGGER.debug("detecting framework for %s", item)
68+
framework = Framework.unknown
69+
70+
if isinstance(item, Framework):
71+
_LOGGER.debug("framework detected from Framework instance")
72+
framework = item
73+
elif isinstance(item, str) and item.lower().strip() in Framework.__members__:
74+
_LOGGER.debug("framework detected from Framework string instance")
75+
framework = Framework[item.lower().strip()]
76+
else:
77+
_LOGGER.debug("detecting framework by calling into supported frameworks")
78+
79+
for test in Framework:
80+
try:
81+
framework = execute_in_sparseml_framework(
82+
test, "detect_framework", item
83+
)
84+
except Exception as err:
85+
# errors are expected if the framework is not installed, log as debug
86+
_LOGGER.debug(f"error while calling detect_framework for {test}: {err}")
87+
88+
if framework != Framework.unknown:
89+
break
90+
91+
_LOGGER.info("detected framework of %s from %s", framework, item)
92+
93+
return framework
94+
95+
96+
def execute_in_sparseml_framework(
97+
framework: Framework, function_name: str, *args, **kwargs
98+
) -> Any:
99+
"""
100+
Execute a general function that is callable from the root of the frameworks
101+
package under SparseML such as sparseml.pytorch.
102+
Useful for benchmarking, analyzing, etc.
103+
Will pass the args and kwargs to the callable function.
104+
:param framework: The ML framework to run the function under in SparseML.
105+
:type framework: Framework
106+
:param function_name: The name of the function in SparseML that should be run
107+
with the given args and kwargs.
108+
:type function_name: str
109+
:param args: Any positional args to be passed into the function.
110+
:param kwargs: Any key word args to be passed into the function.
111+
:return: The return value from the executed function.
112+
:rtype: Any
113+
"""
114+
_LOGGER.debug(
115+
"executing function with name %s for framework %s, args %s, kwargs %s",
116+
function_name,
117+
framework,
118+
args,
119+
kwargs,
120+
)
121+
122+
if not isinstance(framework, Framework):
123+
framework = detect_framework(framework)
124+
125+
if framework == Framework.unknown:
126+
raise ValueError(
127+
f"unknown or unsupported framework {framework}, "
128+
f"cannot call function {function_name}"
129+
)
130+
131+
try:
132+
module = importlib.import_module(f"sparseml.{framework.value}")
133+
function = getattr(module, function_name)
134+
except Exception as err:
135+
raise ValueError(
136+
f"could not find function_name {function_name} in framework {framework}: "
137+
f"{err}"
138+
)
139+
140+
return function(*args, **kwargs)
141+
142+
143+
def get_version(package_name: str, raise_on_error: bool) -> Optional[str]:
144+
"""
145+
:param package_name: The name of the full package, as it would be imported,
146+
to get the version for
147+
:type package_name: str
148+
:param raise_on_error: True to raise an error if package is not installed
149+
or couldn't be imported, False to return None
150+
:return: the version of the desired package if detected, otherwise raises an error
151+
:rtype: str
152+
"""
153+
154+
try:
155+
current_version: str = pkg_resources.get_distribution(package_name).version
156+
except Exception as err:
157+
if raise_on_error:
158+
raise ImportError(
159+
f"error while getting current version for {package_name}: {err}"
160+
)
161+
162+
return None
163+
164+
return current_version
165+
166+
167+
def check_version(
168+
package_name: str,
169+
min_version: Optional[str] = None,
170+
max_version: Optional[str] = None,
171+
raise_on_error: bool = True,
172+
) -> bool:
173+
"""
174+
:param package_name: the name of the package to check the version of
175+
:type package_name: str
176+
:param min_version: The minimum version for the package that it must be greater than
177+
or equal to, if unset will require no minimum version
178+
:type min_version: str
179+
:param max_version: The maximum version for the package that it must be less than
180+
or equal to, if unset will require no maximum version.
181+
:type max_version: str
182+
:param raise_on_error: True to raise any issues such as not installed,
183+
minimum version, or maximum version as ImportError. False to return the result.
184+
:type raise_on_error: bool
185+
:return: If raise_on_error, will return False if the package is not installed
186+
or the version is outside the accepted bounds and True if everything is correct.
187+
:rtype: bool
188+
"""
189+
current_version = get_version(package_name, raise_on_error)
190+
191+
if not current_version:
192+
return False
193+
194+
current_version = version.parse(current_version)
195+
min_version = version.parse(min_version) if min_version else None
196+
max_version = version.parse(max_version) if max_version else None
197+
198+
if min_version and current_version < min_version:
199+
if raise_on_error:
200+
raise ImportError(
201+
f"required min {package_name} version {min_version}, "
202+
f"found {current_version}"
203+
)
204+
return False
205+
206+
if max_version and current_version > max_version:
207+
if raise_on_error:
208+
raise ImportError(
209+
f"required min {package_name} version {min_version}, "
210+
f"found {current_version}"
211+
)
212+
return False
213+
214+
return True

src/sparseml/framework/__init__.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
Functionality related to integrating with, detecting, and getting information for
17+
support and sparsification in ML frameworks.
18+
"""
19+
20+
# flake8: noqa
21+
22+
from .info import *

0 commit comments

Comments
 (0)