From 1b4707cfd1d6b5eb9623f3e51aa31c8d4696655c Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Fri, 20 Feb 2026 11:41:56 +0100 Subject: [PATCH 01/11] Tracer prototype --- src/gt4py/next/ffront/field_operator_ast.py | 12 +++++ .../ffront/foast_passes/type_deduction.py | 26 +++++++++- src/gt4py/next/ffront/foast_pretty_printer.py | 2 + src/gt4py/next/ffront/foast_to_gtir.py | 9 ++++ src/gt4py/next/ffront/foast_to_past.py | 8 ++-- src/gt4py/next/ffront/func_to_foast.py | 31 ++++++++++-- .../next/ffront/past_passes/type_deduction.py | 2 +- src/gt4py/next/iterator/builtins.py | 3 +- .../next/iterator/transforms/pass_manager.py | 3 ++ .../iterator/transforms/unroll_map_tuple.py | 47 +++++++++++++++++++ .../iterator/type_system/type_synthesizer.py | 13 +++++ src/gt4py/next/type_system/type_info.py | 8 ++++ .../next/type_system/type_specifications.py | 9 ++++ .../next/type_system/type_translation.py | 22 +++++++-- tests/next_tests/integration_tests/cases.py | 11 +++++ .../ffront_tests/test_execution.py | 30 ++++++++++++ 16 files changed, 222 insertions(+), 14 deletions(-) create mode 100644 src/gt4py/next/iterator/transforms/unroll_map_tuple.py diff --git a/src/gt4py/next/ffront/field_operator_ast.py b/src/gt4py/next/ffront/field_operator_ast.py index fa5bc4889f..f4950ea402 100644 --- a/src/gt4py/next/ffront/field_operator_ast.py +++ b/src/gt4py/next/ffront/field_operator_ast.py @@ -113,6 +113,18 @@ class TupleExpr(Expr): elts: list[Expr] +# TODO: give a good error for tuple(... for el in iter if ...) so that users understand that and why we don't support conditionals +# TODO: should this have SymbolTableTrait since target declares a new symbol. Write test that has two comprehensions using the same target name. +class TupleComprehension(Expr): + """ + tuple(element_expr for target in iterable) + """ + + element_expr: Expr + target: DataSymbol # TODO: how about `tuple(el1+el2 for el1, el2 in var_arg)`? + iterable: Expr + + class UnaryOp(Expr): op: dialect_ast_enums.UnaryOperator operand: Expr diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 68bf108a0a..e545f9e002 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -5,7 +5,7 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause - +import collections from typing import Any, Optional, TypeAlias, TypeVar, cast import gt4py.next.ffront.field_operator_ast as foast @@ -501,6 +501,10 @@ def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> foast.Subscri f"Tuples need to be indexed with literal integers, got '{node.index}'.", ) from ex new_type = types[index] + case ts.VarArgType(element_type=element_type): + new_type = ( + element_type # TODO: we only temporarily allow any index for vararg types + ) case ts.OffsetType(source=source, target=(target1, target2)): if not target2.kind == DimensionKind.LOCAL: raise errors.DSLError( @@ -747,6 +751,26 @@ def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs: Any) -> foast.TupleEx new_type = ts.TupleType(types=[element.type for element in new_elts]) return foast.TupleExpr(elts=new_elts, type=new_type, location=node.location) + def visit_TupleComprehension( + self, node: foast.TupleComprehension, **kwargs: Any + ) -> foast.TupleComprehension: + symtable: collections.ChainMap = kwargs["symtable"] # todo annotation + iterable = self.visit(node.iterable, **kwargs) + target = self.visit(node.target, **kwargs) + assert isinstance(iterable.type, ts.VarArgType) + target.type = iterable.type.element_type + element_expr = self.visit( + node.element_expr, + **{**kwargs, "symtable": symtable.new_child({node.target.id: target})}, + ) + return foast.TupleComprehension( + element_expr=element_expr, + target=target, + iterable=iterable, + location=node.location, + type=ts.VarArgType(element_type=element_expr.type), + ) + def visit_Call(self, node: foast.Call, **kwargs: Any) -> foast.Call: new_func = self.visit(node.func, **kwargs) new_args = self.visit(node.args, **kwargs) diff --git a/src/gt4py/next/ffront/foast_pretty_printer.py b/src/gt4py/next/ffront/foast_pretty_printer.py index 8b2e369501..77495d78f7 100644 --- a/src/gt4py/next/ffront/foast_pretty_printer.py +++ b/src/gt4py/next/ffront/foast_pretty_printer.py @@ -118,6 +118,8 @@ def apply(cls, node: foast.LocatedNode, **kwargs: Any) -> str: # type: ignore[o TupleExpr = as_fmt("({', '.join(elts)}{',' if len(elts)==1 else ''})") + TupleComprehension = as_fmt("tuple(({element_expr} for {target} in {iterable}))") + UnaryOp = as_fmt("{op}{operand}") def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> str: diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 3825072cb7..2e587c346e 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -257,6 +257,15 @@ def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> itir.Expr: def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs: Any) -> itir.Expr: return im.make_tuple(*[self.visit(el, **kwargs) for el in node.elts]) + def visit_TupleComprehension(self, node: foast.TupleComprehension, **kwargs: Any) -> itir.Expr: + return im.call( + im.call("map_tuple")( + im.lambda_(self.visit(node.target, **kwargs))( + self.visit(node.element_expr, **kwargs) + ) + ) + )(self.visit(node.iterable, **kwargs)) + def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr: # TODO(tehrengruber): extend iterator ir to support unary operators dtype = type_info.extract_dtype(node.type) diff --git a/src/gt4py/next/ffront/foast_to_past.py b/src/gt4py/next/ffront/foast_to_past.py index 05b080b70b..c37cba5a78 100644 --- a/src/gt4py/next/ffront/foast_to_past.py +++ b/src/gt4py/next/ffront/foast_to_past.py @@ -21,7 +21,7 @@ from gt4py.next.ffront.stages import ConcreteFOASTOperatorDef, ConcretePASTProgramDef from gt4py.next.iterator import ir as itir from gt4py.next.otf import toolchain, workflow -from gt4py.next.type_system import type_info, type_specifications as ts +from gt4py.next.type_system import type_specifications as ts @dataclasses.dataclass(frozen=True) @@ -113,9 +113,9 @@ def __call__(self, inp: ConcreteFOASTOperatorDef) -> ConcretePASTProgramDef: *partial_program_type.definition.kw_only_args.keys(), ] assert isinstance(type_, ts.CallableType) - assert arg_types[-1] == type_info.return_type( - type_, with_args=list(arg_types), with_kwargs=kwarg_types - ) + # assert arg_types[-1] == type_info.return_type( + # type_, with_args=list(arg_types), with_kwargs=kwarg_types + # ) assert args_names[-1] == "out" params_decl: list[past.Symbol] = [ diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index ced0ff3905..adefa7ba9e 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -337,7 +337,12 @@ def visit_Expr(self, node: ast.Expr) -> foast.Expr: return self.visit(node.value) def visit_Name(self, node: ast.Name, **kwargs: Any) -> foast.Name: - return foast.Name(id=node.id, location=self.get_location(node)) + loc = self.get_location(node) + if isinstance(node.ctx, ast.Store): + return foast.DataSymbol(id=node.id, location=loc, type=ts.DeferredType(constraint=None)) + else: + assert isinstance(node.ctx, ast.Load) + return foast.Name(id=node.id, location=loc) def visit_UnaryOp(self, node: ast.UnaryOp, **kwargs: Any) -> foast.UnaryOp: return foast.UnaryOp( @@ -469,8 +474,10 @@ def visit_NotEq(self, node: ast.NotEq, **kwargs: Any) -> foast.CompareOperator: return foast.CompareOperator.NOTEQ def _verify_builtin_type_constructor(self, node: ast.Call) -> None: - if len(node.args) > 0: - arg = node.args[0] + (arg,) = ( + node.args + ) # note for review: the change here is unrelated to the actual pr and just a small cleanup + if node.func.id == "tuple": if not ( isinstance(arg, ast.Constant) or (isinstance(arg, ast.UnaryOp) and isinstance(arg.operand, ast.Constant)) @@ -484,9 +491,25 @@ def _func_name(self, node: ast.Call) -> str: return node.func.id # type: ignore[attr-defined] # We want this to fail if the attribute does not exist unexpectedly. def visit_Call(self, node: ast.Call, **kwargs: Any) -> foast.Call: - # TODO(tehrengruber): is this still needed or redundant with the checks in type deduction? if isinstance(node.func, ast.Name): func_name = self._func_name(node) + + if func_name == "tuple": + (gen_expr,) = node.args + assert ( + len(gen_expr.generators) == 1 + ) # we don't support (... for ... in ... for ... in ...) + assert ( + gen_expr.generators[0].ifs == [] + ) # we don't support if conditions in comprehensions + return foast.TupleComprehension( + element_expr=self.visit(gen_expr.elt, **kwargs), + target=self.visit(gen_expr.generators[0].target, **kwargs), + iterable=self.visit(gen_expr.generators[0].iter, **kwargs), + location=self.get_location(node), + ) + + # TODO(tehrengruber): is this still needed or redundant with the checks in type deduction? if func_name in fbuiltins.TYPE_BUILTIN_NAMES: self._verify_builtin_type_constructor(node) diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index 9d021ceb51..530d407459 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -248,7 +248,7 @@ def visit_Call(self, node: past.Call, **kwargs: Any) -> past.Call: operator_return_type = type_info.return_type( new_func.type, with_args=arg_types, with_kwargs=kwarg_types ) - if operator_return_type != new_kwargs["out"].type: + if not type_info.is_compatible_type(operator_return_type, new_kwargs["out"].type): raise ValueError( "Expected keyword argument 'out' to be of " f"type '{operator_return_type}', got " diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index e54c6ea3d7..7b24c91884 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -498,7 +498,8 @@ def get_domain_range(*args): "lift", "make_const_list", "make_tuple", - "map_", + "map_tuple", + "map_", # TODO: rename to map_list "named_range", "neighbors", "reduce", diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 08ca9d94e0..4102790129 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -24,6 +24,7 @@ prune_empty_concat_where, remove_broadcast, symbol_ref_utils, + unroll_map_tuple, ) from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple @@ -179,6 +180,7 @@ def apply_common_transforms( ) # domain inference does not support dynamic offsets yet ir = infer_domain_ops.InferDomainOps.apply(ir) ir = concat_where.canonicalize_domain_argument(ir) + ir = unroll_map_tuple.UnrollMapTuple.apply(ir, uids=uids) ir = infer_domain.infer_program( ir, @@ -293,6 +295,7 @@ def apply_fieldview_transforms( ir = infer_domain_ops.InferDomainOps.apply(ir) ir = concat_where.canonicalize_domain_argument(ir) + ir = unroll_map_tuple.UnrollMapTuple.apply(ir, uids=uids) ir = ConstantFolding.apply(ir) # type: ignore[assignment] # always an itir.Program ir = infer_domain.infer_program( diff --git a/src/gt4py/next/iterator/transforms/unroll_map_tuple.py b/src/gt4py/next/iterator/transforms/unroll_map_tuple.py new file mode 100644 index 0000000000..66f96d66fa --- /dev/null +++ b/src/gt4py/next/iterator/transforms/unroll_map_tuple.py @@ -0,0 +1,47 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +import dataclasses + +from gt4py import eve +from gt4py.next import utils +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next.iterator.type_system import inference as itir_inference +from gt4py.next.type_system import type_specifications as ts + + +@dataclasses.dataclass +class UnrollMapTuple(eve.NodeTranslator): + PRESERVED_ANNEX_ATTRS = ("domain",) + + uids: utils.IDGeneratorPool + + @classmethod + def apply(cls, program: itir.Program, *, uids: utils.IDGeneratorPool): + return cls(uids=uids).visit(program) + + def visit_FunCall(self, node: itir.Expr): + node = self.generic_visit(node) + + if cpm.is_call_to(node.fun, "map_tuple"): + # TODO: we have to duplicate the function here since the domain inference can not handle them yet + f = node.fun.args[0] + tup = node.args[0] + itir_inference.reinfer(tup) + assert isinstance(tup.type, ts.TupleType) + tup_ref = next(self.uids["_ump"]) + + result = im.let(tup_ref, tup)( + im.make_tuple( + *(im.call(f)(im.tuple_get(i, tup_ref)) for i in range(len(tup.type.types))) + ) + ) + itir_inference.reinfer(result) + + return result + return node diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 6d77c70375..4406dd9aa8 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -633,6 +633,19 @@ def applied_map( return applied_map +@_register_builtin_type_synthesizer +def map_tuple(op: TypeSynthesizer) -> TypeSynthesizer: + @type_synthesizer + def applied_map( + arg: ts.TupleType, offset_provider_type: common.OffsetProviderType + ) -> ts.TupleType: + return ts.TupleType( + types=[op(arg_, offset_provider_type=offset_provider_type) for arg_ in arg.types] + ) + + return applied_map + + @_register_builtin_type_synthesizer def reduce(op: TypeSynthesizer, init: ts.TypeSpec) -> TypeSynthesizer: @type_synthesizer diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index eb70d15947..69fccd33da 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -566,6 +566,14 @@ def is_concretizable(symbol_type: ts.TypeSpec, to_type: ts.TypeSpec) -> bool: or issubclass(type_class(to_type), symbol_type.constraint) ): return True + if isinstance(symbol_type, ts.VarArgType) and isinstance(to_type, ts.VarArgType): + return is_concretizable(symbol_type.element_type, to_type.element_type) + if isinstance(symbol_type, ts.VarArgType) and isinstance(to_type, ts.TupleType): + if len(to_type.types) == 0 or ( + all(type_ == to_type.types[0] for type_ in to_type.types) + and is_concretizable(symbol_type.element_type, to_type.types[0]) + ): + return True elif is_concrete(symbol_type): return symbol_type == to_type return False diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index 59ac40f0f3..409138d593 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -148,6 +148,15 @@ def __len__(self) -> int: return len(self.types) +class VarArgType(DataType): + """Represents a variable number of arguments of the same type.""" + + element_type: DataType # TODO: maybe also support different DataTypes + + def __str__(self) -> str: + return f"VarArg[{self.element_type}]" + + class AnyPythonType: """Marker type representing any Python type which cannot be used for instantiation. diff --git a/src/gt4py/next/type_system/type_translation.py b/src/gt4py/next/type_system/type_translation.py index 0f145e04aa..0ca020625a 100644 --- a/src/gt4py/next/type_system/type_translation.py +++ b/src/gt4py/next/type_system/type_translation.py @@ -180,8 +180,12 @@ def from_type_hint( case builtins.tuple: if not args: raise ValueError(f"Tuple annotation '{type_hint}' requires at least one argument.") - if Ellipsis in args: - raise ValueError(f"Unbound tuples '{type_hint}' are not allowed.") + if len(args) == 2 and args[1] is Ellipsis: + return ts.VarArgType(element_type=from_type_hint_same_ns(args[0])) + elif Ellipsis in args: + raise ValueError( + f"Vararg tuple annotation '{type_hint}' cannot have more than one argument." + ) tuple_types = [from_type_hint_same_ns(arg) for arg in args] assert all(isinstance(elem, ts.DataType) for elem in tuple_types) return ts.TupleType(types=tuple_types) @@ -321,7 +325,19 @@ def from_value(value: Any) -> ts.TypeSpec: return UnknownPythonObject(value) else: type_ = xtyping.infer_type(value, annotate_callable_kwargs=True) - symbol_type = from_type_hint(type_) + if type_ == type[tuple]: + # TODO: this special casing here is not nice, but infer_type is also called on the annotations where + # we don't want to allow unparameterized tuples (or do we?). + symbol_type = ts.ConstructorType( + definition=ts.FunctionType( + pos_only_args=[ts.DeferredType(constraint=None)], + pos_or_kw_args={}, + kw_only_args={}, + returns=ts.DeferredType(constraint=ts.VarArgType), + ) + ) + else: + symbol_type = from_type_hint(type_) if isinstance(symbol_type, (ts.DataType, ts.CallableType, ts.OffsetType, ts.DimensionType)): return symbol_type diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 78e6c62781..e723c963de 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -603,6 +603,15 @@ def _allocate_from_type( for t in types ) ) + case ts.VarArgType(element_type=element_type): + return tuple( + ( + _allocate_from_type( + case=case, arg_type=t, domain=domain, dtype=dtype, strategy=strategy + ) + for t in [element_type] * 3 # TODO: revisit + ) + ) case ts.NamedCollectionType(types=types) as named_collection_type_spec: container_constructor = ( named_collections.make_named_collection_constructor_from_type_spec( @@ -648,6 +657,8 @@ def get_param_size(param_type: ts.TypeSpec, sizes: dict[gtx.Dimension, int]) -> return sum([get_param_size(t, sizes=sizes) for t in types]) case ts.NamedCollectionType(types=types): return sum([get_param_size(t, sizes=sizes) for t in types]) + case ts.VarArgType(element_type=element_type): + return get_param_size(ts.TupleType(types=[element_type] * 3), sizes) # TODO: revisit case _: raise TypeError(f"Can not get size for parameter of type '{param_type}'.") diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 8060d5bb36..14f14b3ffb 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -336,6 +336,36 @@ def testee(a: tuple[cases.IField, cases.IJField]) -> cases.IJField: ) +@pytest.mark.uses_tuple_args +def test_tuple_comprehension(cartesian_case): + @gtx.field_operator + def testee( + tracers: tuple[cases.IFloatField, ...], factor: float + ) -> tuple[cases.IFloatField, ...]: + return tuple(tracer * factor for tracer in tracers) + + cases.verify_with_default_data( + cartesian_case, + testee, + ref=lambda t, f: tuple(el * f for el in t), + ) + + +@pytest.mark.uses_tuple_args +def test_tuple_vararg(cartesian_case): + @gtx.field_operator + def testee( + tracers: tuple[cases.IFloatField, ...], factor: float + ) -> tuple[cases.IFloatField, cases.IFloatField]: + return tracers[0] * factor, tracers[1] * factor + + cases.verify_with_default_data( + cartesian_case, + testee, + ref=lambda t, f: tuple(el * f for el in t[:2]), + ) + + @pytest.mark.uses_tuple_args @pytest.mark.xfail(reason="Iterator of tuple approach in lowering does not allow this.") def test_tuple_arg_with_unpromotable_dims(unstructured_case): From 0dfc80cccb4503214a26a6b0cd44c3444de98649 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 30 Apr 2026 15:26:35 +0200 Subject: [PATCH 02/11] Add support for nested tuples --- src/gt4py/next/ffront/field_operator_ast.py | 4 +- .../ffront/foast_passes/type_deduction.py | 51 +++++++++++++++++-- src/gt4py/next/ffront/foast_to_gtir.py | 14 ++++- src/gt4py/next/ffront/func_to_foast.py | 9 +++- .../ffront_tests/test_execution.py | 32 ++++++++++-- .../ffront_tests/test_func_to_foast.py | 22 ++++++++ 6 files changed, 119 insertions(+), 13 deletions(-) diff --git a/src/gt4py/next/ffront/field_operator_ast.py b/src/gt4py/next/ffront/field_operator_ast.py index 8ee216b96b..aa32a3d939 100644 --- a/src/gt4py/next/ffront/field_operator_ast.py +++ b/src/gt4py/next/ffront/field_operator_ast.py @@ -9,7 +9,6 @@ from __future__ import annotations import functools -from typing import Any, Generic, TypeAlias, TypeVar, Union from gt4py import eve from gt4py.eve import ( @@ -22,6 +21,7 @@ datamodels, utils as eve_utils, ) +from gt4py.eve.extended_typing import Any, Generic, TypeAlias, TypeVar, Union from gt4py.eve.traits import SymbolTableTrait from gt4py.eve.type_definitions import StrEnum from gt4py.next.ffront import dialect_ast_enums, type_specifications as ts_ffront @@ -131,7 +131,7 @@ class TupleComprehension(Expr): """ element_expr: Expr - target: DataSymbol # TODO: how about `tuple(el1+el2 for el1, el2 in var_arg)`? + target: Any # should be: MaybeNestedInTuple[DataSymbol] but this has a problem in eve iterable: Expr diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 2b33d54cca..f7b7d148f0 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -9,7 +9,6 @@ import textwrap from typing import Any, Optional, Sequence, TypeAlias, TypeVar, cast - import gt4py.next.ffront.field_operator_ast as foast from gt4py import eve from gt4py.eve import NodeTranslator, NodeVisitor, traits @@ -25,6 +24,7 @@ from gt4py.next.ffront.foast_passes import utils as foast_utils from gt4py.next.iterator import builtins from gt4py.next.type_system import type_info, type_specifications as ts, type_translation +from gt4py.next.utils import tree_map OperatorNodeT = TypeVar("OperatorNodeT", bound=foast.LocatedNode) @@ -685,18 +685,59 @@ def visit_TupleComprehension( symtable: collections.ChainMap = kwargs["symtable"] # todo annotation iterable = self.visit(node.iterable, **kwargs) target = self.visit(node.target, **kwargs) - assert isinstance(iterable.type, ts.VarArgType) - target.type = iterable.type.element_type + if isinstance(iterable.type, ts.TupleType): + if len(iterable.type.types) > 0 and not all( + t == iterable.type.types[0] for t in iterable.type.types + ): + raise errors.DSLError( + iterable.location, + "Not implemented. All elements of the iterable in a tuple comprehensions must have the same type.", + ) + element_type = iterable.type.types[0] + elif isinstance(iterable.type, ts.VarArgType): + element_type = iterable.type.element_type + else: + raise errors.DSLError( + iterable.location, + f"Iterable in generator expression must be a tuple, got '{iterable.type}'.", + ) + + new_syms = {} + + @tree_map(with_path_arg=True) + def process_target(target_el: foast.Symbol, path: tuple[int, ...]): + try: + new_syms[target_el.id] = target_el + type_ = element_type + for i in path: + if not isinstance(type_, ts.TupleType) or len(type_.types) <= i: + raise IndexError() + type_ = type_.types[i] + target_el.type = type_ + except IndexError: + raise errors.DSLError( + target_el.location, f"Cannot unpack non-iterable '{type_}' object." + ) + + process_target(target) + element_expr = self.visit( node.element_expr, - **{**kwargs, "symtable": symtable.new_child({node.target.id: target})}, + **{**kwargs, "symtable": symtable.new_child(new_syms)}, ) + + if isinstance(iterable.type, ts.TupleType): + return_type = ts.TupleType(types=[element_expr.type] * len(iterable.type.types)) + else: + assert isinstance(iterable.type, ts.VarArgType) + return_type = ts.VarArgType(element_type=element_expr.type) + return foast.TupleComprehension( element_expr=element_expr, target=target, iterable=iterable, location=node.location, - type=ts.VarArgType(element_type=element_expr.type), + type=return_type, ) def visit_Call(self, node: foast.Call, **kwargs: Any) -> foast.Call: diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 2e587c346e..349ff62804 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -8,6 +8,7 @@ import dataclasses +import functools from typing import Any, Callable, Optional from gt4py import eve @@ -258,10 +259,19 @@ def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs: Any) -> itir.Expr: return im.make_tuple(*[self.visit(el, **kwargs) for el in node.elts]) def visit_TupleComprehension(self, node: foast.TupleComprehension, **kwargs: Any) -> itir.Expr: + sym = next(self.uid_generator["__tuple_compr"]) + targets = [ + self.visit(target, **kwargs) for target in utils.flatten_nested_tuple(node.target) + ] + target_vals = utils.tree_map( + lambda _, path: functools.reduce(lambda el, i: im.tuple_get(i, el), path, sym), + with_path_arg=True, + )(node.target) + return im.call( im.call("map_tuple")( - im.lambda_(self.visit(node.target, **kwargs))( - self.visit(node.element_expr, **kwargs) + im.lambda_(sym)( + im.let(*zip(targets, target_vals))(self.visit(node.element_expr, **kwargs)) ) ) )(self.visit(node.iterable, **kwargs)) diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index adefa7ba9e..ae8b41145b 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -502,9 +502,16 @@ def visit_Call(self, node: ast.Call, **kwargs: Any) -> foast.Call: assert ( gen_expr.generators[0].ifs == [] ) # we don't support if conditions in comprehensions + + def parse_target(target: ast.Tuple | ast.Name) -> tuple[foast.Name]: + if isinstance(target, ast.Tuple): + return tuple(parse_target(el) for el in target.elts) + assert isinstance(target, ast.Name) + return self.visit(target, **kwargs) + return foast.TupleComprehension( element_expr=self.visit(gen_expr.elt, **kwargs), - target=self.visit(gen_expr.generators[0].target, **kwargs), + target=parse_target(gen_expr.generators[0].target), iterable=self.visit(gen_expr.generators[0].iter, **kwargs), location=self.get_location(node), ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 986fa5f5cb..2861d1aed4 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -339,11 +339,24 @@ def testee(a: tuple[cases.IField, cases.IJField]) -> cases.IJField: @pytest.mark.uses_tuple_args -def test_tuple_comprehension(cartesian_case): +def test_fixed_len_tuple_comprehension(cartesian_case): @gtx.field_operator def testee( - tracers: tuple[cases.IFloatField, ...], factor: float - ) -> tuple[cases.IFloatField, ...]: + tracers: tuple[cases.IField, cases.IField], factor: int32 + ) -> tuple[cases.IField, cases.IField]: + return tuple(tracer * factor for tracer in tracers) + + cases.verify_with_default_data( + cartesian_case, + testee, + ref=lambda t, f: tuple(el * f for el in t), + ) + + +@pytest.mark.uses_tuple_args +def test_var_len_tuple_comprehension(cartesian_case): + @gtx.field_operator + def testee(tracers: tuple[cases.IField, ...], factor: int32) -> tuple[cases.IField, ...]: return tuple(tracer * factor for tracer in tracers) cases.verify_with_default_data( @@ -353,6 +366,19 @@ def testee( ) +@pytest.mark.uses_tuple_args +def test_nested_tuple_comprehension(cartesian_case): + @gtx.field_operator + def testee(nested_tuple: tuple[tuple[int32, cases.IField], ...]) -> tuple[cases.IField, ...]: + return tuple(factor * tracer for factor, tracer in nested_tuple) + + cases.verify_with_default_data( + cartesian_case, + testee, + ref=lambda t: tuple(f * el for f, el in t), + ) + + @pytest.mark.uses_tuple_args def test_tuple_vararg(cartesian_case): @gtx.field_operator diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py index 57c2a8be3a..106b627bb2 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py @@ -447,3 +447,25 @@ def tuple_index_failure( with pytest.raises(errors.DSLError, match=r"need .* literal"): _ = FieldOperatorParser.apply_to_function(tuple_index_failure) + + +def test_tuple_compr_non_tuple_iterable_failure(): + def testee(arg: float): + return tuple(_ for _ in arg) + + with pytest.raises( + errors.DSLError, + match=re.escape("Iterable in generator expression must be a tuple, got 'float64'."), + ): + _ = FieldOperatorParser.apply_to_function(testee) + + +def test_tuple_compr_unpacking_failure(): + def testee(arg: tuple[int32, ...]): + return tuple(a * b for a, b in arg) + + with pytest.raises( + errors.DSLError, + match=re.escape("Cannot unpack non-iterable 'int32' object."), + ): + _ = FieldOperatorParser.apply_to_function(testee) From b771d6690dff7651ec861ad1dd43afd1c5482286 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 5 May 2026 12:46:42 +0200 Subject: [PATCH 03/11] Improve testing, typing, formatting --- .../ffront/foast_passes/type_deduction.py | 5 +-- src/gt4py/next/ffront/func_to_foast.py | 36 ++++++++++++------- .../iterator/transforms/unroll_map_tuple.py | 2 +- .../ffront_tests/test_execution.py | 2 +- .../ffront_tests/test_func_to_foast.py | 11 ++++++ 5 files changed, 39 insertions(+), 17 deletions(-) diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index f7b7d148f0..a2dca69f80 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -705,7 +705,7 @@ def visit_TupleComprehension( new_syms = {} @tree_map(with_path_arg=True) - def process_target(target_el: foast.Symbol, path: tuple[int, ...]): + def process_target(target_el: foast.Symbol, path: tuple[int, ...]) -> None: try: new_syms[target_el.id] = target_el type_ = element_type @@ -717,7 +717,7 @@ def process_target(target_el: foast.Symbol, path: tuple[int, ...]): except IndexError: raise errors.DSLError( target_el.location, f"Cannot unpack non-iterable '{type_}' object." - ) + ) from None process_target(target) @@ -726,6 +726,7 @@ def process_target(target_el: foast.Symbol, path: tuple[int, ...]): **{**kwargs, "symtable": symtable.new_child(new_syms)}, ) + return_type: ts.TupleType | ts.VarArgType if isinstance(iterable.type, ts.TupleType): return_type = ts.TupleType(types=[element_expr.type] * len(iterable.type.types)) else: diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index ae8b41145b..cb5246a7a4 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -11,9 +11,9 @@ import ast import textwrap import typing -from typing import Any, Type import gt4py.eve as eve +from gt4py.eve.extended_typing import Any, NestedTuple, Type from gt4py.next import errors from gt4py.next.ffront import ( dialect_ast_enums, @@ -336,7 +336,7 @@ def visit_Return(self, node: ast.Return, **kwargs: Any) -> foast.Return: def visit_Expr(self, node: ast.Expr) -> foast.Expr: return self.visit(node.value) - def visit_Name(self, node: ast.Name, **kwargs: Any) -> foast.Name: + def visit_Name(self, node: ast.Name, **kwargs: Any) -> foast.DataSymbol | foast.Name: loc = self.get_location(node) if isinstance(node.ctx, ast.Store): return foast.DataSymbol(id=node.id, location=loc, type=ts.DeferredType(constraint=None)) @@ -474,6 +474,7 @@ def visit_NotEq(self, node: ast.NotEq, **kwargs: Any) -> foast.CompareOperator: return foast.CompareOperator.NOTEQ def _verify_builtin_type_constructor(self, node: ast.Call) -> None: + assert isinstance(node.func, ast.Name) (arg,) = ( node.args ) # note for review: the change here is unrelated to the actual pr and just a small cleanup @@ -481,29 +482,38 @@ def _verify_builtin_type_constructor(self, node: ast.Call) -> None: if not ( isinstance(arg, ast.Constant) or (isinstance(arg, ast.UnaryOp) and isinstance(arg.operand, ast.Constant)) + or isinstance(arg, ast.GeneratorExp) ): raise errors.DSLError( self.get_location(node), - f"'{self._func_name(node)}()' only takes literal arguments.", + f"'{self._func_name(node)}()' only takes literal arguments or a generator expression.", ) def _func_name(self, node: ast.Call) -> str: return node.func.id # type: ignore[attr-defined] # We want this to fail if the attribute does not exist unexpectedly. - def visit_Call(self, node: ast.Call, **kwargs: Any) -> foast.Call: + def visit_Call(self, node: ast.Call, **kwargs: Any) -> foast.Call | foast.TupleComprehension: if isinstance(node.func, ast.Name): func_name = self._func_name(node) - if func_name == "tuple": - (gen_expr,) = node.args - assert ( - len(gen_expr.generators) == 1 - ) # we don't support (... for ... in ... for ... in ...) - assert ( - gen_expr.generators[0].ifs == [] - ) # we don't support if conditions in comprehensions + if ( + func_name == "tuple" + and len(node.args) == 1 + and isinstance(gen_expr := node.args[0], ast.GeneratorExp) + ): + if len(gen_expr.generators) != 1: + raise errors.DSLError( + self.get_location(node), + "Nested generator expressions are not supported.", + ) + if gen_expr.generators[0].ifs != []: + raise errors.DSLError( + self.get_location(node), + "Conditionals are not supported in generator expressions as they size of " + "the result can only be deduced at runtime.", + ) - def parse_target(target: ast.Tuple | ast.Name) -> tuple[foast.Name]: + def parse_target(target: ast.expr) -> NestedTuple[foast.Name]: if isinstance(target, ast.Tuple): return tuple(parse_target(el) for el in target.elts) assert isinstance(target, ast.Name) diff --git a/src/gt4py/next/iterator/transforms/unroll_map_tuple.py b/src/gt4py/next/iterator/transforms/unroll_map_tuple.py index 66f96d66fa..b47f5ba7d7 100644 --- a/src/gt4py/next/iterator/transforms/unroll_map_tuple.py +++ b/src/gt4py/next/iterator/transforms/unroll_map_tuple.py @@ -25,7 +25,7 @@ class UnrollMapTuple(eve.NodeTranslator): def apply(cls, program: itir.Program, *, uids: utils.IDGeneratorPool): return cls(uids=uids).visit(program) - def visit_FunCall(self, node: itir.Expr): + def visit_FunCall(self, node: itir.FunCall): node = self.generic_visit(node) if cpm.is_call_to(node.fun, "map_tuple"): diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 2861d1aed4..917eef27c9 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -367,7 +367,7 @@ def testee(tracers: tuple[cases.IField, ...], factor: int32) -> tuple[cases.IFie @pytest.mark.uses_tuple_args -def test_nested_tuple_comprehension(cartesian_case): +def test_multi_target_tuple_comprehension(cartesian_case): @gtx.field_operator def testee(nested_tuple: tuple[tuple[int32, cases.IField], ...]) -> tuple[cases.IField, ...]: return tuple(factor * tracer for factor, tracer in nested_tuple) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py index 106b627bb2..d231fc58ec 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py @@ -460,6 +460,17 @@ def testee(arg: float): _ = FieldOperatorParser.apply_to_function(testee) +def test_nested_tuple_compr_failure(): + def testee(nested_tuple: tuple[tuple[gtx.Field[[TDim], float64], ...], ...], factor: int32): + return tuple(grandchild * factor for child in nested_tuple for grandchild in child) + + with pytest.raises( + errors.DSLError, + match=re.escape("Nested generator expressions are not supported."), + ): + _ = FieldOperatorParser.apply_to_function(testee) + + def test_tuple_compr_unpacking_failure(): def testee(arg: tuple[int32, ...]): return tuple(a * b for a, b in arg) From 08ed490b1fc1ca1d9b30cde3472c710eacf723cb Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 6 May 2026 15:12:02 +0200 Subject: [PATCH 04/11] Fix name shadowing --- src/gt4py/next/ffront/field_operator_ast.py | 22 +++++++++--- .../ffront/foast_passes/type_deduction.py | 21 +++++------ src/gt4py/next/ffront/foast_pretty_printer.py | 8 +++-- src/gt4py/next/ffront/foast_to_gtir.py | 35 +++++++++++-------- src/gt4py/next/ffront/func_to_foast.py | 9 +++-- .../ffront_tests/test_execution.py | 30 ++++++++++++++++ 6 files changed, 89 insertions(+), 36 deletions(-) diff --git a/src/gt4py/next/ffront/field_operator_ast.py b/src/gt4py/next/ffront/field_operator_ast.py index aa32a3d939..8043964cf2 100644 --- a/src/gt4py/next/ffront/field_operator_ast.py +++ b/src/gt4py/next/ffront/field_operator_ast.py @@ -123,18 +123,32 @@ class TupleExpr(Expr): elts: list[Expr] -# TODO: give a good error for tuple(... for el in iter if ...) so that users understand that and why we don't support conditionals -# TODO: should this have SymbolTableTrait since target declares a new symbol. Write test that has two comprehensions using the same target name. +# TODO(tehrengruber): extend this to supported nested tuple comprehension. +# e.g. `tuple(element_expr for child in nested_tuple for grand_child in child)` +# would be represented by: +# ``` +# class TupleComprehension(Expr): # ruff: noqa: ERA001 +# inner: TupleComprehensionMapper | NestedTupleCompr # ruff: noqa: ERA001 +# class NestedTupleCompr(Expr, SymbolTableTrait): # ruff: noqa: ERA001 +# params: tuple[DataSymbol] # ruff: noqa: ERA001 +# body: TupleComprehension # ruff: noqa: ERA001 +# ``` class TupleComprehension(Expr): """ tuple(element_expr for target in iterable) """ - element_expr: Expr - target: Any # should be: MaybeNestedInTuple[DataSymbol] but this has a problem in eve + inner: TupleComprehensionMapper iterable: Expr +# this is essentially a lambda, the difference is for a lambda we might not know the type of the +# args, therefor this is named differently at the moment. +class TupleComprehensionMapper(LocatedNode, SymbolTableTrait): + target: Any # should be: NestedInTuple[DataSymbol] but this has a problem in eve + element_expr: Expr + + class UnaryOp(Expr): op: dialect_ast_enums.UnaryOperator operand: Expr diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index a2dca69f80..160dc05e4a 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -5,7 +5,6 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -import collections import textwrap from typing import Any, Optional, Sequence, TypeAlias, TypeVar, cast @@ -682,9 +681,8 @@ def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs: Any) -> foast.TupleEx def visit_TupleComprehension( self, node: foast.TupleComprehension, **kwargs: Any ) -> foast.TupleComprehension: - symtable: collections.ChainMap = kwargs["symtable"] # todo annotation + target = self.visit(node.inner.target, **kwargs) iterable = self.visit(node.iterable, **kwargs) - target = self.visit(node.target, **kwargs) if isinstance(iterable.type, ts.TupleType): if len(iterable.type.types) > 0 and not all( t == iterable.type.types[0] for t in iterable.type.types @@ -702,29 +700,25 @@ def visit_TupleComprehension( f"Iterable in generator expression must be a tuple, got '{iterable.type}'.", ) - new_syms = {} + inner_kwargs = {"symtable": node.inner.annex.symtable, **kwargs} @tree_map(with_path_arg=True) def process_target(target_el: foast.Symbol, path: tuple[int, ...]) -> None: try: - new_syms[target_el.id] = target_el type_ = element_type for i in path: if not isinstance(type_, ts.TupleType) or len(type_.types) <= i: raise IndexError() type_ = type_.types[i] - target_el.type = type_ + return self.visit(target_el, refine_type=type_, **inner_kwargs) except IndexError: raise errors.DSLError( target_el.location, f"Cannot unpack non-iterable '{type_}' object." ) from None - process_target(target) + new_target = process_target(target) - element_expr = self.visit( - node.element_expr, - **{**kwargs, "symtable": symtable.new_child(new_syms)}, - ) + element_expr = self.visit(node.inner.element_expr, **inner_kwargs) return_type: ts.TupleType | ts.VarArgType if isinstance(iterable.type, ts.TupleType): @@ -734,8 +728,9 @@ def process_target(target_el: foast.Symbol, path: tuple[int, ...]) -> None: return_type = ts.VarArgType(element_type=element_expr.type) return foast.TupleComprehension( - element_expr=element_expr, - target=target, + inner=foast.TupleComprehensionMapper( + target=new_target, element_expr=element_expr, location=node.location + ), iterable=iterable, location=node.location, type=return_type, diff --git a/src/gt4py/next/ffront/foast_pretty_printer.py b/src/gt4py/next/ffront/foast_pretty_printer.py index 77495d78f7..273ecaacaf 100644 --- a/src/gt4py/next/ffront/foast_pretty_printer.py +++ b/src/gt4py/next/ffront/foast_pretty_printer.py @@ -118,10 +118,14 @@ def apply(cls, node: foast.LocatedNode, **kwargs: Any) -> str: # type: ignore[o TupleExpr = as_fmt("({', '.join(elts)}{',' if len(elts)==1 else ''})") - TupleComprehension = as_fmt("tuple(({element_expr} for {target} in {iterable}))") - UnaryOp = as_fmt("{op}{operand}") + def visit_TupleComprehension(self, node: foast.TupleComprehension, **kwargs: Any) -> str: + element_expr = self.visit(node.inner.element_expr, **kwargs) + target = self.visit(node.inner.target, **kwargs) + iterable = self.visit(node.iterable, **kwargs) + return f"tuple(({element_expr} for {target} in {iterable}))" + def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> str: if node.op is dialect_ast_enums.UnaryOperator.NOT: op = "not " diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 349ff62804..7383ff9d9e 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -259,22 +259,27 @@ def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs: Any) -> itir.Expr: return im.make_tuple(*[self.visit(el, **kwargs) for el in node.elts]) def visit_TupleComprehension(self, node: foast.TupleComprehension, **kwargs: Any) -> itir.Expr: - sym = next(self.uid_generator["__tuple_compr"]) - targets = [ - self.visit(target, **kwargs) for target in utils.flatten_nested_tuple(node.target) - ] - target_vals = utils.tree_map( - lambda _, path: functools.reduce(lambda el, i: im.tuple_get(i, el), path, sym), - with_path_arg=True, - )(node.target) - - return im.call( - im.call("map_tuple")( - im.lambda_(sym)( - im.let(*zip(targets, target_vals))(self.visit(node.element_expr, **kwargs)) - ) + target = self.visit(node.inner.target, **kwargs) + element_expr = self.visit(node.inner.element_expr, **kwargs) + + # e.g. `(... for el1, el2 in ...)` -> `(let el1 = t[0], el2[1] ... for t in ...)` + if isinstance(target, tuple): + flat_targets = utils.flatten_nested_tuple(target) + new_target = next(self.uid_generator["__tuple_comprh"]) + flat_targets_vals = utils.flatten_nested_tuple( + utils.tree_map( + lambda _, path: functools.reduce( + lambda el, i: im.tuple_get(i, el), path, new_target + ), + with_path_arg=True, + )(target) ) - )(self.visit(node.iterable, **kwargs)) + target = new_target + element_expr = im.let(*zip(flat_targets, flat_targets_vals))(element_expr) + + return im.call(im.call("map_tuple")(im.lambda_(target)(element_expr)))( + self.visit(node.iterable, **kwargs) + ) def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr: # TODO(tehrengruber): extend iterator ir to support unary operators diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index cb5246a7a4..e4126546c0 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -519,9 +519,14 @@ def parse_target(target: ast.expr) -> NestedTuple[foast.Name]: assert isinstance(target, ast.Name) return self.visit(target, **kwargs) + target = parse_target(gen_expr.generators[0].target) + return foast.TupleComprehension( - element_expr=self.visit(gen_expr.elt, **kwargs), - target=parse_target(gen_expr.generators[0].target), + inner=foast.TupleComprehensionMapper( + target=target, + element_expr=self.visit(gen_expr.elt, **kwargs), + location=self.get_location(node), + ), iterable=self.visit(gen_expr.generators[0].iter, **kwargs), location=self.get_location(node), ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 917eef27c9..27812ef5d1 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -366,6 +366,36 @@ def testee(tracers: tuple[cases.IField, ...], factor: int32) -> tuple[cases.IFie ) +@pytest.mark.uses_tuple_args +def test_nested_tuple_comprehension(cartesian_case): + @gtx.field_operator + def testee( + vals: tuple[tuple[cases.IField, ...], ...], factor: int32 + ) -> tuple[tuple[cases.IField, ...], ...]: + return tuple(tuple(grand_child * factor for grand_child in child) for child in vals) + + cases.verify_with_default_data( + cartesian_case, + testee, + ref=lambda t, f: tuple(tuple(grand_child * f for grand_child in child) for child in t), + ) + + +@pytest.mark.uses_tuple_args +def test_nested_tuple_comprehension_shadowing_names(cartesian_case): + @gtx.field_operator + def testee( + vals: tuple[tuple[cases.IField, ...], ...], factor: int32 + ) -> tuple[tuple[cases.IField, ...], ...]: + return tuple(tuple(child * factor for child in child) for child in vals) + + cases.verify_with_default_data( + cartesian_case, + testee, + ref=lambda t, f: tuple(tuple(child * f for child in child) for child in t), + ) + + @pytest.mark.uses_tuple_args def test_multi_target_tuple_comprehension(cartesian_case): @gtx.field_operator From 9e23d2d94a3ff85a5929d957800b1b056e3874a2 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 6 May 2026 15:14:54 +0200 Subject: [PATCH 05/11] Cleanup --- src/gt4py/next/ffront/func_to_foast.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index e4126546c0..f2148119b6 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -475,9 +475,7 @@ def visit_NotEq(self, node: ast.NotEq, **kwargs: Any) -> foast.CompareOperator: def _verify_builtin_type_constructor(self, node: ast.Call) -> None: assert isinstance(node.func, ast.Name) - (arg,) = ( - node.args - ) # note for review: the change here is unrelated to the actual pr and just a small cleanup + (arg,) = node.args if node.func.id == "tuple": if not ( isinstance(arg, ast.Constant) From 97235d94800729552b13443542e6578ff87016e0 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 11 May 2026 11:14:59 +0200 Subject: [PATCH 06/11] More fixes and commentary --- src/gt4py/next/ffront/field_operator_ast.py | 8 ++- src/gt4py/next/ffront/foast_to_past.py | 9 +-- .../iterator/transforms/unroll_map_tuple.py | 2 +- .../next/type_system/type_specifications.py | 2 +- .../next/type_system/type_translation.py | 57 +++++++++++-------- 5 files changed, 45 insertions(+), 33 deletions(-) diff --git a/src/gt4py/next/ffront/field_operator_ast.py b/src/gt4py/next/ffront/field_operator_ast.py index 8043964cf2..e493d6b5c8 100644 --- a/src/gt4py/next/ffront/field_operator_ast.py +++ b/src/gt4py/next/ffront/field_operator_ast.py @@ -21,7 +21,7 @@ datamodels, utils as eve_utils, ) -from gt4py.eve.extended_typing import Any, Generic, TypeAlias, TypeVar, Union +from gt4py.eve.extended_typing import Any, Generic, NestedTuple, TypeAlias, TypeVar, Union from gt4py.eve.traits import SymbolTableTrait from gt4py.eve.type_definitions import StrEnum from gt4py.next.ffront import dialect_ast_enums, type_specifications as ts_ffront @@ -136,6 +136,10 @@ class TupleExpr(Expr): class TupleComprehension(Expr): """ tuple(element_expr for target in iterable) + + Note: The structure here differs from the one in the python ast. Here we group target and + element expression, in order to cleanly nest by the symbols being introduced, whereas in + the python ast target and iterable are grouped into generator nodes. """ inner: TupleComprehensionMapper @@ -145,7 +149,7 @@ class TupleComprehension(Expr): # this is essentially a lambda, the difference is for a lambda we might not know the type of the # args, therefor this is named differently at the moment. class TupleComprehensionMapper(LocatedNode, SymbolTableTrait): - target: Any # should be: NestedInTuple[DataSymbol] but this has a problem in eve + target: NestedTuple[DataSymbol] element_expr: Expr diff --git a/src/gt4py/next/ffront/foast_to_past.py b/src/gt4py/next/ffront/foast_to_past.py index c37cba5a78..84c753c12d 100644 --- a/src/gt4py/next/ffront/foast_to_past.py +++ b/src/gt4py/next/ffront/foast_to_past.py @@ -21,7 +21,7 @@ from gt4py.next.ffront.stages import ConcreteFOASTOperatorDef, ConcretePASTProgramDef from gt4py.next.iterator import ir as itir from gt4py.next.otf import toolchain, workflow -from gt4py.next.type_system import type_specifications as ts +from gt4py.next.type_system import type_info, type_specifications as ts @dataclasses.dataclass(frozen=True) @@ -113,9 +113,10 @@ def __call__(self, inp: ConcreteFOASTOperatorDef) -> ConcretePASTProgramDef: *partial_program_type.definition.kw_only_args.keys(), ] assert isinstance(type_, ts.CallableType) - # assert arg_types[-1] == type_info.return_type( - # type_, with_args=list(arg_types), with_kwargs=kwarg_types - # ) + return_type = type_info.return_type( + type_, with_args=list(arg_types), with_kwargs=kwarg_types + ) + assert type_info.is_concretizable(return_type, arg_types[-1]) assert args_names[-1] == "out" params_decl: list[past.Symbol] = [ diff --git a/src/gt4py/next/iterator/transforms/unroll_map_tuple.py b/src/gt4py/next/iterator/transforms/unroll_map_tuple.py index b47f5ba7d7..0fb86693f4 100644 --- a/src/gt4py/next/iterator/transforms/unroll_map_tuple.py +++ b/src/gt4py/next/iterator/transforms/unroll_map_tuple.py @@ -29,7 +29,7 @@ def visit_FunCall(self, node: itir.FunCall): node = self.generic_visit(node) if cpm.is_call_to(node.fun, "map_tuple"): - # TODO: we have to duplicate the function here since the domain inference can not handle them yet + # TODO: we have to duplicate the function as domain inference only supports direct calls f = node.fun.args[0] tup = node.args[0] itir_inference.reinfer(tup) diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index 409138d593..1016eba34e 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -151,7 +151,7 @@ def __len__(self) -> int: class VarArgType(DataType): """Represents a variable number of arguments of the same type.""" - element_type: DataType # TODO: maybe also support different DataTypes + element_type: DataType def __str__(self) -> str: return f"VarArg[{self.element_type}]" diff --git a/src/gt4py/next/type_system/type_translation.py b/src/gt4py/next/type_system/type_translation.py index 1d7a9aa2f7..c74c8e7c24 100644 --- a/src/gt4py/next/type_system/type_translation.py +++ b/src/gt4py/next/type_system/type_translation.py @@ -75,6 +75,16 @@ def make_constructor_type(type_spec: ts.TypeSpec) -> ts.ConstructorType: ) ) + case ts.DeferredType(constraint=ts.TupleType): + return ts.ConstructorType( + definition=ts.FunctionType( + pos_only_args=[ts.DeferredType(constraint=None)], + pos_or_kw_args={}, + kw_only_args={}, + returns=ts.DeferredType(constraint=ts.VarArgType), + ) + ) + case ts.NamedCollectionType() as named_tuple_type: type_ = pkgutil.resolve_name(named_tuple_type.original_python_type) pos_or_kw_args = {k: t for k, t in zip(type_spec.keys, type_spec.types)} @@ -132,7 +142,7 @@ def canonicalize_type_hint( *, globalns: Optional[dict[str, Any]] = None, localns: Optional[dict[str, Any]] = None, -) -> tuple[Any, tuple[Any, ...], tuple[Any, ...]]: +) -> tuple[Any, tuple[Any, ...] | None, tuple[Any, ...]]: """ Canonicalize python type annotations as a tuple of (canonical_type, type_args, annotated_extra_args). """ @@ -158,7 +168,8 @@ def canonicalize_type_hint( type_hint = xtyping.eval_forward_ref(type_hint, globalns=globalns, localns=localns) canonical_type = typing.get_origin(type_hint) or type_hint - args = typing.get_args(type_hint) + # in order to distinguish tuple and tuple[()] the former returns None here + args = typing.get_args(type_hint) if typing.get_origin(type_hint) else None return canonical_type, args, tuple(extra_args) @@ -179,21 +190,29 @@ def from_type_hint( match canonical_type: case builtins.tuple: - if not args: - raise ValueError(f"Tuple annotation '{type_hint}' requires at least one argument.") - if len(args) == 2 and args[1] is Ellipsis: + if isinstance(args, tuple) and not any(arg is Ellipsis for arg in args): + tuple_types = [from_type_hint_same_ns(arg) for arg in args] + assert all(isinstance(elem, ts.DataType) for elem in tuple_types) + return ts.TupleType(types=tuple_types) + elif isinstance(args, tuple) and len(args) == 2 and args[1] is Ellipsis: return ts.VarArgType(element_type=from_type_hint_same_ns(args[0])) - elif Ellipsis in args: + elif args is None: + # TODO(tehrengruber): We use `DeferredType` until we have an actual representation + # for a generic type. + return ts.DeferredType(constraint=ts.TupleType) + else: raise ValueError( - f"Vararg tuple annotation '{type_hint}' cannot have more than one argument." + f"Tuple annotation '{type_hint}' must either, " + f"be a list of concrete arguments (e.g. 'tuple[int]'), " + f"be a variadic tuple (e.g. 'tuple[int, ...]'), ", + "or have no arguments (e.g. tuple).", ) - tuple_types = [from_type_hint_same_ns(arg) for arg in args] - assert all(isinstance(elem, ts.DataType) for elem in tuple_types) - return ts.TupleType(types=tuple_types) case common.Field: - if (n_args := len(args)) != 2: - raise ValueError(f"Field type requires two arguments, got {n_args}: '{type_hint}'.") + if args is None or len(args) != 2: + raise ValueError( + f"Field type requires two arguments, got {len(args or ())}: '{type_hint}'." + ) dims: list[common.Dimension] = [] dim_arg, dtype_arg = args dim_arg = ( @@ -330,19 +349,7 @@ def from_value(value: Any) -> ts.TypeSpec: return NamespaceProxy(value) else: type_ = xtyping.infer_type(value, annotate_callable_kwargs=True) - if type_ == type[tuple]: - # TODO: this special casing here is not nice, but infer_type is also called on the annotations where - # we don't want to allow unparameterized tuples (or do we?). - symbol_type = ts.ConstructorType( - definition=ts.FunctionType( - pos_only_args=[ts.DeferredType(constraint=None)], - pos_or_kw_args={}, - kw_only_args={}, - returns=ts.DeferredType(constraint=ts.VarArgType), - ) - ) - else: - symbol_type = from_type_hint(type_) + symbol_type = from_type_hint(type_) if isinstance(symbol_type, (ts.DataType, ts.CallableType, ts.OffsetType, ts.DimensionType)): return symbol_type From dd30833d7a4eae7638652a948339108b2c02667e Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 11 May 2026 11:22:45 +0200 Subject: [PATCH 07/11] Fix failing tests --- src/gt4py/next/iterator/builtins.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index 7b24c91884..40daed7ee5 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -412,6 +412,11 @@ def concat_where(*args): raise BackendNotSelectedError() +@builtin_dispatch +def map_tuple(*args): + raise BackendNotSelectedError() + + @builtin_dispatch def get_domain_range(*args): raise BackendNotSelectedError() From 15b233ecab74dfcaea75c78ecf72a1df3a4699ff Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 11 May 2026 11:30:04 +0200 Subject: [PATCH 08/11] Add test for calling a fo from a tuple comprehension --- .../feature_tests/ffront_tests/test_execution.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 27812ef5d1..1b0f4a0d3a 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -366,6 +366,22 @@ def testee(tracers: tuple[cases.IField, ...], factor: int32) -> tuple[cases.IFie ) +@pytest.mark.uses_tuple_args +def test_tuple_comprehension_other_fo(cartesian_case): + @gtx.field_operator + def inner(tracer: cases.IField, factor: int32) -> cases.IField: + return tracer * factor + + @gtx.field_operator + def testee(tracers: tuple[cases.IField, ...], factor: int32) -> tuple[cases.IField, ...]: + return tuple(inner(tracer, factor) for tracer in tracers) + + cases.verify_with_default_data( + cartesian_case, + testee, + ref=lambda t, f: tuple(el * f for el in t), + ) + @pytest.mark.uses_tuple_args def test_nested_tuple_comprehension(cartesian_case): @gtx.field_operator From 4f0b5d45255ac0b82daa9f1cbdd8c3469d06eb6c Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 11 May 2026 11:30:41 +0200 Subject: [PATCH 09/11] Fix format --- .../feature_tests/ffront_tests/test_execution.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 1b0f4a0d3a..d851e7b15c 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -382,6 +382,7 @@ def testee(tracers: tuple[cases.IField, ...], factor: int32) -> tuple[cases.IFie ref=lambda t, f: tuple(el * f for el in t), ) + @pytest.mark.uses_tuple_args def test_nested_tuple_comprehension(cartesian_case): @gtx.field_operator From b27f80bdea786a5289ae30328a0b8e6d258c8098 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 11 May 2026 11:31:24 +0200 Subject: [PATCH 10/11] Fix format --- src/gt4py/next/ffront/field_operator_ast.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/ffront/field_operator_ast.py b/src/gt4py/next/ffront/field_operator_ast.py index e493d6b5c8..51dd7f6707 100644 --- a/src/gt4py/next/ffront/field_operator_ast.py +++ b/src/gt4py/next/ffront/field_operator_ast.py @@ -21,7 +21,7 @@ datamodels, utils as eve_utils, ) -from gt4py.eve.extended_typing import Any, Generic, NestedTuple, TypeAlias, TypeVar, Union +from gt4py.eve.extended_typing import Any, Generic, TypeAlias, TypeVar, Union from gt4py.eve.traits import SymbolTableTrait from gt4py.eve.type_definitions import StrEnum from gt4py.next.ffront import dialect_ast_enums, type_specifications as ts_ffront @@ -149,7 +149,7 @@ class TupleComprehension(Expr): # this is essentially a lambda, the difference is for a lambda we might not know the type of the # args, therefor this is named differently at the moment. class TupleComprehensionMapper(LocatedNode, SymbolTableTrait): - target: NestedTuple[DataSymbol] + target: Any # TODO(tehrengruber): should be NestedTuple[DataSymbol], but this breaks in eve element_expr: Expr From 24f4c90fe8cbee281d9e2d5f08af37aeecd2f13a Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 11 May 2026 13:58:11 +0200 Subject: [PATCH 11/11] Small fix --- src/gt4py/next/ffront/func_to_foast.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index f2148119b6..245a554965 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -476,16 +476,15 @@ def visit_NotEq(self, node: ast.NotEq, **kwargs: Any) -> foast.CompareOperator: def _verify_builtin_type_constructor(self, node: ast.Call) -> None: assert isinstance(node.func, ast.Name) (arg,) = node.args - if node.func.id == "tuple": - if not ( - isinstance(arg, ast.Constant) - or (isinstance(arg, ast.UnaryOp) and isinstance(arg.operand, ast.Constant)) - or isinstance(arg, ast.GeneratorExp) - ): - raise errors.DSLError( - self.get_location(node), - f"'{self._func_name(node)}()' only takes literal arguments or a generator expression.", - ) + if not ( + isinstance(arg, ast.Constant) + or (isinstance(arg, ast.UnaryOp) and isinstance(arg.operand, ast.Constant)) + or (node.func.id == "tuple" and isinstance(arg, ast.GeneratorExp)) + ): + raise errors.DSLError( + self.get_location(node), + f"'{self._func_name(node)}()' only takes literal arguments or a generator expression.", + ) def _func_name(self, node: ast.Call) -> str: return node.func.id # type: ignore[attr-defined] # We want this to fail if the attribute does not exist unexpectedly.