Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions mypyc/ir/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class to enable the new behavior. Sometimes adding a new abstract
RStruct,
RTuple,
RType,
RTypeVar,
RUnion,
RVec,
RVoid,
Expand Down Expand Up @@ -703,7 +704,7 @@ def __init__(
self,
name: str,
arg_types: list[RType],
return_type: RType, # TODO: What about generic?
return_type: RType,
var_arg_type: RType | None,
truncated_type: RType | None,
c_function_name: str | None,
Expand All @@ -716,6 +717,7 @@ def __init__(
is_pure: bool,
experimental: bool,
dependencies: list[Dependency] | None,
type_params: list[RTypeVar] | None,
) -> None:
# Each primitive much have a distinct name, but otherwise they are arbitrary.
self.name: Final = name
Expand Down Expand Up @@ -749,6 +751,7 @@ def __init__(
# If this flag is set, the primitive has native integer types and must
# be matched using more complex rules.
self.is_ambiguous = any(has_fixed_width_int(t) for t in arg_types)
self.type_params = None if not type_params else type_params

def __repr__(self) -> str:
return f"<PrimitiveDescription {self.name!r}: {self.arg_types}>"
Expand Down Expand Up @@ -776,11 +779,23 @@ class PrimitiveOp(RegisterOp):
code paths for short and long representations.
"""

def __init__(self, args: list[Value], desc: PrimitiveDescription, line: int = -1) -> None:
def __init__(
self,
args: list[Value],
desc: PrimitiveDescription,
line: int = -1,
*,
arg_types: list[RType] | None = None,
return_type: RType | None = None,
type_args: list[RType] | None = None,
) -> None:
self.error_kind = desc.error_kind
super().__init__(line)
self.args = args
self.type = desc.return_type
self.arg_types = arg_types if arg_types is not None else desc.arg_types
self.type = return_type if return_type is not None else desc.return_type
self.is_borrowed = desc.is_borrowed
self.type_args = type_args
self.desc = desc

def sources(self) -> list[Value]:
Expand Down
9 changes: 7 additions & 2 deletions mypyc/ir/pprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,10 +231,15 @@ def visit_call_c(self, op: CallC) -> str:

def visit_primitive_op(self, op: PrimitiveOp) -> str:
args_str = ", ".join(self.format("%r", arg) for arg in op.args)
if op.type_args:
joined = ", ".join(str(arg) for arg in op.type_args)
type_args = f"[{joined}]"
else:
type_args = ""
if op.is_void:
return self.format("%s %s", op.desc.name, args_str)
return self.format("%s%s %s", op.desc.name, type_args, args_str)
else:
return self.format("%r = %s %s", op, op.desc.name, args_str)
return self.format("%r = %s%s %s", op, op.desc.name, type_args, args_str)

def visit_truncate(self, op: Truncate) -> str:
return self.format("%r = truncate %r: %t to %t", op, op.src, op.src_type, op.type)
Expand Down
53 changes: 53 additions & 0 deletions mypyc/ir/rtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ def visit_rprimitive(self, typ: RPrimitive, /) -> T:
def visit_rinstance(self, typ: RInstance, /) -> T:
raise NotImplementedError

def visit_rtypevar(self, typ: RTypeVar, /) -> T:
raise RuntimeError("RTypeVar should not be encountered here")

@abstractmethod
def visit_rvec(self, typ: RVec, /) -> T:
raise NotImplementedError
Expand Down Expand Up @@ -747,6 +750,12 @@ def visit_rarray(self, t: RArray) -> str:
def visit_rvoid(self, t: RVoid) -> str:
assert False, "rvoid in tuple?"

def visit_rtypevar(self, typ: RTypeVar) -> str:
# We need to return something to support generic RTuples, etc. Make sure
# the return value is invalid C so that generic RTuples must be expanded
# before they can be used in IR.
return f"!RTypeVar {typ.id} invalid!"


@final
class RTuple(RType):
Expand Down Expand Up @@ -1013,6 +1022,50 @@ def serialize(self) -> str:
return self.name


@final
class RTypeVar(RType):
"""Type variable type used for generic primitive ops.

This allows having generic primitive operations like vec get item, which is
parametrized by the vec item type.

These types are not valid in any other context outside PrimitiveDescription,
and they will always be substituted during the construction of a PrimitiveOp.

NOTE: This is not related to mypy's TypeVarType!
"""

def __init__(self, id: int) -> None:
self.id = id

@property
def may_be_immortal(self) -> bool:
# RTypeVar must always be substituted before use, so this should never matter.
return False

def accept(self, visitor: RTypeVisitor[T]) -> T:
return visitor.visit_rtypevar(self)

def __str__(self) -> str:
return f"<RTypeVar {self.id}>"

def __repr__(self) -> str:
return f"<RTypeVar {self.id}>"

def __eq__(self, other: object) -> TypeGuard[RTypeVar]:
return isinstance(other, RTypeVar) and other.id == self.id

def __hash__(self) -> int:
return self.id ^ 12345

def serialize(self) -> JsonDict:
return {".class": "RTypeVar", "id": self.id}

@classmethod
def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> RTypeVar:
return RTypeVar(data["id"])


@final
class RVec(RType):
"""librt.vecs.vec[T]"""
Expand Down
27 changes: 24 additions & 3 deletions mypyc/irbuild/ll_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@
new_tuple_with_length_op,
sequence_tuple_op,
)
from mypyc.rt_expandtype import expand_rtype
from mypyc.rt_subtype import is_runtime_subtype
from mypyc.sametype import is_same_type
from mypyc.subtype import is_subtype
Expand Down Expand Up @@ -2334,6 +2335,7 @@ def primitive_op(
args: list[Value],
line: int,
result_type: RType | None = None,
type_args: list[RType] | None = None,
) -> Value:
"""Add a primitive op."""
# Does this primitive map into calling a Python C API
Expand Down Expand Up @@ -2363,18 +2365,37 @@ def primitive_op(
# This primitive gets transformed in a lowering pass to
# lower-level IR ops using a custom transform function.

# Evaluate argument and return types for generic primitives
return_type = None
if desc.type_params is not None:
assert type_args is not None, "Generic primitive op requires explicit type arguments"
assert len(type_args) == len(desc.type_params)
arg_types = [expand_rtype(arg_type, type_args) for arg_type in desc.arg_types]
return_type = expand_rtype(desc.return_type, type_args)
else:
arg_types = desc.arg_types

coerced = []
# Coerce fixed number arguments
for i in range(min(len(args), len(desc.arg_types))):
formal_type = desc.arg_types[i]
for i in range(min(len(args), len(arg_types))):
formal_type = arg_types[i]
arg = args[i]
assert formal_type is not None # TODO
arg = self.coerce(arg, formal_type, line)
coerced.append(arg)
assert desc.ordering is None
assert desc.var_arg_type is None
assert not desc.extra_int_constants
target = self.add(PrimitiveOp(coerced, desc, line=line))
target = self.add(
PrimitiveOp(
coerced,
desc,
line=line,
arg_types=arg_types,
return_type=return_type,
type_args=type_args,
)
)
if desc.is_borrowed:
# If the result is borrowed, force the arguments to be
# kept alive afterwards, as otherwise the result might be
Expand Down
23 changes: 20 additions & 3 deletions mypyc/irbuild/vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
vec_api_by_item_type,
vec_item_type_tags,
)
from mypyc.primitives.librt_vecs_ops import vec_get_item_unsafe_borrow_op, vec_get_item_unsafe_op

if TYPE_CHECKING:
from mypyc.irbuild.ll_builder import LowLevelIRBuilder
Expand Down Expand Up @@ -316,9 +317,10 @@ def vec_get_item(
) -> Value:
"""Generate inlined vec __getitem__ call.
We inline this, since it's simple but performance-critical.
We inline the length and bounds check, since they are simple but
performance-critical. The actual item load is emitted as a generic primitive
op that is lowered later.
"""
# TODO: Support more item types
# TODO: Support more index types
len_val = vec_len(builder, base)
index = vec_check_and_adjust_index(builder, len_val, index, line)
Expand All @@ -328,7 +330,22 @@ def vec_get_item(
def vec_get_item_unsafe(
builder: LowLevelIRBuilder, base: Value, index: Value, line: int, *, can_borrow: bool = False
) -> Value:
"""Get vec item, assuming index is non-negative and within bounds."""
"""Get vec item, assuming index is non-negative and within bounds.
This emits a generic primitive op that is inlined during lowering.
"""
assert isinstance(base.type, RVec)
if can_borrow:
desc = vec_get_item_unsafe_borrow_op
else:
desc = vec_get_item_unsafe_op
return builder.primitive_op(desc, [base, index], line, type_args=[base.type.item_type])


def vec_get_item_unsafe_lower(
builder: LowLevelIRBuilder, base: Value, index: Value, line: int, *, can_borrow: bool = False
) -> Value:
"""Generate the low-level IR for an unsafe vec item load."""
assert isinstance(base.type, RVec)
index = as_platform_int(builder, index, line)
vtype = base.type
Expand Down
2 changes: 1 addition & 1 deletion mypyc/lower/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ def wrapper(f: LF) -> LF:


# Import various modules that set up global state.
from mypyc.lower import int_ops, list_ops, misc_ops # noqa: F401
from mypyc.lower import int_ops, list_ops, misc_ops, vec_ops # noqa: F401
18 changes: 18 additions & 0 deletions mypyc/lower/vec_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from __future__ import annotations

from mypyc.ir.ops import Value
from mypyc.irbuild.ll_builder import LowLevelIRBuilder
from mypyc.irbuild.vec import vec_get_item_unsafe_lower
from mypyc.lower.registry import lower_primitive_op


@lower_primitive_op("vec_get_item_unsafe")
def vec_get_item_unsafe(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
base, index = args
return vec_get_item_unsafe_lower(builder, base, index, line, can_borrow=False)


@lower_primitive_op("vec_get_item_unsafe_borrow")
def vec_get_item_unsafe_borrow(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
base, index = args
return vec_get_item_unsafe_lower(builder, base, index, line, can_borrow=True)
23 changes: 22 additions & 1 deletion mypyc/primitives/librt_vecs_ops.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from mypyc.ir.deps import LIBRT_VECS, VECS_EXTRA_OPS
from mypyc.ir.ops import ERR_MAGIC, ERR_NEVER
from mypyc.ir.rtypes import (
RTypeVar,
RVec,
bit_rprimitive,
bytes_rprimitive,
int64_rprimitive,
object_rprimitive,
uint8_rprimitive,
)
from mypyc.primitives.registry import function_op
from mypyc.primitives.registry import custom_primitive_op, function_op

# isinstance(obj, vec)
isinstance_vec = function_op(
Expand All @@ -28,3 +30,22 @@
error_kind=ERR_MAGIC,
dependencies=[LIBRT_VECS, VECS_EXTRA_OPS],
)

# Get vec item, assuming the index is valid (no bounds check)
vec_get_item_unsafe_op = custom_primitive_op(
name="vec_get_item_unsafe",
arg_types=[RVec(RTypeVar(0)), int64_rprimitive],
return_type=RTypeVar(0),
error_kind=ERR_NEVER,
type_params=[RTypeVar(0)],
)

# Like vec_get_item_unsafe, but the result is a borrowed reference
vec_get_item_unsafe_borrow_op = custom_primitive_op(
name="vec_get_item_unsafe_borrow",
arg_types=[RVec(RTypeVar(0)), int64_rprimitive],
is_borrowed=True,
return_type=RTypeVar(0),
error_kind=ERR_NEVER,
type_params=[RTypeVar(0)],
)
8 changes: 7 additions & 1 deletion mypyc/primitives/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@

from mypyc.ir.deps import Dependency
from mypyc.ir.ops import PrimitiveDescription, StealsDescription
from mypyc.ir.rtypes import RType
from mypyc.ir.rtypes import RType, RTypeVar

# Error kind for functions that return negative integer on exception. This
# is only used for primitives. We translate it away during IR building.
Expand Down Expand Up @@ -154,6 +154,7 @@ def method_op(
is_pure=is_pure,
experimental=experimental,
dependencies=dependencies,
type_params=None,
)
ops.append(desc)
return desc
Expand Down Expand Up @@ -204,6 +205,7 @@ def function_op(
is_pure=False,
experimental=experimental,
dependencies=dependencies,
type_params=None,
)
ops.append(desc)
return desc
Expand Down Expand Up @@ -253,6 +255,7 @@ def binary_op(
is_pure=False,
experimental=False,
dependencies=dependencies,
type_params=None,
)
ops.append(desc)
return desc
Expand Down Expand Up @@ -313,6 +316,7 @@ def custom_primitive_op(
is_pure: bool = False,
experimental: bool = False,
dependencies: list[Dependency] | None = None,
type_params: list[RTypeVar] | None = None,
) -> PrimitiveDescription:
"""Define a primitive op that can't be automatically generated based on the AST.

Expand All @@ -336,6 +340,7 @@ def custom_primitive_op(
is_pure=is_pure,
experimental=experimental,
dependencies=dependencies,
type_params=type_params,
)


Expand Down Expand Up @@ -380,6 +385,7 @@ def unary_op(
is_pure=is_pure,
experimental=False,
dependencies=dependencies,
type_params=None,
)
ops.append(desc)
return desc
Expand Down
Loading
Loading