Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 08c101a

Browse files
committed
add in base methods for framework and sparsification interfaces
1 parent fe5523f commit 08c101a

File tree

9 files changed

+690
-11
lines changed

9 files changed

+690
-11
lines changed

docs/conf.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
"sphinx_copybutton",
5353
"sphinx_markdown_tables",
5454
"sphinx_multiversion",
55+
"sphinx-pydantic",
5556
"sphinx_rtd_theme",
5657
"recommonmark",
5758
]
@@ -60,19 +61,19 @@
6061
templates_path = ["_templates"]
6162

6263
# Whitelist pattern for tags (set to None to ignore all tags)
63-
smv_tag_whitelist = r'^v.*$'
64+
smv_tag_whitelist = r"^v.*$"
6465

6566
# Whitelist pattern for branches (set to None to ignore all branches)
66-
smv_branch_whitelist = r'^main$'
67+
smv_branch_whitelist = r"^main$"
6768

6869
# Whitelist pattern for remotes (set to None to use local branches only)
69-
smv_remote_whitelist = r'^.*$'
70+
smv_remote_whitelist = r"^.*$"
7071

7172
# Pattern for released versions
72-
smv_released_pattern = r'^tags/v.*$'
73+
smv_released_pattern = r"^tags/v.*$"
7374

7475
# Format for versioned output directories inside the build directory
75-
smv_outputdir_format = '{ref.name}'
76+
smv_outputdir_format = "{ref.name}"
7677

7778
# Determines whether remote or local git branches/tags are preferred if their output dirs conflict
7879
smv_prefer_remote_refs = False
@@ -111,8 +112,8 @@
111112
html_logo = "source/icon-sparseml.png"
112113

113114
html_theme_options = {
114-
'analytics_id': 'UA-128364174-1', # Provided by Google in your dashboard
115-
'analytics_anonymize_ip': False,
115+
"analytics_id": "UA-128364174-1", # Provided by Google in your dashboard
116+
"analytics_anonymize_ip": False,
116117
}
117118

118119
# Add any paths that contain custom static files (such as style sheets) here,
@@ -153,7 +154,13 @@
153154
# (source start file, target name, title,
154155
# author, documentclass [howto, manual, or own class]).
155156
latex_documents = [
156-
(master_doc, "sparseml.tex", "SparseML Documentation", [author], "manual",),
157+
(
158+
master_doc,
159+
"sparseml.tex",
160+
"SparseML Documentation",
161+
[author],
162+
"manual",
163+
),
157164
]
158165

159166
# -- Options for manual page output ------------------------------------------

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ ensure_newline_before_comments = True
55
force_grid_wrap = 0
66
include_trailing_comma = True
77
known_first_party = sparseml,sparsezoo,tests
8-
known_third_party = bs4,requests,packaging,yaml,tqdm,numpy,onnx,onnxruntime,pandas,PIL,psutil,scipy,toposort,pytest,torch,torchvision,keras,tensorflow,merge-args,cv2
8+
known_third_party = bs4,requests,packaging,yaml,pydantic,tqdm,numpy,onnx,onnxruntime,pandas,PIL,psutil,scipy,toposort,pytest,torch,torchvision,keras,tensorflow,merge-args,cv2
99
sections = FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER
1010

1111
line_length = 88

setup.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@
4646
"onnx>=1.5.0,<1.8.0",
4747
"onnxruntime>=1.0.0",
4848
"pandas<1.0.0",
49+
"packaging>=20.0",
4950
"psutil>=5.0.0",
51+
"pydantic>=1.0.0",
5052
"requests>=2.0.0",
5153
"scikit-image>=0.15.0",
5254
"scipy>=1.0.0",
@@ -80,6 +82,8 @@
8082
"sphinx-copybutton>=0.3.0",
8183
"sphinx-markdown-tables>=0.0.15",
8284
"sphinx-multiversion==0.2.4",
85+
"sphinx-pydantic>=0.1.0",
86+
"sphinx-rtd-theme>=0.5.0",
8387
"wheel>=0.36.2",
8488
"pytest>=6.0.0",
8589
"flaky>=3.0.0",

src/sparseml/__init__.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,16 @@
1919
# flake8: noqa
2020
# isort: skip_file
2121

22-
from .version import *
23-
2422
# be sure to import all logging first and at the root
2523
# this keeps other loggers in nested files creating from the root logger setups
2624
from .log import *
25+
from .version import *
26+
27+
from .base import (
28+
Framework,
29+
check_version,
30+
detect_framework,
31+
execute_in_sparseml_framework,
32+
)
33+
from .framework import FrameworkInferenceProviderInfo, FrameworkInfo, framework_info
34+
from .sparsification import sparsification_info

