Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit a88e3fc

Browse files
committed
add unit tests and minor fixes/functionality
1 parent 08c101a commit a88e3fc

File tree

11 files changed

+692
-17
lines changed

11 files changed

+692
-17
lines changed

src/sparseml/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,4 @@
3131
execute_in_sparseml_framework,
3232
)
3333
from .framework import FrameworkInferenceProviderInfo, FrameworkInfo, framework_info
34-
from .sparsification import sparsification_info
34+
from .sparsification import SparsificationInfo, sparsification_info

src/sparseml/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import logging
1818
from enum import Enum
1919
from typing import Any, Optional
20+
2021
from packaging import version
2122

2223
import pkg_resources

src/sparseml/framework/info.py

Lines changed: 134 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,53 @@
1515
"""
1616
Functionality related to integrating with, detecting, and getting information for
1717
support and sparsification in ML frameworks.
18+
19+
The file is executable and will get the framework info for a given framework:
20+
21+
##########
22+
Command help:
23+
usage: info.py [-h] [--path PATH] framework
24+
25+
Compile the available setup and information for a given framework.
26+
27+
positional arguments:
28+
framework the ML framework or path to a framework file to load the
29+
framework info for
30+
31+
optional arguments:
32+
-h, --help show this help message and exit
33+
--path PATH A full file path to save the framework info to. If not
34+
supplied, will print out the framework info to the
35+
console.
36+
37+
#########
38+
EXAMPLES
39+
#########
40+
41+
##########
42+
Example command for getting the framework info for pytorch.
43+
python src/sparseml/framework/info.py pytorch
1844
"""
1945

46+
import argparse
2047
import logging
48+
import os
2149
from collections import OrderedDict
2250
from typing import Any, Dict, List, Optional
2351

2452
from pydantic import BaseModel, Field
2553

2654
from sparseml.base import Framework, execute_in_sparseml_framework
2755
from sparseml.sparsification.info import SparsificationInfo
56+
from sparseml.utils import clean_path, create_parent_dirs
2857

2958

3059
__all__ = [
3160
"FrameworkInferenceProviderInfo",
3261
"FrameworkInfo",
3362
"framework_info",
63+
"save_framework_info",
64+
"load_framework_info",
3465
]
3566

3667

