diff --git a/projects/eudsl-python-extras/mlir/extras/dialects/func.py b/projects/eudsl-python-extras/mlir/extras/dialects/func.py index 89568e56..29a7b5a1 100644 --- a/projects/eudsl-python-extras/mlir/extras/dialects/func.py +++ b/projects/eudsl-python-extras/mlir/extras/dialects/func.py @@ -3,17 +3,23 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import inspect import sys -from functools import update_wrapper -from typing import Optional, List, Union, TypeVar +import typing +from functools import update_wrapper, partial +from typing import Optional, List, Union, TypeVar, get_args +import types -from .. import types +from .. import types as extras_types from ..ast.py_type import PyTypeVarObject, _Ptr, PyObject from ..ast.util import copy_func from ..meta import op_region_builder -from ..util import get_user_code_loc, make_maybe_no_args_decorator +from ..util import ( + get_user_code_loc, + make_maybe_no_args_decorator, +) from ...dialects._ods_common import get_op_result_or_op_results -from ...dialects.func import * +from ...dialects.func import FuncOp, CallOp, ReturnOp, call from ...ir import ( + Attribute, FlatSymbolRefAttr, FunctionType, InsertionPoint, @@ -23,6 +29,7 @@ Type, TypeAttr, Value, + ShapedType, ) @@ -98,6 +105,15 @@ def isalambda(v): return isinstance(v, type(LAMBDA)) and v.__name__ == LAMBDA.__name__ +def _is_valid_type_annotation(r): + return ( + isinstance(r, (str, Type, TypeVar)) + or isalambda(r) + or type(r) is types.GenericAlias + or (isinstance(r, type) and issubclass(r, Type)) + ) + + def prep_func_types(sig, return_types): assert not ( not sig.return_annotation is inspect.Signature.empty and len(return_types) > 0 @@ -111,7 +127,7 @@ def prep_func_types(sig, return_types): return_types = [return_types] return_types = list(return_types) assert all( - isinstance(r, (str, Type, TypeVar)) or isalambda(r) for r in return_types + _is_valid_type_annotation(r) for r in return_types ), f"all return types must be mlir types or strings or TypeVars or lambdas {return_types=}" input_types = [ @@ -120,7 +136,7 @@ def prep_func_types(sig, return_types): if not p.annotation is inspect.Signature.empty ] assert all( - isinstance(r, (str, Type, TypeVar)) or isalambda(r) for r in input_types + _is_valid_type_annotation(r) for r in input_types ), f"all input types must be mlir types or strings or TypeVars or lambdas {input_types=}" user_loc = get_user_code_loc() # If ir.Context is none (like for deferred func emit) @@ -221,6 +237,49 @@ def maybe_eval_type_data_closure_vals( return unevaled_type_data() +def evaluate_generic_alias_type(t: types.GenericAlias | typing.Any): + if isinstance(t, (Type, Attribute, bool, float, int, str)): + return t + if isinstance(t, (tuple, list)): + return t.__class__(map(evaluate_generic_alias_type, t)) + if ( + not type(t) is types.GenericAlias + and isinstance(t, type) + and issubclass(t, (Type, Attribute)) + ): + return t.get() + assert type(t) is types.GenericAlias + args = list(get_args(t)) + for i, a in enumerate(args): + args[i] = evaluate_generic_alias_type(a) + return t.get(*args) + + +def evaluate_type_annotation(v, globals_=None, locals_=None): + if isinstance(v, TypeVar): + v = v.__name__ + if isinstance(v, str): + t = Type(eval(v, globals_, locals_)) + elif isalambda(v): + t = v() + elif isinstance(v, Type): + t = v + elif ( + not type(v) is types.GenericAlias + and isinstance(v, type) + and issubclass(v, Type) + ): + t = v.get() + elif type(v) is types.GenericAlias: + if issubclass(v.__origin__, Value): + v = get_args(v)[0] + t = evaluate_generic_alias_type(v) + else: + raise NotImplementedError(f"unsupported type annotation {v=}") + + return t + + class FuncBase: def __init__( self, @@ -326,24 +385,23 @@ def _build_input_types(self) -> Union[list[Type], OpView]: raise ValueError( f"T is a reserved generic name; use a different one for {locals['T']}" ) - locals["T"] = types + locals["T"] = extras_types if "S" in locals: raise ValueError( f"S is a reserved generic name; use a different one for {locals['S']}" ) locals["S"] = ShapedType.get_dynamic_size() - # evaluate type annotations (which could be strings or lambdas) - input_types = self.input_types[:] - for i, v in enumerate(input_types): - if isinstance(v, TypeVar): - v = v.__name__ - if isinstance(v, str): - input_types[i] = Type(eval(v, self.body_builder.__globals__, locals)) - elif isalambda(v): - input_types[i] = v() - - return input_types + return list( + map( + partial( + evaluate_type_annotation, + globals_=self.body_builder.__globals__, + locals_=locals, + ), + self.input_types, + ) + ) def emit(self, *call_args, decl=False, force=False) -> FuncOp: if self._func_op and not (decl or force): @@ -365,7 +423,10 @@ def emit(self, *call_args, decl=False, force=False) -> FuncOp: function_type = TypeAttr.get(self.function_type) else: function_type = TypeAttr.get( - FunctionType.get(inputs=input_types, results=self.return_types) + FunctionType.get( + inputs=input_types, + results=list(map(evaluate_type_annotation, self.return_types)), + ) ) self._func_op = self.func_op_ctor( diff --git a/projects/eudsl-python-extras/setup.py b/projects/eudsl-python-extras/setup.py index 4d4765a5..ce288888 100644 --- a/projects/eudsl-python-extras/setup.py +++ b/projects/eudsl-python-extras/setup.py @@ -58,7 +58,7 @@ def load_requirements(fname): ], "mlir": ["mlir-python-bindings"], }, - python_requires=">=3.8", + python_requires=">=3.10", include_package_data=True, packages=packages, # lhs is package namespace, rhs is path (relative to this setup.py) diff --git a/projects/eudsl-python-extras/tests/dialect/test_func.py b/projects/eudsl-python-extras/tests/dialect/test_func.py index 7d9408aa..65d5f666 100644 --- a/projects/eudsl-python-extras/tests/dialect/test_func.py +++ b/projects/eudsl-python-extras/tests/dialect/test_func.py @@ -6,7 +6,22 @@ import mlir.extras.types as T import pytest -from mlir.ir import FunctionType +from mlir.ir import ( + ComplexType, + F32Type, + F64Type, + FunctionType, + IndexType, + IntegerAttr, + IntegerType, + MemRefType, + OpaqueType, + RankedTensorType, + UnrankedMemRefType, + UnrankedTensorType, + Value, + VectorType, +) from mlir.extras.context import mlir_mod_ctx, RAIIMLIRContextModule from mlir.extras.dialects.arith import constant @@ -184,3 +199,232 @@ def demo_fun1(a, b): # CHECK: } filecheck_with_comments(ctx.module) + + +def test_integer_index_type_annotations(ctx: MLIRContext): + @func + def f_i32(x: IntegerType[32]) -> IntegerType[32]: ... + + @func + def f_i64(x: IntegerType[64]) -> IntegerType[64]: ... + + @func + def f_index(x: IndexType) -> IndexType: ... + + ctx.module.operation.verify() + + # CHECK: func.func private @f_i32(i32) -> i32 + # CHECK: func.func private @f_i64(i64) -> i64 + # CHECK: func.func private @f_index(index) -> index + + filecheck_with_comments(ctx.module) + + +def test_complex_type_annotations(ctx: MLIRContext): + @func + def f_complex_f32(x: ComplexType[F32Type]) -> ComplexType[F32Type]: ... + + @func + def f_complex_f64(x: ComplexType[F64Type]) -> ComplexType[F64Type]: ... + + @func + def f_complex_i32( + x: ComplexType[IntegerType[32]], + ) -> ComplexType[IntegerType[32]]: ... + + ctx.module.operation.verify() + + # CHECK: func.func private @f_complex_f32(complex) -> complex + # CHECK: func.func private @f_complex_f64(complex) -> complex + # CHECK: func.func private @f_complex_i32(complex) -> complex + + filecheck_with_comments(ctx.module) + + +def test_vector_type_annotations(ctx: MLIRContext): + @func + def f_vec_1d(x: VectorType[[4], F32Type]) -> VectorType[[4], F32Type]: ... + + @func + def f_vec_2d(x: VectorType[[2, 3], F32Type]) -> VectorType[[2, 3], F32Type]: ... + + @func + def f_vec_i32( + x: VectorType[[8], IntegerType[32]], + ) -> VectorType[[8], IntegerType[32]]: ... + + ctx.module.operation.verify() + + # CHECK: func.func private @f_vec_1d(vector<4xf32>) -> vector<4xf32> + # CHECK: func.func private @f_vec_2d(vector<2x3xf32>) -> vector<2x3xf32> + # CHECK: func.func private @f_vec_i32(vector<8xi32>) -> vector<8xi32> + + filecheck_with_comments(ctx.module) + + +def test_tensor_type_annotations(ctx: MLIRContext): + @func + def f_ranked_tensor( + x: RankedTensorType[[2, 3], F32Type], + ) -> RankedTensorType[[2, 3], F32Type]: ... + + @func + def f_unranked_tensor( + x: UnrankedTensorType[F32Type], + ) -> UnrankedTensorType[F32Type]: ... + + ctx.module.operation.verify() + + # CHECK: func.func private @f_ranked_tensor(tensor<2x3xf32>) -> tensor<2x3xf32> + # CHECK: func.func private @f_unranked_tensor(tensor<*xf32>) -> tensor<*xf32> + + filecheck_with_comments(ctx.module) + + +def test_memref_type_annotations(ctx: MLIRContext): + @func + def f_memref( + x: MemRefType[[2, 3], F32Type], + ) -> MemRefType[[2, 3], F32Type]: ... + + @func + def f_unranked_memref( + x: UnrankedMemRefType[F32Type, IntegerAttr[IntegerType[64], 2]], + ) -> UnrankedMemRefType[F32Type, IntegerAttr[IntegerType[64], 2]]: ... + + ctx.module.operation.verify() + + # CHECK: func.func private @f_memref(memref<2x3xf32>) -> memref<2x3xf32> + # CHECK: func.func private @f_unranked_memref(memref<*xf32, 2>) -> memref<*xf32, 2> + + filecheck_with_comments(ctx.module) + + +def test_opaque_type_annotation(ctx: MLIRContext): + @func + def f_opaque(x: OpaqueType["tensor", "bob"]) -> OpaqueType["tensor", "bob"]: ... + + ctx.module.operation.verify() + + # CHECK: func.func private @f_opaque(!tensor.bob) -> !tensor.bob + + filecheck_with_comments(ctx.module) + + +def test_type_annotations_with_body(ctx: MLIRContext): + @func + def f_f32(x: F32Type): + return x + + @func + def f_i32(x: IntegerType[32]): + return x + + @func + def f_index(x: IndexType): + return x + + @func + def f_complex(x: ComplexType[F32Type]): + return x + + @func + def f_vector(x: VectorType[[4], F32Type]): + return x + + @func + def f_ranked_tensor(x: RankedTensorType[[2, 3], F32Type]): + return x + + @func + def f_memref(x: MemRefType[[2, 3], F32Type]): + return x + + f_f32.emit() + f_i32.emit() + f_index.emit() + f_complex.emit() + f_vector.emit() + f_ranked_tensor.emit() + f_memref.emit() + + ctx.module.operation.verify() + + # CHECK: func.func @f_f32(%[[V:.*]]: f32) -> f32 { + # CHECK: func.func @f_i32(%[[V:.*]]: i32) -> i32 { + # CHECK: func.func @f_index(%[[V:.*]]: index) -> index { + # CHECK: func.func @f_complex(%[[V:.*]]: complex) -> complex { + # CHECK: func.func @f_vector(%[[V:.*]]: vector<4xf32>) -> vector<4xf32> { + # CHECK: func.func @f_ranked_tensor(%[[V:.*]]: tensor<2x3xf32>) -> tensor<2x3xf32> { + # CHECK: func.func @f_memref(%[[V:.*]]: memref<2x3xf32>) -> memref<2x3xf32> { + + filecheck_with_comments(ctx.module) + + +def test_multiple_arg_type_annotations(ctx: MLIRContext): + @func + def f_two_args(x: IntegerType[32], y: F32Type) -> F32Type: ... + + @func + def f_three_args( + a: IntegerType[32], + b: VectorType[[4], F32Type], + c: MemRefType[[2, 3], F32Type], + ) -> VectorType[[4], F32Type]: ... + + @func + def f_same_type_args( + x: RankedTensorType[[2, 3], F32Type], + y: RankedTensorType[[2, 3], F32Type], + ) -> RankedTensorType[[2, 3], F32Type]: ... + + ctx.module.operation.verify() + + # CHECK: func.func private @f_two_args(i32, f32) -> f32 + # CHECK: func.func private @f_three_args(i32, vector<4xf32>, memref<2x3xf32>) -> vector<4xf32> + # CHECK: func.func private @f_same_type_args(tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> + + filecheck_with_comments(ctx.module) + + +def test_multiple_arg_with_body(ctx: MLIRContext): + @func + def f_return_first(x: IntegerType[32], y: IntegerType[64]): + return x + + @func + def f_mixed_args( + a: F32Type, + b: VectorType[[4], F32Type], + c: MemRefType[[2, 3], F32Type], + ): + return b + + f_return_first.emit() + f_mixed_args.emit() + + ctx.module.operation.verify() + + # CHECK: func.func @f_return_first(%[[A:.*]]: i32, %[[B:.*]]: i64) -> i32 { + # CHECK: func.func @f_mixed_args(%[[A:.*]]: f32, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: memref<2x3xf32>) -> vector<4xf32> { + + filecheck_with_comments(ctx.module) + + +def test_multiple_arg_with_body_with_value(ctx: MLIRContext): + + @func + def f_mixed_args( + a: Value[F32Type], + b: Value[VectorType[[4], F32Type]], + c: Value[MemRefType[[2, 3], F32Type]], + ): + return b + + f_mixed_args.emit() + + ctx.module.operation.verify() + + # CHECK: func.func @f_mixed_args(%[[A:.*]]: f32, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: memref<2x3xf32>) -> vector<4xf32> { + + filecheck_with_comments(ctx.module) diff --git a/projects/eudsl-python-extras/tests/dialect/test_memref.py b/projects/eudsl-python-extras/tests/dialect/test_memref.py index d2b93637..387de183 100644 --- a/projects/eudsl-python-extras/tests/dialect/test_memref.py +++ b/projects/eudsl-python-extras/tests/dialect/test_memref.py @@ -846,3 +846,18 @@ def test_reinterpret_cast_nonzero_dynamic_offset(ctx: MLIRContext): # CHECK: %[[OUT:.*]] = memref.reinterpret_cast %[[ALLOC]] to offset: [%[[C3]]], sizes: [6, 1], strides: [1, 1] : memref<2x3xf32> to memref<6x1xf32, strided<[1, 1], offset: ?>> filecheck_with_comments(ctx.module) + + +def test_reinterpret_cast_zero_sized_to_dynamic(ctx: MLIRContext): + input = alloc((0,), T.f32()) + c0 = constant(0, index=True) + c1 = constant(1, index=True) + reinterpret_cast(input, offsets=[c0], sizes=[c1], strides=[c1]) + + # CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<0xf32> + # CHECK: %[[C0:.*]] = arith.constant 0 : index + # CHECK: %[[C1:.*]] = arith.constant 1 : index + # CHECK: %[[OUT:.*]] = memref.reinterpret_cast %[[ALLOC]] to offset: [%[[C0]]], sizes: [%[[C1]]], strides: [%[[C1]]] : memref<0xf32> to memref> + + filecheck_with_comments(ctx.module) + diff --git a/projects/eudsl-python-extras/tests/dialect/test_types.py b/projects/eudsl-python-extras/tests/dialect/test_types.py deleted file mode 100644 index 6d79ed0a..00000000 --- a/projects/eudsl-python-extras/tests/dialect/test_types.py +++ /dev/null @@ -1,43 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -import mlir.extras.types as T -import pytest -from mlir.extras.types import tensor, memref, vector - -from mlir.extras.dialects.memref import alloc -from mlir.extras.dialects.tensor import S, empty - -# noinspection PyUnresolvedReferences -from mlir.extras.testing import mlir_ctx as ctx, filecheck, MLIRContext - - -def test_shaped_types(ctx: MLIRContext): - t = tensor(S, 3, S, T.f64()) - assert repr(t) == "RankedTensorType(tensor)" - ut = tensor(T.f64()) - assert repr(ut) == "UnrankedTensorType(tensor<*xf64>)" - t = tensor(S, 3, S, element_type=T.f64()) - assert repr(t) == "RankedTensorType(tensor)" - ut = tensor(element_type=T.f64()) - assert repr(ut) == "UnrankedTensorType(tensor<*xf64>)" - - m = memref(S, 3, S, T.f64()) - assert repr(m) == "MemRefType(memref)" - um = memref(T.f64()) - assert repr(um) == "UnrankedMemRefType(memref<*xf64>)" - m = memref(S, 3, S, element_type=T.f64()) - assert repr(m) == "MemRefType(memref)" - um = memref(element_type=T.f64()) - assert repr(um) == "UnrankedMemRefType(memref<*xf64>)" - - v = vector(3, 3, 3, T.f64()) - assert repr(v) == "VectorType(vector<3x3x3xf64>)" - - -def test_n_elements(ctx: MLIRContext): - ten = empty(1, 2, 3, 4, T.i32()) - assert ten.n_elements == 1 * 2 * 3 * 4 - - mem = alloc((1, 2, 3, 4), T.i32()) - assert mem.n_elements == 1 * 2 * 3 * 4 diff --git a/projects/eudsl-python-extras/tests/test_types.py b/projects/eudsl-python-extras/tests/test_types.py new file mode 100644 index 00000000..1c9d1871 --- /dev/null +++ b/projects/eudsl-python-extras/tests/test_types.py @@ -0,0 +1,280 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import mlir.extras.types as T +from mlir.extras.types import tensor, memref, vector +from mlir.extras.dialects.func import evaluate_generic_alias_type + +from mlir.extras.dialects.memref import alloc +from mlir.extras.dialects.tensor import S, empty +from mlir.ir import ( + IntegerType, + IndexType, + NoneType, + F16Type, + BF16Type, + F32Type, + F64Type, + FloatTF32Type, + Float4E2M1FNType, + Float6E2M3FNType, + Float6E3M2FNType, + Float8E3M4Type, + Float8E4M3Type, + Float8E4M3FNType, + Float8E5M2Type, + Float8E4M3FNUZType, + Float8E4M3B11FNUZType, + Float8E5M2FNUZType, + Float8E8M0FNUType, + ComplexType, + VectorType, + RankedTensorType, + UnrankedTensorType, + MemRefType, + UnrankedMemRefType, + FunctionType, + OpaqueType, + UnitAttr, + BoolAttr, + StringAttr, + FlatSymbolRefAttr, + SymbolRefAttr, + FloatAttr, + IntegerAttr, + ArrayAttr, + StridedLayoutAttr, + DenseBoolArrayAttr, + DenseI8ArrayAttr, + DenseI16ArrayAttr, + DenseI32ArrayAttr, + DenseI64ArrayAttr, + DenseF32ArrayAttr, + DenseF64ArrayAttr, +) + +# noinspection PyUnresolvedReferences +from mlir.extras.testing import mlir_ctx as ctx, filecheck, MLIRContext + + +def test_shaped_types(ctx: MLIRContext): + t = tensor(S, 3, S, T.f64()) + assert repr(t) == "RankedTensorType(tensor)" + ut = tensor(T.f64()) + assert repr(ut) == "UnrankedTensorType(tensor<*xf64>)" + t = tensor(S, 3, S, element_type=T.f64()) + assert repr(t) == "RankedTensorType(tensor)" + ut = tensor(element_type=T.f64()) + assert repr(ut) == "UnrankedTensorType(tensor<*xf64>)" + + m = memref(S, 3, S, T.f64()) + assert repr(m) == "MemRefType(memref)" + um = memref(T.f64()) + assert repr(um) == "UnrankedMemRefType(memref<*xf64>)" + m = memref(S, 3, S, element_type=T.f64()) + assert repr(m) == "MemRefType(memref)" + um = memref(element_type=T.f64()) + assert repr(um) == "UnrankedMemRefType(memref<*xf64>)" + + v = vector(3, 3, 3, T.f64()) + assert repr(v) == "VectorType(vector<3x3x3xf64>)" + + +def test_n_elements(ctx: MLIRContext): + ten = empty(1, 2, 3, 4, T.i32()) + assert ten.n_elements == 1 * 2 * 3 * 4 + + mem = alloc((1, 2, 3, 4), T.i32()) + assert mem.n_elements == 1 * 2 * 3 * 4 + + +# evaluate_generic_type tests: subclass form (no args) and GenericAlias[args] -> .get(args) + + +def test_evaluate_generic_type_nullary_types(ctx: MLIRContext): + assert evaluate_generic_alias_type(IndexType) == IndexType.get() + assert evaluate_generic_alias_type(NoneType) == NoneType.get() + assert evaluate_generic_alias_type(F16Type) == F16Type.get() + assert evaluate_generic_alias_type(BF16Type) == BF16Type.get() + assert evaluate_generic_alias_type(F32Type) == F32Type.get() + assert evaluate_generic_alias_type(F64Type) == F64Type.get() + assert evaluate_generic_alias_type(FloatTF32Type) == FloatTF32Type.get() + assert evaluate_generic_alias_type(Float4E2M1FNType) == Float4E2M1FNType.get() + assert evaluate_generic_alias_type(Float6E2M3FNType) == Float6E2M3FNType.get() + assert evaluate_generic_alias_type(Float6E3M2FNType) == Float6E3M2FNType.get() + assert evaluate_generic_alias_type(Float8E3M4Type) == Float8E3M4Type.get() + assert evaluate_generic_alias_type(Float8E4M3Type) == Float8E4M3Type.get() + assert evaluate_generic_alias_type(Float8E4M3FNType) == Float8E4M3FNType.get() + assert evaluate_generic_alias_type(Float8E5M2Type) == Float8E5M2Type.get() + assert evaluate_generic_alias_type(Float8E4M3FNUZType) == Float8E4M3FNUZType.get() + assert ( + evaluate_generic_alias_type(Float8E4M3B11FNUZType) + == Float8E4M3B11FNUZType.get() + ) + assert evaluate_generic_alias_type(Float8E5M2FNUZType) == Float8E5M2FNUZType.get() + assert evaluate_generic_alias_type(Float8E8M0FNUType) == Float8E8M0FNUType.get() + + +def test_evaluate_generic_type_integer_type(ctx: MLIRContext): + assert evaluate_generic_alias_type(IntegerType[32]) == IntegerType.get(32) + assert evaluate_generic_alias_type(IntegerType[64]) == IntegerType.get(64) + assert evaluate_generic_alias_type(IntegerType[1]) == IntegerType.get(1) + + +def test_evaluate_generic_type_complex_type(ctx: MLIRContext): + assert evaluate_generic_alias_type(ComplexType[F32Type]) == ComplexType.get( + F32Type.get() + ) + assert evaluate_generic_alias_type(ComplexType[F64Type]) == ComplexType.get( + F64Type.get() + ) + assert evaluate_generic_alias_type(ComplexType[IntegerType[32]]) == ComplexType.get( + IntegerType.get(32) + ) + + +def test_evaluate_generic_type_vector_type(ctx: MLIRContext): + assert evaluate_generic_alias_type(VectorType[[2, 3], F32Type]) == VectorType.get( + [2, 3], F32Type.get() + ) + assert evaluate_generic_alias_type( + VectorType[[3, 3, 3], F64Type] + ) == VectorType.get([3, 3, 3], F64Type.get()) + + +def test_evaluate_generic_type_tensor_types(ctx: MLIRContext): + assert evaluate_generic_alias_type( + RankedTensorType[[2, 3], F32Type] + ) == RankedTensorType.get([2, 3], F32Type.get()) + assert evaluate_generic_alias_type( + UnrankedTensorType[F32Type] + ) == UnrankedTensorType.get(F32Type.get()) + assert evaluate_generic_alias_type( + UnrankedTensorType[F64Type] + ) == UnrankedTensorType.get(F64Type.get()) + + +def test_evaluate_generic_type_memref_types(ctx: MLIRContext): + assert evaluate_generic_alias_type(MemRefType[[2, 3], F32Type]) == MemRefType.get( + [2, 3], F32Type.get() + ) + assert evaluate_generic_alias_type( + UnrankedMemRefType[F32Type, IntegerAttr[IntegerType[64], 2]] + ) == UnrankedMemRefType.get(F32Type.get(), IntegerAttr.get(IntegerType.get(64), 2)) + + +def test_evaluate_generic_type_function_type(ctx: MLIRContext): + assert evaluate_generic_alias_type(FunctionType[[], []]) == FunctionType.get([], []) + assert evaluate_generic_alias_type( + FunctionType[[F32Type.get(), F64Type.get()], [IndexType.get()]] + ) == FunctionType.get([F32Type.get(), F64Type.get()], [IndexType.get()]) + + +def test_evaluate_generic_type_opaque_type(ctx: MLIRContext): + assert evaluate_generic_alias_type(OpaqueType["tensor", "bob"]) == OpaqueType.get( + "tensor", "bob" + ) + assert evaluate_generic_alias_type( + OpaqueType["foobar", "mytype"] + ) == OpaqueType.get("foobar", "mytype") + + +def test_evaluate_generic_type_unit_attr(ctx: MLIRContext): + assert evaluate_generic_alias_type(UnitAttr) == UnitAttr.get() + + +def test_evaluate_generic_type_bool_attr(ctx: MLIRContext): + assert evaluate_generic_alias_type(BoolAttr[True]) == BoolAttr.get(True) + assert evaluate_generic_alias_type(BoolAttr[False]) == BoolAttr.get(False) + + +def test_evaluate_generic_type_string_attr(ctx: MLIRContext): + assert evaluate_generic_alias_type(StringAttr["hello"]) == StringAttr.get("hello") + assert evaluate_generic_alias_type(StringAttr["foobar"]) == StringAttr.get("foobar") + + +def test_evaluate_generic_type_integer_attr(ctx: MLIRContext): + assert evaluate_generic_alias_type( + IntegerAttr[IntegerType[32], 42] + ) == IntegerAttr.get(IntegerType.get(32), 42) + assert evaluate_generic_alias_type( + IntegerAttr[IntegerType[64], 0] + ) == IntegerAttr.get(IntegerType.get(64), 0) + + +def test_evaluate_generic_type_float_attr(ctx: MLIRContext): + assert evaluate_generic_alias_type(FloatAttr[F32Type, 42.0]) == FloatAttr.get( + F32Type.get(), 42.0 + ) + assert evaluate_generic_alias_type(FloatAttr[F64Type, 1.5]) == FloatAttr.get( + F64Type.get(), 1.5 + ) + + +def test_evaluate_generic_type_flat_symbol_ref_attr(ctx: MLIRContext): + assert evaluate_generic_alias_type( + FlatSymbolRefAttr["symbol"] + ) == FlatSymbolRefAttr.get("symbol") + assert evaluate_generic_alias_type( + FlatSymbolRefAttr["foobar"] + ) == FlatSymbolRefAttr.get("foobar") + + +def test_evaluate_generic_type_symbol_ref_attr(ctx: MLIRContext): + assert evaluate_generic_alias_type( + SymbolRefAttr[["symbol1", "symbol2"]] + ) == SymbolRefAttr.get(["symbol1", "symbol2"]) + + +def test_evaluate_generic_type_strided_layout_attr(ctx: MLIRContext): + assert evaluate_generic_alias_type( + StridedLayoutAttr[42, [5, 7, 13]] + ) == StridedLayoutAttr.get(42, [5, 7, 13]) + assert evaluate_generic_alias_type( + StridedLayoutAttr[0, [1, 2]] + ) == StridedLayoutAttr.get(0, [1, 2]) + + +def test_evaluate_generic_type_array_attr(ctx: MLIRContext): + items = [StringAttr.get("a"), StringAttr.get("b")] + assert evaluate_generic_alias_type(ArrayAttr[items]) == ArrayAttr.get(items) + assert evaluate_generic_alias_type(ArrayAttr[[]]) == ArrayAttr.get([]) + + +def test_evaluate_generic_type_dense_bool_array_attr(ctx: MLIRContext): + assert evaluate_generic_alias_type( + DenseBoolArrayAttr[[True, False, True]] + ) == DenseBoolArrayAttr.get([True, False, True]) + + +def test_evaluate_generic_type_dense_int_array_attrs(ctx: MLIRContext): + assert evaluate_generic_alias_type( + DenseI8ArrayAttr[[1, 2, 3]] + ) == DenseI8ArrayAttr.get([1, 2, 3]) + assert evaluate_generic_alias_type( + DenseI16ArrayAttr[[4, 5, 6]] + ) == DenseI16ArrayAttr.get([4, 5, 6]) + assert evaluate_generic_alias_type( + DenseI32ArrayAttr[[6, 7, 8]] + ) == DenseI32ArrayAttr.get([6, 7, 8]) + assert evaluate_generic_alias_type( + DenseI64ArrayAttr[[8, 9, 10]] + ) == DenseI64ArrayAttr.get([8, 9, 10]) + + +def test_evaluate_generic_type_dense_float_array_attrs(ctx: MLIRContext): + assert evaluate_generic_alias_type( + DenseF32ArrayAttr[[1.0, 2.0, 3.0]] + ) == DenseF32ArrayAttr.get([1.0, 2.0, 3.0]) + assert evaluate_generic_alias_type( + DenseF64ArrayAttr[[4.0, 5.0, 6.0]] + ) == DenseF64ArrayAttr.get([4.0, 5.0, 6.0]) + + +def test_evaluate_generic_type_dense_float_array_attrs_tuple(ctx: MLIRContext): + assert evaluate_generic_alias_type( + DenseF32ArrayAttr[(1.0, 2.0, 3.0),] + ) == DenseF32ArrayAttr.get((1.0, 2.0, 3.0)) + assert evaluate_generic_alias_type( + DenseF64ArrayAttr[(4.0, 5.0, 6.0),] + ) == DenseF64ArrayAttr.get((4.0, 5.0, 6.0))