Skip to content
Open
32 changes: 31 additions & 1 deletion src/gt4py/next/ffront/field_operator_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from __future__ import annotations

import functools
from typing import Any, Generic, TypeAlias, TypeVar, Union

from gt4py import eve
from gt4py.eve import (
Expand All @@ -22,6 +21,7 @@
datamodels,
utils as eve_utils,
)
from gt4py.eve.extended_typing import Any, Generic, TypeAlias, TypeVar, Union
from gt4py.eve.traits import SymbolTableTrait
from gt4py.eve.type_definitions import StrEnum
from gt4py.next.ffront import dialect_ast_enums, type_specifications as ts_ffront
Expand Down Expand Up @@ -123,6 +123,36 @@ class TupleExpr(Expr):
elts: list[Expr]


# TODO(tehrengruber): extend this to supported nested tuple comprehension.
# e.g. `tuple(element_expr for child in nested_tuple for grand_child in child)`
# would be represented by:
# ```
# class TupleComprehension(Expr): # ruff: noqa: ERA001
# inner: TupleComprehensionMapper | NestedTupleCompr # ruff: noqa: ERA001
# class NestedTupleCompr(Expr, SymbolTableTrait): # ruff: noqa: ERA001
# params: tuple[DataSymbol] # ruff: noqa: ERA001
# body: TupleComprehension # ruff: noqa: ERA001
# ```
class TupleComprehension(Expr):
"""
tuple(element_expr for target in iterable)

Note: The structure here differs from the one in the python ast. Here we group target and
element expression, in order to cleanly nest by the symbols being introduced, whereas in
the python ast target and iterable are grouped into generator nodes.
"""

inner: TupleComprehensionMapper
iterable: Expr


# this is essentially a lambda, the difference is for a lambda we might not know the type of the
# args, therefor this is named differently at the moment.
class TupleComprehensionMapper(LocatedNode, SymbolTableTrait):
target: Any # TODO(tehrengruber): should be NestedTuple[DataSymbol], but this breaks in eve
element_expr: Expr


class UnaryOp(Expr):
op: dialect_ast_enums.UnaryOperator
operand: Expr
Expand Down
64 changes: 63 additions & 1 deletion src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

import textwrap
from typing import Any, Optional, Sequence, TypeAlias, TypeVar, cast

Expand All @@ -24,6 +23,7 @@
from gt4py.next.ffront.foast_passes import utils as foast_utils
from gt4py.next.iterator import builtins
from gt4py.next.type_system import type_info, type_specifications as ts, type_translation
from gt4py.next.utils import tree_map


OperatorNodeT = TypeVar("OperatorNodeT", bound=foast.LocatedNode)
Expand Down Expand Up @@ -428,6 +428,10 @@ def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> foast.Subscri
f"Tuples need to be indexed with literal integers, got '{node.index}'.",
) from ex
new_type = types[index]
case ts.VarArgType(element_type=element_type):
new_type = (
element_type # TODO: we only temporarily allow any index for vararg types
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for direct access to tracers[0] * factor, tracers[1] * factor, which I personally think is an anti pattern. I left it here until we take a decision on this. We could also make it an optional feature. One of the disadvantages is that it is not possible to fully type check the field operator at definition time, since the tuple length is only known at call / compile time. The user will then get an error in unroll_map_tuple.

)
case ts.OffsetType(source=source, target=(target1, target2)):
if not target2.kind == DimensionKind.LOCAL:
raise errors.DSLError(
Expand Down Expand Up @@ -674,6 +678,64 @@ def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs: Any) -> foast.TupleEx
new_type = ts.TupleType(types=[element.type for element in new_elts])
return foast.TupleExpr(elts=new_elts, type=new_type, location=node.location)

def visit_TupleComprehension(
self, node: foast.TupleComprehension, **kwargs: Any
) -> foast.TupleComprehension:
target = self.visit(node.inner.target, **kwargs)
iterable = self.visit(node.iterable, **kwargs)
if isinstance(iterable.type, ts.TupleType):
if len(iterable.type.types) > 0 and not all(
t == iterable.type.types[0] for t in iterable.type.types
):
raise errors.DSLError(
iterable.location,
"Not implemented. All elements of the iterable in a tuple comprehensions must have the same type.",
)
element_type = iterable.type.types[0]
elif isinstance(iterable.type, ts.VarArgType):
element_type = iterable.type.element_type
else:
raise errors.DSLError(
iterable.location,
f"Iterable in generator expression must be a tuple, got '{iterable.type}'.",
)

inner_kwargs = {"symtable": node.inner.annex.symtable, **kwargs}

@tree_map(with_path_arg=True)
def process_target(target_el: foast.Symbol, path: tuple[int, ...]) -> None:
try:
type_ = element_type
for i in path:
if not isinstance(type_, ts.TupleType) or len(type_.types) <= i:
raise IndexError()
type_ = type_.types[i]
return self.visit(target_el, refine_type=type_, **inner_kwargs)
except IndexError:
raise errors.DSLError(
target_el.location, f"Cannot unpack non-iterable '{type_}' object."
) from None

