Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 0860c20

Browse files
authored
sparsification oracle base classes (#400)
* Base Modifier classes and refactor for framework impls * moving base modifiers to sparsification package * sparsification oracle base classes * moved new classes to sparsification package, rebased to latest base-modifier * add in typing for Type[BaseModifier] + quality * make builder class setters chainable * update recipe builder set variable test * update modifier builder getter return type
1 parent 24db8ee commit 0860c20

File tree

7 files changed

+1233
-0
lines changed

7 files changed

+1233
-0
lines changed

src/sparseml/sparsification/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,12 @@
1919

2020
# flake8: noqa
2121

22+
from .analyzer import *
2223
from .info import *
24+
from .model_info import *
2325
from .modifier_epoch import *
2426
from .modifier_lr import *
2527
from .modifier_params import *
2628
from .modifier_pruning import *
29+
from .recipe_builder import *
30+
from .recipe_editor import *
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
Code for running analysis on neural networks
17+
"""
18+
19+
from abc import ABC, abstractmethod
20+
21+
from sparseml.sparsification.model_info import ModelInfo, ModelResult
22+
23+
24+
__all__ = [
25+
"Analyzer",
26+
]
27+
28+
29+
class Analyzer(ABC):
30+
"""
31+
Base abstract class for model analyzers. Analyzers should be able to detect
32+
if given a ModelInfo object and other keyword inputs if they should run their
33+
analysis.
34+
35+
:param model_info: ModelInfo object of the model to be analyzed. after
36+
running this analysis, the analysis_results of this ModelInfo object
37+
will be updated
38+
"""
39+
40+
def __init__(self, model_info: ModelInfo):
41+
self._model_info = model_info
42+
self._result = self._initialize_result() # type: ModelResult
43+
44+
@staticmethod
45+
@abstractmethod
46+
def available(model_info: ModelInfo, **kwargs) -> bool:
47+
"""
48+
Abstract method that subclasses must implement to determine if
49+
given the model info and keyword arguments that the Analyzer can
50+
run its analysis
51+
52+
:param model_info: ModelInfo object of the model to be analyzed
53+
:param kwargs: additional keyword arguments that will be passed to the run
54+
function
55+
:return: True if given the inputs, this analyzer can run its analysis. False
56+
otherwise
57+
"""
58+
raise NotImplementedError()
59+
60+
def run(self, **kwargs):
61+
self._run(**kwargs)
62+
self._model_info.add_analysis_result(self._result)
63+
64+
@abstractmethod
65+
def _initialize_result(self) -> ModelResult:
66+
# sets the initial ModelResult object for this analysis
67+
# such as analysis_type, layer selection, and result value initialization
68+
raise NotImplementedError()
69+
70+
@abstractmethod
71+
def _run(self, **kwargs):
72+
# runs the analysis and updates self._result
73+
raise NotImplementedError()

0 commit comments

Comments
 (0)