Skip to content

Commit 6d7445c

Browse files
committed
Fix interpolate hook
1 parent 307e6c3 commit 6d7445c

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

ptflops/pytorch_ops.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -347,15 +347,17 @@ def _interpolate_functional_flops_hook(*args, **kwargs):
347347
if input is None and len(args) > 0:
348348
input = args[0]
349349

350+
assert input.dim() - 2 > 0, "Input of interpolate should have NC... layout"
351+
350352
size = kwargs.get('size', None)
351353
if size is None and len(args) > 1:
352354
size = args[1]
353355

354356
if size is not None:
355357
if isinstance(size, tuple) or isinstance(size, list):
356-
return int(np.prod(size, dtype=np.int64))
358+
return int(np.prod(size, dtype=np.int64))*np.prod(input.shape[:2], dtype=np.int64)
357359
else:
358-
return int(size)
360+
return int(size) ** (input.dim() - 2) * np.prod(input.shape[:2], dtype=np.int64)
359361

360362
scale_factor = kwargs.get('scale_factor', None)
361363
if scale_factor is None and len(args) > 2:
@@ -364,10 +366,10 @@ def _interpolate_functional_flops_hook(*args, **kwargs):
364366
"should be passes to interpolate"
365367

366368
flops = input.numel()
367-
if isinstance(scale_factor, tuple) and len(scale_factor) == len(input):
369+
if isinstance(scale_factor, tuple) and len(scale_factor) == len(input.shape) - 2:
368370
flops *= int(np.prod(scale_factor, dtype=np.int64))
369371
else:
370-
flops *= scale_factor**len(input)
372+
flops *= scale_factor ** (input.dim() - 2) # NC... layout is assumed, see interpolate docs
371373

372374
return flops
373375

0 commit comments

Comments
 (0)