-
Notifications
You must be signed in to change notification settings - Fork 38
Open
Description
First mistake
When we want to use kwargs like this, this code doesn't take into account kwargs and just uses some default kwargs.
def calculate_flops(model,
input_shape=None,
transformer_tokenizer=None,
args=[],
kwargs={
'input_ids': input_ids,
},
forward_mode="forward",
include_backPropagation=False,
compute_bp_factor=2.0,
print_results=True,
print_detailed=True,
output_as_string=True,
output_precision=2,
output_unit=None,
ignore_modules=None):To fix this, I needed to change the code in this file calflops\flops_counter.py: From line 147:
if transformer_tokenizer:
kwargs = generate_transformer_input(input_shape=None,
model_tokenizer=transformer_tokenizer,
device=device)Here we create new kwargs and we forgot about our kwargs from the parameters of calculate_flops. I did the following:
if transformer_tokenizer:
passAnd then instead of:
if kwargs:
for key, value in kwargs.items():
kwargs[key] = value.to(device)I wrote:
if kwargs:
for key, value in kwargs.items():
print(value)
try:
kwargs[key] = value.to(device)
except AttributeError:
passNow it seems to be OK.
Second mistake
In gpt2-small and gpt2-xl, this code worked fine. But in any LLaMa model, it gives errors.
flops_single, macs_single, params_single = calculate_flops(model=model,
kwargs={
'input_ids': input_ids_single,
'attention_mask': attention_mask_single,
'position_ids': position_ids_single,
'max_new_tokens': 10,
},
forward_mode='generate',
transformer_tokenizer=tokenizer,
include_backPropagation=False,
compute_bp_factor=2.0,
print_results=True,
print_detailed=True,
output_as_string=True,
output_precision=2,
output_unit=None,
ignore_modules=None)So, we should use it like this, namely, without 'attention_mask' and 'position_ids':
flops_single, macs_single, params_single = calculate_flops(model=model,
kwargs={
'input_ids': input_ids_single,
# 'attention_mask': attention_mask_single,
# 'position_ids': position_ids_single,
'max_new_tokens': 10,
},
forward_mode='generate',
transformer_tokenizer=tokenizer,
include_backPropagation=False,
compute_bp_factor=2.0,
print_results=True,
print_detailed=True,
output_as_string=True,
output_precision=2,
output_unit=None,
ignore_modules=None)Metadata
Metadata
Assignees
Labels
No labels