Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 3f165e6

Browse files
authored
migrate LayerInfo, Result, and ModelResult to pydantic BaseModel (#405) (#411)
* migrate LayerInfo and analysis result classes to pydantic * bump pydantic min version to 1.5.0 (default_factory support) * removing boilerplate JSON serialization fns for results and layer info in favor of pydantic built-ins
1 parent b0b4696 commit 3f165e6

File tree

3 files changed

+102
-150
lines changed

3 files changed

+102
-150
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
"pandas>=0.25.0",
4444
"packaging>=20.0",
4545
"psutil>=5.0.0",
46-
"pydantic>=1.0.0",
46+
"pydantic>=1.5.0",
4747
"requests>=2.0.0",
4848
"scikit-image>=0.15.0",
4949
"scipy>=1.0.0",

src/sparseml/sparsification/model_info.py

Lines changed: 78 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from typing import Any, Dict, List, Optional, Union
2525

2626
import numpy
27+
from pydantic import BaseModel, Field, root_validator
2728

2829
from sparseml.utils import clean_path, create_parent_dirs
2930

@@ -36,49 +37,63 @@
3637
]
3738

3839

39-
class LayerInfo(object):
40+
class LayerInfo(BaseModel):
4041
"""
4142
Class for storing properties about a layer in a model
42-
43-
:param name: unique name of the layer within its model
44-
:param op_type: type of layer, i.e. "conv", "linear"
45-
:param params: number of non-bias parameters in the layer. must be
46-
included for prunable layers
47-
:param bias_params: number of bias parameters in the layer
48-
:param prunable: True if the layers non-bias parameters can be pruned.
49-
Default is False
50-
:param flops: optional number of float operations within the layer
51-
:param execution_order: optional execution order of the layer within the
52-
model. Default is -1
53-
:param attributes: optional dictionary of string attribute names to their
54-
values
5543
"""
5644

57-
def __init__(
58-
self,
59-
name: str,
60-
op_type: str,
61-
params: Optional[int] = None,
62-
bias_params: Optional[int] = None,
63-
prunable: bool = False,
64-
flops: Optional[int] = None,
65-
execution_order: int = -1,
66-
attributes: Optional[Dict[str, Any]] = None,
67-
):
45+
name: str = Field(
46+
title="name",
47+
description="unique name of the layer within its model",
48+
)
49+
op_type: str = Field(
50+
title="op_type",
51+
description="type of layer, i.e. 'conv', 'linear'",
52+
)
53+
params: Optional[int] = Field(
54+
title="params",
55+
default=None,
56+
description=(
57+
"number of non-bias parameters in the layer. must be included "
58+
"for prunable layers"
59+
),
60+
)
61+
bias_params: Optional[int] = Field(
62+
title="bias_params",
63+
default=None,
64+
description="number of bias parameters in the layer",
65+
)
66+
prunable: bool = Field(
67+
title="prunable",
68+
default=False,
69+
description="True if the layers non-bias parameters can be pruned",
70+
)
71+
flops: Optional[int] = Field(
72+
title="flops",
73+
default=None,
74+
description="number of float operations within the layer",
75+
)
76+
execution_order: int = Field(
77+
title="execution_order",
78+
default=-1,
79+
description="execution order of the layer within the model",
80+
)
81+
attributes: Optional[Dict[str, Any]] = Field(
82+
title="attributes",
83+
default=None,
84+
description="dictionary of string attribute names to their values",
85+
)
86+
87+
@root_validator(pre=True)
88+
def check_params_if_prunable(_, values):
89+
prunable = values.get("prunable")
90+
params = values.get("params")
6891
if prunable and not params:
6992
raise ValueError(
7093
f"Prunable layers must have non 0 number of params given {params} "
71-
f"for layer {name} with prunable set to {prunable}"
94+
f"for layer {values.get('name')} with prunable set to {prunable}"
7295
)
73-
74-
self.name = name
75-
self.op_type = op_type
76-
self.params = params
77-
self.bias_params = bias_params
78-
self.prunable = prunable
79-
self.flops = flops
80-
self.execution_order = execution_order
81-
self.attributes = attributes or {}
96+
return values
8297

