diff --git a/ptflops/pytorch_ops.py b/ptflops/pytorch_ops.py index d6e535a..7fe061a 100644 --- a/ptflops/pytorch_ops.py +++ b/ptflops/pytorch_ops.py @@ -6,6 +6,8 @@ * this file. If not visit https://opensource.org/licenses/MIT ''' +from functools import partial + import numpy as np import torch import torch.nn as nn @@ -55,12 +57,11 @@ def bn_flops_counter_hook(module, input, output): module.__flops__ += int(batch_flops) -def conv_flops_counter_hook(conv_module, input, output, extra_per_position_flops=0): +def conv_flops_counter_hook(conv_module, input, output, extra_per_position_flops=0, transpose=False): # Can have multiple inputs, getting the first one input = input[0] batch_size = input.shape[0] - output_dims = list(output.shape[2:]) kernel_dims = list(conv_module.kernel_size) in_channels = conv_module.in_channels @@ -71,7 +72,12 @@ def conv_flops_counter_hook(conv_module, input, output, extra_per_position_flops conv_per_position_flops = int(np.prod(kernel_dims, dtype=np.int64)) * \ (in_channels * filters_per_channel + extra_per_position_flops) - active_elements_count = batch_size * int(np.prod(output_dims, dtype=np.int64)) + if transpose: + input_dims = list(input.shape[2:]) + active_elements_count = batch_size * int(np.prod(input_dims, dtype=np.int64)) + else: + output_dims = list(output.shape[2:]) + active_elements_count = batch_size * int(np.prod(output_dims, dtype=np.int64)) overall_conv_flops = conv_per_position_flops * active_elements_count @@ -301,9 +307,9 @@ def timm_attention_counter_hook(attention_module, input, output): # Upscale nn.Upsample: upsample_flops_counter_hook, # Deconvolution - nn.ConvTranspose1d: conv_flops_counter_hook, - nn.ConvTranspose2d: conv_flops_counter_hook, - nn.ConvTranspose3d: conv_flops_counter_hook, + nn.ConvTranspose1d: partial(conv_flops_counter_hook, transpose=True), + nn.ConvTranspose2d: partial(conv_flops_counter_hook, transpose=True), + nn.ConvTranspose3d: partial(conv_flops_counter_hook, transpose=True), # RNN nn.RNN: rnn_flops_counter_hook, nn.GRU: rnn_flops_counter_hook,