diff --git a/iron/common/test_utils.py b/iron/common/test_utils.py index 1671491a..219fef9a 100644 --- a/iron/common/test_utils.py +++ b/iron/common/test_utils.py @@ -19,6 +19,17 @@ "i32": torch.int32, } +# Numpy equivalent of torch_dtype_map, for use in AIERuntimeArgSpec and other +# numpy-based interfaces. Derived from torch_dtype_map to stay in sync. +np_dtype_map = { + k: ( + np.dtype(bfloat16) + if v == torch.bfloat16 + else torch.tensor([], dtype=v).numpy().dtype + ) + for k, v in torch_dtype_map.items() +} + # TODO: Consider upstreaming generic buffer utilities to mlir-aie once operator abstractions stabilize. diff --git a/iron/operators/gemm/design.py b/iron/operators/gemm/design.py index a8ed8ad3..ba31e8dd 100644 --- a/iron/operators/gemm/design.py +++ b/iron/operators/gemm/design.py @@ -26,9 +26,11 @@ microkernel_mac_dim_map = { "npu1": { "bf16": (4, 8, 4), + "i8": (8, 8, 8), }, "npu1": { "bf16": (4, 8, 4), + "i8": (8, 8, 8), }, "npu2": { "bf16": { @@ -36,6 +38,7 @@ True: (8, 8, 8), False: (4, 8, 8), }, + "i8": (8, 8, 8), }, } @@ -68,11 +71,13 @@ def main(): default=None, help="Name of the archive file for the AIE kernels", ) - argparser.add_argument("--dtype_in", type=str, choices=["bf16"], default="bf16") + argparser.add_argument( + "--dtype_in", type=str, choices=["bf16", "i8"], default="bf16" + ) argparser.add_argument( "--dtype_out", type=str, - choices=["bf16", "f32"], + choices=["bf16", "f32", "i8", "i16", "i32"], default="bf16", ) argparser.add_argument("--trace_size", type=int, default=0) diff --git a/iron/operators/gemm/op.py b/iron/operators/gemm/op.py index ac391cc3..287575f1 100644 --- a/iron/operators/gemm/op.py +++ b/iron/operators/gemm/op.py @@ -15,6 +15,7 @@ DesignGenerator, ) from iron.common.device_utils import get_kernel_dir +from iron.common.test_utils import np_dtype_map import aie.utils as aie_utils @@ -61,7 +62,9 @@ def __post_init__(self): if self.N % min_N != 0: raise ValueError(f"N ({self.N}) must be a multiple of {min_N}") - if self.emulate_bf16_mmul_with_bfp16: + if self.dtype_in == "i8": + min_tile_m, min_tile_k, min_tile_n = 16, 8, 16 + elif self.emulate_bf16_mmul_with_bfp16: min_tile_m, min_tile_k, min_tile_n = 8, 8, 8 else: min_tile_m, min_tile_k, min_tile_n = 4, 8, 8 @@ -74,9 +77,21 @@ def __post_init__(self): MLIROperator.__init__(self, context=self.context) + @property + def name(self) -> str: + """Include dtype in operator name when not the default bf16, to avoid + xclbin filename collisions between bf16 and int8 GEMM variants with + identical dimensions.""" + base = super().name + if self.dtype_in != "bf16": + base += f"_{self.dtype_in}_{self.dtype_out}" + return base + @property def _kernel_flags_suffix(self): """Suffix encoding compile-time flags that affect the kernel binary.""" + if self.dtype_in == "i8": + return f"_{self.dtype_in}_{self.dtype_out}" return f"_{int(self.prio_accuracy)}_{int(self.emulate_bf16_mmul_with_bfp16)}_{int(self.round_conv_even)}" def get_mlir_artifact(self): @@ -117,14 +132,17 @@ def get_kernel_artifacts(self): f"-DDIM_K={self.tile_k}", f"-DDIM_N={self.tile_n}", ] - if self.prio_accuracy: - kernel_flags.append("-Dbf16_f32_ONLY") + if self.dtype_in == "i8": + kernel_flags.append(f"-D{self.dtype_in}_{self.dtype_out}_ONLY") else: - kernel_flags.append("-Dbf16_bf16_ONLY") - if self.round_conv_even: - kernel_flags.append("-DROUND_CONV_EVEN") - if self.emulate_bf16_mmul_with_bfp16: - kernel_flags.append("-DAIE_API_EMULATE_BFLOAT16_MMUL_WITH_BFP16") + if self.prio_accuracy: + kernel_flags.append("-Dbf16_f32_ONLY") + else: + kernel_flags.append("-Dbf16_bf16_ONLY") + if self.round_conv_even: + kernel_flags.append("-DROUND_CONV_EVEN") + if self.emulate_bf16_mmul_with_bfp16: + kernel_flags.append("-DAIE_API_EMULATE_BFLOAT16_MMUL_WITH_BFP16") if self.b_col_maj: kernel_flags.append("-DB_COL_MAJ") if self.c_col_maj: @@ -150,13 +168,19 @@ def get_kernel_artifacts(self): ] def get_arg_spec(self): + dtype_in_np = np_dtype_map[self.dtype_in] + dtype_out_np = np_dtype_map[self.dtype_out] return [ - AIERuntimeArgSpec("in", (self.M, self.K)), # input A + AIERuntimeArgSpec("in", (self.M, self.K), dtype=dtype_in_np), # input A AIERuntimeArgSpec( - "in", (self.K, self.N) if not self.b_col_maj else (self.N, self.K) + "in", + (self.K, self.N) if not self.b_col_maj else (self.N, self.K), + dtype=dtype_in_np, ), # input B (weights) AIERuntimeArgSpec( - "out", (self.M, self.N) if not self.c_col_maj else (self.N, self.M) + "out", + (self.M, self.N) if not self.c_col_maj else (self.N, self.M), + dtype=dtype_out_np, ), # output C ] diff --git a/iron/operators/gemm/reference.py b/iron/operators/gemm/reference.py index baa183af..73ed2899 100644 --- a/iron/operators/gemm/reference.py +++ b/iron/operators/gemm/reference.py @@ -10,6 +10,7 @@ def generate_golden_reference( K: int, N: int, dtype="bf16", + dtype_out=None, seed=42, b_col_maj=False, c_col_maj=False, @@ -18,9 +19,21 @@ def generate_golden_reference( torch.manual_seed(seed) val_range = 4 dtype_torch = torch_dtype_map[dtype] - input_a = torch.randn(M, K, dtype=dtype_torch) * val_range - input_b_full = torch.rand(K, N, dtype=dtype_torch) * val_range - output_full = torch.matmul(input_a, input_b_full) + if dtype in ("i8",): + input_a = torch.randint(-val_range, val_range + 1, (M, K), dtype=dtype_torch) + input_b_full = torch.randint( + -val_range, val_range + 1, (K, N), dtype=dtype_torch + ) + output_full = torch.matmul( + input_a.to(torch.int32), input_b_full.to(torch.int32) + ) + if dtype_out is not None and dtype_out != "i32": + dtype_out_torch = torch_dtype_map[dtype_out] + output_full = output_full.to(dtype_out_torch) + else: + input_a = torch.randn(M, K, dtype=dtype_torch) * val_range + input_b_full = torch.rand(K, N, dtype=dtype_torch) * val_range + output_full = torch.matmul(input_a, input_b_full) if False: # The following inputs are useful for debugging; # the A matrix becomes a matrix where each element encodes its row and column index, diff --git a/iron/operators/gemm/test.py b/iron/operators/gemm/test.py index bbd41b00..cb4015f0 100755 --- a/iron/operators/gemm/test.py +++ b/iron/operators/gemm/test.py @@ -21,34 +21,41 @@ def get_params(): max_aie_columns = dev.cols device_type = dev.resolve().name # fmt: off - # M, K, N, num_aie_columns, b_col_maj, c_col_maj, m, k, n, trace_size, partition_N + # M, K, N, num_aie_columns, b_col_maj, c_col_maj, m, k, n, trace_size, partition_N, dtype_in, dtype_out regular_params = [ - (2048, 2048, 2048, 1, False, False, 64, 64, 64, 0, 1), - (2048, 2048, 2048, 2, True, False, 64, 64, 64, 0, 1), - (2048, 2048, 2048, 8, True, True, 64, 64, 64, 0, 1), - ( 384, 1536, 1792, 4, True, False, 32, 48, 64, 0, 1), - (1792, 896, 1152, 8, False, True, 64, 32, 48, 0, 1), - ( 896, 1792, 640, 8, False, True, 32, 64, 80, 0, 1), - ( 192, 384, 64, 4, False, False, 48, 96, 16, 0, 1), - ( 192, 384, 64, 4, True, True, 48, 96, 16, 0, 1), - ( 64, 512, 256, 4, True, False, 16, 64, 64, 0, 4), + (2048, 2048, 2048, 1, False, False, 64, 64, 64, 0, 1, "bf16", "bf16"), + (2048, 2048, 2048, 2, True, False, 64, 64, 64, 0, 1, "bf16", "bf16"), + (2048, 2048, 2048, 8, True, True, 64, 64, 64, 0, 1, "bf16", "bf16"), + ( 384, 1536, 1792, 4, True, False, 32, 48, 64, 0, 1, "bf16", "bf16"), + (1792, 896, 1152, 8, False, True, 64, 32, 48, 0, 1, "bf16", "bf16"), + ( 896, 1792, 640, 8, False, True, 32, 64, 80, 0, 1, "bf16", "bf16"), + ( 192, 384, 64, 4, False, False, 48, 96, 16, 0, 1, "bf16", "bf16"), + ( 192, 384, 64, 4, True, True, 48, 96, 16, 0, 1, "bf16", "bf16"), + ( 64, 512, 256, 4, True, False, 16, 64, 64, 0, 4, "bf16", "bf16"), + ] + int8_params = [ + (2048, 2048, 2048, 4, False, False, 64, 64, 64, 0, 1, "i8", "i32"), + ( 512, 512, 512, 4, False, False, 32, 32, 32, 0, 1, "i8", "i32"), + (2048, 2048, 2048, 8, False, False, 64, 64, 64, 0, 1, "i8", "i32"), + (1024, 1024, 1024, 8, False, False, 64, 64, 64, 0, 1, "i8", "i8"), + (1024, 1024, 1024, 8, True, False, 64, 64, 64, 0, 1, "i8", "i16"), ] extensive_params = [ - (2048, 2048, 2048, 8, False, False, 32, 32, 128, 0, 1), - (2048, 2048, 8192, 2, False, False, 64, 64, 64, 0, 1), - (2048, 8192, 2048, 2, False, False, 64, 64, 64, 0, 1), - (2048, 64, 2048, 2, False, False, 64, 64, 64, 0, 1), - (2048, 64, 8192, 2, False, False, 64, 64, 64, 0, 1), - (2048, 2048, 2048, 8, True, False, 128, 32, 32, 0, 1), - (2048, 2048, 8192, 2, True, False, 64, 64, 64, 0, 1), - (2048, 8192, 2048, 2, True, False, 64, 64, 64, 0, 1), - (2048, 64, 2048, 2, True, False, 64, 64, 64, 0, 1), - (2048, 64, 8192, 2, True, False, 64, 64, 64, 0, 1), - (2048, 2048, 2048, 2, False, True, 8, 16, 32, 0, 1), - (2048, 2048, 8192, 2, False, True, 64, 64, 64, 0, 1), - (2048, 8192, 2048, 2, False, True, 64, 64, 64, 0, 1), - (2048, 64, 2048, 2, False, True, 64, 64, 64, 0, 1), - (2048, 64, 8192, 2, False, True, 64, 64, 64, 0, 1), + (2048, 2048, 2048, 8, False, False, 32, 32, 128, 0, 1, "bf16", "bf16"), + (2048, 2048, 8192, 2, False, False, 64, 64, 64, 0, 1, "bf16", "bf16"), + (2048, 8192, 2048, 2, False, False, 64, 64, 64, 0, 1, "bf16", "bf16"), + (2048, 64, 2048, 2, False, False, 64, 64, 64, 0, 1, "bf16", "bf16"), + (2048, 64, 8192, 2, False, False, 64, 64, 64, 0, 1, "bf16", "bf16"), + (2048, 2048, 2048, 8, True, False, 128, 32, 32, 0, 1, "bf16", "bf16"), + (2048, 2048, 8192, 2, True, False, 64, 64, 64, 0, 1, "bf16", "bf16"), + (2048, 8192, 2048, 2, True, False, 64, 64, 64, 0, 1, "bf16", "bf16"), + (2048, 64, 2048, 2, True, False, 64, 64, 64, 0, 1, "bf16", "bf16"), + (2048, 64, 8192, 2, True, False, 64, 64, 64, 0, 1, "bf16", "bf16"), + (2048, 2048, 2048, 2, False, True, 8, 16, 32, 0, 1, "bf16", "bf16"), + (2048, 2048, 8192, 2, False, True, 64, 64, 64, 0, 1, "bf16", "bf16"), + (2048, 8192, 2048, 2, False, True, 64, 64, 64, 0, 1, "bf16", "bf16"), + (2048, 64, 2048, 2, False, True, 64, 64, 64, 0, 1, "bf16", "bf16"), + (2048, 64, 8192, 2, False, True, 64, 64, 64, 0, 1, "bf16", "bf16"), ] # fmt: on @@ -69,6 +76,8 @@ def add_params(param_list, is_extensive): n, trace_size, partition_N, + dtype_in, + dtype_out, ) = p # Skip tests that require more columns than available on the device @@ -84,6 +93,7 @@ def add_params(param_list, is_extensive): params.append(pytest.param(*p, marks=marks)) add_params(regular_params, is_extensive=False) + add_params(int8_params, is_extensive=False) add_params(extensive_params, is_extensive=True) return params @@ -95,7 +105,7 @@ def add_params(param_list, is_extensive): Throughput=r"Throughput: (?P[\d\.e\+-]+) GFLOP/s", ) @pytest.mark.parametrize( - "M,K,N,num_aie_columns,b_col_maj,c_col_maj,m,k,n,trace_size,partition_N", + "M,K,N,num_aie_columns,b_col_maj,c_col_maj,m,k,n,trace_size,partition_N,dtype_in,dtype_out", get_params(), ) def test_gemm( @@ -110,6 +120,8 @@ def test_gemm( n, trace_size, partition_N, + dtype_in, + dtype_out, aie_context, ): total_N = N * partition_N @@ -118,6 +130,8 @@ def test_gemm( M=M, K=K, N=total_N, + dtype=dtype_in, + dtype_out=dtype_out, b_col_maj=b_col_maj, c_col_maj=c_col_maj, ) @@ -130,13 +144,20 @@ def test_gemm( tile_k=k, tile_n=n, num_aie_columns=num_aie_columns, - prio_accuracy=True, + prio_accuracy=dtype_in == "bf16", emulate_bf16_mmul_with_bfp16=False, b_col_maj=b_col_maj, c_col_maj=c_col_maj, + dtype_in=dtype_in, + dtype_out=dtype_out, context=aie_context, ) + if dtype_in == "i8": + rel_tol, abs_tol = 1e-10, 1e-10 + else: + rel_tol, abs_tol = 0.005, 0.005 + if partition_N == 1: input_buffers = { "A": golden_ref["input"].flatten(), @@ -146,7 +167,7 @@ def test_gemm( "C": golden_ref["output"][0].flatten(), } errors, latency_us, bandwidth_gbps = run_test( - operator, input_buffers, output_buffers, rel_tol=0.005, abs_tol=0.005 + operator, input_buffers, output_buffers, rel_tol=rel_tol, abs_tol=abs_tol ) else: compilable = operator.compile() @@ -200,7 +221,7 @@ def test_gemm( # Compare concatenated output to full reference C_expected = golden_ref["output"][0] buf_errors = verify_buffer( - C_concat, "C", C_expected, rel_tol=0.005, abs_tol=0.005 + C_concat, "C", C_expected, rel_tol=rel_tol, abs_tol=abs_tol ) errors = {"C": buf_errors} if buf_errors else {}