8398
@classmethod
8499
def linear_layer(
@@ -159,112 +174,41 @@ def conv_layer(
159174
**kwargs, # TODO: add FLOPS calculation
160175
)
161176

162-
@classmethod
163-
def from_dict(cls, dictionary: Dict[str, Any]):
164-
"""
165-
:param dictionary: dict serialized by LyaerInfo.from_dict
166-
:return: LayerInfo object created from the given dict
167-
"""
168-
dictionary = deepcopy(dictionary)
169-
return cls(**dictionary)
170177

171-
def to_dict(self) -> Dict[str, Any]:
172-
"""
173-
:return: dict representation of this LayerInfo parameters
174-
"""
175-
props = {
176-
"name": self.name,
177-
"op_type": self.op_type,
178-
"prunable": self.prunable,
179-
"execution_order": self.execution_order,
180-
"attributes": self.attributes,
181-
}
182-
if self.params is not None:
183-
props["params"] = self.params
184-
if self.bias_params is not None:
185-
props["bias_params"] = self.bias_params
186-
if self.flops is not None:
187-
props["flops"] = self.flops
188-
return props
189-
190-
191-
class Result(object):
178+
class Result(BaseModel):
192179
"""
193180
Base class for storing the results of an analysis
194-
195-
:param value: initial value of the result. Defaults to None
196-
:param attributes: dict of attributes of this result. Defaults to empty
197181
"""
198182

199-
def __init__(self, value: Any = None, attributes: Optional[Dict[str, Any]] = None):
200-
self.value = value
201-
self.attributes = attributes or {}
202-
203-
@classmethod
204-
def from_dict(cls, dictionary: Dict[str, Any]):
205-
"""
206-
:param dictionary: dict serialized by Result.from_dict
207-
:return: Result object created from the given dict
208-
"""
209-
dictionary = deepcopy(dictionary)
210-
return cls(**dictionary)
211-
212-
def to_dict(self) -> Dict[str, Any]:
213-
"""
214-
:return: dict representation of this Result
215-
"""
216-
return {"value": self.value, "attributes": self.attributes}
183+
value: Any = Field(
184+
title="value",
185+
default=None,
186+
description="initial value of the result",
187+
)
188+
attributes: Optional[Dict[str, Any]] = Field(
189+
title="attributes",
190+
default=None,
191+
description="dict of attributes of this result",
192+
)
217193

218194

219195
class ModelResult(Result):
220196
"""
221197
Class for storing the results of an analysis for an entire model
222-
223-
:param analysis_type: name of the type of analysis that was performed
224-
:param value: initial value of the result. Defaults to None
225-
:param layer_results: dict of layer results to initialize for this model.
226-
Defaults to empty dict
227-
:param attributes: dict of attributes of this result. Defaults to empty
228198
"""
229199

230-
def __init__(
231-
self,
232-
analysis_type: str,
233-
value: Any = None,
234-
layer_results: Dict[str, Result] = None,
235-
attributes: Optional[Dict[str, Any]] = None,
236-
):
237-
super().__init__(value=value, attributes=attributes)
238-
239-
self.analysis_type = analysis_type
240-
self.layer_results = layer_results or {}
241-
242-
@classmethod
243-
def from_dict(cls, dictionary: Dict[str, Any]):
244-
"""
245-
:param dictionary: dict serialized by ModelResult.from_dict
246-
:return: ModelResult object created from the given dict
247-
"""
248-
dictionary = deepcopy(dictionary)
249-
dictionary["layer_results"] = dictionary.get("layer_results", {})
250-
dictionary["layer_results"] = {
251-
layer_name: Result.from_dict(layer_result)
252-
for layer_name, layer_result in dictionary["layer_results"].items()
253-
}
254-
return cls(**dictionary)
255-
256-
def to_dict(self) -> Dict[str, Any]:
257-
"""
258-
:return: dict representation of this ModelResult
259-
"""
260-
dictionary = super().to_dict()
261-
dictionary["analysis_type"] = self.analysis_type
262-
dictionary["layer_results"] = {
263-
layer_name: layer_result.to_dict()
264-
for layer_name, layer_result in self.layer_results.items()
265-
}
266-
267-
return dictionary
200+
analysis_type: str = Field(
201+
title="analysis_type",
202+
description="name of the type of analysis that was performed",
203+
)
204+
layer_results: Dict[str, Result] = Field(
205+
title="layer_results",
206+
default_factory=dict,
207+
description=(
208+
"dict of layer results to initialize for this analysis. should map "
209+
"layer name to Result object"
210+
),
211+
)
268212

269213

270214
class ModelInfo(ABC):
@@ -289,7 +233,7 @@ def __init__(self, model: Any, metadata: Dict[str, Any]):
289233
@classmethod
290234
def from_dict(cls, dictionary: Dict[str, Any]):
291235
"""
292-
:param dictionary: dict serialized by ModelInfo.from_dict
236+
:param dictionary: dict serialized by `dict(ModelInfo(...))`
293237
:return: ModelInfo object created from the given dict
294238
"""
295239
dictionary = deepcopy(dictionary)
@@ -298,15 +242,15 @@ def from_dict(cls, dictionary: Dict[str, Any]):
298242
"ModelInfo objects serialized as a dict must include a 'layer_info' key"
299243
)
300244
layer_info = {
301-
name: LayerInfo.from_dict(info)
245+
name: LayerInfo.parse_obj(info)
302246
for name, info in dictionary["layer_info"].items()
303247
}
304248

305249
model_info = cls(layer_info, metadata=dictionary.get("metadata", {}))
306250

307251
results = dictionary.get("analysis_results", [])
308252
for result in results:
309-
model_result = ModelResult.from_dict(result)
253+
model_result = ModelResult.parse_obj(result)
310254
model_info.add_analysis_result(model_result)
311255