new_target = process_target(target)

element_expr = self.visit(node.inner.element_expr, **inner_kwargs)

return_type: ts.TupleType | ts.VarArgType
if isinstance(iterable.type, ts.TupleType):
return_type = ts.TupleType(types=[element_expr.type] * len(iterable.type.types))
else:
assert isinstance(iterable.type, ts.VarArgType)
return_type = ts.VarArgType(element_type=element_expr.type)

return foast.TupleComprehension(
inner=foast.TupleComprehensionMapper(
target=new_target, element_expr=element_expr, location=node.location
),
iterable=iterable,
location=node.location,
type=return_type,
)

def visit_Call(self, node: foast.Call, **kwargs: Any) -> foast.Call:
new_func = self.visit(node.func, **kwargs)
new_args = self.visit(node.args, **kwargs)
Expand Down
6 changes: 6 additions & 0 deletions src/gt4py/next/ffront/foast_pretty_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@ def apply(cls, node: foast.LocatedNode, **kwargs: Any) -> str: # type: ignore[o

UnaryOp = as_fmt("{op}{operand}")

def visit_TupleComprehension(self, node: foast.TupleComprehension, **kwargs: Any) -> str:
element_expr = self.visit(node.inner.element_expr, **kwargs)
target = self.visit(node.inner.target, **kwargs)
iterable = self.visit(node.iterable, **kwargs)
return f"tuple(({element_expr} for {target} in {iterable}))"

def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> str:
if node.op is dialect_ast_enums.UnaryOperator.NOT:
op = "not "
Expand Down
24 changes: 24 additions & 0 deletions src/gt4py/next/ffront/foast_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@


import dataclasses
import functools
from typing import Any, Callable, Optional

from gt4py import eve
Expand Down Expand Up @@ -257,6 +258,29 @@ def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> itir.Expr:
def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs: Any) -> itir.Expr:
return im.make_tuple(*[self.visit(el, **kwargs) for el in node.elts])

def visit_TupleComprehension(self, node: foast.TupleComprehension, **kwargs: Any) -> itir.Expr:
target = self.visit(node.inner.target, **kwargs)
element_expr = self.visit(node.inner.element_expr, **kwargs)

# e.g. `(... for el1, el2 in ...)` -> `(let el1 = t[0], el2[1] ... for t in ...)`
if isinstance(target, tuple):
flat_targets = utils.flatten_nested_tuple(target)
new_target = next(self.uid_generator["__tuple_comprh"])
flat_targets_vals = utils.flatten_nested_tuple(
utils.tree_map(
lambda _, path: functools.reduce(
lambda el, i: im.tuple_get(i, el), path, new_target
),
with_path_arg=True,
)(target)
)
target = new_target
element_expr = im.let(*zip(flat_targets, flat_targets_vals))(element_expr)

return im.call(im.call("map_tuple")(im.lambda_(target)(element_expr)))(
self.visit(node.iterable, **kwargs)
)

