diff --git a/calflops/flops_counter.py b/calflops/flops_counter.py index ba216f3..1683797 100644 --- a/calflops/flops_counter.py +++ b/calflops/flops_counter.py @@ -36,7 +36,10 @@ def calculate_flops(model, output_precision=2, output_unit=None, ignore_modules=None, - is_sparse=False): + is_sparse=False, + return_output=False, + assume_model_on_device=False, + ): """Returns the total floating-point operations, MACs, and parameters of a model. Args: @@ -55,6 +58,8 @@ def calculate_flops(model, output_unit (str, optional): The unit used to output the result value, such as T, G, M, and K. Default is None, that is the unit of the output decide on value. ignore_modules ([type], optional): the list of modules to ignore during profiling. Defaults to None. is_sparse (bool, optional): Whether to exclude sparse matrix flops. Defaults to False. + return_output (bool, optional): Whether to return the output of the model, mutually exclusive with output_as_string. Defaults to False. + assume_model_on_device (bool, optional): Whether to assume the model is on the device; if False, the model will be moved to the device. Defaults to False. Example: .. code-block:: python @@ -108,6 +113,7 @@ def calculate_flops(model, assert isinstance(model, nn.Module), "model must be a PyTorch module" # assert transformers_tokenizer and auto_generate_transformers_input and "transformers" in str(type(model)), "The model must be a transformers model if args of auto_generate_transformers_input is True and transformers_tokenizer is not None" + assert not (output_as_string and return_output), "output_as_string and return_output are mutually exclusive" model.eval() is_transformer = True if "transformers" in str(type(model)) else False @@ -119,7 +125,8 @@ def calculate_flops(model, calculate_flops_pipline.start_flops_calculate(ignore_list=ignore_modules) device = next(model.parameters()).device - model = model.to(device) + if not assume_model_on_device: + model = model.to(device) if input_shape is not None: assert len(args) == 0 and len( @@ -162,9 +169,9 @@ def calculate_flops(model, args[index] = args[index].to(device) if forward_mode == 'forward': - _ = model(*args, **kwargs) + model_output = model(*args, **kwargs) elif forward_mode == 'generate': - _ = model.generate(*args, **kwargs) + model_output = model.generate(*args, **kwargs) else: raise NotImplementedError("forward_mode should be either forward or generate") @@ -187,5 +194,8 @@ def calculate_flops(model, return flops_to_string(flops, units=output_unit, precision=output_precision), \ macs_to_string(macs, units=output_unit, precision=output_precision), \ params_to_string(params, units=output_unit, precision=output_precision) - - return flops, macs, params + + if return_output: + return flops, macs, params, model_output + else: + return flops, macs, params diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..bfe3da1 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,20 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "calflops" +version = "0.3.3" +description = "A tool to compute FLOPs, MACs, and parameters in various neural networks." +readme = "README.md" +requires-python = ">=3.6" +license = {text = "MIT"} +authors = [ + {name = "MrYxJ"} +] +dependencies = [ + "torch>=1.0.0", +] + +[tool.setuptools] +packages = ["calflops"]