312256
return model_info
@@ -368,8 +312,8 @@ def to_dict(self) -> Dict[str, Any]:
368312
"""
369313
:return: dict representation of this ModelResult
370314
"""
371-
layer_info = {name: info.to_dict() for name, info in self._layer_info.items()}
372-
analysis_results = [result.to_dict() for result in self._analysis_results]
315+
layer_info = {name: dict(info) for name, info in self._layer_info.items()}
316+
analysis_results = [dict(result) for result in self._analysis_results]
373317
return {
374318
"metadata": self.metadata,
375319
"layer_info": layer_info,

tests/sparseml/sparsification/test_model_info.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,16 @@ def _test_layer_info_eq(layer_one, layer_two):
3232
"layer_info,expected_dict",
3333
[
3434
(
35-
LayerInfo("layers.1", "TestLayer", attributes={"val": 1}),
35+
LayerInfo(name="layers.1", op_type="TestLayer", attributes={"val": 1}),
3636
{
3737
"name": "layers.1",
3838
"op_type": "TestLayer",
3939
"prunable": False,
4040
"execution_order": -1,
4141
"attributes": {"val": 1},
42+
"flops": None,
43+
"bias_params": None,
44+
"params": None,
4245
},
4346
),
4447
(
@@ -51,6 +54,7 @@ def _test_layer_info_eq(layer_one, layer_two):
5154
"prunable": True,
5255
"execution_order": -1,
5356
"attributes": {"in_channels": 64, "out_channels": 128},
57+
"flops": None,
5458
},
5559
),
5660
(
@@ -69,6 +73,8 @@ def _test_layer_info_eq(layer_one, layer_two):
6973
"stride": 1,
7074
"padding": [0, 0, 0, 0],
7175
},
76+
"bias_params": None,
77+
"flops": None,
7278
},
7379
),
7480
(
@@ -89,14 +95,16 @@ def _test_layer_info_eq(layer_one, layer_two):
8995
"stride": 1,
9096
"padding": [0, 0, 0, 0],
9197
},
98+
"bias_params": None,
99+
"flops": None,
92100
},
93101
),
94102
],
95103
)
96104
def test_layer_info_serialization(layer_info, expected_dict):
97-
layer_info_dict = layer_info.to_dict()
98-
expected_dict_loaded = LayerInfo.from_dict(expected_dict)
99-
layer_info_dict_reloaded = LayerInfo.from_dict(layer_info_dict)
105+
layer_info_dict = dict(layer_info)
106+
expected_dict_loaded = LayerInfo.parse_obj(expected_dict)
107+
layer_info_dict_reloaded = LayerInfo.parse_obj(layer_info_dict)
100108

101109
assert type(expected_dict_loaded) is LayerInfo
102110
assert type(layer_info_dict_reloaded) is LayerInfo
@@ -128,20 +136,20 @@ def _test_model_result_eq(result_one, result_two):
128136
"model_result,expected_dict",
129137
[
130138
(
131-
ModelResult("lr_sensitivity", value={0.1: 100, 0.2: 50}),
139+
ModelResult(analysis_type="lr_sensitivity", value={0.1: 100, 0.2: 50}),
132140
{
133141
"analysis_type": "lr_sensitivity",
134142
"value": {0.1: 100, 0.2: 50},
135143
"layer_results": {},
136-
"attributes": {},
144+
"attributes": None,
137145
},
138146
),
139147
(
140148
ModelResult(
141-
"pruning_sensitivity",
149+
analysis_type="pruning_sensitivity",
142150
layer_results={
143-
"net.1": Result({0.0: 0.25, 0.6: 0.2, 0.8: 0.1}),
144-
"net.2": Result({0.0: 0.2, 0.6: 0.2, 0.8: 0.2}),
151+
"net.1": Result(value={0.0: 0.25, 0.6: 0.2, 0.8: 0.1}),
152+
"net.2": Result(value={0.0: 0.2, 0.6: 0.2, 0.8: 0.2}),
145153
},
146154
),
147155
{
@@ -150,22 +158,22 @@ def _test_model_result_eq(result_one, result_two):
150158
"layer_results": {
151159
"net.1": {
152160
"value": {0.0: 0.25, 0.6: 0.2, 0.8: 0.1},
153-
"attributes": {},
161+
"attributes": None,
154162
},
155163
"net.2": {
156164
"value": {0.0: 0.2, 0.6: 0.2, 0.8: 0.2},
157-
"attributes": {},
165+
"attributes": None,
158166
},
159167
},
160-
"attributes": {},
168+
"attributes": None,
161169
},
162170
),
163171
],
164172
)
165173
def test_model_result_serialization(model_result, expected_dict):
166-
model_result_dict = model_result.to_dict()
167-
expected_dict_loaded = ModelResult.from_dict(expected_dict)
168-
model_result_dict_reloaded = ModelResult.from_dict(model_result_dict)
174+
model_result_dict = dict(model_result)
175+
expected_dict_loaded = ModelResult.parse_obj(expected_dict)
176+
model_result_dict_reloaded = ModelResult.parse_obj(model_result_dict)
169177

170178
assert type(expected_dict_loaded) is ModelResult
171179
assert type(model_result_dict_reloaded) is ModelResult

0 commit comments

Comments
 (0)