From c774b530fc85955e75620cc4f324ea6c1b41d6fc Mon Sep 17 00:00:00 2001 From: eellison Date: Fri, 29 May 2026 13:11:13 -0700 Subject: [PATCH] [inductor] Save producer partials for downstream sums --- ...est_producer_sum_reduction_accumulation.py | 321 ++++++++++ .../codegen/cuda_combined_scheduling.py | 3 + torch/_inductor/codegen/simd.py | 197 +++++- torch/_inductor/config.py | 23 + torch/_inductor/scheduler.py | 603 ++++++++++++++++++ 5 files changed, 1127 insertions(+), 20 deletions(-) create mode 100644 test/inductor/test_producer_sum_reduction_accumulation.py diff --git a/test/inductor/test_producer_sum_reduction_accumulation.py b/test/inductor/test_producer_sum_reduction_accumulation.py new file mode 100644 index 0000000000000..f02d24a286383 --- /dev/null +++ b/test/inductor/test_producer_sum_reduction_accumulation.py @@ -0,0 +1,321 @@ +# Owner(s): ["module: inductor"] + +import unittest + +import torch +import torch._inductor.config as inductor_config +from torch._dynamo.utils import counters +from torch._inductor import metrics +from torch._inductor.test_case import run_tests, TestCase +from torch.testing._internal.common_device_type import largeTensorTest +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU + + +class SumSumSum6107A2F54029(torch.nn.Module): + def forward( + self, + mm_193, + fmod_2, + primals_20, + mul_10, + div_120, + view_1358, + shape_0, + shape_1, + shape_2, + shape_3, + shape_4, + shape_5, + shape_6, + ): + reshape_default = torch.ops.aten.reshape.default(mm_193, shape_0) + reshape_default_1 = torch.ops.aten.reshape.default(reshape_default, shape_1) + reshape_default_2 = torch.ops.aten.reshape.default(reshape_default_1, shape_2) + permute_default = torch.ops.aten.permute.default( + reshape_default_2, [0, 1, 3, 2, 4, 5] + ) + clone_default = torch.ops.aten.clone.default( + permute_default, memory_format=torch.contiguous_format + ) + reshape_default_3 = torch.ops.aten.reshape.default(clone_default, shape_3) + index_tensor = torch.ops.aten.index.Tensor( + reshape_default_3, [None, None, fmod_2] + ) + index_tensor_1 = torch.ops.aten.index.Tensor(index_tensor, [None, fmod_2]) + mul_tensor = torch.ops.aten.mul.Tensor(index_tensor_1, primals_20) + mul_tensor_1 = torch.ops.aten.mul.Tensor(mul_tensor, 128) + sum_dim_int_list = torch.ops.aten.sum.dim_IntList(mul_tensor, [3], True) + mul_tensor_2 = torch.ops.aten.mul.Tensor(mul_tensor, mul_10) + sum_dim_int_list_1 = torch.ops.aten.sum.dim_IntList(mul_tensor_2, [3], True) + mul_tensor_3 = torch.ops.aten.mul.Tensor(mul_10, sum_dim_int_list_1) + sub_tensor = torch.ops.aten.sub.Tensor(mul_tensor_1, sum_dim_int_list) + sub_tensor_1 = torch.ops.aten.sub.Tensor(sub_tensor, mul_tensor_3) + mul_tensor_4 = torch.ops.aten.mul.Tensor(div_120, sub_tensor_1) + mul_tensor_5 = torch.ops.aten.mul.Tensor(index_tensor_1, mul_10) + sum_dim_int_list_2 = torch.ops.aten.sum.dim_IntList(mul_tensor_5, [0, 1, 2]) + sum_dim_int_list_3 = torch.ops.aten.sum.dim_IntList(index_tensor_1, [0, 1, 2]) + add_tensor = torch.ops.aten.add.Tensor(view_1358, mul_tensor_4) + reshape_default_4 = torch.ops.aten.reshape.default(add_tensor, shape_4) + reshape_default_5 = torch.ops.aten.reshape.default(reshape_default_4, shape_5) + permute_default_1 = torch.ops.aten.permute.default(reshape_default_5, [1, 0]) + sum_dim_int_list_4 = torch.ops.aten.sum.dim_IntList( + reshape_default_5, [0], True + ) + reshape_default_6 = torch.ops.aten.reshape.default(sum_dim_int_list_4, shape_6) + return ( + sum_dim_int_list_2, + sum_dim_int_list_3, + permute_default_1, + reshape_default_6, + ) + + +class SumSumSumDC96C4651516(SumSumSum6107A2F54029): + pass + + +class SumSumSum3213336F4C0B(torch.nn.Module): + def forward( + self, + mm_127, + mm_129, + mm_131, + mul_279, + arg8_1, + arg109_1, + arg297_1, + arg108_1, + shape_0, + shape_1, + shape_2, + shape_3, + shape_4, + ): + view_default = torch.ops.aten.view.default(mm_127, shape_0) + view_default_1 = torch.ops.aten.view.default(mm_129, shape_1) + add_tensor = torch.ops.aten.add.Tensor(view_default, view_default_1) + view_default_2 = torch.ops.aten.view.default(mm_131, shape_2) + add_tensor_1 = torch.ops.aten.add.Tensor(add_tensor, view_default_2) + permute_default = torch.ops.aten.permute.default(add_tensor_1, [1, 0, 2]) + add_tensor_2 = torch.ops.aten.add.Tensor(mul_279, permute_default) + mul_tensor = torch.ops.aten.mul.Tensor(add_tensor_2, arg8_1) + mul_tensor_1 = torch.ops.aten.mul.Tensor(mul_tensor, 768) + sum_dim_int_list = torch.ops.aten.sum.dim_IntList(mul_tensor, [2], True) + mul_tensor_2 = torch.ops.aten.mul.Tensor(mul_tensor, arg109_1) + sum_dim_int_list_1 = torch.ops.aten.sum.dim_IntList( + mul_tensor_2, [2], True + ) + mul_tensor_3 = torch.ops.aten.mul.Tensor(arg109_1, sum_dim_int_list_1) + sub_tensor = torch.ops.aten.sub.Tensor(mul_tensor_1, sum_dim_int_list) + sub_tensor_1 = torch.ops.aten.sub.Tensor(sub_tensor, mul_tensor_3) + mul_tensor_4 = torch.ops.aten.mul.Tensor(arg297_1, sub_tensor_1) + mul_tensor_5 = torch.ops.aten.mul.Tensor(add_tensor_2, arg109_1) + sum_dim_int_list_2 = torch.ops.aten.sum.dim_IntList(mul_tensor_5, [0, 1]) + sum_dim_int_list_3 = torch.ops.aten.sum.dim_IntList(add_tensor_2, [0, 1]) + convert_element_type_default = torch.ops.prims.convert_element_type.default( + arg108_1, torch.float32 + ) + mul_tensor_6 = torch.ops.aten.mul.Tensor(convert_element_type_default, 1.0) + mul_tensor_7 = torch.ops.aten.mul.Tensor(mul_tensor_4, mul_tensor_6) + view_default_3 = torch.ops.aten.view.default(mul_tensor_7, shape_3) + permute_default_1 = torch.ops.aten.permute.default(view_default_3, [1, 0]) + sum_dim_int_list_4 = torch.ops.aten.sum.dim_IntList( + view_default_3, [0], True + ) + view_default_4 = torch.ops.aten.view.default(sum_dim_int_list_4, shape_4) + return ( + sum_dim_int_list_2, + sum_dim_int_list_3, + permute_default_1, + view_default_4, + ) + + +class SumSumSumBaf315Cfc5F0(torch.nn.Module): + def forward( + self, + mm_273, + mm_275, + mm_277, + arg11_1, + arg231_1, + arg537_1, + add_181, + shape_0, + shape_1, + shape_2, + shape_3, + shape_4, + ): + view_default = torch.ops.aten.view.default(mm_273, shape_0) + view_default_1 = torch.ops.aten.view.default(mm_275, shape_1) + add_tensor = torch.ops.aten.add.Tensor(view_default, view_default_1) + view_default_2 = torch.ops.aten.view.default(mm_277, shape_2) + add_tensor_1 = torch.ops.aten.add.Tensor(add_tensor, view_default_2) + mul_tensor = torch.ops.aten.mul.Tensor(add_tensor_1, arg11_1) + mul_tensor_1 = torch.ops.aten.mul.Tensor(mul_tensor, 2048) + sum_dim_int_list = torch.ops.aten.sum.dim_IntList(mul_tensor, [2], True) + mul_tensor_2 = torch.ops.aten.mul.Tensor(mul_tensor, arg231_1) + sum_dim_int_list_1 = torch.ops.aten.sum.dim_IntList( + mul_tensor_2, [2], True + ) + mul_tensor_3 = torch.ops.aten.mul.Tensor(arg231_1, sum_dim_int_list_1) + sub_tensor = torch.ops.aten.sub.Tensor(mul_tensor_1, sum_dim_int_list) + sub_tensor_1 = torch.ops.aten.sub.Tensor(sub_tensor, mul_tensor_3) + mul_tensor_4 = torch.ops.aten.mul.Tensor(arg537_1, sub_tensor_1) + mul_tensor_5 = torch.ops.aten.mul.Tensor(add_tensor_1, arg231_1) + sum_dim_int_list_2 = torch.ops.aten.sum.dim_IntList(mul_tensor_5, [0, 1]) + sum_dim_int_list_3 = torch.ops.aten.sum.dim_IntList(add_tensor_1, [0, 1]) + add_tensor_2 = torch.ops.aten.add.Tensor(add_181, mul_tensor_4) + view_default_3 = torch.ops.aten.view.default(add_tensor_2, shape_3) + permute_default = torch.ops.aten.permute.default(view_default_3, [1, 0]) + sum_dim_int_list_4 = torch.ops.aten.sum.dim_IntList( + view_default_3, [0], True + ) + view_default_4 = torch.ops.aten.view.default(sum_dim_int_list_4, shape_4) + return ( + sum_dim_int_list_2, + sum_dim_int_list_3, + permute_default, + view_default_4, + ) + + +@unittest.skipIf(not HAS_GPU, "requires GPU") +@inductor_config.patch( + { + "producer_sum_reduction_accumulation": True, + "triton.cooperative_reductions": False, + "triton.force_cooperative_reductions": False, + } +) +class ProducerSumReductionAccumulationTest(TestCase): + def setUp(self): + super().setUp() + counters.clear() + metrics.reset() + torch._dynamo.reset() + + def assert_inductor_counters(self, expected): + actual = counters["inductor"] + for name, value in expected.items(): + self.assertEqual(actual[name], value, f"{name}: {dict(actual)}") + + def assert_outputs_close(self, actual, expected): + self.assertEqual(len(actual), len(expected)) + for actual_output, expected_output in zip(actual, expected): + self.assertEqual(actual_output.shape, expected_output.shape) + torch.testing.assert_close( + actual_output, + expected_output, + rtol=1e-3, + atol=1e-1, + ) + + @largeTensorTest("1GB", device=GPU_TYPE, inductor=True) + def test_sum_sum_sum_baf315cfc5f0_rejects_below_min_bytes(self): + args = ( + torch.randn(4096, 2048, device=GPU_TYPE), + torch.randn(4096, 2048, device=GPU_TYPE), + torch.randn(4096, 2048, device=GPU_TYPE), + torch.randn(2048, device=GPU_TYPE), + torch.randn(32, 128, 2048, device=GPU_TYPE), + torch.randn(32, 128, 1, device=GPU_TYPE), + torch.randn(32, 128, 2048, device=GPU_TYPE), + [32, 128, 2048], + [32, 128, 2048], + [32, 128, 2048], + [4096, 2048], + [2048], + ) + + mod = SumSumSumBaf315Cfc5F0() + expected = mod(*args) + actual = torch.compile(mod)(*args) + + self.assert_outputs_close(actual, expected) + self.assert_inductor_counters( + { + "producer_sum_reduction_accumulation_candidates": 0, + "producer_sum_reduction_accumulation_selected": 0, + "producer_sum_reduction_accumulation_codegen": 0, + } + ) + + @largeTensorTest("1GB", device=GPU_TYPE, inductor=True) + def test_sum_sum_sum_3213336f4c0b_rejects_workspace_overhead(self): + args = ( + torch.randn(8192, 768, device=GPU_TYPE), + torch.randn(8192, 768, device=GPU_TYPE), + torch.randn(8192, 768, device=GPU_TYPE), + torch.randn(8, 1024, 768, device=GPU_TYPE), + torch.randn(768, device=GPU_TYPE), + torch.randn(8, 1024, 768, device=GPU_TYPE), + torch.randn(8, 1024, 1, device=GPU_TYPE), + torch.randint(0, 2, (8, 1024, 768), device=GPU_TYPE, dtype=torch.bool), + [1024, 8, 768], + [1024, 8, 768], + [1024, 8, 768], + [8192, 768], + [768], + ) + + mod = SumSumSum3213336F4C0B() + expected = mod(*args) + actual = torch.compile(mod)(*args) + + self.assert_outputs_close(actual, expected) + self.assert_inductor_counters( + { + "producer_sum_reduction_accumulation_candidates": 0, + "producer_sum_reduction_accumulation_selected": 0, + "producer_sum_reduction_accumulation_codegen": 0, + } + ) + + @largeTensorTest("8GB", device=GPU_TYPE, inductor=True) + def check_swin_extra_saved_partials(self, mod): + args = ( + torch.randn(401408, 128, device=GPU_TYPE), + torch.arange(56, device=GPU_TYPE), + torch.randn(128, device=GPU_TYPE), + torch.randn(128, 56, 56, 128, device=GPU_TYPE), + torch.randn(128, 56, 56, 1, device=GPU_TYPE), + torch.randn(128, 56, 56, 128, device=GPU_TYPE), + [8192, 49, 128], + [8192, 7, 7, 128], + [128, 8, 8, 7, 7, 128], + [128, 56, 56, 128], + [128, 3136, 128], + [401408, 128], + [128], + ) + + expected = mod(*args) + actual = torch.compile(mod)(*args) + + self.assert_outputs_close(actual, expected) + self.assertEqual(metrics.generated_kernel_count, 1) + self.assert_inductor_counters( + { + "producer_sum_reduction_accumulation_candidates": 1, + "producer_sum_reduction_accumulation_selected": 1, + "producer_sum_reduction_accumulation_extra_reductions": 1, + "producer_sum_reduction_accumulation_codegen": 1, + "producer_sum_reduction_accumulation_reject_domain_mismatch": 0, + "producer_sum_reduction_accumulation_reject_subnode_domain_mismatch": 0, + "producer_sum_reduction_accumulation_reject_not_mix_order": 0, + } + ) + + @largeTensorTest("8GB", device=GPU_TYPE, inductor=True) + def test_sum_sum_sum_6107a2f54029_selects_extra_saved_partials(self): + self.check_swin_extra_saved_partials(SumSumSum6107A2F54029()) + + @largeTensorTest("8GB", device=GPU_TYPE, inductor=True) + def test_sum_sum_sum_dc96c4651516_selects_extra_saved_partials(self): + self.check_swin_extra_saved_partials(SumSumSumDC96C4651516()) + + +if __name__ == "__main__": + run_tests() diff --git a/torch/_inductor/codegen/cuda_combined_scheduling.py b/torch/_inductor/codegen/cuda_combined_scheduling.py index 07b2c73b1b8bf..e98f6bc0959c6 100644 --- a/torch/_inductor/codegen/cuda_combined_scheduling.py +++ b/torch/_inductor/codegen/cuda_combined_scheduling.py @@ -148,6 +148,9 @@ def codegen_template( def codegen_mix_order_reduction(self, node): return self._triton_scheduling.codegen_mix_order_reduction(node) + def codegen_producer_consumer_partial_reduction(self, node): + return self._triton_scheduling.codegen_producer_consumer_partial_reduction(node) + def codegen_nested_reduction(self, node): return self._triton_scheduling.codegen_nested_reduction(node) diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index ce4ae47d41e05..0437b8f1cbbfd 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -2328,28 +2328,28 @@ def benchmark_codegened_module( ) -> tuple[float, str]: raise NotImplementedError + def _pick_mix_order_reduction_split_size( + self, node: BaseSchedulerNode, numel: sympy.Expr + ) -> int: + # The override has highest priority. + if config.triton.mix_order_reduction_split_size is not None: + return config.triton.mix_order_reduction_split_size + + # Heuristic based on number of SMs. + device_prop = DeviceProperties.create(node.get_device()) + num_sm = device_prop.multi_processor_count + estimated_num_splits = num_sm * 8 + + # split_size is decided based on hint. + # optimization_hint is fine here: the result is clamped to [16, 128], + # so any fallback value still produces a valid split size. + numel_hint = V.graph.sizevars.optimization_hint(numel) + split_size = max(last_power_of_2(numel_hint // estimated_num_splits), 16) + return min(split_size, 128) + def _codegen_mix_order_reduction(self, node1, node2): numel, rnumel = scheduler.MixOrderReduction.get_numel_rnumel(node1) - - def _pick_split_size(): - # the overridden has highest priority - if config.triton.mix_order_reduction_split_size is not None: - return config.triton.mix_order_reduction_split_size - - # heuristics based on number of SMs - device_prop = DeviceProperties.create(node1.get_device()) - num_sm = device_prop.multi_processor_count - estimated_num_splits = num_sm * 8 - - # split_size is decided based on hint. - # optimization_hint is fine here: the result is clamped to [16, 128], - # so any fallback value still produces a valid split size. - numel_hint = V.graph.sizevars.optimization_hint(numel) - split_size = max(last_power_of_2(numel_hint // estimated_num_splits), 16) - split_size = min(split_size, 128) - return split_size - - split_size = _pick_split_size() + split_size = self._pick_mix_order_reduction_split_size(node1, numel) # pyrefly: ignore [bad-assignment] metrics.codegen_mix_order_reduction += 1 @@ -2494,6 +2494,163 @@ def _bench(candidate_split_size): self.free_buffers_in_scheduler() + def codegen_producer_consumer_partial_reduction( + self, node: scheduler.FusedProducerConsumerPartialReduction + ) -> None: + """ + Generate the #21 producer/consumer saved-partial prototype. + + This reuses mix-order's saved partial accumulation machinery for an + existing mixed-order producer, but adds downstream producer-body sum + consumers as extra saved partials instead of rereading the full output. + """ + producer, consumer = node.producer, node.consumer + assert isinstance(producer, scheduler.FusedMixOrderReductions) + candidate = node.candidate + numel = sympy.Integer(candidate.elements_per_reduction_output) + rnumel = sympy.Integer(candidate.consumer_output_numel) + split_size = self._pick_mix_order_reduction_split_size(producer.node1, numel) + + counters["inductor"]["producer_sum_reduction_accumulation_codegen"] += 1 + metrics.codegen_mix_order_reduction += 1 + + def extract_saved_partial_pointwise_nodes(reduction_owner): + reductions, epilogue = self._split_mix_order_reduction_epilogue( + reduction_owner + ) + pointwise_nodes = [] + for subnode in reductions: + subnode.cancel_reduction_split() + converted_node = subnode.extract_pw_from_reduction() + converted_node.swap_pw_red_dimension() + pointwise_nodes.append(converted_node) + return reductions, pointwise_nodes, epilogue + + producer_nodes = producer.node1.get_nodes() + saved_partial_pointwise_nodes = [] + original_reduction_groups = [] + delayed_epilogue_nodes = [] + + reductions, pointwise_nodes, epilogue = extract_saved_partial_pointwise_nodes( + producer.node2 + ) + original_reduction_groups.append(reductions) + saved_partial_pointwise_nodes.extend(pointwise_nodes) + delayed_epilogue_nodes.extend(epilogue) + + for extra_consumer in node.extra_consumers: + reductions, pointwise_nodes, epilogue = ( + extract_saved_partial_pointwise_nodes(extra_consumer) + ) + original_reduction_groups.append(reductions) + saved_partial_pointwise_nodes.extend(pointwise_nodes) + delayed_epilogue_nodes.extend(epilogue) + + consumer_reductions, pointwise_nodes, consumer_epilogue = ( + extract_saved_partial_pointwise_nodes(consumer) + ) + original_reduction_groups.append(consumer_reductions) + saved_partial_pointwise_nodes.extend(pointwise_nodes) + delayed_epilogue_nodes.extend(consumer_epilogue) + + node_schedule = self.generate_node_schedule( + producer_nodes + saved_partial_pointwise_nodes, numel, rnumel + ) + kernel_features = SIMDKernelFeatures(node_schedule, numel, rnumel) + + kernel, ws_name, src_code = self._generate_kernel_code_for_mix_order_reduction( + kernel_features, + split_size=split_size, + for_benchmark=False, + ) + + final_buffer_renames = {} + for reductions in original_reduction_groups: + if not reductions or not bool(reductions[0].node._split_size): + continue + # Split reductions first write an internal partial buffer. Reuse + # mix-order's convention of naming the final workspace reduction as + # the downstream user buffer and marking the intermediate removed. + for subnode in reductions: + bufname = subnode.get_outputs()[0].node.get_name() + username = ( + subnode.get_outputs()[0] + .users[0] + .node.get_outputs()[0] + .node.get_name() + ) + final_buffer_renames[bufname] = username + assert self.scheduler + self.scheduler.removed_ops.add( + subnode.get_outputs()[0].users[0].node.get_name() + ) + V.graph.removed_buffers.add(bufname) + + if final_buffer_renames: + for partial_accum in kernel.saved_partial_accumulate: + partial_accum.buffer_name = final_buffer_renames.get( + partial_accum.buffer_name, partial_accum.buffer_name + ) + + kernel_name = self.define_kernel(src_code, node_schedule, kernel) + kernel.kernel_name = kernel_name + kernel.code_hash = code_hash(src_code) + + with V.set_kernel_handler(kernel): + for subnode in kernel_features.scheduler_nodes(): + if subnode.get_outputs()[0].node.get_name() not in final_buffer_renames: + subnode.mark_run() + + V.graph.wrapper_code.make_comment( + "# Call producer/consumer partial reduction kernel" + ) + self.codegen_comment(node_schedule, None) + kernel.call_kernel(kernel.kernel_name, deallocate_ws=False) + V.graph.removed_buffers |= kernel.removed_buffers + V.graph.inplaced_to_remove |= kernel.inplaced_to_remove + + assert len(saved_partial_pointwise_nodes) == len( + kernel.saved_partial_accumulate + ) + nsplit = V.graph.wrapper_code.codegen_python_sizevar( + (numel + split_size - 1) // split_size + ) + for idx, partial_accum in enumerate(kernel.saved_partial_accumulate): + buffer_name = partial_accum.buffer_name + stride_str = f"({nsplit}) * ({rnumel})" + start = f"{idx} * {stride_str}" + end = f"({idx} + 1) * {stride_str}" + reduction_type2op = { + "min": "amin", + "max": "amax", + } + opname = reduction_type2op.get( + partial_accum.reduction_type, partial_accum.reduction_type + ) + final_reduce = ( + f"{buffer_name} = {ws_name}[{start} : {end}]" + f".view({nsplit}, {rnumel}).{opname}(dim=0)" + ) + buffer = V.graph.get_buffer(buffer_name) + if buffer is not None: + final_shape = [ + V.graph.wrapper_code.codegen_python_sizevar(s) + for s in buffer.get_layout().size + ] + final_shape_str = f"[{', '.join(final_shape)}]" + final_reduce += f".view({final_shape_str})" + if (buffer_dtype := V.graph.get_dtype(buffer_name)) != torch.float: + final_reduce += f".to({buffer_dtype})" + V.graph.wrapper_code.writeline(final_reduce) + V.graph.wrapper_code.allocated.add(buffer_name) + + kernel.deallocate_workspaces() + + if delayed_epilogue_nodes: + self._codegen_nodes(delayed_epilogue_nodes) + + self.free_buffers_in_scheduler() + def codegen_nested_reduction(self, node): """ Generate a single kernel with an outer reduction, a group diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index c4942ef74c061..4f4385d54a975 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -917,6 +917,29 @@ def use_autoheuristic(name: str) -> bool: debug_fusion: bool = os.environ.get("TORCHINDUCTOR_DEBUG_FUSION") == "1" benchmark_fusion: bool = os.environ.get("TORCHINDUCTOR_BENCHMARK_FUSION") == "1" enabled_metric_tables = os.environ.get("TORCHINDUCTOR_ENABLED_METRIC_TABLES", "") + +# Default-off issue #21 prototype. This recognizes large mixed-order producer +# kernels whose fresh full output is returned and also summed downstream. The +# lowering saves producer-body sum partials into a compact workspace, followed +# by a final reduction of that workspace. +producer_sum_reduction_accumulation: bool = ( + os.environ.get("TORCHINDUCTOR_PRODUCER_SUM_REDUCTION_ACCUMULATION", "0") == "1" +) +producer_sum_reduction_accumulation_min_bytes: int = int( + os.environ.get( + "TORCHINDUCTOR_PRODUCER_SUM_REDUCTION_ACCUMULATION_MIN_BYTES", + str(64 * 1024 * 1024), + ) +) +producer_sum_reduction_accumulation_mblock: int = int( + os.environ.get("TORCHINDUCTOR_PRODUCER_SUM_REDUCTION_ACCUMULATION_MBLOCK", "16") +) +producer_sum_reduction_accumulation_min_elements_per_output: int = int( + os.environ.get( + "TORCHINDUCTOR_PRODUCER_SUM_REDUCTION_ACCUMULATION_MIN_ELEMENTS_PER_OUTPUT", + "1024", + ) +) loop_ordering_after_fusion: bool = ( os.environ.get( "TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION", "0" if is_fbcode() else "1" diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index a295bd96a1367..10ce2ad0089f5 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -154,6 +154,34 @@ def get_fusion_nodes(self) -> tuple[BaseSchedulerNode, BaseSchedulerNode]: return (self.node1, self.node2) +@dataclasses.dataclass(frozen=True) +class ProducerConsumerSumReductionCandidate: + """ + Scheduler-level facts for the #21 saved-partial prototype. + + The producer still has to materialize its full output, but exactly one + downstream sum reduction also reads that output. These fields describe the + output/consumer pair and the profitability estimate used before creating + the fused scheduler carrier. + """ + + producer_name: str + output_name: str + consumer_name: str + alias_names: tuple[str, ...] + user_names: tuple[str, ...] + producer_reduction_ancestor_names: tuple[str, ...] + output_numel: int + consumer_output_numel: int + elements_per_reduction_output: int + full_output_bytes: int + reduction_read_numel: int + reduction_read_bytes: int + partial_workspace_mblock: int + estimated_partial_workspace_bytes: int + atomic_contention_estimate: int + + class _LocalEntry(NamedTuple): """One row of the post-rewrite slice the gate builds. @@ -3041,6 +3069,62 @@ def fuse_with(self, other: BaseSchedulerNode) -> FusedNestedReductions: return FusedNestedReductions(self.node1, new_node2) +class FusedProducerConsumerPartialReduction(FusedSchedulerNode): + """ + Issue #21 carrier for a mixed-order producer whose fresh full output is also + consumed by a downstream sum reduction. + + This wraps an existing FusedMixOrderReductions producer and adds compatible + producer-body sum consumers as extra saved partials. The lowering is: + + 1. Codegen the producer/full-output kernel as today. + 2. While each producer tile computes the full output value, also write a + compact partial row: + workspace[row_block, consumer_output_index] = + sum_{producer rows in row_block}(producer_value) + 3. In the wrapper, reduce the compact workspace: + final = workspace.view(workspace_rows, consumer_output_numel) + .sum(dim=0) + .view(final_shape) + .to(final_dtype) + + That reuses the mix-order saved-partial lifetime shape + (KernelArgs.workspace + wrapper final reduction), while keeping the + producer/consumer removal explicit at the scheduler level. + """ + + def __init__( + self, + producer: BaseSchedulerNode, + consumer: BaseSchedulerNode, + candidate: ProducerConsumerSumReductionCandidate, + extra_consumers: tuple[BaseSchedulerNode, ...] = (), + ) -> None: + self.producer = producer + self.consumer = consumer + self.extra_consumers = extra_consumers + self.candidate = candidate + super().__init__( + producer.scheduler, + list(producer.get_nodes()) + + list(consumer.get_nodes()) + + list( + itertools.chain.from_iterable( + extra_consumer.get_nodes() for extra_consumer in extra_consumers + ) + ), + ) + # Hide the internal producer -> consumer edge from later scheduling + # checks if this carrier is enabled. + self.ancestors -= self.get_operation_names() + + def codegen(self, wrapper: PythonWrapperCodegen) -> None: + raise NotImplementedError( + "FusedProducerConsumerPartialReduction should be lowered by " + "BaseScheduling.codegen_producer_consumer_partial_reduction()." + ) + + class FusedExternTritonKernelSchedulerNode(FusedSchedulerNode): def __init__( self, @@ -3357,6 +3441,14 @@ def combinable_nodes( "ComboKernels: %d FusedNestedReductions nodes are filtered", len(nested_reductions), ) + producer_consumer_partials = [ + x for x in nodes if isinstance(x, FusedProducerConsumerPartialReduction) + ] + if producer_consumer_partials: + log.debug( + "ComboKernels: %d FusedProducerConsumerPartialReduction nodes are filtered", + len(producer_consumer_partials), + ) filtered_nodes = [ x @@ -3369,6 +3461,7 @@ def combinable_nodes( GroupedSchedulerNode, FusedMixOrderReductions, FusedNestedReductions, + FusedProducerConsumerPartialReduction, ), ) ] @@ -4059,6 +4152,13 @@ def _init(self, nodes: list[ir.Operation]) -> None: self.name_to_fused_node = {n.get_name(): n for n in self.nodes} self.compute_ancestors() self.compute_input_distances() + self.producer_sum_reduction_accumulation_candidates: list[ + ProducerConsumerSumReductionCandidate + ] = [] + if config.producer_sum_reduction_accumulation: + self.producer_sum_reduction_accumulation_candidates = ( + self.find_producer_sum_reduction_accumulation_candidates() + ) # pyrefly: ignore [bad-assignment] metrics.ir_nodes_pre_fusion += len(self.nodes) @@ -4088,6 +4188,8 @@ def _init(self, nodes: list[ir.Operation]) -> None: self._populate_stream_assignments() self.nodes = self.fuse_nodes(self.nodes) + if self.producer_sum_reduction_accumulation_candidates: + self.nodes = self._fuse_producer_sum_reduction_accumulations(self.nodes) if config._post_fusion_custom_pass is not None: self.nodes = config._post_fusion_custom_pass(self.nodes) @@ -4647,6 +4749,497 @@ def add_user( compute_dependencies_log.debug("BUFFER USER LIST\n") compute_dependencies_log.debug("===== AFTER SCHEDULING =====\n%s", str) + def _producer_sum_accumulation_alias_names(self, buf_name: str) -> OrderedSet[str]: + alias_names = OrderedSet([buf_name]) + changed = True + while changed: + changed = False + for candidate_name, candidate_buf in self.name_to_buf.items(): + if candidate_name in alias_names: + continue + if any(alias in alias_names for alias in candidate_buf.get_aliases()): + alias_names.add(candidate_name) + changed = True + return alias_names + + def _buffer_numel_hint(self, buf: SchedulerBuffer) -> int: + try: + return V.graph.sizevars.optimization_hint(buf.node.get_numel(), fallback=0) + except Exception: + return 0 + + def _buffer_bytes_hint(self, buf: SchedulerBuffer) -> int: + try: + return self._buffer_numel_hint(buf) * get_dtype_size(buf.node.get_dtype()) + except Exception: + return 0 + + def _node_output_numel_hint(self, node: BaseSchedulerNode) -> int: + return sum(self._buffer_numel_hint(buf) for buf in node.get_outputs()) + + def _single_non_output_user(self, buf: SchedulerBuffer) -> BaseSchedulerNode | None: + users = [ + user.node + for user in buf.users + if not user.is_weak and not isinstance(user.node, OutputNode) + ] + if len(users) != 1 or not isinstance(users[0], BaseSchedulerNode): + return None + return users[0] + + def _terminal_sum_reduction_consumer( + self, consumer: BaseSchedulerNode + ) -> BaseSchedulerNode: + """ + Follow a single-use sum-reduction chain to the true final output. + + Large #21 repros often split the final reduction into two scheduler + nodes. The immediate consumer's output can be a large partial buffer; + profitability needs the terminal reduction output size instead. + """ + terminal = consumer + seen: OrderedSet[BaseSchedulerNode] = OrderedSet() + while terminal not in seen: + seen.add(terminal) + outputs = terminal.get_outputs() + if len(outputs) != 1: + break + next_node = self._single_non_output_user(outputs[0]) + if next_node is None or not next_node.is_reduction(): + break + reduction_types = self._reduction_types(next_node) + if not reduction_types or any(rt != "sum" for rt in reduction_types): + break + terminal = next_node + return terminal + + @staticmethod + def _reduction_types(node: BaseSchedulerNode) -> list[str]: + reduction_types: list[str] = [] + for subnode in node.get_nodes(): + if not ( + isinstance(subnode, SchedulerNode) + and isinstance(subnode.node, ir.ComputedBuffer) + ): + continue + reduction_type = subnode.node.get_reduction_type() + if reduction_type is not None: + reduction_types.append(reduction_type) + return reduction_types + + @staticmethod + def _read_names(reads: OrderedSet[Dep]) -> OrderedSet[str]: + return OrderedSet(dep.name for dep in reads if not isinstance(dep, WeakDep)) + + @staticmethod + def _node_outputs_all_dtype(node: BaseSchedulerNode, dtype: torch.dtype) -> bool: + return all(buf.node.get_dtype() == dtype for buf in node.get_outputs()) + + def _consumer_reads_exactly_alias( + self, + consumer: BaseSchedulerNode, + alias_names: OrderedSet[str], + ) -> bool: + read_names = self._read_names(consumer.read_writes.reads) + return bool(read_names & alias_names) and read_names.issubset(alias_names) + + def _producer_sum_accumulation_candidate_for_output( + self, + producer: BaseSchedulerNode, + output: SchedulerBuffer, + ) -> ProducerConsumerSumReductionCandidate | None: + if output.get_aliases() or output.get_mutations(): + return None + + output_name = output.get_name() + alias_names = self._producer_sum_accumulation_alias_names(output_name) + alias_defining_nodes = OrderedSet[BaseSchedulerNode]() + for alias_name in alias_names: + if alias_name == output_name: + continue + alias_buf = self.name_to_buf.get(alias_name) + if alias_buf is not None and alias_buf.defining_op is not None: + alias_defining_nodes.add(alias_buf.defining_op) + + has_output_user = False + consumer_nodes = OrderedSet[BaseSchedulerNode]() + user_names: list[str] = [] + for alias_name in alias_names: + alias_buf = self.name_to_buf.get(alias_name) + if alias_buf is None: + continue + for user in alias_buf.users: + if user.is_weak: + continue + user_names.append(user.get_name()) + if isinstance(user.node, OutputNode): + has_output_user = True + elif isinstance(user.node, BaseSchedulerNode): + if user.node in alias_defining_nodes: + continue + consumer_nodes.add(user.node) + + # The intended path only pays off when the full producer output must + # still exist independently of the sum reduction. + if not has_output_user or len(consumer_nodes) != 1: + return None + + consumer = next(iter(consumer_nodes)) + if not consumer.is_gpu() or not consumer.is_reduction(): + return None + + reduction_types = self._reduction_types(consumer) + if not reduction_types or any(rt != "sum" for rt in reduction_types): + return None + + # The saved-partial workspace path currently allocates torch.float + # workspace rows. Keep the prototype on fp32 reductions until workspace + # dtype is carried through PartialAccumulate/codegen. + if output.node.get_dtype() != torch.float: + return None + + if not self._consumer_reads_exactly_alias(consumer, alias_names): + return None + + output_numel = self._buffer_numel_hint(output) + terminal_consumer = self._terminal_sum_reduction_consumer(consumer) + if not self._node_outputs_all_dtype(terminal_consumer, torch.float): + return None + consumer_output_numel = self._node_output_numel_hint(terminal_consumer) + full_output_bytes = self._buffer_bytes_hint(output) + if ( + output_numel <= 0 + or consumer_output_numel <= 0 + or output_numel <= consumer_output_numel + or full_output_bytes < config.producer_sum_reduction_accumulation_min_bytes + ): + return None + + reduction_read_bytes = sum( + self.dep_size_hint(dep) + for dep in consumer.read_writes.reads + if dep.name in alias_names and not isinstance(dep, WeakDep) + ) + if reduction_read_bytes * 4 < full_output_bytes * 3: + return None + + elements_per_reduction_output = math.ceil(output_numel / consumer_output_numel) + if ( + elements_per_reduction_output + < config.producer_sum_reduction_accumulation_min_elements_per_output + ): + return None + + mblock = max(1, config.producer_sum_reduction_accumulation_mblock) + partial_workspace_rows = math.ceil(elements_per_reduction_output / mblock) + estimated_partial_workspace_bytes = ( + partial_workspace_rows + * consumer_output_numel + * get_dtype_size(output.node.get_dtype()) + ) + if estimated_partial_workspace_bytes * 4 >= reduction_read_bytes * 3: + return None + + reduction_read_numel = sum( + self.dep_size_hint(dep, count_bytes=False) + for dep in consumer.read_writes.reads + if dep.name in alias_names and not isinstance(dep, WeakDep) + ) + producer_reduction_ancestor_names = tuple( + sorted( + name + for name in producer.ancestors + if ( + self.name_to_node.get(name) is not None + and self.name_to_node[name].is_reduction() + ) + ) + ) + if not producer_reduction_ancestor_names: + return None + + return ProducerConsumerSumReductionCandidate( + producer_name=producer.get_name(), + output_name=output_name, + consumer_name=consumer.get_name(), + alias_names=tuple(alias_names), + user_names=tuple(sorted(set(user_names))), + producer_reduction_ancestor_names=producer_reduction_ancestor_names, + output_numel=output_numel, + consumer_output_numel=consumer_output_numel, + elements_per_reduction_output=elements_per_reduction_output, + full_output_bytes=full_output_bytes, + reduction_read_numel=reduction_read_numel, + reduction_read_bytes=reduction_read_bytes, + partial_workspace_mblock=mblock, + estimated_partial_workspace_bytes=estimated_partial_workspace_bytes, + atomic_contention_estimate=elements_per_reduction_output, + ) + + def find_producer_sum_reduction_accumulation_candidates( + self, + ) -> list[ProducerConsumerSumReductionCandidate]: + if V.graph.cpp_wrapper or config.deterministic: + return [] + + candidates: list[ProducerConsumerSumReductionCandidate] = [] + for producer in self.nodes: + device = producer.get_device() + if ( + device is None + or device.type != "cuda" + or not producer.is_gpu() + or producer.has_side_effects() + or producer.has_aliasing_or_mutation() + ): + continue + for output in producer.get_outputs(): + candidate = self._producer_sum_accumulation_candidate_for_output( + producer, output + ) + if candidate is None: + continue + candidates.append(candidate) + + counters["inductor"]["producer_sum_reduction_accumulation_candidates"] += len( + candidates + ) + for candidate in candidates: + log.info( + "producer_sum_reduction_accumulation candidate: " + "producer=%s output=%s consumer=%s output_bytes=%d " + "read_numel=%d read_bytes=%d reduction_outputs=%d " + "elems_per_output=%d " + "partial_workspace_mblock=%d partial_workspace_bytes=%d " + "atomic_contention_estimate=%d aliases=%s users=%s " + "producer_reduction_ancestors=%s", + candidate.producer_name, + candidate.output_name, + candidate.consumer_name, + candidate.full_output_bytes, + candidate.reduction_read_numel, + candidate.reduction_read_bytes, + candidate.consumer_output_numel, + candidate.elements_per_reduction_output, + candidate.partial_workspace_mblock, + candidate.estimated_partial_workspace_bytes, + candidate.atomic_contention_estimate, + candidate.alias_names, + candidate.user_names, + candidate.producer_reduction_ancestor_names, + ) + return candidates + + def _fuse_producer_sum_reduction_accumulations( + self, nodes: list[BaseSchedulerNode] + ) -> list[BaseSchedulerNode]: + if not self.producer_sum_reduction_accumulation_candidates: + return nodes + + def domain_matches_saved_partial_shape( + group: tuple[sympy.Expr, sympy.Expr], + expected_numel: int, + expected_rnumel: int, + ) -> bool: + return V.graph.sizevars.statically_known_equals( + group[0], expected_numel + ) and V.graph.sizevars.statically_known_equals(group[1], expected_rnumel) + + def flat_pointwise_domain_matches_saved_partial_shape( + group: tuple[sympy.Expr, sympy.Expr], + expected_numel: int, + expected_rnumel: int, + ) -> bool: + return V.graph.sizevars.statically_known_equals( + group[0], expected_numel * expected_rnumel + ) and V.graph.sizevars.statically_known_equals(group[1], 1) + + def node_logical_reduction_group( + node: BaseSchedulerNode, + ) -> tuple[sympy.Expr, sympy.Expr]: + if node.is_reduction(): + return MixOrderReduction.get_numel_rnumel(node) + return typing.cast(tuple[sympy.Expr, sympy.Expr], node.group[1]) + + def subnode_domain_matches_saved_partial_shape( + subnode: BaseSchedulerNode, + expected_numel: int, + expected_rnumel: int, + *, + allow_transposed_reduction: bool, + ) -> bool: + group = node_logical_reduction_group(subnode) + if domain_matches_saved_partial_shape( + group, expected_numel, expected_rnumel + ) or ( + not subnode.is_reduction() + and flat_pointwise_domain_matches_saved_partial_shape( + group, expected_numel, expected_rnumel + ) + ): + return True + if not allow_transposed_reduction or not subnode.is_reduction(): + return False + reduction_types = self._reduction_types(subnode) + return ( + bool(reduction_types) + and all(rt == "sum" for rt in reduction_types) + and domain_matches_saved_partial_shape( + group, expected_rnumel, expected_numel + ) + ) + + def compatible_extra_mix_order_sum_consumers( + producer: BaseSchedulerNode, + consumer: BaseSchedulerNode, + candidate: ProducerConsumerSumReductionCandidate, + removed: OrderedSet[BaseSchedulerNode], + ) -> tuple[BaseSchedulerNode, ...]: + if not isinstance(producer, FusedMixOrderReductions): + return () + + expected_numel = candidate.elements_per_reduction_output + expected_rnumel = candidate.consumer_output_numel + producer_ops = producer.get_operation_names() + producer_read_names = self._read_names(producer.read_writes.reads) + alias_names = OrderedSet(candidate.alias_names) + extras: list[BaseSchedulerNode] = [] + for node in nodes: + if node in removed or node is producer or node is consumer: + continue + if ( + not node.is_gpu() + or not node.is_reduction() + or node.has_side_effects() + or node.has_aliasing_or_mutation() + ): + continue + node_ops = node.get_operation_names() + if ( + node_ops & producer_ops + or node.ancestors & producer_ops + or producer.ancestors & node_ops + ): + continue + reduction_types = self._reduction_types(node) + if not reduction_types or any(rt != "sum" for rt in reduction_types): + continue + terminal = self._terminal_sum_reduction_consumer(node) + if not self._node_outputs_all_dtype(terminal, torch.float): + continue + if self._node_output_numel_hint(terminal) != expected_rnumel: + continue + if not subnode_domain_matches_saved_partial_shape( + node, + expected_numel, + expected_rnumel, + allow_transposed_reduction=True, + ): + continue + read_names = self._read_names(node.read_writes.reads) + if read_names & alias_names: + continue + if not read_names.issubset(producer_read_names): + continue + extras.append(node) + + return tuple(extras) + + name_to_owner: dict[str, BaseSchedulerNode] = {} + for node in nodes: + for name in node.get_operation_names(): + name_to_owner[name] = node + + replacements: dict[ + BaseSchedulerNode, FusedProducerConsumerPartialReduction + ] = {} + removed: OrderedSet[BaseSchedulerNode] = OrderedSet() + for candidate in self.producer_sum_reduction_accumulation_candidates: + producer = name_to_owner.get(candidate.producer_name) + consumer = name_to_owner.get(candidate.consumer_name) + if ( + producer is None + or consumer is None + or producer is consumer + or producer in removed + or consumer in removed + ): + continue + if not isinstance(producer, FusedMixOrderReductions): + counters["inductor"][ + "producer_sum_reduction_accumulation_reject_not_mix_order" + ] += 1 + continue + producer_group = typing.cast( + tuple[sympy.Expr, sympy.Expr], producer.group[1] + ) + # Only create the carrier if the fused producer already has the + # same logical [output-index, reduction-index] domain that the + # saved-partial workspace expects. + if not domain_matches_saved_partial_shape( + producer_group, + candidate.elements_per_reduction_output, + candidate.consumer_output_numel, + ): + counters["inductor"][ + "producer_sum_reduction_accumulation_reject_domain_mismatch" + ] += 1 + continue + expected_numel = candidate.elements_per_reduction_output + expected_rnumel = candidate.consumer_output_numel + mix_order_node2_ops = producer.node2.get_operation_names() + # Every subnode in the producer carrier must use the same workspace + # domain. For a FusedMixOrderReductions producer, only node2's + # existing mixed-order reduction may be transposed; extra consumers + # are added separately below when their sum domain is available in + # the producer body. + if any( + not subnode_domain_matches_saved_partial_shape( + subnode, + expected_numel, + expected_rnumel, + allow_transposed_reduction=bool( + subnode.get_operation_names() & mix_order_node2_ops + ), + ) + for subnode in producer.get_nodes() + ): + counters["inductor"][ + "producer_sum_reduction_accumulation_reject_subnode_domain_mismatch" + ] += 1 + continue + extra_consumers = compatible_extra_mix_order_sum_consumers( + producer, consumer, candidate, removed + ) + fused = FusedProducerConsumerPartialReduction( + producer, consumer, candidate, extra_consumers + ) + replacements[producer] = fused + removed.add(consumer) + for extra_consumer in extra_consumers: + removed.add(extra_consumer) + counters["inductor"]["producer_sum_reduction_accumulation_selected"] += 1 + counters["inductor"][ + "producer_sum_reduction_accumulation_extra_reductions" + ] += len(extra_consumers) + + if not replacements: + return nodes + + new_nodes: list[BaseSchedulerNode] = [] + for node in nodes: + if node in removed: + continue + new_nodes.append(replacements.get(node, node)) + + self.name_to_node = {n.get_name(): n for n in new_nodes} + self.name_to_fused_node = {n.get_name(): n for n in new_nodes} + for node in new_nodes: + for name in node.get_operation_names(): + self.name_to_fused_node[name] = node + self.name_to_node[name] = node + return self.topological_sort_schedule(new_nodes) + def insert_memory_check_nodes(self) -> None: from .memory import ( assign_memory_planning_info_for_scheduler_buffers, @@ -9364,6 +9957,11 @@ def _codegen(self, nodes: list[BaseSchedulerNode]) -> None: elif isinstance(node, FusedMixOrderReductions): # pyrefly: ignore [unbound-name] self.get_backend(device).codegen_mix_order_reduction(node) + elif isinstance(node, FusedProducerConsumerPartialReduction): + # pyrefly: ignore [unbound-name] + self.get_backend(device).codegen_producer_consumer_partial_reduction( + node + ) elif isinstance(node, (FusedSchedulerNode, SchedulerNode)): # pyrefly: ignore [unbound-name] self.get_backend(device).codegen_node(node) @@ -9703,6 +10301,11 @@ def codegen_node(self, node: FusedSchedulerNode | SchedulerNode) -> None: def codegen_mix_order_reduction(self, node: FusedMixOrderReductions) -> None: raise NotImplementedError + def codegen_producer_consumer_partial_reduction( + self, node: FusedProducerConsumerPartialReduction + ) -> None: + raise NotImplementedError + def codegen_nested_reduction(self, node: FusedNestedReductions) -> None: raise NotImplementedError