From 8306055714d6b1da289e064e60cdcc9851ff7021 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Tue, 26 May 2026 23:11:17 +0900 Subject: [PATCH 1/2] [Sim] Remove dead bit-pack formula for bool in dump_args MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit dump_args:109 had `(N+7)//8` for torch.bool which computed a bit-packed byte count — inconsistent with the rest of the data path which uses 1 byte per bool throughout: - write_arg writes tensor.untyped_storage() as raw bytes (N bytes for N bools) - C wrapper load_arg reads N * sizeof(uint8_t) bytes - C wrapper malloc uses bits=8 for bool (mlir_caller_codegen.py:101) - DTYPE_TO_C[torch.bool] = "uint8_t" - spike sees the raw N bytes The (N+7)//8 line was a vestige of an abandoned i1-storage experiment; the wrapper C ABI is structurally byte-aligned and cannot accept bit-packed bool shapes. Plus array_size is unused at the caller (`_, file_path = self.dump_args(...)` at line 134), so this dead line was silently inconsistent. Co-Authored-By: Claude Opus 4.7 (1M context) --- Simulator/simulator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Simulator/simulator.py b/Simulator/simulator.py index 2b9f05be..9e1da76f 100644 --- a/Simulator/simulator.py +++ b/Simulator/simulator.py @@ -106,7 +106,7 @@ def dump_args(self, args, arg_attributes, load_path, dump_path): array_size = [] file_path = [] for (arg_name, arg_attribute), arg in zip(arg_attributes, args): - size = arg_attribute[2] if arg_attribute[1] != torch.bool else (arg_attribute[2] + 7) // 8 + size = arg_attribute[2] array_size.append(size) if MLIRKernelArgs.is_mlir_arg_in(arg_attribute[0]): index = self.write_arg(arg, load_path, arg_name) From bd7e98f0dba6393f737d4f8e596e0d79332625d5 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Tue, 26 May 2026 23:11:53 +0900 Subject: [PATCH 2/2] [Frontend] Fix #238: replace var_info dict with typed CSEVariable attributes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Issue #238 was the visible symptom (silent bool/uint8 -> int8 downcast via the lossy `MLIR_TO_DTYPE[var_info[1]]` round-trip at mlir_codegen_backend.py:1535) of a deeper architectural smell: PyTorchSim maintained a parallel `self.var_info` dict tracking `[vec_size, mlir_dtype_string]` per csevar, duplicating type info that already lives on Inductor's `CSEVariable.dtype`. The lossy MLIR->torch round-trip was the only place this duplication actively caused corruption, but collapsing the two systems is the structural fix. Core changes: - New type: - `MLIRCSEVariable(common.CSEVariable)` carries `vec_size: int` and inherits `dtype: Optional[torch.dtype]`. `mlir_dtype` is a derived @property from `dtype` via `DTYPE_TO_MLIR`. There is no separate predicate/mask subclass: `torch.bool` maps to MLIR `"i1"` directly (DTYPE_TO_MLIR[torch.bool] = "i1"). MLIR-to-LLVM lowering pads i1 storage to bytes, matching the wrapper C ABI (`uint8_t*`, one byte per element). The wrapper architecturally cannot accept bit-packed i1 storage (mlir_caller_codegen.py uses sizeof(ctype) loads), so the `memref<...xi1>` -> `i8`-backed pipeline is the natural fit. - `OpResult(vec_size, dtype)` frozen dataclass replaces the legacy `[vec_size, mlir_dtype_string]` ret_info list. `OpResult.from_var` and `OpResult.from_mlir` are classmethod constructors. - `INDEX_DTYPE` singleton sentinel for MLIR `index` type (no torch equivalent). `MLIR_TO_DTYPE["index"] = INDEX_DTYPE` and `DTYPE_TO_MLIR[INDEX_DTYPE] = "index"` so the dicts are bijective for all known types — clearer than overloading `None`. - `MLIRCSE(common.CSE)` extends Inductor's CSE with a `vec_size` axis: - `newvar` / `namedvar` construct `MLIRCSEVariable` directly, bypassing the kernel-side `V.kernel.create_cse_var` hook (which is no longer needed). - `generate(buffer, code, *, vec_size=N, dtype=X, ...)` plumbs `vec_size` to `newvar` via a transient instance attribute, calling `super().generate(...)` for the rest. No need to reimplement the upstream generate body. - Handler proxy (mlir_common.py CSEProxy) rewritten to expect `(code, OpResult|None)` from ops. Single uniform path: `target_cse.generate(buf, code, dtype=ret.dtype, vec_size=ret.vec_size)` — no post-hoc attribute assignment. - All ops in mlir_ops.py, mlir_template.py, mlir_sort_template.py return `(code, OpResult)` (or `OpResult.from_var` / `OpResult.from_mlir` helpers). Legacy `[size, mlir_str]` shape gone. - `register_var_info` / `register_var_cse` deleted. Six previously-named csevars (`compute_idx`, `itervar_cses`, `init_iter`, `reduce_loop_idx`, `idx_step_index`, `idx_base`) now use `cse.namedvar(..., dtype=..., vec_size=...)` directly. `make_named_csevar` wrapper removed. - ~108 read sites of `var_info[v][...]` migrated to attribute access (`v.vec_size`, `v.mlir_dtype`). `var_info[v][1] == "i1"` patterns collapse to `v.dtype == torch.bool` since the mask subclass is gone. - `self.var_info` dict removed entirely. - Issue #238 fix at mlir_codegen_backend.py:1535: csevar = self.cse.varname_map[target_dim] dtype = csevar.dtype No more round-trip; the torch dtype set at csevar construction is preserved end-to-end. Files touched: mlir_common.py (foundation), mlir_codegen_backend.py (#238 site + read migration + memory-entry call sites), mlir_ops.py (ops layer ret_info migration), mlir_template.py + mlir_sort_template.py (template ops + named csevar sites). Sample-verified: test_add, test_softmax, test_sort (i1 mask path via cmp), test_matmul, test_layernorm, test_indirect_access (#238 critical path), test_expert_mask, test_transcendental, test_reduce. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../mlir/mlir_codegen_backend.py | 80 ++-- PyTorchSimFrontend/mlir/mlir_common.py | 199 ++++++++- PyTorchSimFrontend/mlir/mlir_ops.py | 418 +++++++----------- PyTorchSimFrontend/mlir/mlir_sort_template.py | 10 +- PyTorchSimFrontend/mlir/mlir_template.py | 27 +- 5 files changed, 415 insertions(+), 319 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index 9c20311c..55979969 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -295,17 +295,17 @@ def __init__(self, kernel_group, reason=None): self.header.writeline(" return p;") self.header.writeline("}") self.header.writeline("void __wrap_free(void *ptr) { return; }") - self.reduction_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="tmp_acc") - self.spad_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="spad") - self.apply_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="apply") - self.mask_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="mask") - self.iterator_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="iter") - self.init_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="init") - self.init_vec_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="init_vec") - self.const_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="const") - self.alloc_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="alloc") - self.indexed_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="indexed_op") - self.map_cse = common.CSE("#", self.suffix, name_prefix="map") + self.reduction_cse = mlir_common.MLIRCSE(self.newvar_prefix, self.suffix, name_prefix="tmp_acc") + self.spad_cse = mlir_common.MLIRCSE(self.newvar_prefix, self.suffix, name_prefix="spad") + self.apply_cse = mlir_common.MLIRCSE(self.newvar_prefix, self.suffix, name_prefix="apply") + self.mask_cse = mlir_common.MLIRCSE(self.newvar_prefix, self.suffix, name_prefix="mask") + self.iterator_cse = mlir_common.MLIRCSE(self.newvar_prefix, self.suffix, name_prefix="iter") + self.init_cse = mlir_common.MLIRCSE(self.newvar_prefix, self.suffix, name_prefix="init") + self.init_vec_cse = mlir_common.MLIRCSE(self.newvar_prefix, self.suffix, name_prefix="init_vec") + self.const_cse = mlir_common.MLIRCSE(self.newvar_prefix, self.suffix, name_prefix="const") + self.alloc_cse = mlir_common.MLIRCSE(self.newvar_prefix, self.suffix, name_prefix="alloc") + self.indexed_cse = mlir_common.MLIRCSE(self.newvar_prefix, self.suffix, name_prefix="indexed_op") + self.map_cse = mlir_common.MLIRCSE("#", self.suffix, name_prefix="map") self.global_vars_dict = dict() self.reduction_vars = dict() self.consts = dict() @@ -549,7 +549,12 @@ def load(self, name: str, index: sympy.Expr): else: # FIXME. Any good idea? out = sram_var - self.register_var_info(out, [compute_vec_size, mlir_dtype]) + # `out` is the spad memref reference (an MLIRCSEVariable from + # spad_cse.generate). Annotate it with the load's compute-vec + # size and torch dtype so downstream attribute reads (vec_size, + # mlir_dtype) reflect the load shape. + out.vec_size = compute_vec_size + out.dtype = dtype self.spad_buffer_dict[str(out)] = [sram_var, local_tile_desc.get_tile_size(), tile_numel_per_lane, sram_index_var, tile_shape, vshape] return out @@ -593,11 +598,11 @@ def store(self, name: str, index: sympy.Expr, value, mode=None, *args, **kwargs) sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, local_tile_desc, index) compute_index_var = ",".join(sram_index_var.split(",")[:-1] + [f"%{self.compute_idx}"]) # Generate vector store instruction - _, operand_type = self.var_info[value] + _, operand_type = value.vec_size, value.mlir_dtype if mlir_dtype != operand_type: value = ops.to_dtype(value, mlir_dtype) - if compute_vec_size < self.var_info[value][0]: + if compute_vec_size < value.vec_size: with self.override_buffer_cse(buffer=self.stores): value = ops.extract_strided_slice(value, compute_vec_size) @@ -644,6 +649,9 @@ def reduction(self, dtype, src_dtype, reduction_type, value): init = self.get_const_cse(reduction_init(reduction_type, dtype), type_name) init_vec = init if vec_len == 1 else ops.broadcast(init, vec_len) + # The outermost acc (reduction_depth == 0) carries the final reduced + # shape; inner accumulators stay at default vec_size=1 until lowered. + outer_reduction_size = self.kernel_group.tile_desc.get_numel_per_lane() // self.kernel_group.tile_desc.get_reduction_numel() acc_var_list = [] iter_var_list = [] for reduction_depth in range(self.get_nr_rdim()): @@ -651,7 +659,10 @@ def reduction(self, dtype, src_dtype, reduction_type, value): reduction_key = src_dtype, reduction_type, value, reduction_depth acc_init_var = init_vec if reduction_depth == 0 else iter_var_list[-1] - acc = self.reduction_cse.generate(self.loads, f"reduction {reduction_key}", write=False) + acc = self.reduction_cse.generate( + self.loads, f"reduction {reduction_key}", + write=False, dtype=dtype, vec_size=outer_reduction_size, + ) iterator = self.iterator_cse.generate(self.loads, f"reduction {reduction_key}", write=False) acc_var_list.append(acc) iter_var_list.append(iterator) @@ -664,8 +675,10 @@ def reduction(self, dtype, src_dtype, reduction_type, value): # Note: reduction body is inner most loop body. So it doesn't need reduction depth. body_key = src_dtype, reduction_type, value body_acc = self.reduction_cse.generate(self.compute, f"reduction {body_key}body_acc", write=False) - body_iter_arg = self.iterator_cse.generate(self.compute, f"reduction {body_key}body_iter_arg", write=False) - self.register_var_info(body_iter_arg, [vec_len, type_name]) + body_iter_arg = self.iterator_cse.generate( + self.compute, f"reduction {body_key}body_iter_arg", + write=False, dtype=dtype, vec_size=vec_len, + ) acc_var_list.append(body_acc) # Reduction body codegen @@ -683,9 +696,8 @@ def reduction(self, dtype, src_dtype, reduction_type, value): self.affine_yield[acc] = reduced_shape, reduction_depth # Final reduction - reduction_size = self.kernel_group.tile_desc.get_numel_per_lane() // self.kernel_group.tile_desc.get_reduction_numel() - acc = acc_var_list[0] # Set outermost acc var - self.register_var_info(acc, [reduction_size, type_name]) + reduction_size = outer_reduction_size # already attached to acc_var_list[0] + acc = acc_var_list[0] # outermost acc var (already typed at creation) assert(vec_len % reduction_size==0) # Prepare init value @@ -794,7 +806,7 @@ def _index_expr(self, tile_desc, renamed_expression, index, base_vector_index): with self.override_buffer_cse(buffer=self.const_buffer, cse=self.const_cse): vlane_offset = ops.vlane_offset(vlane_vec, vlane_vec, attributes={"vlane_offset": offset}, comment="vlane offset") - if compute_vec_size < self.var_info[vlane_offset][0]: + if compute_vec_size < vlane_offset.vec_size: vlane_offset = ops.extract_strided_slice(vlane_offset, compute_vec_size) vlane_offset = ops.index_cast(vlane_offset, "index") dim = ops.add(dim, vlane_offset) @@ -874,7 +886,7 @@ def index_expr(self, index, dtype): # Initialize base vector if not self.base_vector_initialized: - init_iter = self.register_var_cse("init_iter", 1, "index") + init_iter = self.cse.namedvar("init_iter", dtype=mlir_common.INDEX_DTYPE) parallel_map = f"affine.parallel (%{init_iter}) = ({0}) to ({compute_vec_size}) {{ // Base vector initializer" self.spad_buffer.writeline(parallel_map) with self.spad_buffer.indent(): @@ -1479,8 +1491,12 @@ def get_const_cse(self, value, dtype="index") -> common.CSEVariable: value = int(value) key = str(value)+dtype if key not in self.consts: - self.consts[key] = self.const_cse.generate(self.const_buffer, f"arith.constant {value} : {dtype}") - self.register_var_info(self.consts[key], [1, dtype]) + # MLIR_TO_DTYPE maps "index" -> INDEX_DTYPE sentinel (not + # torch.int64, which would make mlir_dtype derive to "i64"). + self.consts[key] = self.const_cse.generate( + self.const_buffer, f"arith.constant {value} : {dtype}", + dtype=mlir_common.MLIR_TO_DTYPE.get(dtype), + ) return self.consts[key] def get_tag_cse(self, value=None, shape="memref<1xi32>"): @@ -1531,15 +1547,17 @@ def convert_indirect_indexing(self, index :sympy.Expr): if target_dim in self.spad_buffer_dict: sram_var, _, tile_numel_per_lane, sram_index_var, tile_shape, vshape = self.spad_buffer_dict[target_dim] else: - # FIXME. - var_info = [v for k, v in self.var_info.items() if str(k) == target_dim][0] - dtype = mlir_common.MLIR_TO_DTYPE[var_info[1]] + # Issue #238: read torch dtype directly from the csevar's attribute + # rather than round-tripping the MLIR string through MLIR_TO_DTYPE + # (which silently downcasts bool/uint8 to int8). + csevar = self.cse.varname_map[target_dim] + dtype = csevar.dtype local_tile_desc = self.kernel_group.tile_desc tile_numel_per_lane = local_tile_desc.get_numel_per_lane() - tile_shape = local_tile_desc.get_mlir_shape(var_info[1]) + tile_shape = local_tile_desc.get_mlir_shape(csevar.mlir_dtype) tile_vec = local_tile_desc.get_compute_vec_size() - vshape = f"vector<{var_info[0]}x{var_info[1]}>" + vshape = f"vector<{csevar.vec_size}x{csevar.mlir_dtype}>" sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, target_dim, local_tile_desc, target_dim) self.spad_buffer_dict[target_dim] = [sram_var, local_tile_desc.get_tile_size(), tile_numel_per_lane, sram_index_var, tile_shape, vshape] @@ -1559,7 +1577,7 @@ def convert_indirect_indexing(self, index :sympy.Expr): if "tmp" not in str(arg): continue if arg.is_Mul and arg.args[0].is_number: - coeff_dtype = self.var_info[spad_vars[str(arg.args[1])]][1] + coeff_dtype = spad_vars[str(arg.args[1])].mlir_dtype coeff = self.get_const_cse(int(arg.args[0]), coeff_dtype) spad_vars[str(arg.args[1])] = ops.mul(spad_vars[str(arg.args[1])], coeff) index = index.replace(arg, 0) @@ -1577,7 +1595,7 @@ def convert_indirect_indexing(self, index :sympy.Expr): ops._store(spad_vars[first_dim], sram_var, sram_index_var, tile_shape) # FIXME. Maybe require fine grain compute... # Conversion - mlir_dtype = self.var_info[spad_vars[first_dim]][1] + mlir_dtype = spad_vars[first_dim].mlir_dtype with self.override_buffer_cse(buffer=target_dma_buffers): out = ops._load(1, mlir_dtype, sram_var, sram_index_var, tile_shape) if mlir_dtype != "index": diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index 734ca967..0ae09b8d 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -40,6 +40,28 @@ schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") +class _IndexDType: + """Sentinel for MLIR ``index`` type values. + + MLIR's ``index`` is a platform-dependent integer with no exact + ``torch.dtype`` equivalent (it's not strictly i64 — mapping it to + ``torch.int64`` would make ``MLIRCSEVariable.mlir_dtype`` derive to + ``"i64"`` and break consumers that expect the literal ``"index"`` + keyword in MLIR text). + + Use ``INDEX_DTYPE`` (the singleton instance) wherever a csevar's + ``dtype`` is the MLIR index type. Compared by identity. + """ + + __slots__ = () + + def __repr__(self) -> str: + return "INDEX_DTYPE" + + +INDEX_DTYPE = _IndexDType() + + DTYPE_TO_MLIR = { torch.float32: "f32", torch.float64: "f64", @@ -49,8 +71,15 @@ torch.int16: "i16", torch.int8: "i8", torch.uint8: "i8", + # torch.bool maps to "i8" for *storage* (memref/SRAM), matching the + # wrapper-C ABI (uint8_t*, byte-aligned). The `i1` MLIR type is used + # only for predicate SSA values (cmp results), where it lowers to + # RISC-V V-extension mask registers via `vlm.v`. Predicate-producing + # ops set `mlir_dtype="i1"` explicitly on their OpResult; otherwise + # bool csevars derive `"i8"` from this table. torch.bool: "i8", torch.bfloat16: "bf16", + INDEX_DTYPE: "index", } MLIR_TO_DTYPE = { @@ -61,7 +90,11 @@ "i32": torch.int32, "i16": torch.int16, "i8": torch.int8, + # "i1" -> torch.bool: enables `OpResult.from_mlir("i1")` for predicate + # ops (cmp/logical) that compose their type as an MLIR string. + "i1": torch.bool, "bf16": torch.bfloat16, + "index": INDEX_DTYPE, } DTYPE_TO_C = { @@ -94,7 +127,126 @@ def get_dtype_nbytes(dtype): mlir_dtype = DTYPE_TO_MLIR.get(dtype) if mlir_dtype is None or mlir_dtype not in MLIR_TO_BIT: raise NotImplementedError(f"Unsupported dtype for precision calculation: {dtype}") - return MLIR_TO_BIT[mlir_dtype] // 8 + # MLIR_TO_BIT["i1"] = 1 (semantic bit width). Storage rounds up to + # at least 1 byte (the wrapper C ABI is byte-aligned; LLVM lowers i1 + # storage as a padded byte). + return max(1, MLIR_TO_BIT[mlir_dtype] // 8) + + +class MLIRCSEVariable(common.CSEVariable): + """An MLIR SSA value with attached vec_size and torch dtype. + + ``dtype`` is one of: a ``torch.dtype``, :data:`INDEX_DTYPE` sentinel + (for MLIR ``index``), or ``None`` (not yet typed). ``mlir_dtype`` is + normally derived via :data:`DTYPE_TO_MLIR`. + + ``torch.bool`` maps to MLIR ``"i8"`` (storage; wrapper-C ABI is + ``uint8_t*``, byte-aligned). Predicate SSA values (cmp/logical + results) use MLIR ``"i1"`` instead; they pass ``mlir_dtype="i1"`` + explicitly so the override below takes effect. ``"i1"`` storage is + forbidden because RISC-V V-extension ``vlm.v`` interprets memory as + bit-packed, mismatching the wrapper-C 1-byte-per-bool convention. + """ + + def __init__(self, name, bounds, dtype=None, *, vec_size=1, mlir_dtype=None): + super().__init__(name, bounds, dtype=dtype) + self.vec_size = vec_size + # ``None`` means derive from ``dtype``. Set explicitly only when + # the MLIR type differs from the torch-dtype default (the bool + # storage-vs-predicate split). + self._mlir_dtype_override: Optional[str] = mlir_dtype + + @property + def mlir_dtype(self) -> str: + if self._mlir_dtype_override is not None: + return self._mlir_dtype_override + # ``dtype=None`` is treated as MLIR index for defensive backward + # compatibility, but new code should pass :data:`INDEX_DTYPE` + # explicitly to make intent clear. + return DTYPE_TO_MLIR.get(self.dtype, "index") + + +@dataclass(frozen=True) +class OpResult: + """Ops handler return-info payload, replacing the legacy + ``[vec_size, mlir_dtype_string]`` list. The proxy uses this to + instantiate the resulting :class:`MLIRCSEVariable`. ``dtype= + INDEX_DTYPE`` means MLIR ``index`` type; ``None`` means unknown. + + ``mlir_dtype`` overrides the derived value for the storage-vs- + predicate bool split: cmp/logical ops set + ``OpResult(dtype=torch.bool, mlir_dtype="i1")`` to mark the result + as an i1 predicate even though bool storage maps to ``"i8"``. + """ + + vec_size: int + dtype: Optional[torch.dtype] + mlir_dtype: Optional[str] = None + + @classmethod + def from_var(cls, var) -> "OpResult": + """Mirror an existing csevar's type info (vec_size + dtype + mlir_dtype).""" + assert isinstance(var, MLIRCSEVariable), \ + f"OpResult.from_var expects MLIRCSEVariable, got {type(var).__name__}" + return cls(vec_size=var.vec_size, dtype=var.dtype, + mlir_dtype=var._mlir_dtype_override) + + @classmethod + def from_mlir(cls, vec_size: int, mlir_str: str) -> "OpResult": + """Build from an MLIR dtype string. ``"i8"`` maps to ``torch.int8`` + (legacy default; bool/uint8 callers should construct OpResult + directly with the correct torch dtype). ``"i1"`` returns dtype= + torch.bool + ``mlir_dtype="i1"`` override (predicate semantics). + ``"index"`` maps to :data:`INDEX_DTYPE`. + """ + dtype = MLIR_TO_DTYPE.get(mlir_str) + # For "i1" the MLIR text type must be preserved as an override, + # because DTYPE_TO_MLIR[torch.bool] = "i8" (storage default). + override = "i1" if mlir_str == "i1" else None + return cls(vec_size=vec_size, dtype=dtype, mlir_dtype=override) + + +class MLIRCSE(common.CSE): + """``common.CSE`` that allocates :class:`MLIRCSEVariable` directly and + plumbs a ``vec_size`` axis. + + ``newvar`` / ``namedvar`` construct :class:`MLIRCSEVariable` themselves + (bypassing the Inductor ``V.kernel.create_cse_var`` hook), so the + kernel doesn't need to override ``create_cse_var``. ``generate`` only + needs to thread ``vec_size`` through to the ``newvar`` call, which it + does via an instance attribute around a ``super().generate()`` call. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._pending_vec_size = 1 + self._pending_mlir_dtype: Optional[str] = None + + def newvar(self, bounds=ValueRanges.unknown(), dtype=None): + var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}" + var = MLIRCSEVariable(var_name, bounds, dtype=dtype, + vec_size=self._pending_vec_size, + mlir_dtype=self._pending_mlir_dtype) + self.varname_map[var_name] = var + return var + + def namedvar(self, name, bounds=ValueRanges.unknown(), dtype=None, *, + vec_size=1, mlir_dtype=None): + assert name not in self.varname_map, f"duplicate name: {name}" + var = MLIRCSEVariable(name, bounds, dtype=dtype, vec_size=vec_size, + mlir_dtype=mlir_dtype) + self.varname_map[name] = var + return var + + def generate(self, buffer, expr, *, vec_size=1, mlir_dtype=None, **kwargs): + self._pending_vec_size = vec_size + self._pending_mlir_dtype = mlir_dtype + try: + return super().generate(buffer, expr, **kwargs) + finally: + self._pending_vec_size = 1 + self._pending_mlir_dtype = None + DTYPE_LOWP_FP = [ torch.bfloat16, @@ -637,12 +789,12 @@ def __init__(self, kernel_group, reason=None): # Code buffer self.vector_compute = IndentedBuffer() self.reductions_suffix = IndentedBuffer() - self.cse = common.CSE(self.newvar_prefix, self.suffix) - # MLIR SSA tracker - self.var_info = {} # MLIR variable info + self.cse = MLIRCSE(self.newvar_prefix, self.suffix) + # All MLIR type/size info now lives on MLIRCSEVariable / MaskCSEVariable + # attributes (vec_size, mlir_dtype, dtype). No parallel var_info dict. self.buffer_types : dict = None # format: dtype, numel, size, stride # Create compute idx - self.compute_idx = self.register_var_cse("compute_idx", 1, "index") + self.compute_idx = self.cse.namedvar("compute_idx", dtype=INDEX_DTYPE) self.compute_body_loop = LoopLevel(self.compute_idx, 1) self.prologue_compute_body_loop = LoopLevel(self.compute_idx, 1) self.recodegen = reason # spad overflow, tile size, vlane stride @@ -668,7 +820,7 @@ def set_ranges(self, lengths, reduction_lengths, index_names=None): assert len(index_names) == len(self.ranges), f"Index names length mismatch: {len(index_names)} != {len(self.ranges)}" self.itervars = [sympy.Symbol(str(n)) for n in index_names] - self.itervar_cses = {str(index) : self.register_var_cse(str(index), 1, "index") for index in self.itervars} + self.itervar_cses = {str(index) : self.cse.namedvar(str(index), dtype=INDEX_DTYPE) for index in self.itervars} self.reduction_depth = len(lengths) return ( self.itervars[: self.reduction_depth], @@ -889,14 +1041,6 @@ def is_scalar(self, name): def roundup_vectorlane(self, size, amp=1): return ((size + self.vector_lane - 1) // self.vector_lane) * self.vector_lane * amp - def register_var_cse(self, name, size, dtype): - var = self.create_cse_var(name, ValueRanges.unknown()) - self.register_var_info(var, [size, dtype]) - return var - - def register_var_info(self, var, var_info): - self.var_info[var] = var_info - def rename_indexing(self, index) -> sympy.Expr: # adds the necessary kernel args for index expressions # and renames variables in index expressions to kernel arg names @@ -951,22 +1095,29 @@ class CSEProxy: @staticmethod def __getattr__(name: str) -> Callable[..., common.CSEVariable]: # type: ignore[misc] def inner(*args, **kwargs): - code, ret_info = getattr(parent_handler, name)(*args, **kwargs) + code, ret = getattr(parent_handler, name)(*args, **kwargs) target_buffer = self.target_buffer_override.get() target_cse = self.target_cse_override.get() if isinstance(code, common.DeferredLine): target_buffer.writeline(code) return None - else: - csevar = target_cse.generate( - target_buffer, - code, - bounds=ValueRanges.unknown(), - assignment=(ret_info[0] is not None) + if ret is None: + # void op (e.g. store) + target_cse.generate( + target_buffer, code, + bounds=ValueRanges.unknown(), assignment=False, ) - if ret_info[0] is not None: - self.register_var_info(csevar, ret_info) - csevar.update_on_args(name, args, kwargs) + return None + assert isinstance(ret, OpResult), ( + f"op {name!r} must return (code, OpResult|None); got {type(ret).__name__}" + ) + csevar = target_cse.generate( + target_buffer, code, + bounds=ValueRanges.unknown(), assignment=True, + dtype=ret.dtype, vec_size=ret.vec_size, + mlir_dtype=ret.mlir_dtype, + ) + csevar.update_on_args(name, args, kwargs) return csevar return inner diff --git a/PyTorchSimFrontend/mlir/mlir_ops.py b/PyTorchSimFrontend/mlir/mlir_ops.py index 217129e8..0c041c34 100644 --- a/PyTorchSimFrontend/mlir/mlir_ops.py +++ b/PyTorchSimFrontend/mlir/mlir_ops.py @@ -5,6 +5,7 @@ from torch._inductor.codegen import common from torch._inductor.virtualized import V, _ops as ops from . import mlir_common +from .mlir_common import OpResult, MLIRCSEVariable warnings.filterwarnings('ignore', message='undefined OpHandler\\..*, please add missing op schema') @@ -70,11 +71,10 @@ def constant(value, src_type, *args, **kwargs): value = format(float(value), ".20f") elif src_type[0] == "i": value = int(float(value)) - return format_mlir_op(f'arith.constant {value}', src_type, **kwargs), [1, src_type] - + return format_mlir_op(f'arith.constant {value}', src_type, **kwargs), OpResult.from_mlir(1, src_type) @staticmethod def broadcast(operand, target_size, *args, **kwargs): - src_size, dtype = V.kernel.var_info[operand] + src_size, dtype = operand.vec_size, operand.mlir_dtype src_shape = f"vector<{src_size}x{dtype}>" if src_size > 1 else dtype dst_shape = f"vector<{target_size}x{dtype}>" @@ -98,11 +98,10 @@ def broadcast(operand, target_size, *args, **kwargs): shape = f"{src_shape} to {dst_shape}" else: raise ValueError(f"Invalid source size: {src_size}") - return format_mlir_op(op_str, shape, **kwargs), [target_size, dtype] - + return format_mlir_op(op_str, shape, **kwargs), OpResult.from_mlir(target_size, dtype) @staticmethod def broadcast_unflat(operand, target_size, *args, **kwargs): - src_size, dtype = V.kernel.var_info[operand] + src_size, dtype = operand.vec_size, operand.mlir_dtype outer_dim = target_size // src_size src_shape = f"vector<{src_size}x{dtype}>" @@ -110,8 +109,7 @@ def broadcast_unflat(operand, target_size, *args, **kwargs): op_str = f"vector.broadcast %{operand}" shape = f"{src_shape} to {dst_shape}" - return format_mlir_op(op_str, shape, **kwargs), [target_size, dtype] - + return format_mlir_op(op_str, shape, **kwargs), OpResult.from_mlir(target_size, dtype) def load_seed(self, *args, **kwargs): raise NotImplementedError @@ -130,32 +128,30 @@ def masked(mask, body, other, *args, tile_size=16, dtype="f32", ninf_declared=Fa result = body() val = ops.constant(other, dtype, *args, **kwargs) result = ops.where(mask, result, val) - return result, V.kernel.var_info[result] - + return result, OpResult.from_var(result) @staticmethod def where(condition, operand1, operand2, *args, **kwargs): tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) - cond_type = V.kernel.var_info[condition] - operand_type = V.kernel.var_info[operand1] + cond_type = (condition.vec_size, condition.mlir_dtype) + operand_type = (operand1.vec_size, operand1.mlir_dtype) condition = ops.to_bool(condition) if cond_type[0] < tile_size: condition = ops.broadcast(condition, tile_size) elif cond_type[0] > tile_size: operand1 = ops.broadcast(operand1, cond_type[0]) operand2 = ops.broadcast(operand2, cond_type[0]) - tile_size, ret_type = V.kernel.var_info[operand1] + tile_size, ret_type = operand1.vec_size, operand1.mlir_dtype shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type cond_shape = f"vector<{tile_size}xi1>" if tile_size > 1 else "" op_str = f"arith.select %{condition}, %{operand1}, %{operand2}" shape = f"{cond_shape}, {shape}" - return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] - + return format_mlir_op(op_str, shape, **kwargs), OpResult.from_mlir(tile_size, ret_type) @staticmethod def to_dtype(operand, dst_mlir_dtype, *args, **kwargs): # Extract source information - src_mlir_dtype = V.kernel.var_info[operand][1] - tile_size = V.kernel.var_info[operand][0] + src_mlir_dtype = operand.mlir_dtype + tile_size = operand.vec_size # Normalize destination type (Torch dtype -> MLIR string) if isinstance(dst_mlir_dtype, torch.dtype): @@ -168,13 +164,11 @@ def to_dtype(operand, dst_mlir_dtype, *args, **kwargs): if dst_mlir_dtype == "index": # If source is already index, return as is; otherwise cast if src_mlir_dtype == "index": - return operand, [tile_size, "index"] - return ops.index_cast(operand, "index"), [tile_size, "index"] - + return operand, OpResult(vec_size=tile_size, dtype=mlir_common.INDEX_DTYPE) + return ops.index_cast(operand, "index"), OpResult(vec_size=tile_size, dtype=mlir_common.INDEX_DTYPE) # Early return if types are identical if src_mlir_dtype == dst_mlir_dtype: - return operand, [tile_size, dst_mlir_dtype] - + return operand, OpResult.from_mlir(tile_size, dst_mlir_dtype) dst_bits = mlir_common.MLIR_TO_BIT[dst_mlir_dtype] src_bits = mlir_common.MLIR_TO_BIT[src_mlir_dtype] shape = f"vector<{tile_size}x{dst_mlir_dtype}>" if tile_size > 1 else dst_mlir_dtype @@ -198,7 +192,7 @@ def to_dtype(operand, dst_mlir_dtype, *args, **kwargs): # Use arith.trunci for integer truncation op_str = f"arith.trunci %{operand} : {src_shape} to {shape}" else: - return operand, [tile_size, dst_mlir_dtype] + return operand, OpResult.from_mlir(tile_size, dst_mlir_dtype) # Case D: Float -> Float (Extension / Truncation) elif src_type_char == "f" and dst_type_char == "f": if dst_bits > src_bits: @@ -207,20 +201,18 @@ def to_dtype(operand, dst_mlir_dtype, *args, **kwargs): # Corrected 'trunf' to 'truncf' op_str = f"arith.truncf %{operand} : {src_shape} to {shape}" else: - return operand, [tile_size, dst_mlir_dtype] + return operand, OpResult.from_mlir(tile_size, dst_mlir_dtype) else: raise NotImplementedError(f"Unsupported conversion: {src_mlir_dtype} -> {dst_mlir_dtype}") - return op_str, [tile_size, dst_mlir_dtype] - + return op_str, OpResult.from_mlir(tile_size, dst_mlir_dtype) @staticmethod def identity(operand, *args, **kwargs): - operand_info = V.kernel.var_info[operand] - return operand, operand_info + return operand, OpResult.from_var(operand) @staticmethod def to_dtype_bitcast(operand, dtype, *args, **kwargs): - tile_size, current_src_type = V.kernel.var_info[operand] + tile_size, current_src_type = operand.vec_size, operand.mlir_dtype if isinstance(dtype, torch.dtype): dst_mlir_type = mlir_common.DTYPE_TO_MLIR[dtype] @@ -241,16 +233,14 @@ def to_dtype_bitcast(operand, dtype, *args, **kwargs): op_str = f"arith.bitcast %{operand}" shape = f"{src_shape} to {dst_shape}" - return format_mlir_op(op_str, shape, **kwargs), [tile_size, dst_mlir_type] - + return format_mlir_op(op_str, shape, **kwargs), OpResult.from_mlir(tile_size, dst_mlir_type) # Binary element wise operations @staticmethod def binary_elementwise_common(operand1, operand2): - V.kernel.var_info = V.kernel.var_info operand1.bounds = operand1.bounds.unknown() operand2.bounds = operand2.bounds.unknown() - op_type1 = V.kernel.var_info[operand1] - op_type2 = V.kernel.var_info[operand2] + op_type1 = (operand1.vec_size, operand1.mlir_dtype) + op_type2 = (operand2.vec_size, operand2.mlir_dtype) # Tile size check if op_type1[0] != op_type2[0]: # Try to broad cast @@ -258,11 +248,10 @@ def binary_elementwise_common(operand1, operand2): rhs_tile_size, rhs_dtype = op_type2 if lhs_tile_size > rhs_tile_size: operand2 = ops.broadcast(operand2, lhs_tile_size) - op_type2 = V.kernel.var_info[operand2] + op_type2 = (operand2.vec_size, operand2.mlir_dtype) elif lhs_tile_size < rhs_tile_size: operand1 = ops.broadcast(operand1, rhs_tile_size) - op_type1 = V.kernel.var_info[operand1] - + op_type1 = (operand1.vec_size, operand1.mlir_dtype) # Data type check if op_type1[1] != op_type2[1]: if op_type1[1] == "index" or op_type2[1] == "index": @@ -271,34 +260,34 @@ def binary_elementwise_common(operand1, operand2): if op_type2[1][0] == "f": operand1 = ops.index_cast(operand1, "i64") operand1 = ops.to_dtype(operand1, op_type2[1]) - op_type1 = V.kernel.var_info[operand1] + op_type1 = (operand1.vec_size, operand1.mlir_dtype) else: # index -> integer: direct casting operand1 = ops.index_cast(operand1, op_type2[1]) - op_type1 = V.kernel.var_info[operand1] + op_type1 = (operand1.vec_size, operand1.mlir_dtype) if op_type2[1] == "index": # index -> target type: 2-step casting if target is float if op_type1[1][0] == "f": operand2 = ops.index_cast(operand2, "i64") operand2 = ops.to_dtype(operand2, op_type1[1]) - op_type2 = V.kernel.var_info[operand2] + op_type2 = (operand2.vec_size, operand2.mlir_dtype) else: # index -> integer: direct casting operand2 = ops.index_cast(operand2, op_type1[1]) - op_type2 = V.kernel.var_info[operand2] + op_type2 = (operand2.vec_size, operand2.mlir_dtype) elif op_type1[1][0] == "i" and op_type2[1][0] == "f": operand1 = ops.to_dtype(operand1, op_type2[1]) - op_type1 = V.kernel.var_info[operand1] + op_type1 = (operand1.vec_size, operand1.mlir_dtype) elif op_type1[1][0] == "f" and op_type2[1][0] == "i": operand2 = ops.to_dtype(operand2, op_type1[1]) - op_type2 = V.kernel.var_info[operand2] + op_type2 = (operand2.vec_size, operand2.mlir_dtype) elif op_type1[1][0] == op_type2[1][0]: if mlir_common.MLIR_TO_BIT[op_type1[1]] > mlir_common.MLIR_TO_BIT[op_type2[1]]: operand2 = ops.ext(operand2, op_type1[1]) - op_type2 = V.kernel.var_info[operand2] + op_type2 = (operand2.vec_size, operand2.mlir_dtype) elif mlir_common.MLIR_TO_BIT[op_type1[1]] < mlir_common.MLIR_TO_BIT[op_type2[1]]: operand1 = ops.ext(operand1, op_type2[1]) - op_type1 = V.kernel.var_info[operand1] + op_type1 = (operand1.vec_size, operand1.mlir_dtype) else: raise NotImplementedError("Unsupported type converting") @@ -314,25 +303,24 @@ def abs(operand, *args, **kwargs): @staticmethod def exp(operand, *args, **kwargs): # Check scalar - op_type = V.kernel.var_info[operand] + op_type = (operand.vec_size, operand.mlir_dtype) if op_type[0] == 1: operand = ops.broadcast(operand, 4) val = ops.exp(operand) result = ops.extractelement(val, 0) - return result, V.kernel.var_info[result] - op_type = V.kernel.var_info[operand] + return result, OpResult.from_var(result) + op_type = (operand.vec_size, operand.mlir_dtype) tile_size = op_type[0] dtype = op_type[1] shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return format_mlir_op(f'math.exp %{operand}', shape, **kwargs), [tile_size, dtype] - + return format_mlir_op(f'math.exp %{operand}', shape, **kwargs), OpResult.from_mlir(tile_size, dtype) @staticmethod def exp2(operand, *args, **kwargs): # Hands-on part: implement exp2 using math.exp2 # V.kernel.var_info = {operand: [tile_size, dtype]} # Ex) V.kernel.var_info[operand] = [8, "f32"] # - # tile_size, dtype = V.kernel.var_info[operand] + # tile_size, dtype = operand.vec_size, operand.mlir_dtype # if tile_size > 1: # shape = f"vector<{tile_size}x{dtype}>" # else: @@ -342,19 +330,16 @@ def exp2(operand, *args, **kwargs): ln2 = math.log(2) coeff = ops.constant(ln2, "f32") operand = ops.mul(operand, coeff) - return ops.exp(operand), V.kernel.var_info[operand] - + return ops.exp(operand), OpResult.from_var(operand) @staticmethod def expm1(operand, *args, **kwargs): coeff = ops.constant(1.0, "f32") operand = ops.exp(operand) operand = ops.sub(operand, coeff) - return operand, V.kernel.var_info[operand] - + return operand, OpResult.from_var(operand) @staticmethod def sqrt(operand, *args, **kwargs): - op_type = V.kernel.var_info[operand] - + op_type = (operand.vec_size, operand.mlir_dtype) tile_size = op_type[0] dtype = op_type[1] @@ -363,14 +348,12 @@ def sqrt(operand, *args, **kwargs): operand = ops.to_dtype(operand, "f32") shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return format_mlir_op(f'math.sqrt %{operand}', shape, **kwargs), [tile_size, dtype] - + return format_mlir_op(f'math.sqrt %{operand}', shape, **kwargs), OpResult.from_mlir(tile_size, dtype) @staticmethod def relu(operand, *args, **kwargs): - src_mlir_dtype = V.kernel.var_info[operand][1] - tile_size = V.kernel.var_info[operand][0] - return ops.maximum(operand, ops.constant(0, src_mlir_dtype)), [tile_size, src_mlir_dtype] - + src_mlir_dtype = operand.mlir_dtype + tile_size = operand.vec_size + return ops.maximum(operand, ops.constant(0, src_mlir_dtype)), OpResult.from_mlir(tile_size, src_mlir_dtype) @staticmethod def minimum(operand1, operand2, *args, **kwargs): tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) @@ -380,8 +363,7 @@ def minimum(operand1, operand2, *args, **kwargs): else: opcode = f'arith.minsi' op_str = f'{opcode} %{operand1}, %{operand2}' - return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] - + return format_mlir_op(op_str, shape, **kwargs), OpResult.from_mlir(tile_size, ret_type) @staticmethod def maximum(operand1, operand2, *args, **kwargs): tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) @@ -391,20 +373,18 @@ def maximum(operand1, operand2, *args, **kwargs): else: opcode = f'arith.maxsi' op_str = f'{opcode} %{operand1}, %{operand2}' - return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] - + return format_mlir_op(op_str, shape, **kwargs), OpResult.from_mlir(tile_size, ret_type) @staticmethod def cos(operand, *args, **kwargs): - op_type = V.kernel.var_info[operand] - + op_type = (operand.vec_size, operand.mlir_dtype) # Check scalar - op_type = V.kernel.var_info[operand] + op_type = (operand.vec_size, operand.mlir_dtype) if op_type[0] == 1: operand = ops.broadcast(operand, 4) val = ops.cos(operand) result = ops.extractelement(val, 0) - return result, V.kernel.var_info[result] - op_type = V.kernel.var_info[operand] + return result, OpResult.from_var(result) + op_type = (operand.vec_size, operand.mlir_dtype) tile_size = op_type[0] dtype = op_type[1] @@ -412,20 +392,18 @@ def cos(operand, *args, **kwargs): if dtype.startswith("f"): operand = ops.to_dtype(operand, "f32") shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return format_mlir_op(f'math.cos %{operand}', shape, **kwargs), [tile_size, dtype] - + return format_mlir_op(f'math.cos %{operand}', shape, **kwargs), OpResult.from_mlir(tile_size, dtype) @staticmethod def sin(operand, *args, **kwargs): - op_type = V.kernel.var_info[operand] - + op_type = (operand.vec_size, operand.mlir_dtype) # Check scalar - op_type = V.kernel.var_info[operand] + op_type = (operand.vec_size, operand.mlir_dtype) if op_type[0] == 1: operand = ops.broadcast(operand, 4) val = ops.sin(operand) result = ops.extractelement(val, 0) - return result, V.kernel.var_info[result] - op_type = V.kernel.var_info[operand] + return result, OpResult.from_var(result) + op_type = (operand.vec_size, operand.mlir_dtype) tile_size = op_type[0] dtype = op_type[1] @@ -433,15 +411,13 @@ def sin(operand, *args, **kwargs): if dtype.startswith("f"): operand = ops.to_dtype(operand, "f32") shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return format_mlir_op(f'math.sin %{operand}', shape, **kwargs), [tile_size, dtype] - + return format_mlir_op(f'math.sin %{operand}', shape, **kwargs), OpResult.from_mlir(tile_size, dtype) @staticmethod def tan(operand, *args, **kwargs): sin_res = ops.sin(operand) cos_res = ops.cos(operand) operand = ops.truediv(sin_res, cos_res) - return operand, V.kernel.var_info[operand] - + return operand, OpResult.from_var(operand) @staticmethod def lgamma(operand, *args, **kwargs): raise NotImplementedError @@ -449,18 +425,17 @@ def lgamma(operand, *args, **kwargs): @staticmethod def erf(operand, *args, **kwargs): # Check scalar - op_type = V.kernel.var_info[operand] + op_type = (operand.vec_size, operand.mlir_dtype) if op_type[0] == 1: operand = ops.broadcast(operand, 4) val = ops.erf(operand) result = ops.extractelement(val, 0) - return result, V.kernel.var_info[result] - op_type = V.kernel.var_info[operand] + return result, OpResult.from_var(result) + op_type = (operand.vec_size, operand.mlir_dtype) tile_size = op_type[0] dtype = op_type[1] shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return format_mlir_op(f'math.erf %{operand}', shape, **kwargs), [tile_size, dtype] - + return format_mlir_op(f'math.erf %{operand}', shape, **kwargs), OpResult.from_mlir(tile_size, dtype) @staticmethod def cosh(operand, *args, **kwargs): raise NotImplementedError @@ -471,16 +446,15 @@ def sinh(operand, *args, **kwargs): @staticmethod def tanh(operand, *args, **kwargs): - op_type = V.kernel.var_info[operand] - + op_type = (operand.vec_size, operand.mlir_dtype) # Check scalar - op_type = V.kernel.var_info[operand] + op_type = (operand.vec_size, operand.mlir_dtype) if op_type[0] == 1: operand = ops.broadcast(operand, 4) val = ops.tanh(operand) result = ops.extractelement(val, 0) - return result, V.kernel.var_info[result] - op_type = V.kernel.var_info[operand] + return result, OpResult.from_var(result) + op_type = (operand.vec_size, operand.mlir_dtype) tile_size = op_type[0] dtype = op_type[1] @@ -488,8 +462,7 @@ def tanh(operand, *args, **kwargs): if dtype.startswith("f"): operand = ops.to_dtype(operand, "f32") shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return format_mlir_op(f'math.tanh %{operand}', shape, **kwargs), [tile_size, dtype] - + return format_mlir_op(f'math.tanh %{operand}', shape, **kwargs), OpResult.from_mlir(tile_size, dtype) @staticmethod def acos(operand, *args, **kwargs): raise NotImplementedError @@ -542,28 +515,26 @@ def hypot(operand1, operand2, *args, **kwargs): def log10(operand, *args, **kwargs): val_ln = ops.log(operand) - tile_size, dtype = V.kernel.var_info[val_ln] + tile_size, dtype = val_ln.vec_size, val_ln.mlir_dtype inv_ln10 = 1/math.log(10) const_op = ops.constant(inv_ln10, dtype) # Multiply: ln(x) * (1/ln(10)) result = ops.mul(val_ln, const_op) - return result, V.kernel.var_info[result] - + return result, OpResult.from_var(result) @staticmethod def log2(operand, *args, **kwargs): val_ln = ops.log(operand) - tile_size, dtype = V.kernel.var_info[val_ln] + tile_size, dtype = val_ln.vec_size, val_ln.mlir_dtype inv_ln10 = 1/math.log(2) const_op = ops.constant(inv_ln10, dtype) # Multiply: ln(x) * (1/ln(10)) result = ops.mul(val_ln, const_op) - return result, V.kernel.var_info[result] - + return result, OpResult.from_var(result) @staticmethod def log(operand, *args, **kwargs): - op_type = V.kernel.var_info[operand] + op_type = (operand.vec_size, operand.mlir_dtype) tile_size = op_type[0] dtype = op_type[1] @@ -572,92 +543,82 @@ def log(operand, *args, **kwargs): operand = ops.to_dtype(operand, "f32") shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return format_mlir_op(f'math.log %{operand}', shape, **kwargs), [tile_size, dtype] - + return format_mlir_op(f'math.log %{operand}', shape, **kwargs), OpResult.from_mlir(tile_size, dtype) @staticmethod def log1p(operand, *args, **kwargs): - tile_size, dtype = V.kernel.var_info[operand] + tile_size, dtype = operand.vec_size, operand.mlir_dtype const_one = ops.constant(1, dtype) val_add = ops.add(operand, const_one) result = ops.log(val_add) - return result, V.kernel.var_info[result] - + return result, OpResult.from_var(result) @staticmethod def nextafter(operand1, operand2, *args, **kwargs): raise NotImplementedError @staticmethod def logical_and(operand1, operand2, *args, **kwargs): - if V.kernel.var_info[operand1][1] != "i1": + if operand1.dtype != torch.bool: operand1 = ops.to_bool(operand1) - if V.kernel.var_info[operand2][1] != "i1": + if operand2.dtype != torch.bool: operand2 = ops.to_bool(operand2) result = ops.and_(operand1, operand2) - return result, V.kernel.var_info[result] - + return result, OpResult.from_var(result) @staticmethod def logical_or(operand1, operand2, *args, **kwargs): - if V.kernel.var_info[operand1][1] != "i1": + if operand1.dtype != torch.bool: operand1 = ops.to_bool(operand1) - if V.kernel.var_info[operand2][1] != "i1": + if operand2.dtype != torch.bool: operand2 = ops.to_bool(operand2) result = ops.or_(operand1, operand2) - return result, V.kernel.var_info[result] - + return result, OpResult.from_var(result) @staticmethod def logical_xor(operand1, operand2, *args, **kwargs): - if V.kernel.var_info[operand1][1] != "i1": + if operand1.dtype != torch.bool: operand1 = ops.to_bool(operand1) - if V.kernel.var_info[operand2][1] != "i1": + if operand2.dtype != torch.bool: operand2 = ops.to_bool(operand2) result = ops.xor(operand1, operand2) - return result, V.kernel.var_info[result] - + return result, OpResult.from_var(result) @staticmethod def logical_not(operand, *args, **kwargs): - op_info = V.kernel.var_info[operand] + op_info = (operand.vec_size, operand.mlir_dtype) tile_size = op_info[0] dtype = op_info[1] zero_const = ops.constant(0, dtype) result = ops.eq(operand, zero_const) - return result, V.kernel.var_info[result] - + return result, OpResult.from_var(result) @staticmethod def bitwise_and(operand1, operand2, *args, **kwargs): # Float check - if V.kernel.var_info[operand1][1].startswith("f") or V.kernel.var_info[operand2][1].startswith("f"): + if operand1.mlir_dtype.startswith("f") or operand2.mlir_dtype.startswith("f"): raise ValueError("Bitwise AND not supported for floats") result = ops.and_(operand1, operand2) - return result, V.kernel.var_info[result] - + return result, OpResult.from_var(result) @staticmethod def bitwise_not(operand, *args, **kwargs): - tile_size, dtype = V.kernel.var_info[operand] + tile_size, dtype = operand.vec_size, operand.mlir_dtype # Float check - if V.kernel.var_info[operand][1].startswith("f"): + if operand.mlir_dtype.startswith("f"): raise ValueError("Bitwise NOT not supported for floats") neg_one = ops.constant(-1, dtype) result = ops.xor(operand, neg_one) - return result, V.kernel.var_info[result] - + return result, OpResult.from_var(result) @staticmethod def bitwise_or(operand1, operand2, *args, **kwargs): # Float check - if V.kernel.var_info[operand1][1].startswith("f") or V.kernel.var_info[operand2][1].startswith("f"): + if operand1.mlir_dtype.startswith("f") or operand2.mlir_dtype.startswith("f"): raise ValueError("Bitwise AND not supported for floats") result = ops.or_(operand1, operand2) - return result, V.kernel.var_info[result] - + return result, OpResult.from_var(result) @staticmethod def bitwise_xor(operand1, operand2, *args, **kwargs): # Float check - if V.kernel.var_info[operand1][1].startswith("f") or V.kernel.var_info[operand2][1].startswith("f"): + if operand1.mlir_dtype.startswith("f") or operand2.mlir_dtype.startswith("f"): raise ValueError("Bitwise AND not supported for floats") result = ops.xor(operand1, operand2) - return result, V.kernel.var_info[result] - + return result, OpResult.from_var(result) @staticmethod def bitwise_left_shift(operand1, operand2, *args, **kwargs): raise NotImplementedError @@ -668,7 +629,7 @@ def bitwise_right_shift(operand1, operand2, *args, **kwargs): @staticmethod def rsqrt(operand, *args, **kwargs): - op_type = V.kernel.var_info[operand] + op_type = (operand.vec_size, operand.mlir_dtype) tile_size = op_type[0] dtype = op_type[1] @@ -677,16 +638,14 @@ def rsqrt(operand, *args, **kwargs): operand = ops.to_dtype(operand, "f32") shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return format_mlir_op(f'math.rsqrt %{operand}', shape, **kwargs), [tile_size, dtype] - + return format_mlir_op(f'math.rsqrt %{operand}', shape, **kwargs), OpResult.from_mlir(tile_size, dtype) @staticmethod def sigmoid(operand, *args, **kwargs): - op_type = V.kernel.var_info[operand] + op_type = (operand.vec_size, operand.mlir_dtype) tile_size = op_type[0] dtype = op_type[1] one = ops.constant(1, dtype) - return ops.truediv(one, ops.add(one, ops.exp(ops.neg(operand)))), [tile_size, dtype] - + return ops.truediv(one, ops.add(one, ops.exp(ops.neg(operand)))), OpResult.from_mlir(tile_size, dtype) @staticmethod def fmod(operand1, operand2, *args, **kwargs): raise NotImplementedError @@ -701,56 +660,52 @@ def isnan(operand, *args, **kwargs): @staticmethod def round(operand, *args, **kwargs): - tile_size, dtype = V.kernel.var_info[operand] + tile_size, dtype = operand.vec_size, operand.mlir_dtype shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype if dtype.startswith("f"): op_str = f"math.roundeven %{operand}" - return format_mlir_op(op_str, shape, **kwargs), [tile_size, dtype] + return format_mlir_op(op_str, shape, **kwargs), OpResult.from_mlir(tile_size, dtype) else: - return operand, [tile_size, dtype] - + return operand, OpResult.from_mlir(tile_size, dtype) @staticmethod def floor(operand, *args, **kwargs): - tile_size, dtype = V.kernel.var_info[operand] + tile_size, dtype = operand.vec_size, operand.mlir_dtype shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype if dtype.startswith("f"): op_str = f"math.floor %{operand}" - return format_mlir_op(op_str, shape, **kwargs), [tile_size, dtype] + return format_mlir_op(op_str, shape, **kwargs), OpResult.from_mlir(tile_size, dtype) else: - return operand, [tile_size, dtype] - + return operand, OpResult.from_mlir(tile_size, dtype) @staticmethod def sign(operand, *args, **kwargs): raise NotImplementedError @staticmethod def trunc(operand, *args, **kwargs): - tile_size, dtype = V.kernel.var_info[operand] + tile_size, dtype = operand.vec_size, operand.mlir_dtype shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype if dtype.startswith("f"): op_str = f"math.trunc %{operand}" - return format_mlir_op(op_str, shape, **kwargs), [tile_size, dtype] + return format_mlir_op(op_str, shape, **kwargs), OpResult.from_mlir(tile_size, dtype) else: - return operand, [tile_size, dtype] - + return operand, OpResult.from_mlir(tile_size, dtype) @staticmethod def ceil(operand, *args, **kwargs): - tile_size, dtype = V.kernel.var_info[operand] + tile_size, dtype = operand.vec_size, operand.mlir_dtype shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype if dtype.startswith("f"): op_str = f"math.ceil %{operand}" - return format_mlir_op(op_str, shape, **kwargs), [tile_size, dtype] + return format_mlir_op(op_str, shape, **kwargs), OpResult.from_mlir(tile_size, dtype) else: - return operand, [tile_size, dtype] - + return operand, OpResult.from_mlir(tile_size, dtype) # Logical operations @staticmethod def neg(operand, *args, **kwargs): - op_type = V.kernel.var_info[operand] + op_type = (operand.vec_size, operand.mlir_dtype) tile_size = op_type[0] dtype = op_type[1] @@ -759,19 +714,17 @@ def neg(operand, *args, **kwargs): operand = ops.to_dtype(operand, "f32") op_str = f"arith.negf %{operand}" shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return format_mlir_op(op_str, shape, **kwargs), [tile_size, dtype] - + return format_mlir_op(op_str, shape, **kwargs), OpResult.from_mlir(tile_size, dtype) @staticmethod def reciprocal(operand, *args, **kwargs): - op_type = V.kernel.var_info[operand] + op_type = (operand.vec_size, operand.mlir_dtype) tile_size, dtype = op_type[0], op_type[1] if dtype.startswith("i"): openand = ops.to_dtype(operand, "f32") - op_type = V.kernel.var_info[operand] + op_type = (operand.vec_size, operand.mlir_dtype) tile_size, dtype = op_type[0], op_type[1] - return ops.truediv(ops.constant(1.0, dtype), operand), [tile_size, dtype] - + return ops.truediv(ops.constant(1.0, dtype), operand), OpResult.from_mlir(tile_size, dtype) @staticmethod def eq(operand1, operand2, *args, **kwargs): tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) @@ -786,8 +739,7 @@ def eq(operand1, operand2, *args, **kwargs): op_str = f'{op_type} {attribute}, %{operand1}, %{operand2}' shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - return format_mlir_op(op_str, shape, **kwargs), [tile_size, "i1"] - + return format_mlir_op(op_str, shape, **kwargs), OpResult(vec_size=tile_size, dtype=torch.bool, mlir_dtype="i1") @staticmethod def ne(operand1, operand2, *args, **kwargs): tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) @@ -802,8 +754,7 @@ def ne(operand1, operand2, *args, **kwargs): op_str = f'{op_type} {attribute}, %{operand1}, %{operand2}' shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - return format_mlir_op(op_str, shape, **kwargs), [tile_size, "i1"] - + return format_mlir_op(op_str, shape, **kwargs), OpResult(vec_size=tile_size, dtype=torch.bool, mlir_dtype="i1") @staticmethod def lt(operand1, operand2, *args, **kwargs): tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) @@ -818,8 +769,7 @@ def lt(operand1, operand2, *args, **kwargs): op_str = f'{op_type} {attribute}, %{operand1}, %{operand2}' shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - return format_mlir_op(op_str, shape, **kwargs), [tile_size, "i1"] - + return format_mlir_op(op_str, shape, **kwargs), OpResult(vec_size=tile_size, dtype=torch.bool, mlir_dtype="i1") @staticmethod def gt(operand1, operand2, *args, **kwargs): tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) @@ -834,8 +784,7 @@ def gt(operand1, operand2, *args, **kwargs): op_str = f'{op_type} {attribute}, %{operand1}, %{operand2}' shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - return format_mlir_op(op_str, shape, **kwargs), [tile_size, "i1"] - + return format_mlir_op(op_str, shape, **kwargs), OpResult(vec_size=tile_size, dtype=torch.bool, mlir_dtype="i1") @staticmethod def le(operand1, operand2, *args, **kwargs): tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) @@ -850,8 +799,7 @@ def le(operand1, operand2, *args, **kwargs): op_str = f'{op_type} {attribute}, %{operand1}, %{operand2}' shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - return format_mlir_op(op_str, shape, **kwargs), [tile_size, "i1"] - + return format_mlir_op(op_str, shape, **kwargs), OpResult(vec_size=tile_size, dtype=torch.bool, mlir_dtype="i1") @staticmethod def ge(operand1, operand2, *args, **kwargs): tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) @@ -866,32 +814,28 @@ def ge(operand1, operand2, *args, **kwargs): op_str = f'{op_type} {attribute}, %{operand1}, %{operand2}' shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - return format_mlir_op(op_str, shape, **kwargs), [tile_size, "i1"] - + return format_mlir_op(op_str, shape, **kwargs), OpResult(vec_size=tile_size, dtype=torch.bool, mlir_dtype="i1") @staticmethod def add(operand1, operand2, *args, **kwargs): tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type opcode = f'arith.add{ret_type[0]}' op_str = f'{opcode} %{operand1}, %{operand2}' - return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] - + return format_mlir_op(op_str, shape, **kwargs), OpResult.from_mlir(tile_size, ret_type) @staticmethod def sub(operand1, operand2, *args, **kwargs): tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type opcode = f'arith.sub{ret_type[0]}' op_str = f'{opcode} %{operand1}, %{operand2}' - return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] - + return format_mlir_op(op_str, shape, **kwargs), OpResult.from_mlir(tile_size, ret_type) @staticmethod def mul(operand1, operand2, *args, **kwargs): tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type opcode = f'arith.mul{ret_type[0]}' op_str = f'{opcode} %{operand1}, %{operand2}' - return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] - + return format_mlir_op(op_str, shape, **kwargs), OpResult.from_mlir(tile_size, ret_type) @staticmethod def pow(operand1, operand2, *args, **kwargs): tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) @@ -905,32 +849,28 @@ def pow(operand1, operand2, *args, **kwargs): shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type op_str = f"math.pow{ret_type[0]} %{operand1}, %{operand2}" - return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] - + return format_mlir_op(op_str, shape, **kwargs), OpResult.from_mlir(tile_size, ret_type) @staticmethod def and_(operand1, operand2, *args, **kwargs): tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type op_str = f'arith.andi %{operand1}, %{operand2}' - return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] - + return format_mlir_op(op_str, shape, **kwargs), OpResult.from_mlir(tile_size, ret_type) @staticmethod def or_(operand1, operand2, *args, **kwargs): tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type op_str = f'arith.ori %{operand1}, %{operand2}' - return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] - + return format_mlir_op(op_str, shape, **kwargs), OpResult.from_mlir(tile_size, ret_type) @staticmethod def xor(operand1, operand2, *args, **kwargs): tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type op_str = f'arith.xori %{operand1}, %{operand2}' - return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] - + return format_mlir_op(op_str, shape, **kwargs), OpResult.from_mlir(tile_size, ret_type) @staticmethod def lshift(operand1, operand2, *args, **kwargs): raise NotImplementedError @@ -949,8 +889,7 @@ def truncdiv(operand1, operand2, *args, **kwargs): # arith.divsi: Signed Integer Division (Result is truncated) op_str = f'arith.divsi %{operand1}, %{operand2}' - return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] - + return format_mlir_op(op_str, shape, **kwargs), OpResult.from_mlir(tile_size, ret_type) @staticmethod def floordiv(operand1, operand2, *args, **kwargs): tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) @@ -962,8 +901,7 @@ def floordiv(operand1, operand2, *args, **kwargs): # arith.floordivsi: Floor Division for Signed Integers op_str = f'arith.floordivsi %{operand1}, %{operand2}' - return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] - + return format_mlir_op(op_str, shape, **kwargs), OpResult.from_mlir(tile_size, ret_type) @staticmethod def truediv(operand1, operand2, *args, **kwargs): tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) @@ -973,8 +911,7 @@ def truediv(operand1, operand2, *args, **kwargs): raise ValueError(f"truediv expects float inputs, but got {ret_type}. Use int_truediv for integers.") op_str = f'arith.divf %{operand1}, %{operand2}' - return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] - + return format_mlir_op(op_str, shape, **kwargs), OpResult.from_mlir(tile_size, ret_type) @staticmethod def int_truediv(operand1, operand2, *args, **kwargs): """ @@ -989,8 +926,7 @@ def int_truediv(operand1, operand2, *args, **kwargs): src_type = target_float_type result = ops.truediv(operand1, operand2) - return result, V.kernel.var_info[result] - + return result, OpResult.from_var(result) @staticmethod def mod(operand1, operand2, *args, **kwargs): tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) @@ -1000,8 +936,7 @@ def mod(operand1, operand2, *args, **kwargs): else: opcode = f'arith.remsi' op_str = f'{opcode} %{operand1}, %{operand2}' - return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] - + return format_mlir_op(op_str, shape, **kwargs), OpResult.from_mlir(tile_size, ret_type) @staticmethod def remainder(operand1, operand2, *args, **kwargs): tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) @@ -1013,39 +948,34 @@ def remainder(operand1, operand2, *args, **kwargs): opcode = 'arith.remsi' # Signed Integer Remainder (LHS sign) op_str = f'{opcode} %{operand1}, %{operand2}' - return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] - + return format_mlir_op(op_str, shape, **kwargs), OpResult.from_mlir(tile_size, ret_type) @staticmethod def square(operand, *args, **kwargs): result = ops.mul(operand, operand) - return result, V.kernel.var_info[result] - + return result, OpResult.from_var(result) @staticmethod def fma(operand1, operand2, operand3, *args, **kwargs): result = ops.mul(operand1, operand2) result = ops.add(result, operand3) - return result, V.kernel.var_info[result] - + return result, OpResult.from_var(result) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # PyTorchSim specific operations @staticmethod def alloc(size, src_type, *args, **kwargs): - return f"memref.alloc() : memref<{size}x{src_type}>", [size, src_type] - + return f"memref.alloc() : memref<{size}x{src_type}>", OpResult.from_mlir(size, src_type) @staticmethod def extractelement(operand, idx, *args, **kwargs): - op_type = V.kernel.var_info[operand] + op_type = (operand.vec_size, operand.mlir_dtype) tile_size = op_type[0] dtype = op_type[1] shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype op_str = f"vector.extract %{operand}[{idx}]" shape = f"{dtype} from {shape}" - return format_mlir_op(op_str, shape, **kwargs), [1, dtype] - + return format_mlir_op(op_str, shape, **kwargs), OpResult.from_mlir(1, dtype) @staticmethod def ext(operand, dtype, *args, **kwargs): - op_type = V.kernel.var_info[operand] + op_type = (operand.vec_size, operand.mlir_dtype) shape = f"vector<{op_type[0]}x{op_type[1]}>" if op_type[0] > 1 else f"{op_type[1]}" target_type = f"vector<{op_type[0]}x{dtype}>" if op_type[0] > 1 else f"{dtype}" if dtype[0] == "f": @@ -1054,44 +984,43 @@ def ext(operand, dtype, *args, **kwargs): opcode = f'arith.extui' op_str = f'{opcode} %{operand}' shape = f"{shape} to {target_type}" - return format_mlir_op(op_str, shape, **kwargs), [op_type[0], dtype] + return format_mlir_op(op_str, shape, **kwargs), OpResult.from_mlir(op_type[0], dtype) @staticmethod def to_bool(operand, *args, **kwargs): - tile_size, ret_type = V.kernel.var_info[operand] - if ret_type == "i1": - return operand, [tile_size, ret_type] - + """Convert to torch.bool / MLIR i1 via ``ne 0``. Idempotent.""" + if operand.dtype == torch.bool: + return operand, OpResult(vec_size=operand.vec_size, dtype=torch.bool, mlir_dtype="i1") + tile_size, ret_type = operand.vec_size, operand.mlir_dtype const_zero = ops.constant(0, ret_type) if tile_size > 1: const_zero = ops.broadcast(const_zero, tile_size) ret = ops.ne(operand, const_zero) - return ret, [tile_size, "i1"] + return ret, OpResult(vec_size=tile_size, dtype=torch.bool, mlir_dtype="i1") + @staticmethod def step(size, dtype, *args, **kwargs): index_shape = f"vector<{size}x{dtype}>" op_str = f"vector.step" - return format_mlir_op(op_str, index_shape, **kwargs), [size, dtype] - + return format_mlir_op(op_str, index_shape, **kwargs), OpResult.from_mlir(size, dtype) @staticmethod def index_cast(operand, target_type, *args, **kwargs): - op_type = V.kernel.var_info[operand] + op_type = (operand.vec_size, operand.mlir_dtype) src_shape = f"vector<{op_type[0]}x{op_type[1]}>" if op_type[0] > 1 else op_type[1] des_shape = f"vector<{op_type[0]}x{target_type}>" if op_type[0] > 1 else target_type op_str = f"arith.index_cast %{operand}" shape = f"{src_shape} to {des_shape}" - return format_mlir_op(op_str, shape, **kwargs), [op_type[0], target_type] + return format_mlir_op(op_str, shape, **kwargs), OpResult.from_mlir(op_type[0], target_type) @staticmethod def shape_cast(operand, src_shape, dst_shape, *args, **kwargs): - operand_type = V.kernel.var_info[operand] op_str = f"vector.shape_cast %{operand}" shape = f"{src_shape} to {dst_shape}" - return format_mlir_op(op_str, shape, **kwargs), operand_type + return format_mlir_op(op_str, shape, **kwargs), OpResult.from_var(operand) @staticmethod def extract_strided_slice(operand, target_size, offsets=None, sizes=None, strides=None, *args, **kwargs): - op_type = V.kernel.var_info[operand] + op_type = (operand.vec_size, operand.mlir_dtype) src_size = op_type[0] dtype = op_type[1] @@ -1131,16 +1060,14 @@ def extract_strided_slice(operand, target_size, offsets=None, sizes=None, stride # Pass merged attributes to format_mlir_op updated_kwargs = {**kwargs, 'attributes': merged_attributes} - return format_mlir_op(op_str, shape, **updated_kwargs), [target_size, dtype] - + return format_mlir_op(op_str, shape, **updated_kwargs), OpResult.from_mlir(target_size, dtype) @staticmethod def vlane_offset(operand1, operand2, *args, **kwargs): tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type opcode = f'arith.add{ret_type[0]}' op_str = f'{opcode} %{operand1}, %{operand2}' - return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] - + return format_mlir_op(op_str, shape, **kwargs), OpResult.from_mlir(tile_size, ret_type) @staticmethod def multi_reduction(acc, init, vec_size, red_size, red_shape, red_type, type_name, *args, **kwargs): if red_size == 1: @@ -1151,14 +1078,13 @@ def multi_reduction(acc, init, vec_size, red_size, red_shape, red_type, type_nam new_vshape= f"vector<{vec_size//red_size}x{red_size}x{type_name}>" value = ops.shape_cast(acc, red_shape, new_vshape) line = reduction_combine_vec(red_type, value, init, axis=0, shape=new_vshape, reduced_shape=final_reduced_shape) - return line, [red_size, type_name] - + return line, OpResult.from_mlir(red_size, type_name) @staticmethod def vector_shuffle(operand, indices, operand2=None, *args, **kwargs): - tile_size1, dtype1 = V.kernel.var_info[operand] + tile_size1, dtype1 = operand.vec_size, operand.mlir_dtype if operand2 is None: operand2 = operand - tile_size2, dtype2 = V.kernel.var_info[operand2] + tile_size2, dtype2 = operand2.vec_size, operand2.mlir_dtype if dtype1 != dtype2: raise ValueError( f"vector_shuffle expects same element type, got {dtype1} and {dtype2}" @@ -1173,14 +1099,12 @@ def vector_shuffle(operand, indices, operand2=None, *args, **kwargs): vt2 = f"vector<{tile_size2}x{dtype1}>" idx_str = ", ".join(str(i) for i in indices) op_str = f"vector.shuffle %{operand}, %{operand2} [{idx_str}]" - return format_mlir_op(op_str, f"{vt1}, {vt2}", **kwargs), [len(indices), dtype1] - + return format_mlir_op(op_str, f"{vt1}, {vt2}", **kwargs), OpResult.from_mlir(len(indices), dtype1) @staticmethod def constant_mask(select_min, N, *args, **kwargs): vals = ", ".join("true" if x else "false" for x in select_min) op_str = f"arith.constant dense<[{vals}]>" - return format_mlir_op(op_str, f"vector<{N}xi1>", **kwargs), [N, "i1"] - + return format_mlir_op(op_str, f"vector<{N}xi1>", **kwargs), OpResult(vec_size=N, dtype=torch.bool, mlir_dtype="i1") @staticmethod def bitonic_sort(operand, descending=False, *args, **kwargs): def _compute_bitonic_stages(N: int, descending: bool): @@ -1217,7 +1141,7 @@ def _compute_bitonic_stages(N: int, descending: bool): size *= 2 return stages - tile_size, _ = V.kernel.var_info[operand] + tile_size, _ = operand.vec_size, operand.mlir_dtype cur = operand for stage in _compute_bitonic_stages(tile_size, descending): mask = ops.constant_mask(stage["select_min"], tile_size) @@ -1225,8 +1149,7 @@ def _compute_bitonic_stages(N: int, descending: bool): vmin = ops.minimum(cur, shuffled) vmax = ops.maximum(cur, shuffled) cur = ops.where(mask, vmin, vmax) - return cur, V.kernel.var_info[cur] - + return cur, OpResult.from_var(cur) @staticmethod def _load(compute_vec_size, mlir_dtype, buffer, indices, buffer_shape, *args, **kwargs): if compute_vec_size == 1: @@ -1239,11 +1162,10 @@ def _load(compute_vec_size, mlir_dtype, buffer, indices, buffer_shape, *args, ** operation = "affine.vector_load" line = f"{operation} %{buffer}[{indices}]" shape = f"{buffer_shape}, {vshape}" - return format_mlir_op(line, shape, **kwargs), [compute_vec_size, mlir_dtype] - + return format_mlir_op(line, shape, **kwargs), OpResult.from_mlir(compute_vec_size, mlir_dtype) @staticmethod def _store(operand, buffer, indices, buffer_shape, *args, buffer_name=None, **kwargs): - compute_vec_size, mlir_dtype = V.kernel.var_info[operand][0], V.kernel.var_info[operand][1] + compute_vec_size, mlir_dtype = operand.vec_size, operand.mlir_dtype if compute_vec_size == 1: vshape = f"{mlir_dtype}" @@ -1258,10 +1180,9 @@ def _store(operand, buffer, indices, buffer_shape, *args, buffer_name=None, **kw line = format_mlir_op(line, shape, **kwargs) if buffer_name is not None: - return common.DeferredLine(buffer_name, line), [None, None] + return common.DeferredLine(buffer_name, line), None else: - return line, [None, None] - + return line, None @staticmethod def affine_apply(map_var, indices, indirect_dims=None, comment=None, *args, **kwargs): # Format indices arguments @@ -1274,8 +1195,7 @@ def affine_apply(map_var, indices, indirect_dims=None, comment=None, *args, **kw op_str += f"[{indirect_str}] {{indirect_access}}" if comment: op_str += f" // {comment}" - return op_str, [1, "index"] - + return op_str, OpResult(vec_size=1, dtype=mlir_common.INDEX_DTYPE) @staticmethod def affine_map(dim_names, expr_str, symbol_names=None, comment=None, *args, **kwargs): # Handle dim_names as list or string @@ -1297,4 +1217,4 @@ def affine_map(dim_names, expr_str, symbol_names=None, comment=None, *args, **kw if comment: map_str += f" // {comment}" - return map_str, [1, "map"] + return map_str, OpResult.from_mlir(1, "map") \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_sort_template.py b/PyTorchSimFrontend/mlir/mlir_sort_template.py index 24b3a460..e5577c2a 100644 --- a/PyTorchSimFrontend/mlir/mlir_sort_template.py +++ b/PyTorchSimFrontend/mlir/mlir_sort_template.py @@ -283,7 +283,7 @@ def render( bitonic_body = mlir_common.ParallelLoopBuffer(initial_indent=2) bitonic_body.tabwidth = 2 # 1) Local SIMD sort per chunk. - init_cse = common.CSE(kernel.newvar_prefix, kernel.suffix, name_prefix="sort_init") + init_cse = mlir_common.MLIRCSE(kernel.newvar_prefix, kernel.suffix, name_prefix="sort_init") with kernel, kernel.override_buffer_cse(buffer=bitonic_body, cse=init_cse): bitonic_body.writelines(LoopLevel("chunk", num_chunks).lines()) with bitonic_body.indent(attribute="{inner_loop=true}"): @@ -295,10 +295,12 @@ def render( "%t_const0, %elem", x_tile_desc.get_mlir_shape(data_stype), ) - idx_step_index = kernel.register_var_cse("idx_step_index", vector_size, "index") + idx_step_index = kernel.cse.namedvar("idx_step_index", dtype=mlir_common.INDEX_DTYPE, vec_size=vector_size) bitonic_body.writeline(f"%{idx_step_index} = vector.step : vector<{vector_size}xindex>") idx_step = ops.index_cast(idx_step_index, idx_stype) - idx_base = kernel.register_var_cse("idx_base", 1, idx_stype) + idx_base = kernel.cse.namedvar( + "idx_base", vec_size=1, dtype=mlir_common.MLIR_TO_DTYPE.get(idx_stype), + ) bitonic_body.writeline(f"%{idx_base} = arith.index_cast %elem : index to {idx_stype}") idx_base_vec = ops.broadcast(idx_base, vector_size) idx_chunk = ops.add(idx_base_vec, idx_step) @@ -328,7 +330,7 @@ def render( if block_start >= num_chunks: continue asc_dir = is_even_block if not self.descending else (not is_even_block) - stage_cse = common.CSE(kernel.newvar_prefix, kernel.suffix, name_prefix=f"sort_stage_{stage}") + stage_cse = mlir_common.MLIRCSE(kernel.newvar_prefix, kernel.suffix, name_prefix=f"sort_stage_{stage}") with kernel, kernel.override_buffer_cse(buffer=bitonic_body, cse=stage_cse): stage_loops = [ LoopLevel("base", num_chunks, start=block_start, step=2 * k), diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index c8fc036f..92ac6d2c 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -14,7 +14,7 @@ from unittest.mock import patch from PyTorchSimFrontend import extension_config -from torch._inductor.codegen.common import KernelTemplate, CSE, DeferredLine +from torch._inductor.codegen.common import KernelTemplate, DeferredLine from torch._inductor.ir import Buffer, IRNode, TemplateBuffer, ChoiceCaller, ir_node_to_tensor from torch._inductor.select_algorithm import PartialRender from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller @@ -46,8 +46,8 @@ def __init__(self, kernel: 'MLIRTemplateKernel', prefix=""): self.dma_loads = IndentedBuffer() self.dma_stores = IndentedBuffer() self.spad_buffer = IndentedBuffer() - self.cse = common.CSE("%", "", name_prefix=f"{prefix}") - self.apply_cse = common.CSE("%", "", name_prefix=f"{prefix}apply") + self.cse = mlir_common.MLIRCSE("%", "", name_prefix=f"{prefix}") + self.apply_cse = mlir_common.MLIRCSE("%", "", name_prefix=f"{prefix}apply") # Original buffers will be saved later in the 'with' block self.original_buffers = {} @@ -118,9 +118,9 @@ def __init__(self, self.render_options = dict() self.tile_size = [] self.loop_size = None - self.map_cse = CSE("#", self.suffix, name_prefix="t_map") - self.const_cse = CSE(self.newvar_prefix, self.suffix, name_prefix="t_const") - self.alloc_cse = CSE(self.newvar_prefix, self.suffix, name_prefix="t_alloc") + self.map_cse = mlir_common.MLIRCSE("#", self.suffix, name_prefix="t_map") + self.const_cse = mlir_common.MLIRCSE(self.newvar_prefix, self.suffix, name_prefix="t_const") + self.alloc_cse = mlir_common.MLIRCSE(self.newvar_prefix, self.suffix, name_prefix="t_alloc") self.prologue_buffer_group = IndentedBufferGroup(self, prefix="prologue_") self.epilogue_buffer_group = IndentedBufferGroup(self, prefix="epilogue_") self.global_vars = IndentedBuffer() @@ -1057,7 +1057,12 @@ def load_epilogue(self, name: str, index: sympy.Expr): with self.override_buffer_cse(buffer=self.loads): out = ops._load(vsize, mlir_dtype, sram_var, compute_index_var, tile_shape) - self.register_var_info(out, [self.compute_body_loop.step, mlir_dtype]) + # `vsize` is the MLIR emission chunk; the logical (loop-tracking) + # size for downstream is the outer compute step. `dtype` overrides + # the proxy's lossy mlir_dtype->torch derivation for the uint8/int8 + # ambiguity (OpResult.from_mlir defaults "i8" to torch.int8). + out.vec_size = self.compute_body_loop.step + out.dtype = dtype return out def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs): @@ -1088,7 +1093,7 @@ def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs): sram_var = self.buffer_names[name] zero_var = self.get_const_cse(0) - _, operand_type = self.var_info[value] + _, operand_type = value.vec_size, value.mlir_dtype if mlir_dtype != operand_type: value = ops.to_dtype(value, mlir_dtype) compute_index_var = ",".join([f"%{zero_var}"] * (self.kernel_group.tile_desc.get_nr_dim()-1) + [f"%{self.compute_idx}"]) @@ -1217,8 +1222,8 @@ def store_reduction_epilogue(self, name, index, value): reduction_type = self.reduction_info[value][0] out = ops.multi_reduction(out, init_vec2, partial_vec_size, new_vec_size, partial_vshape, reduction_type, mlir_dtype) - out2 = self.cse.generate(self.reductions_suffix, f"vector.shuffle %{out}, %{out} [1, 0] : {new_reduced_shape}, {new_reduced_shape}") - self.register_var_info(out2, [new_vec_size, mlir_dtype]) + out2 = self.cse.generate(self.reductions_suffix, f"vector.shuffle %{out}, %{out} [1, 0] : {new_reduced_shape}, {new_reduced_shape}", dtype=dtype) + out2.vec_size = new_vec_size with self.override_buffer_cse(buffer=self.reductions_suffix): out = reduction_partial_combine_vec(self.reduction_info[value][0], out, out2) @@ -1270,7 +1275,7 @@ def set_tile_size(self, template_fusion_info, prologue=False): self.r_tile_size = tile_desc.get_tile_size()[-1] self.r_dim_size = template_fusion_info['r_dim_size'] self.reduction_nr_outer_loop = nr_outer_loop - self.reduction_loop_idx = self.register_var_cse("reduce_loop_idx", 1, "index") + self.reduction_loop_idx = self.cse.namedvar("reduce_loop_idx", dtype=mlir_common.INDEX_DTYPE) self.compute_body_loop.size = r_tile_size self.compute_body_loop.step = tile_desc.get_compute_vec_size() // nr_outer_loop self.reduction_body_loop = mlir_common.LoopLevel(self.reduction_loop_idx, nr_outer_loop)