Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 2803d36

Browse files
authored
Merge branch 'main' into interface-framework-onnx
2 parents f08b0de + b7f0544 commit 2803d36

File tree

17 files changed

+697
-19
lines changed

17 files changed

+697
-19
lines changed

src/sparseml/tensorflow_v1/__init__.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,26 +13,11 @@
1313
# limitations under the License.
1414

1515
"""
16-
Code for working with the tensorflow_v1 framework for creating /
17-
editing models for performance in the Neural Magic System
16+
Functionality for working with and sparsifying Models in the TensorFlow 1.x framework
1817
"""
1918

2019
# flake8: noqa
2120

22-
import os as _os
23-
24-
25-
try:
26-
import tensorflow
27-
28-
if not _os.getenv("SPARSEML_IGNORE_TFV1", False):
29-
# special use case so docs can be generated without having
30-
# conflicting TF versions for V1 and Keras
31-
version = [int(v) for v in tensorflow.__version__.split(".")]
32-
if version[0] != 1 or version[1] < 8:
33-
raise Exception
34-
except:
35-
raise RuntimeError(
36-
"Unable to import tensorflow. tensorflow>=1.8,<2.0 is required"
37-
" to use sparseml.tensorflow_v1."
38-
)
21+
from .base import *
22+
from .framework import detect_framework, framework_info, is_supported
23+
from .sparsification import sparsification_info

