Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 3190d6d

Browse files
authored
PyTorch framework and sparsification implementation for phase 2 (#172)
* add in base methods for framework and sparsification interfaces * add unit tests and minor fixes/functionality * fix missing imports, add console scripts, fix quality checks * PyTorch functionality for phase2 to enable loading the information available for training, sparsifying, and exporting for within the PyTorch framework
1 parent 833ec73 commit 3190d6d

File tree

17 files changed

+672
-11
lines changed

17 files changed

+672
-11
lines changed

src/sparseml/pytorch/__init__.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,11 @@
1313
# limitations under the License.
1414

1515
"""
16-
Code for working with the pytorch framework for creating /
17-
editing models for performance in the Neural Magic System
16+
Functionality for working with and sparsifying Models in the PyTorch framework
1817
"""
1918

2019
# flake8: noqa
2120

22-
try:
23-
import torch
24-
25-
if torch.__version__[0] != "1":
26-
raise Exception
27-
except:
28-
raise RuntimeError(
29-
"Unable to import torch. torch>=1.0.0 is required to use sparseml.pytorch"
30-
)
21+
from .base import *
22+
from .framework import detect_framework, framework_info, is_supported
23+
from .sparsification import sparsification_info

src/sparseml/pytorch/base.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import functools
17+
from typing import Optional
18+
19+
from sparseml.base import check_version
20+
21+
22+
try:
23+
import torch
24+
25+
torch_err = None
26+
except Exception as err:
27+
torch = object() # TODO: populate with fake object for necessary imports
28+
torch_err = err
29+
30+
try:
31+
import torchvision
32+
33+
torchvision_err = None
34+
except Exception as err:
35+
torchvision = object() # TODO: populate with fake object for necessary imports
36+
torchvision_err = err
37+
38+
39+
__all__ = [
40+
"torch",
41+
"torch_err",
42+
"torchvision",
43+
"torchvision_err",
44+
"check_torch_install",
45+
"check_torchvision_install",
46+
"require_torch",
47+
"require_torchvision",
48+
]
49+
50+
51+
_TORCH_MIN_VERSION = "1.0.0"
52+
_TORCH_MAX_VERSION = "1.8.100" # set bug to 100 to support all future 1.8.X versions
53+
54+
55+
def check_torch_install(
56+
min_version: Optional[str] = _TORCH_MIN_VERSION,
57+
max_version: Optional[str] = _TORCH_MAX_VERSION,
58+
raise_on_error: bool = True,
59+
) -> bool:
60+
"""
61+
Check that the torch package is installed.
62+
If raise_on_error, will raise an ImportError if it is not installed or
63+
the required version range, if set, is not installed.
64+
If not raise_on_error, will return True if installed with required version
65+
and False otherwise.
66+
67+
:param min_version: The minimum version for torch that it must be greater than
68+
or equal to, if unset will require no minimum version
69+
:type min_version: str
70+
:param max_version: The maximum version for torch that it must be less than
71+
or equal to, if unset will require no maximum version.
72+
:type max_version: str
73+
:param raise_on_error: True to raise any issues such as not installed,
74+
minimum version, or maximum version as ImportError. False to return the result.
75+
:type raise_on_error: bool
76+
:return: If raise_on_error, will return False if torch is not installed
77+
or the version is outside the accepted bounds and True if everything is correct.
78+
:rtype: bool
79+
"""
80+
if torch_err is not None:
81+
if raise_on_error:
82+
raise torch_err
83+
return False
84+
85+
return check_version("torch", min_version, max_version, raise_on_error)
86+
87+
88+
def check_torchvision_install(
89+
min_version: Optional[str] = None,
90+
max_version: Optional[str] = None,
91+
raise_on_error: bool = True,
92+
) -> bool:
93+
"""
94+
Check that the torchvision package is installed.
95+
If raise_on_error, will raise an ImportError if it is not installed or
96+
the required version range, if set, is not installed.
97+
If not raise_on_error, will return True if installed with required version
98+
and False otherwise.
99+
100+
:param min_version: The minimum version for torchvision that it must be greater than
101+
or equal to, if unset will require no minimum version
102+
:type min_version: str
103+
:param max_version: The maximum version for torchvision that it must be less than
104+
or equal to, if unset will require no maximum version.
105+
:type max_version: str
106+
:param raise_on_error: True to raise any issues such as not installed,
107+
minimum version, or maximum version as ImportError. False to return the result.
108+
:type raise_on_error: bool
109+
:return: If raise_on_error, will return False if torchvision is not installed
110+
or the version is outside the accepted bounds and True if everything is correct.
111+
:rtype: bool
112+
"""
113+
if torchvision_err is not None:
114+
if raise_on_error:
115+
raise torchvision_err
116+
return False
117+
118+
return check_version("torchvision", min_version, max_version, raise_on_error)
119+
120+
121+
def require_torch(
122+
min_version: Optional[str] = _TORCH_MIN_VERSION,
123+
max_version: Optional[str] = _TORCH_MAX_VERSION,
124+
):
125+
"""
126+
Decorator function to require use of torch.
127+
Will check that torch package is installed and within the bounding
128+
ranges of min_version and max_version if they are set before calling
129+
the wrapped function.
130+
See :func:`check_torch_install` for more info.
131+
132+
:param min_version: The minimum version for torch that it must be greater than
133+
or equal to, if unset will require no minimum version
134+
:type min_version: str
135+
:param max_version: The maximum version for torch that it must be less than
136+
or equal to, if unset will require no maximum version.
137+
:type max_version: str
138+
"""
139+
140+
def _decorator(func):
141+
@functools.wraps(func)
142+
def _wrapper(*args, **kwargs):
143+
check_torch_install(min_version, max_version)
144+
145+
return func(*args, **kwargs)
146+
147+
return _wrapper
148+
149+
return _decorator
150+
151+
152+
def require_torchvision(
153+
min_version: Optional[str] = None, max_version: Optional[str] = None
154+
):
155+
"""
156+
Decorator function to require use of torchvision.
157+
Will check that torchvision package is installed and within the bounding
158+
ranges of min_version and max_version if they are set before calling
159+
the wrapped function.
160+
See :func:`check_torchvision_install` for more info.
161+
162+
:param min_version: The minimum version for torchvision that it must be greater than
163+
or equal to, if unset will require no minimum version
164+
:type min_version: str
165+
:param max_version: The maximum version for torchvision that it must be less than
166+
or equal to, if unset will require no maximum version.
167+
:type max_version: str
168+
"""
169+
170+
def _decorator(func):
171+
@functools.wraps(func)
172+
def _wrapper(*args, **kwargs):
173+
check_torchvision_install(min_version, max_version)
174+
175+
return func(*args, **kwargs)
176+
177+
return _wrapper
178+
179+
return _decorator

src/sparseml/pytorch/datasets/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,13 @@
1818

1919
# flake8: noqa
2020

21+
from ..base import check_torch_install as _check_torch_install
2122
from .classification import *
2223
from .detection import *
2324
from .generic import *
2425
from .recommendation import *
2526
from .registry import *
2627
from .video import *
28+
29+
30+
_check_torch_install() # TODO: remove once files within package load without installs
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# flake8: noqa
16+
17+
"""
18+
Functionality related to integrating with, detecting, and getting information for
19+
support and sparsification in the PyTorch framework.
20+
"""
21+
22+
# flake8: noqa
23+
24+
from .info import *
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
Functionality related to detecting and getting information for
17+
support and sparsification in the PyTorch framework.
18+
"""
19+
20+
import logging
21+
from typing import Any
22+
23+
from sparseml.base import Framework, get_version
24+
from sparseml.framework import FrameworkInferenceProviderInfo, FrameworkInfo
25+
from sparseml.pytorch.base import check_torch_install, torch
26+
from sparseml.pytorch.sparsification import sparsification_info
27+
from sparseml.sparsification import SparsificationInfo
28+
29+
30+
__all__ = ["is_supported", "detect_framework", "framework_info"]
31+
32+
33+
_LOGGER = logging.getLogger(__name__)
34+
35+
36+
def is_supported(item: Any) -> bool:
37+
"""
38+
:param item: The item to detect the support for
39+
:type item: Any
40+
:return: True if the item is supported by pytorch, False otherwise
41+
:rtype: bool
42+
"""
43+
framework = detect_framework(item)
44+
45+
return framework == Framework.pytorch
46+
47+
48+
def detect_framework(item: Any) -> Framework:
49+
"""
50+
Detect the supported ML framework for a given item specifically for the
51+
pytorch package.
52+
Supported input types are the following:
53+
- A Framework enum
54+
- A string of any case representing the name of the framework
55+
(deepsparse, onnx, keras, pytorch, tensorflow_v1)
56+
- A supported file type within the framework such as model files:
57+
(onnx, pth, h5, pb)
58+
- An object from a supported ML framework such as a model instance
59+
If the framework cannot be determined, will return Framework.unknown
60+
61+
:param item: The item to detect the ML framework for
62+
:type item: Any
63+
:return: The detected framework from the given item
64+
:rtype: Framework
65+
"""
66+
framework = Framework.unknown
67+
68+
if isinstance(item, Framework):
69+
_LOGGER.debug("framework detected from Framework instance")
70+
framework = item
71+
elif isinstance(item, str) and item.lower().strip() in Framework.__members__:
72+
_LOGGER.debug("framework detected from Framework string instance")
73+
framework = Framework[item.lower().strip()]
74+
elif isinstance(item, str) and "torch" in item.lower().strip():
75+
_LOGGER.debug("framework detected from torch text")
76+
# string, check if it's a string saying onnx first
77+
framework = Framework.pytorch
78+
elif isinstance(item, str) and (
79+
".pt" in item.lower().strip() or ".mar" in item.lower().strip()
80+
):
81+
_LOGGER.debug("framework detected from .pt or .mar")
82+
# string, check if it's a file url or path that ends with onnx extension
83+
framework = Framework.pytorch
84+
elif check_torch_install(raise_on_error=False):
85+
from torch.nn import Module
86+
87+
if isinstance(item, Module):
88+
_LOGGER.debug("framework detected from pytorch instance")
89+
# pytorch native support
90+
framework = Framework.pytorch
91+
92+
return framework
93+
94+
95+
def framework_info() -> FrameworkInfo:
96+
"""
97+
Detect the information for the onnx/onnxruntime framework such as package versions,
98+
availability for core actions such as training and inference,
99+
sparsification support, and inference provider support.
100+
101+
:return: The framework info for onnx/onnxruntime
102+
:rtype: FrameworkInfo
103+
"""
104+
cpu_provider = FrameworkInferenceProviderInfo(
105+
name="cpu",
106+
description="Base CPU provider within PyTorch",
107+
device="cpu",
108+
supported_sparsification=SparsificationInfo(), # TODO: fill in when available
109+
available=check_torch_install(raise_on_error=False),
110+
properties={},
111+
warnings=[],
112+
)
113+
gpu_provider = FrameworkInferenceProviderInfo(
114+
name="cuda",
115+
description="Base GPU CUDA provider within PyTorch",
116+
device="gpu",
117+
supported_sparsification=SparsificationInfo(), # TODO: fill in when available
118+
available=(
119+
check_torch_install(raise_on_error=False) and torch.cuda.is_available()
120+
),
121+
properties={},
122+
warnings=[],
123+
)
124+
125+
return FrameworkInfo(
126+
framework=Framework.pytorch,
127+
package_versions={
128+
"torch": get_version(package_name="torch", raise_on_error=False),
129+
"torchvision": (
130+
get_version(package_name="torchvision", raise_on_error=False)
131+
),
132+
"onnx": get_version(package_name="onnx", raise_on_error=False),
133+
"sparsezoo": get_version(package_name="sparsezoo", raise_on_error=False),
134+
"sparseml": get_version(package_name="sparseml", raise_on_error=False),
135+
},
136+
sparsification=sparsification_info(),
137+
inference_providers=[cpu_provider, gpu_provider],
138+
properties={},
139+
training_available=True,
140+
sparsification_available=True,
141+
exporting_onnx_available=True,
142+
inference_available=True,
143+
)

src/sparseml/pytorch/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,12 @@
1818

1919
# flake8: noqa
2020

21+
from ..base import check_torch_install as _check_torch_install
2122
from .classification import *
2223
from .detection import *
2324
from .external import *
2425
from .recommendation import *
2526
from .registry import *
27+
28+
29+
_check_torch_install() # TODO: remove once files within package load without installs

0 commit comments

Comments
 (0)