2424from typing import Any , Dict , List , Optional , Union
2525
2626import numpy
27+ from pydantic import BaseModel , Field , root_validator
2728
2829from sparseml .utils import clean_path , create_parent_dirs
2930
3637]
3738
3839
39- class LayerInfo (object ):
40+ class LayerInfo (BaseModel ):
4041 """
4142 Class for storing properties about a layer in a model
42-
43- :param name: unique name of the layer within its model
44- :param op_type: type of layer, i.e. "conv", "linear"
45- :param params: number of non-bias parameters in the layer. must be
46- included for prunable layers
47- :param bias_params: number of bias parameters in the layer
48- :param prunable: True if the layers non-bias parameters can be pruned.
49- Default is False
50- :param flops: optional number of float operations within the layer
51- :param execution_order: optional execution order of the layer within the
52- model. Default is -1
53- :param attributes: optional dictionary of string attribute names to their
54- values
5543 """
5644
57- def __init__ (
58- self ,
59- name : str ,
60- op_type : str ,
61- params : Optional [int ] = None ,
62- bias_params : Optional [int ] = None ,
63- prunable : bool = False ,
64- flops : Optional [int ] = None ,
65- execution_order : int = - 1 ,
66- attributes : Optional [Dict [str , Any ]] = None ,
67- ):
45+ name : str = Field (
46+ title = "name" ,
47+ description = "unique name of the layer within its model" ,
48+ )
49+ op_type : str = Field (
50+ title = "op_type" ,
51+ description = "type of layer, i.e. 'conv', 'linear'" ,
52+ )
53+ params : Optional [int ] = Field (
54+ title = "params" ,
55+ default = None ,
56+ description = (
57+ "number of non-bias parameters in the layer. must be included "
58+ "for prunable layers"
59+ ),
60+ )
61+ bias_params : Optional [int ] = Field (
62+ title = "bias_params" ,
63+ default = None ,
64+ description = "number of bias parameters in the layer" ,
65+ )
66+ prunable : bool = Field (
67+ title = "prunable" ,
68+ default = False ,
69+ description = "True if the layers non-bias parameters can be pruned" ,
70+ )
71+ flops : Optional [int ] = Field (
72+ title = "flops" ,
73+ default = None ,
74+ description = "number of float operations within the layer" ,
75+ )
76+ execution_order : int = Field (
77+ title = "execution_order" ,
78+ default = - 1 ,
79+ description = "execution order of the layer within the model" ,
80+ )
81+ attributes : Optional [Dict [str , Any ]] = Field (
82+ title = "attributes" ,
83+ default = None ,
84+ description = "dictionary of string attribute names to their values" ,
85+ )
86+
87+ @root_validator (pre = True )
88+ def check_params_if_prunable (_ , values ):
89+ prunable = values .get ("prunable" )
90+ params = values .get ("params" )
6891 if prunable and not params :
6992 raise ValueError (
7093 f"Prunable layers must have non 0 number of params given { params } "
71- f"for layer { name } with prunable set to { prunable } "
94+ f"for layer { values . get ( ' name' ) } with prunable set to { prunable } "
7295 )
73-
74- self .name = name
75- self .op_type = op_type
76- self .params = params
77- self .bias_params = bias_params
78- self .prunable = prunable
79- self .flops = flops
80- self .execution_order = execution_order
81- self .attributes = attributes or {}
96+ return values
8297
8398 @classmethod
8499 def linear_layer (
@@ -159,112 +174,41 @@ def conv_layer(
159174 ** kwargs , # TODO: add FLOPS calculation
160175 )
161176
162- @classmethod
163- def from_dict (cls , dictionary : Dict [str , Any ]):
164- """
165- :param dictionary: dict serialized by LyaerInfo.from_dict
166- :return: LayerInfo object created from the given dict
167- """
168- dictionary = deepcopy (dictionary )
169- return cls (** dictionary )
170177
171- def to_dict (self ) -> Dict [str , Any ]:
172- """
173- :return: dict representation of this LayerInfo parameters
174- """
175- props = {
176- "name" : self .name ,
177- "op_type" : self .op_type ,
178- "prunable" : self .prunable ,
179- "execution_order" : self .execution_order ,
180- "attributes" : self .attributes ,
181- }
182- if self .params is not None :
183- props ["params" ] = self .params
184- if self .bias_params is not None :
185- props ["bias_params" ] = self .bias_params
186- if self .flops is not None :
187- props ["flops" ] = self .flops
188- return props
189-
190-
191- class Result (object ):
178+ class Result (BaseModel ):
192179 """
193180 Base class for storing the results of an analysis
194-
195- :param value: initial value of the result. Defaults to None
196- :param attributes: dict of attributes of this result. Defaults to empty
197181 """
198182
199- def __init__ (self , value : Any = None , attributes : Optional [Dict [str , Any ]] = None ):
200- self .value = value
201- self .attributes = attributes or {}
202-
203- @classmethod
204- def from_dict (cls , dictionary : Dict [str , Any ]):
205- """
206- :param dictionary: dict serialized by Result.from_dict
207- :return: Result object created from the given dict
208- """
209- dictionary = deepcopy (dictionary )
210- return cls (** dictionary )
211-
212- def to_dict (self ) -> Dict [str , Any ]:
213- """
214- :return: dict representation of this Result
215- """
216- return {"value" : self .value , "attributes" : self .attributes }
183+ value : Any = Field (
184+ title = "value" ,
185+ default = None ,
186+ description = "initial value of the result" ,
187+ )
188+ attributes : Optional [Dict [str , Any ]] = Field (
189+ title = "attributes" ,
190+ default = None ,
191+ description = "dict of attributes of this result" ,
192+ )
217193
218194
219195class ModelResult (Result ):
220196 """
221197 Class for storing the results of an analysis for an entire model
222-
223- :param analysis_type: name of the type of analysis that was performed
224- :param value: initial value of the result. Defaults to None
225- :param layer_results: dict of layer results to initialize for this model.
226- Defaults to empty dict
227- :param attributes: dict of attributes of this result. Defaults to empty
228198 """
229199
230- def __init__ (
231- self ,
232- analysis_type : str ,
233- value : Any = None ,
234- layer_results : Dict [str , Result ] = None ,
235- attributes : Optional [Dict [str , Any ]] = None ,
236- ):
237- super ().__init__ (value = value , attributes = attributes )
238-
239- self .analysis_type = analysis_type
240- self .layer_results = layer_results or {}
241-
242- @classmethod
243- def from_dict (cls , dictionary : Dict [str , Any ]):
244- """
245- :param dictionary: dict serialized by ModelResult.from_dict
246- :return: ModelResult object created from the given dict
247- """
248- dictionary = deepcopy (dictionary )
249- dictionary ["layer_results" ] = dictionary .get ("layer_results" , {})
250- dictionary ["layer_results" ] = {
251- layer_name : Result .from_dict (layer_result )
252- for layer_name , layer_result in dictionary ["layer_results" ].items ()
253- }
254- return cls (** dictionary )
255-
256- def to_dict (self ) -> Dict [str , Any ]:
257- """
258- :return: dict representation of this ModelResult
259- """
260- dictionary = super ().to_dict ()
261- dictionary ["analysis_type" ] = self .analysis_type
262- dictionary ["layer_results" ] = {
263- layer_name : layer_result .to_dict ()
264- for layer_name , layer_result in self .layer_results .items ()
265- }
266-
267- return dictionary
200+ analysis_type : str = Field (
201+ title = "analysis_type" ,
202+ description = "name of the type of analysis that was performed" ,
203+ )
204+ layer_results : Dict [str , Result ] = Field (
205+ title = "layer_results" ,
206+ default_factory = dict ,
207+ description = (
208+ "dict of layer results to initialize for this analysis. should map "
209+ "layer name to Result object"
210+ ),
211+ )
268212
269213
270214class ModelInfo (ABC ):
@@ -289,7 +233,7 @@ def __init__(self, model: Any, metadata: Dict[str, Any]):
289233 @classmethod
290234 def from_dict (cls , dictionary : Dict [str , Any ]):
291235 """
292- :param dictionary: dict serialized by ModelInfo.from_dict
236+ :param dictionary: dict serialized by `dict( ModelInfo(...))`
293237 :return: ModelInfo object created from the given dict
294238 """
295239 dictionary = deepcopy (dictionary )
@@ -298,15 +242,15 @@ def from_dict(cls, dictionary: Dict[str, Any]):
298242 "ModelInfo objects serialized as a dict must include a 'layer_info' key"
299243 )
300244 layer_info = {
301- name : LayerInfo .from_dict (info )
245+ name : LayerInfo .parse_obj (info )
302246 for name , info in dictionary ["layer_info" ].items ()
303247 }
304248
305249 model_info = cls (layer_info , metadata = dictionary .get ("metadata" , {}))
306250
307251 results = dictionary .get ("analysis_results" , [])
308252 for result in results :
309- model_result = ModelResult .from_dict (result )
253+ model_result = ModelResult .parse_obj (result )
310254 model_info .add_analysis_result (model_result )
311255
312256 return model_info
@@ -368,8 +312,8 @@ def to_dict(self) -> Dict[str, Any]:
368312 """
369313 :return: dict representation of this ModelResult
370314 """
371- layer_info = {name : info . to_dict ( ) for name , info in self ._layer_info .items ()}
372- analysis_results = [result . to_dict ( ) for result in self ._analysis_results ]
315+ layer_info = {name : dict ( info ) for name , info in self ._layer_info .items ()}
316+ analysis_results = [dict ( result ) for result in self ._analysis_results ]
373317 return {
374318 "metadata" : self .metadata ,
375319 "layer_info" : layer_info ,
0 commit comments