src/sparseml/tensorflow_v1/base.py

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import functools
17+
import os
18+
from typing import Optional
19+
20+
from sparseml.base import check_version
21+
22+
23+
try:
24+
import tensorflow
25+
26+
tf_compat = (
27+
tensorflow
28+
if not hasattr(tensorflow, "compat")
29+
or not hasattr(getattr(tensorflow, "compat"), "v1")
30+
else tensorflow.compat.v1
31+
)
32+
tensorflow_err = None
33+
except Exception as err:
34+
tensorflow = object() # TODO: populate with fake object for necessary imports
35+
tf_compat = object() # TODO: populate with fake object for necessary imports
36+
tensorflow_err = err
37+
38+
39+
try:
40+
import tf2onnx
41+
42+
tf2onnx_err = None
43+
except Exception as err:
44+
tf2onnx = object() # TODO: populate with fake object for necessary imports
45+
tf2onnx_err = err
46+
47+
48+
__all__ = [
49+
"tensorflow",
50+
"tf_compat",
51+
"tensorflow_err",
52+
"tf2onnx",
53+
"tf2onnx_err",
54+
"check_tensorflow_install",
55+
"check_tf2onnx_install",
56+
"require_tensorflow",
57+
"require_tf2onnx",
58+
]
59+
60+
61+
_TENSORFLOW_MIN_VERSION = "1.8.0"
62+
_TENSORFLOW_MAX_VERSION = "1.16.0"
63+
64+
_TF2ONNX_MIN_VERSION = "1.0.0"
65+
66+
67+
def check_tensorflow_install(
68+
min_version: Optional[str] = _TENSORFLOW_MIN_VERSION,
69+
max_version: Optional[str] = _TENSORFLOW_MAX_VERSION,
70+
raise_on_error: bool = True,
71+
allow_env_ignore_flag: bool = True,
72+
) -> bool:
73+
"""
74+
Check that the tensorflow package is installed.
75+
If raise_on_error, will raise an ImportError if it is not installed or
76+
the required version range, if set, is not installed.
77+
If not raise_on_error, will return True if installed with required version
78+
and False otherwise.
79+
80+
:param min_version: The minimum version for tensorflow that it must be greater than
81+
or equal to, if unset will require no minimum version
82+
:type min_version: str
83+
:param max_version: The maximum version for tensorflow that it must be less than
84+
or equal to, if unset will require no maximum version.
85+
:type max_version: str
86+
:param raise_on_error: True to raise any issues such as not installed,
87+
minimum version, or maximum version as ImportError. False to return the result.
88+
:type raise_on_error: bool
89+
:param allow_env_ignore_flag: True to allow the env variable SPARSEML_IGNORE_TFV1
90+
to ignore the tensorflow install and version checks.
91+
False to ignore the ignore flag.
92+
:type allow_env_ignore_flag: bool
93+
:return: If raise_on_error, will return False if tensorflow is not installed
94+
or the version is outside the accepted bounds and True if everything is correct.
95+
:rtype: bool
96+
"""
97+
if allow_env_ignore_flag and os.getenv("SPARSEML_IGNORE_TFV1", False):
98+
return True
99+
100+
if tensorflow_err is not None:
101+
if raise_on_error:
102+
raise tensorflow_err
103+
return False
104+
105+
return check_version("tensorflow", min_version, max_version, raise_on_error)
106+
107+
108+
def check_tf2onnx_install(
109+
min_version: Optional[str] = _TF2ONNX_MIN_VERSION,
110+
max_version: Optional[str] = None,
111+
raise_on_error: bool = True,
112+
) -> bool:
113+
"""
114+
Check that the tf2onnx package is installed.
115+
If raise_on_error, will raise an ImportError if it is not installed or
116+
the required version range, if set, is not installed.
117+
If not raise_on_error, will return True if installed with required version
118+
and False otherwise.
119+
120+
:param min_version: The minimum version for tf2onnx that it must be greater than
121+
or equal to, if unset will require no minimum version
122+
:type min_version: str
123+
:param max_version: The maximum version for tf2onnx that it must be less than
124+
or equal to, if unset will require no maximum version.
125+
:type max_version: str
126+
:param raise_on_error: True to raise any issues such as not installed,
127+
minimum version, or maximum version as ImportError. False to return the result.
128+
:type raise_on_error: bool
129+
:return: If raise_on_error, will return False if tf2onnx is not installed
130+
or the version is outside the accepted bounds and True if everything is correct.
131+
:rtype: bool
132+
"""
133+
if tf2onnx_err is not None:
134+
if raise_on_error:
135+
raise tf2onnx_err
136+
return False
137+
138+
return check_version("tf2onnx", min_version, max_version, raise_on_error)
139+
140+
141+
def require_tensorflow(
142+
min_version: Optional[str] = _TENSORFLOW_MIN_VERSION,
143+
max_version: Optional[str] = _TENSORFLOW_MAX_VERSION,
144+
allow_env_ignore_flag: bool = True,
145+
):
146+
"""
147+
Decorator function to require use of tensorflow.
148+
Will check that tensorflow package is installed and within the bounding
149+
ranges of min_version and max_version if they are set before calling
150+
the wrapped function.
151+
See :func:`check_tensorflow_install` for more info.
152+
153+
:param min_version: The minimum version for tensorflow that it must be greater than
154+
or equal to, if unset will require no minimum version
155+
:type min_version: str
156+
:param max_version: The maximum version for tensorflow that it must be less than
157+
or equal to, if unset will require no maximum version.
158+
:type max_version: str
159+
:param allow_env_ignore_flag: True to allow the env variable SPARSEML_IGNORE_TFV1
160+
to ignore the tensorflow install and version checks.
161+
False to ignore the ignore flag.
162+
:type allow_env_ignore_flag: bool
163+
"""
164+
165+
def _decorator(func):
166+
@functools.wraps(func)
167+
def _wrapper(*args, **kwargs):
168+
check_tensorflow_install(min_version, max_version, allow_env_ignore_flag)
169+
170+
return func(*args, **kwargs)
171+
172+
return _wrapper
173+
174+
return _decorator
175+
176+
177+
def require_tf2onnx(
178+
min_version: Optional[str] = _TF2ONNX_MIN_VERSION,
179+
max_version: Optional[str] = None,
180+
):
181+
"""
182+
Decorator function to require use of tf2onnx.
183+
Will check that tf2onnx package is installed and within the bounding
184+
ranges of min_version and max_version if they are set before calling
185+
the wrapped function.
186+
See :func:`check_tf2onnx_install` for more info.
187+
188+
:param min_version: The minimum version for tf2onnx that it must be greater than
189+
or equal to, if unset will require no minimum version
190+
:type min_version: str
191+
:param max_version: The maximum version for tf2onnx that it must be less than
192+
or equal to, if unset will require no maximum version.
193+
:type max_version: str
194+
"""
195+
196+
def _decorator(func):
197+
@functools.wraps(func)
198+
def _wrapper(*args, **kwargs):
199+
check_tf2onnx_install(min_version, max_version)
200+
201+
return func(*args, **kwargs)
202+
203+
return _wrapper
204+
205+
return _decorator

src/sparseml/tensorflow_v1/datasets/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818

1919
# flake8: noqa
2020

21+
from ..base import check_tensorflow_install as _check_tensorflow_install
2122
from .classification import *
2223
from .dataset import *
2324
from .registry import *
25+
26+
27+
_check_tensorflow_install() # TODO: remove once files load without installs
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
Functionality related to integrating with, detecting, and getting information for
17+
support and sparsification in the TensorFLow 1.x framework.
18+
"""
19+
20+
# flake8: noqa
21+
22+
from .info import *

0 commit comments

Comments
 (0)