Skip to content

Commit 2441917

Browse files
Arm backend: Add U55 and U85 tests for deit_tiny (#16145)
Adds tests of deit_tiny for U55 and U85. Signed-off-by: Oscar Andersson <oscar.andersson@arm.com>
1 parent 670bc11 commit 2441917

File tree

1 file changed

+34
-5
lines changed

1 file changed

+34
-5
lines changed

backends/arm/test/models/test_deit_tiny_arm.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
import logging
7-
86
from typing import Tuple
97

108
import timm # type: ignore[import-untyped]
@@ -14,6 +12,8 @@
1412
from executorch.backends.arm.test import common
1513

1614
from executorch.backends.arm.test.tester.test_pipeline import (
15+
EthosU55PipelineINT,
16+
EthosU85PipelineINT,
1717
TosaPipelineFP,
1818
TosaPipelineINT,
1919
VgfPipeline,
@@ -25,10 +25,8 @@
2525
)
2626
from torchvision import transforms # type: ignore[import-untyped]
2727

28-
logger = logging.getLogger(__name__)
29-
30-
3128
deit_tiny = timm.models.deit.deit_tiny_patch16_224(pretrained=True)
29+
3230
deit_tiny.eval()
3331

3432
normalize = transforms.Normalize(
@@ -63,6 +61,37 @@ def test_deit_tiny_tosa_INT():
6361
pipeline.run()
6462

6563

64+
def test_deit_tiny_u55_INT():
65+
pipeline = EthosU55PipelineINT[input_t](
66+
deit_tiny,
67+
model_inputs,
68+
aten_ops=[],
69+
exir_ops=[],
70+
use_to_edge_transform_and_lower=True,
71+
atol=1.5,
72+
qtol=1,
73+
)
74+
# Multiple partitions
75+
pipeline.pop_stage("check_count.exir")
76+
# Don't run inference as model is too large for Corstone-300
77+
pipeline.pop_stage("run_method_and_compare_outputs")
78+
pipeline.run()
79+
80+
81+
@common.XfailIfNoCorstone320
82+
def test_deit_tiny_u85_INT():
83+
pipeline = EthosU85PipelineINT[input_t](
84+
deit_tiny,
85+
model_inputs,
86+
aten_ops=[],
87+
exir_ops=[],
88+
use_to_edge_transform_and_lower=True,
89+
atol=1.5,
90+
qtol=1,
91+
)
92+
pipeline.run()
93+
94+
6695
@common.SkipIfNoModelConverter
6796
def test_deit_tiny_vgf_INT():
6897
pipeline = VgfPipeline[input_t](

0 commit comments

Comments
 (0)