Skip to content

Commit 97251b5

Browse files
committed
Fix inception input in cls sample
1 parent 119851b commit 97251b5

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

samples/classification.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
'squeezenet': models.squeezenet1_0,
1616
'densenet': models.densenet161,
1717
'inception': models.inception_v3,
18-
'convnext_base': models.convnext_base}
18+
'convnext_base': models.convnext_base,
19+
'vit_b_16': models.vit_b_16}
1920

2021
if version.parse(torchvision.__version__) > version.parse('0.15'):
2122
pt_models['vit_b_16'] = models.vit_b_16
@@ -42,7 +43,12 @@
4243
if torch.cuda.is_available():
4344
net.cuda(device=args.device)
4445

45-
macs, params = get_model_complexity_info(net, (3, 224, 224),
46+
if args.model == 'inception':
47+
input_res = (3, 299, 299)
48+
else:
49+
input_res = (3, 224, 224)
50+
51+
macs, params = get_model_complexity_info(net, input_res,
4652
as_strings=True,
4753
backend=args.backend,
4854
print_per_layer_stat=True,

0 commit comments

Comments
 (0)