|
10 | 10 | import sys |
11 | 11 | import traceback |
12 | 12 | from collections import defaultdict |
| 13 | +from copy import deepcopy |
13 | 14 | from functools import partial |
14 | 15 | from typing import Optional, Tuple, Union |
15 | 16 |
|
|
23 | 24 |
|
24 | 25 | class FlopCounterMode(TorchDispatchMode): |
25 | 26 | def __init__(self, module=None, verbose=False, print_per_layer_stat=False, |
26 | | - output_params=None): |
| 27 | + output_params=None, custom_hooks={}, ignored_ops=[]): |
27 | 28 | self.verbose = verbose |
28 | 29 | if output_params is None: |
29 | 30 | output_params = defaultdict(dict) |
30 | 31 | self.output_params = output_params |
31 | 32 | self.print_fn = partial(print, **self.output_params['print_params']) |
| 33 | + self.all_ops = deepcopy(ATEN_OPS_MAPPING) |
| 34 | + self.all_ops.update(custom_hooks) |
| 35 | + self.ignored_ops = ignored_ops |
32 | 36 |
|
33 | 37 | self.print_per_layer_stat = print_per_layer_stat |
34 | 38 | self.flop_counts = defaultdict(lambda: defaultdict(int)) |
@@ -82,8 +86,11 @@ def normalize_tuple(x): |
82 | 86 |
|
83 | 87 | out = func(*args, **kwargs) |
84 | 88 | func_packet = func._overloadpacket |
85 | | - if func_packet in ATEN_OPS_MAPPING: |
86 | | - flop_count = ATEN_OPS_MAPPING[func_packet](args, normalize_tuple(out)) |
| 89 | + |
| 90 | + if func_packet in self.ignored_ops: |
| 91 | + self.print_fn(f'Warning: {func_packet} operation is ignored') |
| 92 | + elif func_packet in self.all_ops: |
| 93 | + flop_count = self.all_ops[func_packet](args, normalize_tuple(out)) |
87 | 94 | for par in self.parents: |
88 | 95 | self.flop_counts[par][func_packet] += flop_count |
89 | 96 | elif self.verbose: |
@@ -119,7 +126,8 @@ def get_flops_aten(model, input_res, |
119 | 126 | batch = torch.ones(()).new_empty((1, *input_res)) |
120 | 127 |
|
121 | 128 | try: |
122 | | - counter = FlopCounterMode(model, verbose, print_per_layer_stat, output_params) |
| 129 | + counter = FlopCounterMode(model, verbose, print_per_layer_stat, output_params, |
| 130 | + custom_modules_hooks, ignore_modules) |
123 | 131 | with counter: |
124 | 132 | if isinstance(batch, dict): |
125 | 133 | _ = model(**batch) |
|
0 commit comments