@@ -347,15 +347,17 @@ def _interpolate_functional_flops_hook(*args, **kwargs):
347347 if input is None and len (args ) > 0 :
348348 input = args [0 ]
349349
350+ assert input .dim () - 2 > 0 , "Input of interpolate should have NC... layout"
351+
350352 size = kwargs .get ('size' , None )
351353 if size is None and len (args ) > 1 :
352354 size = args [1 ]
353355
354356 if size is not None :
355357 if isinstance (size , tuple ) or isinstance (size , list ):
356- return int (np .prod (size , dtype = np .int64 ))
358+ return int (np .prod (size , dtype = np .int64 ))* np . prod ( input . shape [: 2 ], dtype = np . int64 )
357359 else :
358- return int (size )
360+ return int (size ) ** ( input . dim () - 2 ) * np . prod ( input . shape [: 2 ], dtype = np . int64 )
359361
360362 scale_factor = kwargs .get ('scale_factor' , None )
361363 if scale_factor is None and len (args ) > 2 :
@@ -364,10 +366,10 @@ def _interpolate_functional_flops_hook(*args, **kwargs):
364366 "should be passes to interpolate"
365367
366368 flops = input .numel ()
367- if isinstance (scale_factor , tuple ) and len (scale_factor ) == len (input ) :
369+ if isinstance (scale_factor , tuple ) and len (scale_factor ) == len (input . shape ) - 2 :
368370 flops *= int (np .prod (scale_factor , dtype = np .int64 ))
369371 else :
370- flops *= scale_factor ** len (input )
372+ flops *= scale_factor ** (input . dim () - 2 ) # NC... layout is assumed, see interpolate docs
371373
372374 return flops
373375
0 commit comments