@@ -53,14 +84,14 @@ class FrameworkInferenceProviderInfo(BaseModel):
5384
device: str = Field(
5485
title="device", description="The device the provider is for such as cpu or gpu."
5586
)
56-
supported_sparsification: SparsificationInfo = Field(
87+
supported_sparsification: Optional[SparsificationInfo] = Field(
88+
default=None,
5789
title="supported_sparsification",
5890
description=(
5991
"The supported sparsification information for support "
6092
"for inference speedup in the provider."
6193
),
6294
)
63-
6495
available: bool = Field(
6596
default=False,
6697
title="available",
@@ -97,21 +128,22 @@ class FrameworkInfo(BaseModel):
97128
"If the package is not detected, will be set to None."
98129
),
99130
)
100-
sparsification: SparsificationInfo = Field(
131+
sparsification: Optional[SparsificationInfo] = Field(
132+
default=None,
101133
title="sparsification",
102134
description=(
103135
"True if inference for a model is available on the system "
104136
"for the given framework, False otherwise."
105137
),
106138
)
107139
inference_providers: List[FrameworkInferenceProviderInfo] = Field(
140+
default=[],
108141
title="inference_providers",
109142
description=(
110143
"True if inference for a model is available on the system "
111144
"for the given framework, False otherwise."
112145
),
113146
)
114-
115147
properties: Dict[str, Any] = Field(
116148
default={},
117149
title="properties",
@@ -156,6 +188,7 @@ def framework_info(framework: Any) -> FrameworkInfo:
156188
Detect the information for the given ML framework such as package versions,
157189
availability for core actions such as training and inference,
158190
sparsification support, and inference provider support.
191+
159192
:param framework: The item to detect the ML framework for.
160193
See :func:`detect_framework` for more information.
161194
:type framework: Any
@@ -167,3 +200,100 @@ def framework_info(framework: Any) -> FrameworkInfo:
167200
_LOGGER.info("retrieved system info for framework %s: %s", framework, info)
168201

169202
return info
203+
204+
205+
def save_framework_info(framework: Any, path: Optional[str] = None):
206+
"""
207+
Save the framework info for a given framework.
208+
If path is provided, will save to a json file at that path.
209+
If path is not provided, will print out the info.
210+
211+
:param framework: The item to detect the ML framework for.
212+
See :func:`detect_framework` for more information.
213+
:type framework: Any
214+
:param path: The path, if any, to save the info to in json format.
215+
If not provided will print out the info.
216+
:type path: Optional[str]
217+
"""
218+
_LOGGER.debug(
219+
"saving framework info for framework %s to %s",
220+
framework,
221+
path if path else "sys.out",
222+
)
223+
info = (
224+
framework_info(framework)
225+
if not isinstance(framework, FrameworkInfo)
226+
else framework
227+
)
228+
229+
if path:
230+
path = clean_path(path)
231+
create_parent_dirs(path)
232+
233+
with open(path, "w") as file:
234+
file.write(info.json())
235+
236+
_LOGGER.info(
237+
"saved framework info for framework %s in file at %s", framework, path
238+
),
239+
else:
240+
print(info.json())
241+
_LOGGER.info("printed out framework info for framework %s", framework)
242+
243+
244+
def load_framework_info(load: str) -> FrameworkInfo:
245+
"""
246+
Load the framework info from a file or raw json.
247+
If load exists as a path, will read from the file and use that.
248+
Otherwise will try to parse the input as a raw json str.
249+
250+
:param load: Either a file path to a json file or a raw json string.
251+
:type load: str
252+
:return: The loaded framework info.
253+
:rtype: FrameworkInfo
254+
"""
255+
loaded_path = clean_path(load)
256+
257+
if os.path.exists(loaded_path):
258+
with open(loaded_path, "r") as file:
259+
load = file.read()
260+
261+
info = FrameworkInfo.parse_raw(load)
262+
263+
return info
264+
265+
266+
def _parse_args():
267+
parser = argparse.ArgumentParser(
268+
description=(
269+
"Compile the available setup and information for a given framework."
270+
)
271+
)
272+
parser.add_argument(
273+
"framework",
274+
type=str,
275+
help=(
276+
"the ML framework or path to a framework file to load the "
277+
"framework info for"
278+
),
279+
)
280+
parser.add_argument(
281+
"--path",
282+
type=str,
283+
default=None,
284+
help=(
285+
"A full file path to save the framework info to. "
286+
"If not supplied, will print out the framework info to the console."
287+
),
288+
)
289+
290+
return parser.parse_args()
291+
292+
293+
def _main():
294+
args = _parse_args()
295+
save_framework_info(args.framework, args.path)
296+
297+
298+
if __name__ == "__main__":
299+
_main()

src/sparseml/sparsification/info.py

Lines changed: 82 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,40 @@
1515
"""
1616
Functionality related to describing availability and information of sparsification
1717
algorithms to models within in the ML frameworks.
18+
19+
The file is executable and will get the sparsification info for a given framework:
20+
21+
##########
22+
Command help:
23+
usage: info.py [-h] [--path PATH] framework
24+
25+
Compile the available setup and information for the sparsification of a model
26+
in a given framework.
27+
28+
positional arguments:
29+
framework the ML framework or path to a framework file to load the
30+
sparsification info for
31+
32+
optional arguments:
33+
-h, --help show this help message and exit
34+
--path PATH A full file path to save the sparsification info to. If not
35+
supplied, will print out the sparsification info to the
36+
console.
37+
38+
#########
39+
EXAMPLES
40+
#########
41+
42+
##########
43+
Example command for getting the sparsification info for pytorch.
44+
python src/sparseml/sparsification/info.py pytorch
1845
"""
1946

20-
import os
47+
import argparse
2148
import logging
49+
import os
2250
from enum import Enum
2351
from typing import Any, List, Optional
24-
import argparse
2552

2653
from pydantic import BaseModel, Field
2754

@@ -35,6 +62,8 @@
3562
"ModifierInfo",
3663
"SparsificationInfo",
3764
"sparsification_info",
65+
"save_sparsification_info",
66+
"load_sparsification_info",
3867
]
3968

4069

@@ -160,7 +189,7 @@ def type_modifiers(self, type_: ModifierType) -> List[ModifierInfo]:
160189

161190
def sparsification_info(framework: Any) -> SparsificationInfo:
162191
"""
163-
Load the available setup for sparsifying model in the given framework.
192+
Get the available setup for sparsifying model in the given framework.
164193
165194
:param framework: The item to detect the ML framework for.
166195
See :func:`detect_framework` for more information.
@@ -178,24 +207,59 @@ def sparsification_info(framework: Any) -> SparsificationInfo:
178207

179208

180209
def save_sparsification_info(framework: Any, path: Optional[str] = None):
181-
_LOGGER.debug("saving sparsification info for framework %s to %s", framework, path if path else "sys.out")
182-
info = sparsification_info(framework)
210+
"""
211+
Save the sparsification info for a given framework.
212+
If path is provided, will save to a json file at that path.
213+
If path is not provided, will print out the info.
214+
215+
:param framework: The item to detect the ML framework for.
216+
See :func:`detect_framework` for more information.
217+
:type framework: Any
218+
:param path: The path, if any, to save the info to in json format.
219+
If not provided will print out the info.
220+
:type path: Optional[str]
221+
"""
222+
_LOGGER.debug(
223+
"saving sparsification info for framework %s to %s",
224+
framework,
225+
path if path else "sys.out",
226+
)
227+
info = (
228+
sparsification_info(framework)
229+
if not isinstance(framework, SparsificationInfo)
230+
else framework
231+
)
183232

184233
if path:
185234
path = clean_path(path)
186235
create_parent_dirs(path)
187236

188237
with open(path, "w") as file:
189238
file.write(info.json())
239+
240+
_LOGGER.info(
241+
"saved sparsification info for framework %s in file at %s", framework, path
242+
),
190243
else:
191244
print(info.json())
245+
_LOGGER.info("printed out sparsification info for framework %s", framework)
192246

193247

194248
def load_sparsification_info(load: str) -> SparsificationInfo:
195-
load = clean_path(load)
249+
"""
250+
Load the sparsification info from a file or raw json.
251+
If load exists as a path, will read from the file and use that.
252+
Otherwise will try to parse the input as a raw json str.
196253
197-
if os.path.exists(load):
198-
with open(load, "r") as file:
254+
:param load: Either a file path to a json file or a raw json string.
255+
:type load: str
256+
:return: The loaded sparsification info.
257+
:rtype: SparsificationInfo
258+
"""
259+
load_path = clean_path(load)
260+
261+
if os.path.exists(load_path):
262+
with open(load_path, "r") as file:
199263
load = file.read()
200264

201265
info = SparsificationInfo.parse_raw(load)
@@ -213,13 +277,20 @@ def _parse_args():
213277
parser.add_argument(
214278
"framework",
215279
type=str,
216-
required=True,
217280
help=(
218281
"the ML framework or path to a framework file to load the "
219282
"sparsification info for"
220-
)
283+
),
284+
)
285+
parser.add_argument(
286+
"--path",
287+
type=str,
288+
default=None,
289+
help=(
290+
"A full file path to save the sparsification info to. "
291+
"If not supplied, will print out the sparsification info to the console."
292+
),
221293
)
222-
parser.add_argument()
223294

224295
return parser.parse_args()
225296

@@ -231,4 +302,3 @@ def _main():
231302

232303
if __name__ == "__main__":
233304
_main()
234-
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.

0 commit comments

Comments
 (0)