Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit cc2ba1d

Browse files
authored
Merge pull request #171 from neuralmagic/interface-framework-onnx
Onnx framework and sparsification implementation for phase 2
2 parents 1ae34c5 + ba797ac commit cc2ba1d

File tree

14 files changed

+661
-61
lines changed

14 files changed

+661
-61
lines changed

src/sparseml/onnx/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,11 @@
1313
# limitations under the License.
1414

1515
"""
16-
Code for working with the ONNX framework for creating /
17-
editing models for performance in the Neural Magic System
16+
Functionality for working with and sparsifying Models in the ONNX/ONNXRuntime framework
1817
"""
1918

2019
# flake8: noqa
20+
21+
from .base import *
22+
from .framework import detect_framework, framework_info, is_supported
23+
from .sparsification import sparsification_info

src/sparseml/onnx/base.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import functools
17+
from typing import Optional
18+
19+
from sparseml.base import check_version
20+
21+
22+
try:
23+
import onnx
24+
25+
onnx_err = None
26+
except Exception as err:
27+
onnx = object() # TODO: populate with fake object for necessary imports
28+
onnx_err = err
29+
30+
try:
31+
import onnxruntime
32+
33+
onnxruntime_err = None
34+
except Exception as err:
35+
onnxruntime = object() # TODO: populate with fake object for necessary imports
36+
onnxruntime_err = err
37+
38+
39+
__all__ = [
40+
"onnx",
41+
"onnx_err",
42+
"onnxruntime",
43+
"onnxruntime_err",
44+
"check_onnx_install",
45+
"check_onnxruntime_install",
46+
"require_onnx",
47+
"require_onnxruntime",
48+
]
49+
50+
51+
_ONNX_MIN_VERSION = "1.5.0"
52+
_ORT_MIN_VERSION = "1.0.0"
53+
54+
55+
def check_onnx_install(
56+
min_version: Optional[str] = _ONNX_MIN_VERSION,
57+
max_version: Optional[str] = None,
58+
raise_on_error: bool = True,
59+
) -> bool:
60+
"""
61+
Check that the onnx package is installed.
62+
If raise_on_error, will raise an ImportError if it is not installed or
63+
the required version range, if set, is not installed.
64+
If not raise_on_error, will return True if installed with required version
65+
and False otherwise.
66+
67+
:param min_version: The minimum version for onnx that it must be greater than
68+
or equal to, if unset will require no minimum version
69+
:type min_version: str
70+
:param max_version: The maximum version for onnx that it must be less than
71+
or equal to, if unset will require no maximum version.
72+
:type max_version: str
73+
:param raise_on_error: True to raise any issues such as not installed,
74+
minimum version, or maximum version as ImportError. False to return the result.
75+
:type raise_on_error: bool
76+
:return: If raise_on_error, will return False if onnx is not installed
77+
or the version is outside the accepted bounds and True if everything is correct.
78+
:rtype: bool
79+
"""
80+
if onnx_err is not None:
81+
if raise_on_error:
82+
raise onnx_err
83+
return False
84+
85+
return check_version("onnx", min_version, max_version, raise_on_error)
86+
87+
88+
def check_onnxruntime_install(
89+
min_version: Optional[str] = _ORT_MIN_VERSION,
90+
max_version: Optional[str] = None,
91+
raise_on_error: bool = True,
92+
) -> bool:
93+
"""
94+
Check that the onnxruntime package is installed.
95+
If raise_on_error, will raise an ImportError if it is not installed or
96+
the required version range, if set, is not installed.
97+
If not raise_on_error, will return True if installed with required version
98+
and False otherwise.
99+
100+
:param min_version: The minimum version for onnxruntime that it must be greater than
101+
or equal to, if unset will require no minimum version
102+
:type min_version: str
103+
:param max_version: The maximum version for onnxruntime that it must be less than
104+
or equal to, if unset will require no maximum version.
105+
:type max_version: str
106+
:param raise_on_error: True to raise any issues such as not installed,
107+
minimum version, or maximum version as ImportError. False to return the result.
108+
:type raise_on_error: bool
109+
:return: If raise_on_error, will return False if onnxruntime is not installed
110+
or the version is outside the accepted bounds and True if everything is correct.
111+
:rtype: bool
112+
"""
113+
if onnxruntime_err is not None:
114+
if raise_on_error:
115+
raise onnxruntime_err
116+
return False
117+
118+
return check_version("onnxruntime", min_version, max_version, raise_on_error)
119+
120+
121+
def require_onnx(
122+
min_version: Optional[str] = _ONNX_MIN_VERSION, max_version: Optional[str] = None
123+
):
124+
"""
125+
Decorator function to require use of onnx.
126+
Will check that onnx package is installed and within the bounding
127+
ranges of min_version and max_version if they are set before calling
128+
the wrapped function.
129+
See :func:`check_onnx_install` for more info.
130+
131+
param min_version: The minimum version for onnx that it must be greater than
132+
or equal to, if unset will require no minimum version
133+
:type min_version: str
134+
:param max_version: The maximum version for onnx that it must be less than
135+
or equal to, if unset will require no maximum version.
136+
:type max_version: str
137+
"""
138+
139+
def _decorator(func):
140+
@functools.wraps(func)
141+
def _wrapper(*args, **kwargs):
142+
check_onnx_install(min_version, max_version)
143+
144+
return func(*args, **kwargs)
145+
146+
return _wrapper
147+
148+
return _decorator
149+
150+
151+
def require_onnxruntime(
152+
min_version: Optional[str] = _ORT_MIN_VERSION, max_version: Optional[str] = None
153+
):
154+
"""
155+
Decorator function to require use of onnxruntime.
156+
Will check that onnxruntime package is installed and within the bounding
157+
ranges of min_version and max_version if they are set before calling
158+
the wrapped function.
159+
See :func:`check_onnxruntime_install` for more info.
160+
161+
param min_version: The minimum version for onnxruntime that it must be greater than
162+
or equal to, if unset will require no minimum version
163+
:type min_version: str
164+
:param max_version: The maximum version for onnxruntime that it must be less than
165+
or equal to, if unset will require no maximum version.
166+
:type max_version: str
167+
"""
168+
169+
def _decorator(func):
170+
@functools.wraps(func)
171+
def _wrapper(*args, **kwargs):
172+
check_onnxruntime_install(min_version, max_version)
173+
174+
return func(*args, **kwargs)
175+
176+
return _wrapper
177+
178+
return _decorator
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# flake8: noqa
16+
17+
"""
18+
Functionality related to integrating with, detecting, and getting information for
19+
support and sparsification in the ONNX/ONNXRuntime framework.
20+
"""
21+
22+
# flake8: noqa
23+
24+
from .info import *
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
Functionality related to detecting and getting information for
17+
support and sparsification in the ONNX/ONNXRuntime framework.
18+
"""
19+
20+
import logging
21+
from typing import Any
22+
23+
from sparseml.base import Framework, get_version
24+
from sparseml.framework import FrameworkInferenceProviderInfo, FrameworkInfo
25+
from sparseml.onnx.base import check_onnx_install, check_onnxruntime_install
26+
from sparseml.onnx.sparsification import sparsification_info
27+
from sparseml.sparsification import SparsificationInfo
28+
29+
30+
__all__ = ["is_supported", "detect_framework", "framework_info"]
31+
32+
33+
_LOGGER = logging.getLogger(__name__)
34+
35+
36+
def is_supported(item: Any) -> bool:
37+
"""
38+
:param item: The item to detect the support for
39+
:type item: Any
40+
:return: True if the item is supported by onnx/onnxruntime, False otherwise
41+
:rtype: bool
42+
"""
43+
framework = detect_framework(item)
44+
45+
return framework == Framework.onnx
46+
47+
48+
def detect_framework(item: Any) -> Framework:
49+
"""
50+
Detect the supported ML framework for a given item specifically for the
51+
onnx/onnxruntime package.
52+
Supported input types are the following:
53+
- A Framework enum
54+
- A string of any case representing the name of the framework
55+
(deepsparse, onnx, keras, pytorch, tensorflow_v1)
56+
- A supported file type within the framework such as model files:
57+
(onnx, pth, h5, pb)
58+
- An object from a supported ML framework such as a model instance
59+
If the framework cannot be determined, will return Framework.unknown
60+
61+
:param item: The item to detect the ML framework for
62+
:type item: Any
63+
:return: The detected framework from the given item
64+
:rtype: Framework
65+
"""
66+
framework = Framework.unknown
67+
68+
if isinstance(item, Framework):
69+
_LOGGER.debug("framework detected from Framework instance")
70+
framework = item
71+
elif isinstance(item, str) and item.lower().strip() in Framework.__members__:
72+
_LOGGER.debug("framework detected from Framework string instance")
73+
framework = Framework[item.lower().strip()]
74+
elif isinstance(item, str) and "onnx" in item.lower().strip():
75+
_LOGGER.debug("framework detected from onnx text")
76+
# string, check if it's a string saying onnx first
77+
framework = Framework.onnx
78+
elif isinstance(item, str) and ".onnx" in item.lower().strip():
79+
_LOGGER.debug("framework detected from .onnx")
80+
# string, check if it's a file url or path that ends with onnx extension
81+
framework = Framework.onnx
82+
elif check_onnx_install(raise_on_error=False):
83+
from onnx import ModelProto
84+
85+
if isinstance(item, ModelProto):
86+
_LOGGER.debug("framework detected from ONNX instance")
87+
# onnx native support
88+
framework = Framework.onnx
89+
90+
return framework
91+
92+
93+
def framework_info() -> FrameworkInfo:
94+
"""
95+
Detect the information for the onnx/onnxruntime framework such as package versions,
96+
availability for core actions such as training and inference,
97+
sparsification support, and inference provider support.
98+
99+
:return: The framework info for onnx/onnxruntime
100+
:rtype: FrameworkInfo
101+
"""
102+
all_providers = []
103+
available_providers = []
104+
if check_onnxruntime_install(raise_on_error=False):
105+
from onnxruntime import get_all_providers, get_available_providers
106+
107+
available_providers = get_available_providers()
108+
all_providers = get_all_providers()
109+
110+
cpu_provider = FrameworkInferenceProviderInfo(
111+
name="cpu",
112+
description="Base CPU provider within ONNXRuntime",
113+
device="cpu",
114+
supported_sparsification=SparsificationInfo(), # TODO: fill in when available
115+
available=(
116+
check_onnx_install(raise_on_error=False)
117+
and check_onnxruntime_install(raise_on_error=False)
118+
and "CPUExecutionProvider" in available_providers
119+
),
120+
properties={},
121+
warnings=[],
122+
)
123+
gpu_provider = FrameworkInferenceProviderInfo(
124+
name="cuda",
125+
description="Base GPU CUDA provider within ONNXRuntime",
126+
device="gpu",
127+
supported_sparsification=SparsificationInfo(), # TODO: fill in when available
128+
available=(
129+
check_onnx_install(raise_on_error=False)
130+
and check_onnxruntime_install(raise_on_error=False)
131+
and "CUDAExecutionProvider" in available_providers
132+
),
133+
properties={},
134+
warnings=[],
135+
)
136+
137+
return FrameworkInfo(
138+
framework=Framework.onnx,
139+
package_versions={
140+
"onnx": get_version(package_name="onnx", raise_on_error=False),
141+
"onnxruntime": (
142+
get_version(package_name="onnxruntime", raise_on_error=False)
143+
),
144+
"sparsezoo": get_version(package_name="sparsezoo", raise_on_error=False),
145+
"sparseml": get_version(package_name="sparseml", raise_on_error=False),
146+
},
147+
sparsification=sparsification_info(),
148+
inference_providers=[cpu_provider, gpu_provider],
149+
properties={
150+
"available_providers": available_providers,
151+
"all_providers": all_providers,
152+
},
153+
training_available=False,
154+
sparsification_available=True,
155+
exporting_onnx_available=True,
156+
inference_available=True,
157+
)

0 commit comments

Comments
 (0)