def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr:
# TODO(tehrengruber): extend iterator ir to support unary operators
dtype = type_info.extract_dtype(node.type)
Expand Down
3 changes: 2 additions & 1 deletion src/gt4py/next/ffront/foast_to_past.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,10 @@ def __call__(self, inp: ConcreteFOASTOperatorDef) -> ConcretePASTProgramDef:
*partial_program_type.definition.kw_only_args.keys(),
]
assert isinstance(type_, ts.CallableType)
assert arg_types[-1] == type_info.return_type(
return_type = type_info.return_type(
type_, with_args=list(arg_types), with_kwargs=kwarg_types
)
assert type_info.is_concretizable(return_type, arg_types[-1])
assert args_names[-1] == "out"

params_decl: list[past.Symbol] = [
Expand Down
72 changes: 57 additions & 15 deletions src/gt4py/next/ffront/func_to_foast.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
import ast
import textwrap
import typing
from typing import Any, Type

import gt4py.eve as eve
from gt4py.eve.extended_typing import Any, NestedTuple, Type
from gt4py.next import errors
from gt4py.next.ffront import (
dialect_ast_enums,
Expand Down Expand Up @@ -336,8 +336,13 @@ def visit_Return(self, node: ast.Return, **kwargs: Any) -> foast.Return:
def visit_Expr(self, node: ast.Expr) -> foast.Expr:
return self.visit(node.value)

def visit_Name(self, node: ast.Name, **kwargs: Any) -> foast.Name:
return foast.Name(id=node.id, location=self.get_location(node))
def visit_Name(self, node: ast.Name, **kwargs: Any) -> foast.DataSymbol | foast.Name:
loc = self.get_location(node)
if isinstance(node.ctx, ast.Store):
return foast.DataSymbol(id=node.id, location=loc, type=ts.DeferredType(constraint=None))
else:
assert isinstance(node.ctx, ast.Load)
return foast.Name(id=node.id, location=loc)

def visit_UnaryOp(self, node: ast.UnaryOp, **kwargs: Any) -> foast.UnaryOp:
return foast.UnaryOp(
Expand Down Expand Up @@ -469,24 +474,61 @@ def visit_NotEq(self, node: ast.NotEq, **kwargs: Any) -> foast.CompareOperator:
return foast.CompareOperator.NOTEQ

def _verify_builtin_type_constructor(self, node: ast.Call) -> None:
if len(node.args) > 0:
arg = node.args[0]
if not (
isinstance(arg, ast.Constant)
or (isinstance(arg, ast.UnaryOp) and isinstance(arg.operand, ast.Constant))
):
raise errors.DSLError(
self.get_location(node),
f"'{self._func_name(node)}()' only takes literal arguments.",
)
assert isinstance(node.func, ast.Name)
(arg,) = node.args
if not (
isinstance(arg, ast.Constant)
or (isinstance(arg, ast.UnaryOp) and isinstance(arg.operand, ast.Constant))
or (node.func.id == "tuple" and isinstance(arg, ast.GeneratorExp))
):
raise errors.DSLError(
self.get_location(node),
f"'{self._func_name(node)}()' only takes literal arguments or a generator expression.",
)

def _func_name(self, node: ast.Call) -> str:
return node.func.id # type: ignore[attr-defined] # We want this to fail if the attribute does not exist unexpectedly.

def visit_Call(self, node: ast.Call, **kwargs: Any) -> foast.Call:
# TODO(tehrengruber): is this still needed or redundant with the checks in type deduction?
def visit_Call(self, node: ast.Call, **kwargs: Any) -> foast.Call | foast.TupleComprehension:
if isinstance(node.func, ast.Name):
func_name = self._func_name(node)

if (
func_name == "tuple"
and len(node.args) == 1
and isinstance(gen_expr := node.args[0], ast.GeneratorExp)
):
if len(gen_expr.generators) != 1:
raise errors.DSLError(
self.get_location(node),
"Nested generator expressions are not supported.",
)
if gen_expr.generators[0].ifs != []:
raise errors.DSLError(
self.get_location(node),
"Conditionals are not supported in generator expressions as they size of "
"the result can only be deduced at runtime.",
)

def parse_target(target: ast.expr) -> NestedTuple[foast.Name]:
if isinstance(target, ast.Tuple):
return tuple(parse_target(el) for el in target.elts)
assert isinstance(target, ast.Name)
return self.visit(target, **kwargs)

target = parse_target(gen_expr.generators[0].target)

return foast.TupleComprehension(
inner=foast.TupleComprehensionMapper(
target=target,
element_expr=self.visit(gen_expr.elt, **kwargs),
location=self.get_location(node),
),
iterable=self.visit(gen_expr.generators[0].iter, **kwargs),
location=self.get_location(node),
)

# TODO(tehrengruber): is this still needed or redundant with the checks in type deduction?
if func_name in fbuiltins.TYPE_BUILTIN_NAMES:
self._verify_builtin_type_constructor(node)

Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/ffront/past_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def visit_Call(self, node: past.Call, **kwargs: Any) -> past.Call:
operator_return_type = type_info.return_type(
new_func.type, with_args=arg_types, with_kwargs=kwarg_types
)
if operator_return_type != new_kwargs["out"].type:
if not type_info.is_compatible_type(operator_return_type, new_kwargs["out"].type):
raise ValueError(
"Expected keyword argument 'out' to be of "
f"type '{operator_return_type}', got "
Expand Down
8 changes: 7 additions & 1 deletion src/gt4py/next/iterator/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,11 @@ def concat_where(*args):
raise BackendNotSelectedError()


@builtin_dispatch
def map_tuple(*args):
raise BackendNotSelectedError()


@builtin_dispatch
def get_domain_range(*args):
raise BackendNotSelectedError()
Expand Down Expand Up @@ -498,7 +503,8 @@ def get_domain_range(*args):
"lift",
"make_const_list",
"make_tuple",
"map_",
"map_tuple",
"map_", # TODO: rename to map_list
"named_range",
"neighbors",
"reduce",
Expand Down
3 changes: 3 additions & 0 deletions src/gt4py/next/iterator/transforms/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
prune_empty_concat_where,
remove_broadcast,
symbol_ref_utils,
unroll_map_tuple,
)
from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet
from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple
Expand Down Expand Up @@ -177,6 +178,7 @@ def apply_common_transforms(
) # domain inference does not support dynamic offsets yet
ir = infer_domain_ops.InferDomainOps.apply(ir)
ir = concat_where.canonicalize_domain_argument(ir)
ir = unroll_map_tuple.UnrollMapTuple.apply(ir, uids=uids)

ir = infer_domain.infer_program(
ir,
Expand Down Expand Up @@ -291,6 +293,7 @@ def apply_fieldview_transforms(

ir = infer_domain_ops.InferDomainOps.apply(ir)
ir = concat_where.canonicalize_domain_argument(ir)
ir = unroll_map_tuple.UnrollMapTuple.apply(ir, uids=uids)
ir = ConstantFolding.apply(ir) # type: ignore[assignment] # always an itir.Program

ir = infer_domain.infer_program(
Expand Down
Loading
Loading