1515"""
1616Functionality related to integrating with, detecting, and getting information for
1717support and sparsification in ML frameworks.
18+
19+ The file is executable and will get the framework info for a given framework:
20+
21+ ##########
22+ Command help:
23+ usage: info.py [-h] [--path PATH] framework
24+
25+ Compile the available setup and information for a given framework.
26+
27+ positional arguments:
28+ framework the ML framework or path to a framework file to load the
29+ framework info for
30+
31+ optional arguments:
32+ -h, --help show this help message and exit
33+ --path PATH A full file path to save the framework info to. If not
34+ supplied, will print out the framework info to the
35+ console.
36+
37+ #########
38+ EXAMPLES
39+ #########
40+
41+ ##########
42+ Example command for getting the framework info for pytorch.
43+ python src/sparseml/framework/info.py pytorch
1844"""
1945
46+ import argparse
2047import logging
48+ import os
2149from collections import OrderedDict
2250from typing import Any , Dict , List , Optional
2351
2452from pydantic import BaseModel , Field
2553
2654from sparseml .base import Framework , execute_in_sparseml_framework
2755from sparseml .sparsification .info import SparsificationInfo
56+ from sparseml .utils import clean_path , create_parent_dirs
2857
2958
3059__all__ = [
3160 "FrameworkInferenceProviderInfo" ,
3261 "FrameworkInfo" ,
3362 "framework_info" ,
63+ "save_framework_info" ,
64+ "load_framework_info" ,
3465]
3566
3667
@@ -53,14 +84,14 @@ class FrameworkInferenceProviderInfo(BaseModel):
5384 device : str = Field (
5485 title = "device" , description = "The device the provider is for such as cpu or gpu."
5586 )
56- supported_sparsification : SparsificationInfo = Field (
87+ supported_sparsification : Optional [SparsificationInfo ] = Field (
88+ default = None ,
5789 title = "supported_sparsification" ,
5890 description = (
5991 "The supported sparsification information for support "
6092 "for inference speedup in the provider."
6193 ),
6294 )
63-
6495 available : bool = Field (
6596 default = False ,
6697 title = "available" ,
@@ -97,21 +128,22 @@ class FrameworkInfo(BaseModel):
97128 "If the package is not detected, will be set to None."
98129 ),
99130 )
100- sparsification : SparsificationInfo = Field (
131+ sparsification : Optional [SparsificationInfo ] = Field (
132+ default = None ,
101133 title = "sparsification" ,
102134 description = (
103135 "True if inference for a model is available on the system "
104136 "for the given framework, False otherwise."
105137 ),
106138 )
107139 inference_providers : List [FrameworkInferenceProviderInfo ] = Field (
140+ default = [],
108141 title = "inference_providers" ,
109142 description = (
110143 "True if inference for a model is available on the system "
111144 "for the given framework, False otherwise."
112145 ),
113146 )
114-
115147 properties : Dict [str , Any ] = Field (
116148 default = {},
117149 title = "properties" ,
@@ -156,6 +188,7 @@ def framework_info(framework: Any) -> FrameworkInfo:
156188 Detect the information for the given ML framework such as package versions,
157189 availability for core actions such as training and inference,
158190 sparsification support, and inference provider support.
191+
159192 :param framework: The item to detect the ML framework for.
160193 See :func:`detect_framework` for more information.
161194 :type framework: Any
@@ -167,3 +200,100 @@ def framework_info(framework: Any) -> FrameworkInfo:
167200 _LOGGER .info ("retrieved system info for framework %s: %s" , framework , info )
168201
169202 return info
203+
204+
205+ def save_framework_info (framework : Any , path : Optional [str ] = None ):
206+ """
207+ Save the framework info for a given framework.
208+ If path is provided, will save to a json file at that path.
209+ If path is not provided, will print out the info.
210+
211+ :param framework: The item to detect the ML framework for.
212+ See :func:`detect_framework` for more information.
213+ :type framework: Any
214+ :param path: The path, if any, to save the info to in json format.
215+ If not provided will print out the info.
216+ :type path: Optional[str]
217+ """
218+ _LOGGER .debug (
219+ "saving framework info for framework %s to %s" ,
220+ framework ,
221+ path if path else "sys.out" ,
222+ )
223+ info = (
224+ framework_info (framework )
225+ if not isinstance (framework , FrameworkInfo )
226+ else framework
227+ )
228+
229+ if path :
230+ path = clean_path (path )
231+ create_parent_dirs (path )
232+
233+ with open (path , "w" ) as file :
234+ file .write (info .json ())
235+
236+ _LOGGER .info (
237+ "saved framework info for framework %s in file at %s" , framework , path
238+ ),
239+ else :
240+ print (info .json ())
241+ _LOGGER .info ("printed out framework info for framework %s" , framework )
242+
243+
244+ def load_framework_info (load : str ) -> FrameworkInfo :
245+ """
246+ Load the framework info from a file or raw json.
247+ If load exists as a path, will read from the file and use that.
248+ Otherwise will try to parse the input as a raw json str.
249+
250+ :param load: Either a file path to a json file or a raw json string.
251+ :type load: str
252+ :return: The loaded framework info.
253+ :rtype: FrameworkInfo
254+ """
255+ loaded_path = clean_path (load )
256+
257+ if os .path .exists (loaded_path ):
258+ with open (loaded_path , "r" ) as file :
259+ load = file .read ()
260+
261+ info = FrameworkInfo .parse_raw (load )
262+
263+ return info
264+
265+
266+ def _parse_args ():
267+ parser = argparse .ArgumentParser (
268+ description = (
269+ "Compile the available setup and information for a given framework."
270+ )
271+ )
272+ parser .add_argument (
273+ "framework" ,
274+ type = str ,
275+ help = (
276+ "the ML framework or path to a framework file to load the "
277+ "framework info for"
278+ ),
279+ )
280+ parser .add_argument (
281+ "--path" ,
282+ type = str ,
283+ default = None ,
284+ help = (
285+ "A full file path to save the framework info to. "
286+ "If not supplied, will print out the framework info to the console."
287+ ),
288+ )
289+
290+ return parser .parse_args ()
291+
292+
293+ def _main ():
294+ args = _parse_args ()
295+ save_framework_info (args .framework , args .path )
296+
297+
298+ if __name__ == "__main__" :
299+ _main ()
0 commit comments