-
Notifications
You must be signed in to change notification settings - Fork 38
Open
Description
my script
import torch
import logging
from thop import profile, clever_format
def torch_model_profile_via_thop(model, input_):
with torch.no_grad():
flops, params = profile(model, (input_,))
s_flops, s_params = clever_format([flops, params], "%.1f")
logging.info(f'Params: {s_params}, FLOPs: {s_flops}.')
return flops, params
def torch_model_profile_via_calflops(model, input_):
from calflops import calculate_flops
flops, macs, params = calculate_flops(
model=model,
input_shape=tuple(input_.shape),
output_as_string=False,
print_detailed=False,
# output_as_string=True,
# output_precision=4
)
s_flops, s_params, s_macs = clever_format([flops, params, macs], "%.1f")
logging.info(f'Params: {s_params}, FLOPs: {s_flops}, MACs: {s_macs}.')
return s_flops, s_params, s_macs
if __name__ == '__main__':
from timm import create_model
# model = create_model('hf-hub:animetimm/swinv2_base_window8_256.dbv4-full', pretrained=False)
model = create_model('caformer_b36.sail_in22k_ft_in1k_384', pretrained=False)
dummy_input = torch.randn(1, 3, 448, 448)
print(torch_model_profile_via_thop(model, dummy_input))
print(torch_model_profile_via_calflops(model, dummy_input))get error
Traceback (most recent call last):
File "/home/narugo1992/.pyenv/versions/3.10.1/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/narugo1992/.pyenv/versions/3.10.1/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/home/narugo1992/wtf-projects/animetimm/animetimm/utils/profile.py", line 39, in <module>
print(torch_model_profile_via_calflops(model, dummy_input))
File "/home/narugo1992/wtf-projects/animetimm/animetimm/utils/profile.py", line 17, in torch_model_profile_via_calflops
from calflops import calculate_flops
File "/home/narugo1992/wtf-projects/animetimm/venv/lib/python3.10/site-packages/calflops/__init__.py", line 16, in <module>
from .flops_counter_hf import calculate_flops_hf
File "/home/narugo1992/wtf-projects/animetimm/venv/lib/python3.10/site-packages/calflops/flops_counter_hf.py", line 18, in <module>
from transformers import AutoTokenizer
ModuleNotFoundError: No module named 'transformers'
i suggest maybe these kind of requirements can be moved into some local scopes, e.g. inside of the functions.
janblumenkamp, p16i, FrzMtrsprt and julian-carpenter
Metadata
Metadata
Assignees
Labels
No labels