Skip to content

Commit c398ff4

Browse files
Arm backend: Make INT+FP default for vgf-backend (#16176)
### Summary Make INT+FP default for vgf. With INT+FP supported, the following issues were solved: - Make sure that pow.Tensor_Scalar is not replaced by pow.Tensor_Tensor during transform_for_annotation. - Make sure that Scalar-ops that also maps to table ops are not replaced by Tensor ops when quantized. - Fix a bug in tosa_supported_operators where nodes with integer outputs were considered ok for partitioning when they shouldn't be. ### Test plan The plan is to move most of the vgf-tests to lower with both INT and FP support. We can then test both the quantized and non-quantized flows. Signed-off-by: Oscar Andersson <oscar.andersson@arm.com>
1 parent 0c54fd0 commit c398ff4

File tree

7 files changed

+69
-30
lines changed

7 files changed

+69
-30
lines changed

backends/arm/_passes/replace_scalar_with_tensor_pass.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Dict, Set, Type, Union
88

99
import torch
10+
from executorch.backends.arm._passes.insert_table_ops import TableOps
1011

1112
from executorch.backends.arm.tosa.specification import get_context_spec
1213
from executorch.backends.transforms.replace_scalar_with_tensor import (
@@ -64,7 +65,6 @@
6465
Union[EdgeOpOverload, torch._ops.OpOverload],
6566
] = _common_ops | {
6667
exir_ops.edge.aten.pow.Tensor_Scalar: exir_ops.edge.aten.pow.Tensor_Tensor,
67-
torch.ops.aten.pow.Tensor_Scalar: torch.ops.aten.pow.Tensor_Tensor,
6868
}
6969

7070
_int_profile_ops: Dict[
@@ -101,7 +101,15 @@ def call_operator(self, op, args, kwargs, meta):
101101
included_ops |= _fp_profile_ops
102102

103103
if included_ops == {}:
104-
raise ValueError("Profile must support either INT or FP")
104+
raise ValueError("Profile must support at least INT or FP")
105+
106+
if op in TableOps.included_ops():
107+
# Do not handle quantized table ops; forward unchanged.
108+
input_qparams = meta.data.get("input_qparams", {})
109+
output_qparams = meta.data.get("input_qparams", {})
110+
if len(input_qparams) > 0 and len(output_qparams) > 0:
111+
# Do not handle; forward unchanged.
112+
return ExportPass.call_operator(self, op, args, kwargs, meta)
105113

106114
if op in included_ops:
107115
# Include this op based on the current profile.

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,10 @@ def register_tosa_support_check(checker: Type[SupportedTOSAOperatorCheck]):
146146
return checker
147147

148148

149+
def _is_integer_dtype(dtype: torch.dtype) -> bool:
150+
return not dtype.is_floating_point and not dtype.is_complex
151+
152+
149153
def _is_quantized_constant(node: torch.fx.Node) -> bool:
150154
if node.target not in (
151155
exir_ops.edge.aten.full_like.default,
@@ -161,7 +165,7 @@ def _is_quantized_constant(node: torch.fx.Node) -> bool:
161165
for user in users:
162166
if user.target == exir_ops.edge.dim_order_ops._to_dim_order_copy.default:
163167
dim_order_dtype = get_first_fake_tensor(user).dtype
164-
if dim_order_dtype.is_complex or dim_order_dtype.is_floating_point:
168+
if not _is_integer_dtype(dim_order_dtype):
165169
return False
166170
else:
167171
return False
@@ -184,10 +188,24 @@ def is_quantized(node: torch.fx.Node) -> bool:
184188
bool: True if the node is quantized, False otherwise.
185189
"""
186190

187-
node_dtype = get_first_fake_tensor(node).dtype
188-
# Integer-like dtype implies the node is already quantized.
189-
if not node_dtype.is_complex and not node_dtype.is_floating_point:
190-
return True
191+
try:
192+
node_dtype = get_first_fake_tensor(node).dtype
193+
# Integer-like dtype implies the node is already quantized as long
194+
# as inputs are not floating-point.
195+
if _is_integer_dtype(node_dtype):
196+
input_nodes = node.all_input_nodes
197+
input_nodes_dtypes = [
198+
get_first_fake_tensor(input_node).dtype for input_node in input_nodes
199+
]
200+
if all(
201+
_is_integer_dtype(input_node_dtype)
202+
for input_node_dtype in input_nodes_dtypes
203+
):
204+
return True
205+
206+
except TypeError:
207+
# Could not determine dtype, fall back to other checks.
208+
pass
191209

192210
# Nodes introduced during lowering that exclusively feed quantized users.
193211
if _is_quantized_constant(node):
@@ -510,7 +528,7 @@ def is_node_supported(
510528

511529
input_quantized = input_quantized or all(
512530
(input_node.target in DQ_OPS)
513-
or (not get_first_fake_tensor(input_node).dtype.is_floating_point)
531+
or _is_integer_dtype(get_first_fake_tensor(input_node).dtype)
514532
for input_node in node.all_input_nodes
515533
)
516534

@@ -519,8 +537,10 @@ def is_node_supported(
519537
return False
520538

521539
all_q_users = all((output_node.target in Q_OPS) for output_node in node.users)
522-
is_floating_point = get_first_fake_tensor(node).dtype.is_floating_point
523-
output_quantized = output_quantized or all_q_users or not is_floating_point
540+
output_dtype = get_first_fake_tensor(node).dtype
541+
output_quantized = (
542+
output_quantized or all_q_users or _is_integer_dtype(output_dtype)
543+
)
524544

525545
if not output_quantized:
526546
self.reporter.report_reject(node, "One or more outputs were not quantized.")

backends/arm/quantizer/quantization_annotator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,7 @@ def _match_pattern(
356356
torch.ops.aten.hardswish.default,
357357
torch.ops.aten.hardswish_.default,
358358
torch.ops.aten.full_like.default,
359+
torch.ops.aten.zeros_like.default,
359360
torch.ops.aten.pow.Tensor_Scalar,
360361
torch.ops.aten.gelu.default,
361362
torch.ops.aten.sinh.default,

backends/arm/test/common.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,13 +170,22 @@ def get_vgf_compile_spec(
170170

171171
if not custom_path:
172172
custom_path = maybe_get_tosa_collate_path()
173+
profiles = []
173174
if "FP" in repr(tosa_spec):
174-
artifact_path = custom_path or tempfile.mkdtemp(prefix="arm_vgf_fp_")
175-
elif "INT" in repr(tosa_spec):
176-
artifact_path = custom_path or tempfile.mkdtemp(prefix="arm_vgf_int_")
177-
else:
175+
profiles.append("fp")
176+
if "INT" in repr(tosa_spec):
177+
profiles.append("int")
178+
if len(profiles) == 0:
178179
raise ValueError(f"Unsupported vgf compile_spec: {repr(tosa_spec)}")
179180

181+
if custom_path is None:
182+
artifact_path = "arm_vgf_"
183+
for profile in profiles:
184+
artifact_path = artifact_path + f"_{profile}"
185+
artifact_path = tempfile.mkdtemp(artifact_path)
186+
else:
187+
artifact_path = custom_path
188+
180189
if not os.path.exists(artifact_path):
181190
os.makedirs(artifact_path, exist_ok=True)
182191

backends/arm/test/ops/test_eq.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def test_eq_scalar_tosa_INT(test_module):
122122

123123

124124
@common.parametrize("test_module", test_data_tensor)
125-
def test_eq_tensor_tosa_INT_a16w8(test_module):
125+
def test_eq_tensor_tosa_INT_16a8w(test_module):
126126
pipeline = TosaPipelineINT[input_t](
127127
test_module(),
128128
test_module().get_inputs(),
@@ -134,7 +134,7 @@ def test_eq_tensor_tosa_INT_a16w8(test_module):
134134

135135

136136
@common.parametrize("test_module", test_data_scalar)
137-
def test_eq_scalar_tosa_INT_a16w8(test_module):
137+
def test_eq_scalar_tosa_INT_16a8w(test_module):
138138
pipeline = TosaPipelineINT[input_t](
139139
test_module(),
140140
test_module().get_inputs(),
@@ -238,7 +238,11 @@ def test_eq_scalar_16a8w_u85_INT16(test_module):
238238
@common.SkipIfNoModelConverter
239239
def test_eq_scalar_vgf_FP_tensor(test_module):
240240
pipeline = VgfPipeline[input_t](
241-
test_module(), test_module().get_inputs(), Equal.aten_op_Tensor, Equal.exir_op
241+
test_module(),
242+
test_module().get_inputs(),
243+
Equal.aten_op_Tensor,
244+
Equal.exir_op,
245+
tosa_version="TOSA-1.0+FP",
242246
)
243247
pipeline.run()
244248

@@ -247,7 +251,11 @@ def test_eq_scalar_vgf_FP_tensor(test_module):
247251
@common.SkipIfNoModelConverter
248252
def test_eq_scalar_vgf_FP(test_module):
249253
pipeline = VgfPipeline[input_t](
250-
test_module(), test_module().get_inputs(), Equal.aten_op_Scalar, Equal.exir_op
254+
test_module(),
255+
test_module().get_inputs(),
256+
Equal.aten_op_Scalar,
257+
Equal.exir_op,
258+
tosa_version="TOSA-1.0+FP",
251259
)
252260
pipeline.run()
253261

backends/arm/test/tester/test_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -990,7 +990,7 @@ def __init__(
990990
exir_op: Optional[str | List[str]] = None,
991991
run_on_vulkan_runtime: bool = True,
992992
vgf_compiler_flags: Optional[str] = "",
993-
tosa_version: str = "TOSA-1.0+FP",
993+
tosa_version: str = "TOSA-1.0+INT+FP",
994994
symmetric_io_quantization: bool = False,
995995
per_channel_quantization: bool = True,
996996
use_to_edge_transform_and_lower: bool = True,

backends/arm/vgf/compile_spec.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,12 @@ def __init__(
2828
tosa_spec (TosaSpecification | str | None): TOSA specification to
2929
target. Strings are parsed via
3030
:meth:`TosaSpecification.create_from_string`. Defaults to
31-
``"TOSA-1.0+FP"``.
31+
``"TOSA-1.0+FP+INT"``.
3232
compiler_flags (list[str] | None): Optional converter-backend flags.
33-
3433
"""
3534
if tosa_spec is None:
36-
tosa_spec = "TOSA-1.0+FP"
37-
if isinstance(tosa_spec, str):
35+
tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+FP+INT")
36+
elif isinstance(tosa_spec, str):
3837
tosa_spec = TosaSpecification.create_from_string(tosa_spec)
3938

4039
if compiler_flags is None:
@@ -55,13 +54,7 @@ def validate(self):
5554

5655
if "FP" not in tosa_profiles and "INT" not in tosa_profiles:
5756
raise ValueError(
58-
"Arm backend only supports converter-backend for FP or INT. "
59-
f"Invalid TOSA profile: {tosa_profiles}"
60-
)
61-
62-
if len(tosa_profiles) != 1:
63-
raise ValueError(
64-
"For now Arm backend only supports converter-backend for either FP or INT. "
57+
"Arm backend only supports converter-backend for FP and/or INT. "
6558
f"Invalid TOSA profile: {tosa_profiles}"
6659
)
6760

0 commit comments

Comments
 (0)