Skip to content

Commit f7506d6

Browse files
committed
Add ignore and custom modules for aten backend
1 parent cc05220 commit f7506d6

File tree

2 files changed

+19
-10
lines changed

2 files changed

+19
-10
lines changed

ptflops/aten_engine.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import sys
1111
import traceback
1212
from collections import defaultdict
13+
from copy import deepcopy
1314
from functools import partial
1415
from typing import Optional, Tuple, Union
1516

@@ -23,12 +24,15 @@
2324

2425
class FlopCounterMode(TorchDispatchMode):
2526
def __init__(self, module=None, verbose=False, print_per_layer_stat=False,
26-
output_params=None):
27+
output_params=None, custom_hooks={}, ignored_ops=[]):
2728
self.verbose = verbose
2829
if output_params is None:
2930
output_params = defaultdict(dict)
3031
self.output_params = output_params
3132
self.print_fn = partial(print, **self.output_params['print_params'])
33+
self.all_ops = deepcopy(ATEN_OPS_MAPPING)
34+
self.all_ops.update(custom_hooks)
35+
self.ignored_ops = ignored_ops
3236

3337
self.print_per_layer_stat = print_per_layer_stat
3438
self.flop_counts = defaultdict(lambda: defaultdict(int))
@@ -82,8 +86,11 @@ def normalize_tuple(x):
8286

8387
out = func(*args, **kwargs)
8488
func_packet = func._overloadpacket
85-
if func_packet in ATEN_OPS_MAPPING:
86-
flop_count = ATEN_OPS_MAPPING[func_packet](args, normalize_tuple(out))
89+
90+
if func_packet in self.ignored_ops:
91+
self.print_fn(f'Warning: {func_packet} operation is ignored')
92+
elif func_packet in self.all_ops:
93+
flop_count = self.all_ops[func_packet](args, normalize_tuple(out))
8794
for par in self.parents:
8895
self.flop_counts[par][func_packet] += flop_count
8996
elif self.verbose:
@@ -119,7 +126,8 @@ def get_flops_aten(model, input_res,
119126
batch = torch.ones(()).new_empty((1, *input_res))
120127

121128
try:
122-
counter = FlopCounterMode(model, verbose, print_per_layer_stat, output_params)
129+
counter = FlopCounterMode(model, verbose, print_per_layer_stat, output_params,
130+
custom_modules_hooks, ignore_modules)
123131
with counter:
124132
if isinstance(batch, dict):
125133
_ = model(**batch)

ptflops/flops_counter.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ def get_model_complexity_info(model: nn.Module,
2929
input_constructor: Optional[Callable[[Tuple], Dict]] = None,
3030
ost: TextIO = sys.stdout,
3131
verbose: bool = False,
32-
ignore_modules: List[nn.Module] = [],
33-
custom_modules_hooks: Dict[nn.Module, Any] = {},
32+
ignore_modules: List[Union[nn.Module, Any]] = [],
33+
custom_modules_hooks: Dict[Union[nn.Module, Any], Any] = {},
3434
backend: Union[str, FLOPS_BACKEND] = FLOPS_BACKEND.ATEN,
3535
flops_units: Optional[str] = None,
3636
param_units: Optional[str] = None,
@@ -61,10 +61,11 @@ def get_model_complexity_info(model: nn.Module,
6161
:type ost: TextIO
6262
:param verbose: Parameter to control printing of extra information and warnings.
6363
:type verbose: bool
64-
:param ignore_modules: A list of torch.nn.Module modules to ignore.
65-
:type ignore_modules: nn.Module
66-
:param custom_modules_hooks: A dict that contains custom hooks on torch modules.
67-
:type custom_modules_hooks: Dict[nn.Module, Any]
64+
:param ignore_modules: A list of torch.nn.Module or torch.ops.aten modules to ignore.
65+
:type ignore_modules: List[Union[nn.Module, Any]]
66+
:param custom_modules_hooks: A dict that contains custom hooks for torch.nn.Module or
67+
torch.ops.aten modules.
68+
:type custom_modules_hooks: Dict[Union[nn.Module, Any], Any]
6869
:param backend: Backend that used for evaluating model complexity.
6970
:type backend: FLOPS_BACKEND
7071
:param flops_units: Units for string representation of MACs (GMac, MMac or KMac).

0 commit comments

Comments
 (0)