diff --git a/mypyc/ir/ops.py b/mypyc/ir/ops.py index 485aa84886b32..4bc7671b82082 100644 --- a/mypyc/ir/ops.py +++ b/mypyc/ir/ops.py @@ -39,6 +39,7 @@ class to enable the new behavior. Sometimes adding a new abstract RStruct, RTuple, RType, + RTypeVar, RUnion, RVec, RVoid, @@ -703,7 +704,7 @@ def __init__( self, name: str, arg_types: list[RType], - return_type: RType, # TODO: What about generic? + return_type: RType, var_arg_type: RType | None, truncated_type: RType | None, c_function_name: str | None, @@ -716,6 +717,7 @@ def __init__( is_pure: bool, experimental: bool, dependencies: list[Dependency] | None, + type_params: list[RTypeVar] | None, ) -> None: # Each primitive much have a distinct name, but otherwise they are arbitrary. self.name: Final = name @@ -749,6 +751,7 @@ def __init__( # If this flag is set, the primitive has native integer types and must # be matched using more complex rules. self.is_ambiguous = any(has_fixed_width_int(t) for t in arg_types) + self.type_params = None if not type_params else type_params def __repr__(self) -> str: return f"" @@ -776,11 +779,23 @@ class PrimitiveOp(RegisterOp): code paths for short and long representations. """ - def __init__(self, args: list[Value], desc: PrimitiveDescription, line: int = -1) -> None: + def __init__( + self, + args: list[Value], + desc: PrimitiveDescription, + line: int = -1, + *, + arg_types: list[RType] | None = None, + return_type: RType | None = None, + type_args: list[RType] | None = None, + ) -> None: self.error_kind = desc.error_kind super().__init__(line) self.args = args - self.type = desc.return_type + self.arg_types = arg_types if arg_types is not None else desc.arg_types + self.type = return_type if return_type is not None else desc.return_type + self.is_borrowed = desc.is_borrowed + self.type_args = type_args self.desc = desc def sources(self) -> list[Value]: diff --git a/mypyc/ir/pprint.py b/mypyc/ir/pprint.py index d0db9f2460a1d..734426ca42de9 100644 --- a/mypyc/ir/pprint.py +++ b/mypyc/ir/pprint.py @@ -231,10 +231,15 @@ def visit_call_c(self, op: CallC) -> str: def visit_primitive_op(self, op: PrimitiveOp) -> str: args_str = ", ".join(self.format("%r", arg) for arg in op.args) + if op.type_args: + joined = ", ".join(str(arg) for arg in op.type_args) + type_args = f"[{joined}]" + else: + type_args = "" if op.is_void: - return self.format("%s %s", op.desc.name, args_str) + return self.format("%s%s %s", op.desc.name, type_args, args_str) else: - return self.format("%r = %s %s", op, op.desc.name, args_str) + return self.format("%r = %s%s %s", op, op.desc.name, type_args, args_str) def visit_truncate(self, op: Truncate) -> str: return self.format("%r = truncate %r: %t to %t", op, op.src, op.src_type, op.type) diff --git a/mypyc/ir/rtypes.py b/mypyc/ir/rtypes.py index db29f9e304d8d..9d13ecc83175d 100644 --- a/mypyc/ir/rtypes.py +++ b/mypyc/ir/rtypes.py @@ -152,6 +152,9 @@ def visit_rprimitive(self, typ: RPrimitive, /) -> T: def visit_rinstance(self, typ: RInstance, /) -> T: raise NotImplementedError + def visit_rtypevar(self, typ: RTypeVar, /) -> T: + raise RuntimeError("RTypeVar should not be encountered here") + @abstractmethod def visit_rvec(self, typ: RVec, /) -> T: raise NotImplementedError @@ -747,6 +750,12 @@ def visit_rarray(self, t: RArray) -> str: def visit_rvoid(self, t: RVoid) -> str: assert False, "rvoid in tuple?" + def visit_rtypevar(self, typ: RTypeVar) -> str: + # We need to return something to support generic RTuples, etc. Make sure + # the return value is invalid C so that generic RTuples must be expanded + # before they can be used in IR. + return f"!RTypeVar {typ.id} invalid!" + @final class RTuple(RType): @@ -1013,6 +1022,50 @@ def serialize(self) -> str: return self.name +@final +class RTypeVar(RType): + """Type variable type used for generic primitive ops. + + This allows having generic primitive operations like vec get item, which is + parametrized by the vec item type. + + These types are not valid in any other context outside PrimitiveDescription, + and they will always be substituted during the construction of a PrimitiveOp. + + NOTE: This is not related to mypy's TypeVarType! + """ + + def __init__(self, id: int) -> None: + self.id = id + + @property + def may_be_immortal(self) -> bool: + # RTypeVar must always be substituted before use, so this should never matter. + return False + + def accept(self, visitor: RTypeVisitor[T]) -> T: + return visitor.visit_rtypevar(self) + + def __str__(self) -> str: + return f"" + + def __repr__(self) -> str: + return f"" + + def __eq__(self, other: object) -> TypeGuard[RTypeVar]: + return isinstance(other, RTypeVar) and other.id == self.id + + def __hash__(self) -> int: + return self.id ^ 12345 + + def serialize(self) -> JsonDict: + return {".class": "RTypeVar", "id": self.id} + + @classmethod + def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> RTypeVar: + return RTypeVar(data["id"]) + + @final class RVec(RType): """librt.vecs.vec[T]""" diff --git a/mypyc/irbuild/ll_builder.py b/mypyc/irbuild/ll_builder.py index c19eded77464e..c0ad1cf1f8264 100644 --- a/mypyc/irbuild/ll_builder.py +++ b/mypyc/irbuild/ll_builder.py @@ -206,6 +206,7 @@ new_tuple_with_length_op, sequence_tuple_op, ) +from mypyc.rt_expandtype import expand_rtype from mypyc.rt_subtype import is_runtime_subtype from mypyc.sametype import is_same_type from mypyc.subtype import is_subtype @@ -2334,6 +2335,7 @@ def primitive_op( args: list[Value], line: int, result_type: RType | None = None, + type_args: list[RType] | None = None, ) -> Value: """Add a primitive op.""" # Does this primitive map into calling a Python C API @@ -2363,10 +2365,20 @@ def primitive_op( # This primitive gets transformed in a lowering pass to # lower-level IR ops using a custom transform function. + # Evaluate argument and return types for generic primitives + return_type = None + if desc.type_params is not None: + assert type_args is not None, "Generic primitive op requires explicit type arguments" + assert len(type_args) == len(desc.type_params) + arg_types = [expand_rtype(arg_type, type_args) for arg_type in desc.arg_types] + return_type = expand_rtype(desc.return_type, type_args) + else: + arg_types = desc.arg_types + coerced = [] # Coerce fixed number arguments - for i in range(min(len(args), len(desc.arg_types))): - formal_type = desc.arg_types[i] + for i in range(min(len(args), len(arg_types))): + formal_type = arg_types[i] arg = args[i] assert formal_type is not None # TODO arg = self.coerce(arg, formal_type, line) @@ -2374,7 +2386,16 @@ def primitive_op( assert desc.ordering is None assert desc.var_arg_type is None assert not desc.extra_int_constants - target = self.add(PrimitiveOp(coerced, desc, line=line)) + target = self.add( + PrimitiveOp( + coerced, + desc, + line=line, + arg_types=arg_types, + return_type=return_type, + type_args=type_args, + ) + ) if desc.is_borrowed: # If the result is borrowed, force the arguments to be # kept alive afterwards, as otherwise the result might be diff --git a/mypyc/irbuild/vec.py b/mypyc/irbuild/vec.py index bfcfabee45c21..38615ebe16026 100644 --- a/mypyc/irbuild/vec.py +++ b/mypyc/irbuild/vec.py @@ -52,6 +52,7 @@ vec_api_by_item_type, vec_item_type_tags, ) +from mypyc.primitives.librt_vecs_ops import vec_get_item_unsafe_borrow_op, vec_get_item_unsafe_op if TYPE_CHECKING: from mypyc.irbuild.ll_builder import LowLevelIRBuilder @@ -316,9 +317,10 @@ def vec_get_item( ) -> Value: """Generate inlined vec __getitem__ call. - We inline this, since it's simple but performance-critical. + We inline the length and bounds check, since they are simple but + performance-critical. The actual item load is emitted as a generic primitive + op that is lowered later. """ - # TODO: Support more item types # TODO: Support more index types len_val = vec_len(builder, base) index = vec_check_and_adjust_index(builder, len_val, index, line) @@ -328,7 +330,22 @@ def vec_get_item( def vec_get_item_unsafe( builder: LowLevelIRBuilder, base: Value, index: Value, line: int, *, can_borrow: bool = False ) -> Value: - """Get vec item, assuming index is non-negative and within bounds.""" + """Get vec item, assuming index is non-negative and within bounds. + + This emits a generic primitive op that is inlined during lowering. + """ + assert isinstance(base.type, RVec) + if can_borrow: + desc = vec_get_item_unsafe_borrow_op + else: + desc = vec_get_item_unsafe_op + return builder.primitive_op(desc, [base, index], line, type_args=[base.type.item_type]) + + +def vec_get_item_unsafe_lower( + builder: LowLevelIRBuilder, base: Value, index: Value, line: int, *, can_borrow: bool = False +) -> Value: + """Generate the low-level IR for an unsafe vec item load.""" assert isinstance(base.type, RVec) index = as_platform_int(builder, index, line) vtype = base.type diff --git a/mypyc/lower/registry.py b/mypyc/lower/registry.py index dec6a24b9417a..36262c4a011a3 100644 --- a/mypyc/lower/registry.py +++ b/mypyc/lower/registry.py @@ -26,4 +26,4 @@ def wrapper(f: LF) -> LF: # Import various modules that set up global state. -from mypyc.lower import int_ops, list_ops, misc_ops # noqa: F401 +from mypyc.lower import int_ops, list_ops, misc_ops, vec_ops # noqa: F401 diff --git a/mypyc/lower/vec_ops.py b/mypyc/lower/vec_ops.py new file mode 100644 index 0000000000000..768c8e0073af8 --- /dev/null +++ b/mypyc/lower/vec_ops.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from mypyc.ir.ops import Value +from mypyc.irbuild.ll_builder import LowLevelIRBuilder +from mypyc.irbuild.vec import vec_get_item_unsafe_lower +from mypyc.lower.registry import lower_primitive_op + + +@lower_primitive_op("vec_get_item_unsafe") +def vec_get_item_unsafe(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value: + base, index = args + return vec_get_item_unsafe_lower(builder, base, index, line, can_borrow=False) + + +@lower_primitive_op("vec_get_item_unsafe_borrow") +def vec_get_item_unsafe_borrow(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value: + base, index = args + return vec_get_item_unsafe_lower(builder, base, index, line, can_borrow=True) diff --git a/mypyc/primitives/librt_vecs_ops.py b/mypyc/primitives/librt_vecs_ops.py index e4852d5387069..901c779a2bfd2 100644 --- a/mypyc/primitives/librt_vecs_ops.py +++ b/mypyc/primitives/librt_vecs_ops.py @@ -1,13 +1,15 @@ from mypyc.ir.deps import LIBRT_VECS, VECS_EXTRA_OPS from mypyc.ir.ops import ERR_MAGIC, ERR_NEVER from mypyc.ir.rtypes import ( + RTypeVar, RVec, bit_rprimitive, bytes_rprimitive, + int64_rprimitive, object_rprimitive, uint8_rprimitive, ) -from mypyc.primitives.registry import function_op +from mypyc.primitives.registry import custom_primitive_op, function_op # isinstance(obj, vec) isinstance_vec = function_op( @@ -28,3 +30,22 @@ error_kind=ERR_MAGIC, dependencies=[LIBRT_VECS, VECS_EXTRA_OPS], ) + +# Get vec item, assuming the index is valid (no bounds check) +vec_get_item_unsafe_op = custom_primitive_op( + name="vec_get_item_unsafe", + arg_types=[RVec(RTypeVar(0)), int64_rprimitive], + return_type=RTypeVar(0), + error_kind=ERR_NEVER, + type_params=[RTypeVar(0)], +) + +# Like vec_get_item_unsafe, but the result is a borrowed reference +vec_get_item_unsafe_borrow_op = custom_primitive_op( + name="vec_get_item_unsafe_borrow", + arg_types=[RVec(RTypeVar(0)), int64_rprimitive], + is_borrowed=True, + return_type=RTypeVar(0), + error_kind=ERR_NEVER, + type_params=[RTypeVar(0)], +) diff --git a/mypyc/primitives/registry.py b/mypyc/primitives/registry.py index e22a044d9bb27..22422987b4277 100644 --- a/mypyc/primitives/registry.py +++ b/mypyc/primitives/registry.py @@ -41,7 +41,7 @@ from mypyc.ir.deps import Dependency from mypyc.ir.ops import PrimitiveDescription, StealsDescription -from mypyc.ir.rtypes import RType +from mypyc.ir.rtypes import RType, RTypeVar # Error kind for functions that return negative integer on exception. This # is only used for primitives. We translate it away during IR building. @@ -154,6 +154,7 @@ def method_op( is_pure=is_pure, experimental=experimental, dependencies=dependencies, + type_params=None, ) ops.append(desc) return desc @@ -204,6 +205,7 @@ def function_op( is_pure=False, experimental=experimental, dependencies=dependencies, + type_params=None, ) ops.append(desc) return desc @@ -253,6 +255,7 @@ def binary_op( is_pure=False, experimental=False, dependencies=dependencies, + type_params=None, ) ops.append(desc) return desc @@ -313,6 +316,7 @@ def custom_primitive_op( is_pure: bool = False, experimental: bool = False, dependencies: list[Dependency] | None = None, + type_params: list[RTypeVar] | None = None, ) -> PrimitiveDescription: """Define a primitive op that can't be automatically generated based on the AST. @@ -336,6 +340,7 @@ def custom_primitive_op( is_pure=is_pure, experimental=experimental, dependencies=dependencies, + type_params=type_params, ) @@ -380,6 +385,7 @@ def unary_op( is_pure=is_pure, experimental=False, dependencies=dependencies, + type_params=None, ) ops.append(desc) return desc diff --git a/mypyc/rt_expandtype.py b/mypyc/rt_expandtype.py new file mode 100644 index 0000000000000..8537e6777ccb5 --- /dev/null +++ b/mypyc/rt_expandtype.py @@ -0,0 +1,32 @@ +from mypyc.ir.rtypes import ( + RArray, + RInstance, + RPrimitive, + RStruct, + RTuple, + RType, + RTypeVar, + RUnion, + RVec, + RVoid, +) + + +def expand_rtype(typ: RType, type_args: list[RType]) -> RType: + if isinstance(typ, (RPrimitive, RInstance, RVoid)): + # Atomic types can't contain type variables + return typ + elif isinstance(typ, RTypeVar): + return type_args[typ.id] + elif isinstance(typ, RVec): + return RVec(expand_rtype(typ.item_type, type_args)) + elif isinstance(typ, RUnion): + return RUnion([expand_rtype(item, type_args) for item in typ.items]) + elif isinstance(typ, RTuple): + return RTuple([expand_rtype(item, type_args) for item in typ.types]) + elif isinstance(typ, RStruct): + assert False, "Generic RStruct type not supported" + elif isinstance(typ, RArray): + assert False, "Generic RArray type not supported" + else: + assert False, r"Unexpected type {typ!r}" diff --git a/mypyc/test-data/irbuild-vec-i64.test b/mypyc/test-data/irbuild-vec-i64.test index aeab3ed9f2a8b..af69f30924d46 100644 --- a/mypyc/test-data/irbuild-vec-i64.test +++ b/mypyc/test-data/irbuild-vec-i64.test @@ -64,11 +64,7 @@ def f(v, i): r2 :: i64 r3 :: bit r4 :: bool - r5 :: i64 - r6 :: ptr - r7 :: i64 - r8 :: ptr - r9 :: i64 + r5, r6 :: i64 L0: r0 = v.len r1 = i < r0 :: unsigned @@ -86,15 +82,10 @@ L3: L4: r5 = i L5: - r6 = v.items - r7 = r5 * 8 - r8 = r6 + r7 - r9 = load_mem r8 :: i64* - keep_alive v - return r9 + r6 = vec_get_item_unsafe[i64] v, r5 + return r6 [case testVecI64GetItem_32bit] -# The IR is quite verbose, but it's acceptable since 32-bit targets are not common any more from librt.vecs import vec from mypy_extensions import i64 @@ -110,13 +101,7 @@ def f(v, i): r3 :: i64 r4 :: bit r5 :: bool - r6 :: i64 - r7, r8 :: bit - r9 :: native_int - r10 :: ptr - r11 :: native_int - r12 :: ptr - r13 :: i64 + r6, r7 :: i64 L0: r0 = v.len r1 = extend signed r0: native_int to i64 @@ -135,24 +120,8 @@ L3: L4: r6 = i L5: - r7 = r6 < 2147483648 :: signed - if r7 goto L6 else goto L8 :: bool -L6: - r8 = r6 >= -2147483648 :: signed - if r8 goto L7 else goto L8 :: bool -L7: - r9 = truncate r6: i64 to native_int - goto L9 -L8: - CPyInt32_Overflow() - unreachable -L9: - r10 = v.items - r11 = r9 * 8 - r12 = r10 + r11 - r13 = load_mem r12 :: i64* - keep_alive v - return r13 + r7 = vec_get_item_unsafe[i64] v, r6 + return r7 [case testVecI64Append] from librt.vecs import vec, append @@ -411,7 +380,7 @@ L3: L4: return r1 -[case testVecI64FastComprehensionFromVec] +[case testVecI64FastComprehensionFromVec_64bit] from librt.vecs import vec from mypy_extensions import i64 from typing import List @@ -426,14 +395,11 @@ def f(n, v): r1 :: vec[i64] r2, r3 :: native_int r4 :: bit - r5 :: ptr - r6 :: native_int + r5, x, r6 :: i64 r7 :: ptr - r8, x, r9 :: i64 - r10 :: ptr - r11 :: native_int - r12 :: ptr - r13 :: native_int + r8 :: native_int + r9 :: ptr + r10 :: native_int L0: r0 = v.len r1 = VecI64Api.alloc(r0, r0) @@ -443,21 +409,17 @@ L1: r4 = r2 < r3 :: signed if r4 goto L2 else goto L4 :: bool L2: - r5 = v.items - r6 = r2 * 8 - r7 = r5 + r6 - r8 = load_mem r7 :: i64* - x = r8 - keep_alive v - r9 = x + 1 - r10 = r1.items - r11 = r2 * 8 - r12 = r10 + r11 - set_mem r12, r9 :: i64* + r5 = vec_get_item_unsafe[i64] v, r2 + x = r5 + r6 = x + 1 + r7 = r1.items + r8 = r2 * 8 + r9 = r7 + r8 + set_mem r9, r6 :: i64* keep_alive r1 L3: - r13 = r2 + 1 - r2 = r13 + r10 = r2 + 1 + r2 = r10 goto L1 L4: return r1 @@ -499,7 +461,7 @@ L3: L4: return r1 -[case testVecI64ForLoop] +[case testVecI64ForLoop_64bit] from librt.vecs import vec from mypy_extensions import i64 @@ -514,11 +476,8 @@ def f(v): t :: i64 r0, r1 :: native_int r2 :: bit - r3 :: ptr - r4 :: native_int - r5 :: ptr - r6, x, r7 :: i64 - r8 :: native_int + r3, x, r4 :: i64 + r5 :: native_int L0: t = 0 r0 = 0 @@ -527,17 +486,13 @@ L1: r2 = r0 < r1 :: signed if r2 goto L2 else goto L4 :: bool L2: - r3 = v.items - r4 = r0 * 8 - r5 = r3 + r4 - r6 = load_mem r5 :: i64* - x = r6 - keep_alive v - r7 = t + 1 - t = r7 + r3 = vec_get_item_unsafe[i64] v, r0 + x = r3 + r4 = t + 1 + t = r4 L3: - r8 = r0 + 1 - r0 = r8 + r5 = r0 + 1 + r0 = r5 goto L1 L4: return t @@ -601,11 +556,7 @@ def f(v): r2 :: i64 r3 :: bit r4 :: bool - r5 :: i64 - r6 :: ptr - r7 :: i64 - r8 :: ptr - r9 :: i64 + r5, r6 :: i64 L0: r0 = v.len r1 = 0 < r0 :: unsigned @@ -623,12 +574,8 @@ L3: L4: r5 = 0 L5: - r6 = v.items - r7 = r5 * 8 - r8 = r6 + r7 - r9 = load_mem r8 :: i64* - keep_alive v - return r9 + r6 = vec_get_item_unsafe[i64] v, r5 + return r6 [case testVecI64Slicing_64bit] from librt.vecs import vec @@ -720,20 +667,16 @@ def inplace(v, n, m): r2 :: i64 r3 :: bit r4 :: bool - r5 :: i64 - r6 :: ptr - r7 :: i64 - r8 :: ptr - r9, r10 :: i64 - r11 :: native_int - r12 :: bit + r5, r6, r7 :: i64 + r8 :: native_int + r9 :: bit + r10 :: i64 + r11 :: bit + r12 :: bool r13 :: i64 - r14 :: bit - r15 :: bool - r16 :: i64 - r17 :: ptr - r18 :: i64 - r19 :: ptr + r14 :: ptr + r15 :: i64 + r16 :: ptr L0: r0 = v.len r1 = n < r0 :: unsigned @@ -751,32 +694,28 @@ L3: L4: r5 = n L5: - r6 = v.items - r7 = r5 * 8 - r8 = r6 + r7 - r9 = load_mem r8 :: i64* - keep_alive v - r10 = r9 + m - r11 = v.len - r12 = n < r11 :: unsigned - if r12 goto L9 else goto L6 :: bool + r6 = vec_get_item_unsafe[i64] v, r5 + r7 = r6 + m + r8 = v.len + r9 = n < r8 :: unsigned + if r9 goto L9 else goto L6 :: bool L6: - r13 = n + r11 - r14 = r13 < r11 :: unsigned - if r14 goto L8 else goto L7 :: bool + r10 = n + r8 + r11 = r10 < r8 :: unsigned + if r11 goto L8 else goto L7 :: bool L7: - r15 = raise IndexError + r12 = raise IndexError unreachable L8: - r16 = r13 + r13 = r10 goto L10 L9: - r16 = n + r13 = n L10: - r17 = v.items - r18 = r16 * 8 - r19 = r17 + r18 - set_mem r19, r10 :: i64* + r14 = v.items + r15 = r13 * 8 + r16 = r14 + r15 + set_mem r16, r7 :: i64* keep_alive v return 1 diff --git a/mypyc/test-data/irbuild-vec-misc.test b/mypyc/test-data/irbuild-vec-misc.test index c0d4325e38fcc..22037d8597a73 100644 --- a/mypyc/test-data/irbuild-vec-misc.test +++ b/mypyc/test-data/irbuild-vec-misc.test @@ -125,10 +125,7 @@ def get_item_bool(v, i): r3 :: bit r4 :: bool r5 :: i64 - r6 :: ptr - r7 :: i64 - r8 :: ptr - r9 :: bool + r6 :: bool L0: r0 = v.len r1 = i < r0 :: unsigned @@ -146,12 +143,8 @@ L3: L4: r5 = i L5: - r6 = v.items - r7 = r5 * 1 - r8 = r6 + r7 - r9 = load_mem r8 :: builtins.bool* - keep_alive v - return r9 + r6 = vec_get_item_unsafe[bool] v, r5 + return r6 [case testVecMiscPop] from librt.vecs import vec, pop @@ -209,7 +202,7 @@ L0: r0 = VecFloatApi.slice(v, x, y) return r0 -[case testVecMiscForLoop] +[case testVecMiscForLoop_64bit] from librt.vecs import vec, remove from mypy_extensions import i64, i16 @@ -225,11 +218,8 @@ def for_bool(v): s :: i16 r0, r1 :: native_int r2 :: bit - r3 :: ptr - r4 :: native_int - r5 :: ptr - r6, x, r7 :: i16 - r8 :: native_int + r3, x, r4 :: i16 + r5 :: native_int L0: s = 0 r0 = 0 @@ -238,17 +228,13 @@ L1: r2 = r0 < r1 :: signed if r2 goto L2 else goto L4 :: bool L2: - r3 = v.items - r4 = r0 * 2 - r5 = r3 + r4 - r6 = load_mem r5 :: i16* - x = r6 - keep_alive v - r7 = s + x - s = r7 + r3 = vec_get_item_unsafe[i16] v, r0 + x = r3 + r4 = s + x + s = r4 L3: - r8 = r0 + 1 - r0 = r8 + r5 = r0 + 1 + r0 = r5 goto L1 L4: return s @@ -270,10 +256,7 @@ def get_item_nested(v, i): r3 :: bit r4 :: bool r5 :: i64 - r6 :: ptr - r7 :: i64 - r8 :: ptr - r9 :: vec[i32] + r6 :: vec[i32] L0: r0 = v.len r1 = i < r0 :: unsigned @@ -291,12 +274,8 @@ L3: L4: r5 = i L5: - r6 = v.items - r7 = r5 * 16 - r8 = r6 + r7 - r9 = load_mem r8 :: vec[i32]* - keep_alive v - return r9 + r6 = vec_get_item_unsafe[vec[i32]] v, r5 + return r6 [case testVecMiscNestedPop_64bit] from librt.vecs import vec, pop diff --git a/mypyc/test-data/irbuild-vec-nested.test b/mypyc/test-data/irbuild-vec-nested.test index 1fe42a880d5b0..dd49d9475812f 100644 --- a/mypyc/test-data/irbuild-vec-nested.test +++ b/mypyc/test-data/irbuild-vec-nested.test @@ -206,10 +206,7 @@ def f(v, n): r3 :: bit r4 :: bool r5 :: i64 - r6 :: ptr - r7 :: i64 - r8 :: ptr - r9 :: vec[str] + r6 :: vec[str] L0: r0 = v.len r1 = n < r0 :: unsigned @@ -227,12 +224,8 @@ L3: L4: r5 = n L5: - r6 = v.items - r7 = r5 * 16 - r8 = r6 + r7 - r9 = load_mem r8 :: vec[str]* - keep_alive v - return r9 + r6 = vec_get_item_unsafe[vec[str]] v, r5 + return r6 [case testVecNestedI64GetItem_64bit] from librt.vecs import vec @@ -250,10 +243,7 @@ def f(v, n): r3 :: bit r4 :: bool r5 :: i64 - r6 :: ptr - r7 :: i64 - r8 :: ptr - r9 :: vec[i64] + r6 :: vec[i64] L0: r0 = v.len r1 = n < r0 :: unsigned @@ -271,12 +261,8 @@ L3: L4: r5 = n L5: - r6 = v.items - r7 = r5 * 16 - r8 = r6 + r7 - r9 = load_mem r8 :: vec[i64]* - keep_alive v - return r9 + r6 = vec_get_item_unsafe[vec[i64]] v, r5 + return r6 [case testVecNestedI64GetItemWithBorrow_64bit] from librt.vecs import vec @@ -294,20 +280,13 @@ def f(v, n): r3 :: bit r4 :: bool r5 :: i64 - r6 :: ptr - r7 :: i64 - r8 :: ptr - r9 :: vec[i64] - r10 :: native_int - r11 :: bit - r12 :: i64 - r13 :: bit - r14 :: bool - r15 :: i64 - r16 :: ptr - r17 :: i64 - r18 :: ptr - r19 :: i64 + r6 :: vec[i64] + r7 :: native_int + r8 :: bit + r9 :: i64 + r10 :: bit + r11 :: bool + r12, r13 :: i64 L0: r0 = v.len r1 = n < r0 :: unsigned @@ -325,32 +304,94 @@ L3: L4: r5 = n L5: - r6 = v.items - r7 = r5 * 16 - r8 = r6 + r7 - r9 = borrow load_mem r8 :: vec[i64]* - r10 = r9.len - r11 = n < r10 :: unsigned - if r11 goto L9 else goto L6 :: bool + r6 = vec_get_item_unsafe_borrow[vec[i64]] v, r5 + r7 = r6.len + r8 = n < r7 :: unsigned + if r8 goto L9 else goto L6 :: bool +L6: + r9 = n + r7 + r10 = r9 < r7 :: unsigned + if r10 goto L8 else goto L7 :: bool +L7: + r11 = raise IndexError + unreachable +L8: + r12 = r9 + goto L10 +L9: + r12 = n +L10: + r13 = vec_get_item_unsafe[i64] r6, r12 + keep_alive v, r5 + return r13 + +[case testVecNestedStrGetItemWithBorrow_64bit] +from librt.vecs import vec +from mypy_extensions import i64 + +class C: + v: vec[vec[str]] + + def f(self, n: i64) -> str: + # The intermediate vec is borrowed, but the result must be owned. + return self.v[n][n] +[out] +def C.f(self, n): + self :: __main__.C + n :: i64 + r0 :: vec[vec[str]] + r1 :: native_int + r2 :: bit + r3 :: i64 + r4 :: bit + r5 :: bool + r6 :: i64 + r7 :: vec[str] + r8 :: native_int + r9 :: bit + r10 :: i64 + r11 :: bit + r12 :: bool + r13 :: i64 + r14 :: str +L0: + r0 = borrow self.v + r1 = r0.len + r2 = n < r1 :: unsigned + if r2 goto L4 else goto L1 :: bool +L1: + r3 = n + r1 + r4 = r3 < r1 :: unsigned + if r4 goto L3 else goto L2 :: bool +L2: + r5 = raise IndexError + unreachable +L3: + r6 = r3 + goto L5 +L4: + r6 = n +L5: + r7 = vec_get_item_unsafe_borrow[vec[str]] r0, r6 + r8 = r7.len + r9 = n < r8 :: unsigned + if r9 goto L9 else goto L6 :: bool L6: - r12 = n + r10 - r13 = r12 < r10 :: unsigned - if r13 goto L8 else goto L7 :: bool + r10 = n + r8 + r11 = r10 < r8 :: unsigned + if r11 goto L8 else goto L7 :: bool L7: - r14 = raise IndexError + r12 = raise IndexError unreachable L8: - r15 = r12 + r13 = r10 goto L10 L9: - r15 = n + r13 = n L10: - r16 = r9.items - r17 = r15 * 8 - r18 = r16 + r17 - r19 = load_mem r18 :: i64* - keep_alive v, r9 - return r19 + r14 = vec_get_item_unsafe[str] r7, r13 + keep_alive self, r0, r6 + return r14 [case testVecDoublyNestedGetItem_64bit] from librt.vecs import vec @@ -368,10 +409,7 @@ def f(v, n): r3 :: bit r4 :: bool r5 :: i64 - r6 :: ptr - r7 :: i64 - r8 :: ptr - r9 :: vec[vec[str]] + r6 :: vec[vec[str]] L0: r0 = v.len r1 = n < r0 :: unsigned @@ -389,12 +427,8 @@ L3: L4: r5 = n L5: - r6 = v.items - r7 = r5 * 16 - r8 = r6 + r7 - r9 = load_mem r8 :: vec[vec[str]]* - keep_alive v - return r9 + r6 = vec_get_item_unsafe[vec[vec[str]]] v, r5 + return r6 [case testVecNestedCreateWithCap_64bit] from librt.vecs import vec diff --git a/mypyc/test-data/irbuild-vec-t.test b/mypyc/test-data/irbuild-vec-t.test index 63ad14bc2d7a2..ee48d81fc29c8 100644 --- a/mypyc/test-data/irbuild-vec-t.test +++ b/mypyc/test-data/irbuild-vec-t.test @@ -215,10 +215,7 @@ def f(v, n): r3 :: bit r4 :: bool r5 :: i64 - r6 :: ptr - r7 :: i64 - r8 :: ptr - r9 :: str + r6 :: str L0: r0 = v.len r1 = n < r0 :: unsigned @@ -236,12 +233,8 @@ L3: L4: r5 = n L5: - r6 = v.items - r7 = r5 * 8 - r8 = r6 + r7 - r9 = load_mem r8 :: builtins.str* - keep_alive v - return r9 + r6 = vec_get_item_unsafe[str] v, r5 + return r6 [case testVecTOptionalGetItem_64bit] from librt.vecs import vec @@ -260,10 +253,7 @@ def f(v, n): r3 :: bit r4 :: bool r5 :: i64 - r6 :: ptr - r7 :: i64 - r8 :: ptr - r9 :: union[str, None] + r6 :: union[str, None] L0: r0 = v.len r1 = n < r0 :: unsigned @@ -281,12 +271,8 @@ L3: L4: r5 = n L5: - r6 = v.items - r7 = r5 * 8 - r8 = r6 + r7 - r9 = load_mem r8 :: union* - keep_alive v - return r9 + r6 = vec_get_item_unsafe[union[str, None]] v, r5 + return r6 [case testNewTPopLast] from typing import Tuple @@ -524,3 +510,52 @@ L0: r1 = r0 r2 = VecTApi.from_iterable(r1, a, 0) return r2 + +[case testVecTBorrowGetItem_64bit] +from librt.vecs import vec +from mypy_extensions import i64 + +class A: + x: str + +def f(v: vec[A], n: i64) -> int: + return len(v[n].x) +[out] +def f(v, n): + v :: vec[__main__.A] + n :: i64 + r0 :: native_int + r1 :: bit + r2 :: i64 + r3 :: bit + r4 :: bool + r5 :: i64 + r6 :: __main__.A + r7 :: str + r8 :: native_int + r9 :: bit + r10 :: short_int +L0: + r0 = v.len + r1 = n < r0 :: unsigned + if r1 goto L4 else goto L1 :: bool +L1: + r2 = n + r0 + r3 = r2 < r0 :: unsigned + if r3 goto L3 else goto L2 :: bool +L2: + r4 = raise IndexError + unreachable +L3: + r5 = r2 + goto L5 +L4: + r5 = n +L5: + r6 = vec_get_item_unsafe_borrow[__main__.A] v, r5 + r7 = r6.x + keep_alive v, r5 + r8 = CPyStr_Size_size_t(r7) + r9 = r8 >= 0 :: signed + r10 = r8 << 1 + return r10 diff --git a/mypyc/test-data/lowering-vec.test b/mypyc/test-data/lowering-vec.test new file mode 100644 index 0000000000000..374d336b98d8e --- /dev/null +++ b/mypyc/test-data/lowering-vec.test @@ -0,0 +1,155 @@ +[case testLowerVecI64GetItem_64bit] +from librt.vecs import vec +from mypy_extensions import i64 + +def f(i: i64) -> i64: + v = vec[i64]() + return v[i] +[out] +def f(i): + i :: i64 + r0, v :: vec[i64] + r1 :: native_int + r2 :: bit + r3 :: i64 + r4 :: bit + r5 :: bool + r6 :: i64 + r7 :: ptr + r8 :: i64 + r9 :: ptr + r10, r11 :: i64 +L0: + r0 = VecI64Api.alloc(0, 0) + if is_error(r0) goto L8 (error at f:5) else goto L1 +L1: + v = r0 + r1 = v.len + r2 = i < r1 :: unsigned + if r2 goto L6 else goto L2 :: bool +L2: + r3 = i + r1 + r4 = r3 < r1 :: unsigned + if r4 goto L5 else goto L9 :: bool +L3: + r5 = raise IndexError + if not r5 goto L8 (error at f:6) else goto L4 :: bool +L4: + unreachable +L5: + r6 = r3 + goto L7 +L6: + r6 = i +L7: + r7 = v.items + r8 = r6 * 8 + r9 = r7 + r8 + r10 = load_mem r9 :: i64* + dec_ref v + return r10 +L8: + r11 = :: i64 + return r11 +L9: + dec_ref v + goto L3 + +[case testLowerVecNestedGetItem_64bit] +from librt.vecs import vec +from mypy_extensions import i64 + +def f(i: i64, j: i64) -> str: + v = vec[vec[str]]([]) + return v[i][j] +[out] +def f(i, j): + i, j :: i64 + r0 :: object + r1 :: ptr + r2 :: vec[vec[str]] + r3 :: ptr + v :: vec[vec[str]] + r4 :: native_int + r5 :: bit + r6 :: i64 + r7 :: bit + r8 :: bool + r9 :: i64 + r10 :: ptr + r11 :: i64 + r12 :: ptr + r13 :: vec[str] + r14 :: native_int + r15 :: bit + r16 :: i64 + r17 :: bit + r18 :: bool + r19 :: i64 + r20 :: ptr + r21 :: i64 + r22 :: ptr + r23, r24 :: str +L0: + r0 = load_address PyUnicode_Type + r1 = r0 + r2 = VecNestedApi.alloc(0, 0, r1, 1) + if is_error(r2) goto L14 (error at f:5) else goto L1 +L1: + r3 = r2.items + v = r2 + r4 = v.len + r5 = i < r4 :: unsigned + if r5 goto L6 else goto L2 :: bool +L2: + r6 = i + r4 + r7 = r6 < r4 :: unsigned + if r7 goto L5 else goto L15 :: bool +L3: + r8 = raise IndexError + if not r8 goto L14 (error at f:6) else goto L4 :: bool +L4: + unreachable +L5: + r9 = r6 + goto L7 +L6: + r9 = i +L7: + r10 = v.items + r11 = r9 * 16 + r12 = r10 + r11 + r13 = borrow load_mem r12 :: vec[str]* + r14 = r13.len + r15 = j < r14 :: unsigned + if r15 goto L12 else goto L8 :: bool +L8: + r16 = j + r14 + r17 = r16 < r14 :: unsigned + if r17 goto L11 else goto L16 :: bool +L9: + r18 = raise IndexError + if not r18 goto L14 (error at f:6) else goto L10 :: bool +L10: + unreachable +L11: + r19 = r16 + goto L13 +L12: + r19 = j +L13: + r20 = r13.items + r21 = r19 * 8 + r22 = r20 + r21 + r23 = load_mem r22 :: builtins.str* + dec_ref v + return r23 +L14: + r24 = :: str + return r24 +L15: + dec_ref v + goto L3 +L16: + dec_ref v + goto L9 diff --git a/mypyc/test-data/refcount.test b/mypyc/test-data/refcount.test index 918c84ee3b0ad..cefa050e8f0e9 100644 --- a/mypyc/test-data/refcount.test +++ b/mypyc/test-data/refcount.test @@ -1606,12 +1606,9 @@ def f(v): t :: i64 r0, r1 :: native_int r2 :: bit - r3 :: ptr - r4 :: native_int - r5 :: ptr - r6, s :: str - r7 :: None - r8 :: native_int + r3, s :: str + r4 :: None + r5 :: native_int L0: t = 0 r0 = 0 @@ -1620,16 +1617,13 @@ L1: r2 = r0 < r1 :: signed if r2 goto L2 else goto L4 :: bool L2: - r3 = v.items - r4 = r0 * 8 - r5 = r3 + r4 - r6 = load_mem r5 :: builtins.str* - s = r6 - r7 = g(s) + r3 = vec_get_item_unsafe[str] v, r0 + s = r3 + r4 = g(s) dec_ref s L3: - r8 = r0 + 1 - r0 = r8 + r5 = r0 + 1 + r0 = r5 goto L1 L4: return t @@ -1684,11 +1678,7 @@ def C.f(self, x): r3 :: i64 r4 :: bit r5 :: bool - r6 :: i64 - r7 :: ptr - r8 :: i64 - r9 :: ptr - r10 :: i64 + r6, r7 :: i64 L0: r0 = borrow self.v r1 = r0.len @@ -1707,11 +1697,8 @@ L3: L4: r6 = x L5: - r7 = r0.items - r8 = r6 * 8 - r9 = r7 + r8 - r10 = load_mem r9 :: i64* - return r10 + r7 = vec_get_item_unsafe[i64] r0, r6 + return r7 [case testVecI64LenBorrowVec_64bit] from librt.vecs import vec @@ -1771,10 +1758,7 @@ def f(v, n): r3 :: bit r4 :: bool r5 :: i64 - r6 :: ptr - r7 :: i64 - r8 :: ptr - r9 :: str + r6 :: str L0: r0 = v.len r1 = n < r0 :: unsigned @@ -1792,11 +1776,8 @@ L3: L4: r5 = n L5: - r6 = v.items - r7 = r5 * 8 - r8 = r6 + r7 - r9 = load_mem r8 :: builtins.str* - return r9 + r6 = vec_get_item_unsafe[str] v, r5 + return r6 [case testVecNestedGetItem_64bit] from librt.vecs import vec @@ -1817,10 +1798,7 @@ def f(v, n): r6 :: bit r7 :: bool r8 :: i64 - r9 :: ptr - r10 :: i64 - r11 :: ptr - r12, vv :: vec[str] + r9, vv :: vec[str] L0: r0 = load_address PyUnicode_Type r1 = r0 @@ -1841,12 +1819,9 @@ L3: L4: r8 = n L5: - r9 = r2.items - r10 = r8 * 16 - r11 = r9 + r10 - r12 = load_mem r11 :: vec[str]* + r9 = vec_get_item_unsafe[vec[str]] r2, r8 dec_ref r2 - vv = r12 + vv = r9 dec_ref vv return 1 L6: @@ -1869,20 +1844,13 @@ def f(v, n, m): r3 :: bit r4 :: bool r5 :: i64 - r6 :: ptr - r7 :: i64 - r8 :: ptr - r9 :: vec[i64] - r10 :: native_int - r11 :: bit - r12 :: i64 - r13 :: bit - r14 :: bool - r15 :: i64 - r16 :: ptr - r17 :: i64 - r18 :: ptr - r19 :: i64 + r6 :: vec[i64] + r7 :: native_int + r8 :: bit + r9 :: i64 + r10 :: bit + r11 :: bool + r12, r13 :: i64 L0: r0 = v.len r1 = n < r0 :: unsigned @@ -1900,31 +1868,25 @@ L3: L4: r5 = n L5: - r6 = v.items - r7 = r5 * 16 - r8 = r6 + r7 - r9 = borrow load_mem r8 :: vec[i64]* - r10 = r9.len - r11 = m < r10 :: unsigned - if r11 goto L9 else goto L6 :: bool + r6 = vec_get_item_unsafe_borrow[vec[i64]] v, r5 + r7 = r6.len + r8 = m < r7 :: unsigned + if r8 goto L9 else goto L6 :: bool L6: - r12 = m + r10 - r13 = r12 < r10 :: unsigned - if r13 goto L8 else goto L7 :: bool + r9 = m + r7 + r10 = r9 < r7 :: unsigned + if r10 goto L8 else goto L7 :: bool L7: - r14 = raise IndexError + r11 = raise IndexError unreachable L8: - r15 = r12 + r12 = r9 goto L10 L9: - r15 = m + r12 = m L10: - r16 = r9.items - r17 = r15 * 8 - r18 = r16 + r17 - r19 = load_mem r18 :: i64* - return r19 + r13 = vec_get_item_unsafe[i64] r6, r12 + return r13 [case testVecPop] from librt.vecs import vec, pop, append diff --git a/mypyc/test/test_expand_rtype.py b/mypyc/test/test_expand_rtype.py new file mode 100644 index 0000000000000..90acc61f625d9 --- /dev/null +++ b/mypyc/test/test_expand_rtype.py @@ -0,0 +1,48 @@ +import unittest + +from mypyc.ir.class_ir import ClassIR +from mypyc.ir.rtypes import ( + RInstance, + RTuple, + RTypeVar, + RUnion, + RVec, + int_rprimitive, + str_rprimitive, + void_rtype, +) +from mypyc.rt_expandtype import expand_rtype + + +class TestExpandRType(unittest.TestCase): + def test_trivial(self) -> None: + assert expand_rtype(str_rprimitive, []) == str_rprimitive + assert expand_rtype(str_rprimitive, [int_rprimitive]) == str_rprimitive + assert expand_rtype(void_rtype, []) == void_rtype + + def test_instance(self) -> None: + inst = RInstance(ClassIR("A", "__main__")) + assert expand_rtype(inst, [int_rprimitive]) == inst + + def test_simple_expansion(self) -> None: + assert expand_rtype(RTypeVar(0), [str_rprimitive]) == str_rprimitive + + def test_tuple_expansion(self) -> None: + assert expand_rtype( + RTuple([RTypeVar(0), RTypeVar(1)]), [str_rprimitive, int_rprimitive] + ) == RTuple([str_rprimitive, int_rprimitive]) + + def test_union_expansion(self) -> None: + assert expand_rtype( + RUnion([RTypeVar(0), RTypeVar(1)]), [str_rprimitive, int_rprimitive] + ) == RUnion([str_rprimitive, int_rprimitive]) + + def test_vec_expansion(self) -> None: + assert expand_rtype(RVec(RTypeVar(0)), [str_rprimitive]) == RVec(str_rprimitive) + + def test_nested_expansion(self) -> None: + typ = RUnion([RTuple([RVec(RTypeVar(0)), RTypeVar(1)]), RVec(RVec(RTypeVar(0)))]) + expected = RUnion( + [RTuple([RVec(str_rprimitive), int_rprimitive]), RVec(RVec(str_rprimitive))] + ) + assert expand_rtype(typ, [str_rprimitive, int_rprimitive]) == expected diff --git a/mypyc/test/test_lowering.py b/mypyc/test/test_lowering.py index e27b4e77eea8a..3eb4698c68793 100644 --- a/mypyc/test/test_lowering.py +++ b/mypyc/test/test_lowering.py @@ -28,7 +28,7 @@ class TestLowering(MypycDataSuite): - files = ["lowering-int.test", "lowering-list.test"] + files = ["lowering-int.test", "lowering-list.test", "lowering-vec.test"] base_path = test_temp_dir def run_case(self, testcase: DataDrivenTestCase) -> None: