Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 81 additions & 20 deletions projects/eudsl-python-extras/mlir/extras/dialects/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,23 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import inspect
import sys
from functools import update_wrapper
from typing import Optional, List, Union, TypeVar
import typing
from functools import update_wrapper, partial
from typing import Optional, List, Union, TypeVar, get_args
import types

from .. import types
from .. import types as extras_types
from ..ast.py_type import PyTypeVarObject, _Ptr, PyObject
from ..ast.util import copy_func
from ..meta import op_region_builder
from ..util import get_user_code_loc, make_maybe_no_args_decorator
from ..util import (
get_user_code_loc,
make_maybe_no_args_decorator,
)
from ...dialects._ods_common import get_op_result_or_op_results
from ...dialects.func import *
from ...dialects.func import FuncOp, CallOp, ReturnOp, call
from ...ir import (
Attribute,
FlatSymbolRefAttr,
FunctionType,
InsertionPoint,
Expand All @@ -23,6 +29,7 @@
Type,
TypeAttr,
Value,
ShapedType,
)


Expand Down Expand Up @@ -98,6 +105,15 @@ def isalambda(v):
return isinstance(v, type(LAMBDA)) and v.__name__ == LAMBDA.__name__


def _is_valid_type_annotation(r):
return (
isinstance(r, (str, Type, TypeVar))
or isalambda(r)
or type(r) is types.GenericAlias
or (isinstance(r, type) and issubclass(r, Type))
)


def prep_func_types(sig, return_types):
assert not (
not sig.return_annotation is inspect.Signature.empty and len(return_types) > 0
Expand All @@ -111,7 +127,7 @@ def prep_func_types(sig, return_types):
return_types = [return_types]
return_types = list(return_types)
assert all(
isinstance(r, (str, Type, TypeVar)) or isalambda(r) for r in return_types
_is_valid_type_annotation(r) for r in return_types
), f"all return types must be mlir types or strings or TypeVars or lambdas {return_types=}"

input_types = [
Expand All @@ -120,7 +136,7 @@ def prep_func_types(sig, return_types):
if not p.annotation is inspect.Signature.empty
]
assert all(
isinstance(r, (str, Type, TypeVar)) or isalambda(r) for r in input_types
_is_valid_type_annotation(r) for r in input_types
), f"all input types must be mlir types or strings or TypeVars or lambdas {input_types=}"
user_loc = get_user_code_loc()
# If ir.Context is none (like for deferred func emit)
Expand Down Expand Up @@ -221,6 +237,49 @@ def maybe_eval_type_data_closure_vals(
return unevaled_type_data()


def evaluate_generic_alias_type(t: types.GenericAlias | typing.Any):
Comment thread
makslevental marked this conversation as resolved.
if isinstance(t, (Type, Attribute, bool, float, int, str)):
return t
if isinstance(t, (tuple, list)):
return t.__class__(map(evaluate_generic_alias_type, t))
if (
not type(t) is types.GenericAlias
and isinstance(t, type)
and issubclass(t, (Type, Attribute))
):
return t.get()
assert type(t) is types.GenericAlias
args = list(get_args(t))
for i, a in enumerate(args):
args[i] = evaluate_generic_alias_type(a)
return t.get(*args)


def evaluate_type_annotation(v, globals_=None, locals_=None):
if isinstance(v, TypeVar):
v = v.__name__
if isinstance(v, str):
t = Type(eval(v, globals_, locals_))
elif isalambda(v):
t = v()
elif isinstance(v, Type):
t = v
elif (
not type(v) is types.GenericAlias
and isinstance(v, type)
and issubclass(v, Type)
):
t = v.get()
elif type(v) is types.GenericAlias:
if issubclass(v.__origin__, Value):
v = get_args(v)[0]
t = evaluate_generic_alias_type(v)
Comment thread
makslevental marked this conversation as resolved.
else:
raise NotImplementedError(f"unsupported type annotation {v=}")

return t


class FuncBase:
def __init__(
self,
Expand Down Expand Up @@ -326,24 +385,23 @@ def _build_input_types(self) -> Union[list[Type], OpView]:
raise ValueError(
f"T is a reserved generic name; use a different one for {locals['T']}"
)
locals["T"] = types
locals["T"] = extras_types
if "S" in locals:
raise ValueError(
f"S is a reserved generic name; use a different one for {locals['S']}"
)
locals["S"] = ShapedType.get_dynamic_size()

# evaluate type annotations (which could be strings or lambdas)
input_types = self.input_types[:]
for i, v in enumerate(input_types):
if isinstance(v, TypeVar):
v = v.__name__
if isinstance(v, str):
input_types[i] = Type(eval(v, self.body_builder.__globals__, locals))
elif isalambda(v):
input_types[i] = v()

return input_types
return list(
map(
partial(
evaluate_type_annotation,
globals_=self.body_builder.__globals__,
locals_=locals,
),
self.input_types,
)
)

def emit(self, *call_args, decl=False, force=False) -> FuncOp:
if self._func_op and not (decl or force):
Expand All @@ -365,7 +423,10 @@ def emit(self, *call_args, decl=False, force=False) -> FuncOp:
function_type = TypeAttr.get(self.function_type)
else:
function_type = TypeAttr.get(
FunctionType.get(inputs=input_types, results=self.return_types)
FunctionType.get(
inputs=input_types,
results=list(map(evaluate_type_annotation, self.return_types)),
)
)

self._func_op = self.func_op_ctor(
Expand Down
2 changes: 1 addition & 1 deletion projects/eudsl-python-extras/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def load_requirements(fname):
],
"mlir": ["mlir-python-bindings"],
},
python_requires=">=3.8",
python_requires=">=3.10",
include_package_data=True,
packages=packages,
# lhs is package namespace, rhs is path (relative to this setup.py)
Expand Down
Loading