src/sparseml/base.py

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import importlib
17+
import logging
18+
from enum import Enum
19+
from typing import Any, Optional
20+
from packaging import version
21+
22+
import pkg_resources
23+
24+
25+
__all__ = [
26+
"Framework",
27+
"detect_framework",
28+
"execute_in_sparseml_framework",
29+
"get_version",
30+
"check_version",
31+
]
32+
33+
34+
_LOGGER = logging.getLogger(__name__)
35+
36+
37+
class Framework(Enum):
38+
"""
39+
Framework types known of/supported within the sparseml/deepsparse ecosystem
40+
"""
41+
42+
unknown = "unknown"
43+
deepsparse = "deepsparse"
44+
onnx = "onnx"
45+
keras = "keras"
46+
pytorch = "pytorch"
47+
tensorflow_v1 = "tensorflow_v1"
48+
49+
50+
def detect_framework(item: Any) -> Framework:
51+
"""
52+
Detect the supported ML framework for a given item.
53+
Supported input types are the following:
54+
- A Framework enum
55+
- A string of any case representing the name of the framework
56+
(deepsparse, onnx, keras, pytorch, tensorflow_v1)
57+
- A supported file type within the framework such as model files:
58+
(onnx, pth, h5, pb)
59+
- An object from a supported ML framework such as a model instance
60+
If the framework cannot be determined, will return Framework.unknown
61+
:param item: The item to detect the ML framework for
62+
:type item: Any
63+
:return: The detected framework from the given item
64+
:rtype: Framework
65+
"""
66+
_LOGGER.debug("detecting framework for %s", item)
67+
framework = Framework.unknown
68+
69+
if isinstance(item, Framework):
70+
_LOGGER.debug("framework detected from Framework instance")
71+
framework = item
72+
elif isinstance(item, str) and item.lower().strip() in Framework.__members__:
73+
_LOGGER.debug("framework detected from Framework string instance")
74+
framework = Framework[item.lower().strip()]
75+
else:
76+
_LOGGER.debug("detecting framework by calling into supported frameworks")
77+
78+
for test in Framework:
79+
try:
80+
framework = execute_in_sparseml_framework(
81+
test, "detect_framework", item
82+
)
83+
except Exception as err:
84+
# errors are expected if the framework is not installed, log as debug
85+
_LOGGER.debug(f"error while calling detect_framework for {test}: {err}")
86+
87+
if framework != Framework.unknown:
88+
break
89+
90+
_LOGGER.info("detected framework of %s from %s", framework, item)
91+
92+
return framework
93+
94+
95+
def execute_in_sparseml_framework(
96+
framework: Framework, function_name: str, *args, **kwargs
97+
) -> Any:
98+
"""
99+
Execute a general function that is callable from the root of the frameworks
100+
package under SparseML such as sparseml.pytorch.
101+
Useful for benchmarking, analyzing, etc.
102+
Will pass the args and kwargs to the callable function.
103+
:param framework: The ML framework to run the function under in SparseML.
104+
:type framework: Framework
105+
:param function_name: The name of the function in SparseML that should be run
106+
with the given args and kwargs.
107+
:type function_name: str
108+
:param args: Any positional args to be passed into the function.
109+
:param kwargs: Any key word args to be passed into the function.
110+
:return: The return value from the executed function.
111+
:rtype: Any
112+
"""
113+
_LOGGER.debug(
114+
"executing function with name %s for framework %s, args %s, kwargs %s",
115+
function_name,
116+
framework,
117+
args,
118+
kwargs,
119+
)
120+
121+
if not isinstance(framework, Framework):
122+
framework = detect_framework(framework)
123+
124+
if framework == Framework.unknown:
125+
raise ValueError(
126+
f"unknown or unsupported framework {framework}, "
127+
f"cannot call function {function_name}"
128+
)
129+
130+
try:
131+
module = importlib.import_module(f"sparseml.{framework.value}")
132+
function = getattr(module, function_name)
133+
except Exception as err:
134+
raise ValueError(
135+
f"could not find function_name {function_name} in framework {framework}: "
136+
f"{err}"
137+
)
138+
139+
return function(*args, **kwargs)
140+
141+
142+
def get_version(package_name: str, raise_on_error: bool) -> Optional[str]:
143+
"""
144+
:param package_name: The name of the full package, as it would be imported,
145+
to get the version for
146+
:type package_name: str
147+
:param raise_on_error: True to raise an error if package is not installed
148+
or couldn't be imported, False to return None
149+
:return: the version of the desired package if detected, otherwise raises an error
150+
:rtype: str
151+
"""
152+
153+
try:
154+
current_version: str = pkg_resources.get_distribution(package_name).version
155+
except Exception as err:
156+
if raise_on_error:
157+
raise ImportError(
158+
f"error while getting current version for {package_name}: {err}"
159+
)
160+
161+
return None
162+
163+
return current_version
164+
165+
166+
def check_version(
167+
package_name: str,
168+
min_version: Optional[str] = None,
169+
max_version: Optional[str] = None,
170+
raise_on_error: bool = True,
171+
) -> bool:
172+
"""
173+
:param package_name: the name of the package to check the version of
174+
:type package_name: str
175+
:param min_version: The minimum version for the package that it must be greater than
176+
or equal to, if unset will require no minimum version
177+
:type min_version: str
178+
:param max_version: The maximum version for the package that it must be less than
179+
or equal to, if unset will require no maximum version.
180+
:type max_version: str
181+
:param raise_on_error: True to raise any issues such as not installed,
182+
minimum version, or maximum version as ImportError. False to return the result.
183+
:type raise_on_error: bool
184+
:return: If raise_on_error, will return False if the package is not installed
185+
or the version is outside the accepted bounds and True if everything is correct.
186+
:rtype: bool
187+
"""
188+
current_version = get_version(package_name, raise_on_error)
189+
190+
if not current_version:
191+
return False
192+
193+
current_version = version.parse(current_version)
194+
min_version = version.parse(min_version) if min_version else None
195+
max_version = version.parse(max_version) if max_version else None
196+
197+
if min_version and current_version < min_version:
198+
if raise_on_error:
199+
raise ImportError(
200+
f"required min {package_name} version {min_version}, "
201+
f"found {current_version}"
202+
)
203+
return False
204+
205+
if max_version and current_version > max_version:
206+
if raise_on_error:
207+
raise ImportError(
208+
f"required min {package_name} version {min_version}, "
209+
f"found {current_version}"
210+
)
211+
return False
212+
213+
return True

src/sparseml/framework/__init__.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
Functionality related to integrating with, detecting, and getting information for
17+
support and sparsification in ML frameworks.
18+
"""
19+
20+
# flake8: noqa
21+
22+
from .info import *

0 commit comments

Comments
 (0)