From 89d0d982c6ee1865e2362e18ece7309d881709ad Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Fri, 6 Feb 2026 12:49:31 -0800 Subject: [PATCH 1/3] Add arg for KA test timeout --- Fuser/auto_agent.py | 6 ++++++ Fuser/dispatch_kernel_agent.py | 9 +++++++++ Fuser/pipeline.py | 4 ++++ triton_kernel_agent/agent.py | 5 ++++- triton_kernel_agent/manager.py | 6 ++++++ triton_kernel_agent/worker.py | 7 +++++-- 6 files changed, 34 insertions(+), 3 deletions(-) diff --git a/Fuser/auto_agent.py b/Fuser/auto_agent.py index eacc290..5fa2c7c 100644 --- a/Fuser/auto_agent.py +++ b/Fuser/auto_agent.py @@ -347,6 +347,7 @@ def __init__( ignore_router_config: bool = False, use_router_cache: bool = True, no_cusolver: bool = False, + test_timeout_s: int = 30, ) -> None: self.ka_model = ka_model self.ka_num_workers = ka_num_workers @@ -372,6 +373,7 @@ def __init__( self.ignore_router_config = ignore_router_config self.use_router_cache = use_router_cache self.no_cusolver = no_cusolver + self.test_timeout_s = test_timeout_s def _solve_with_kernelagent(self, problem_code: str) -> RouteResult: agent = TritonKernelAgent( @@ -381,6 +383,7 @@ def _solve_with_kernelagent(self, problem_code: str) -> RouteResult: high_reasoning_effort=self.ka_high_reasoning, target_platform=self.platform_config, no_cusolver=self.no_cusolver, + test_timeout_s=self.run_timeout_s, ) try: # Ensure exceptions in KernelAgent do not abort routing; return a structured failure @@ -445,6 +448,7 @@ def _solve_with_fuser(self, problem_path: Path) -> RouteResult: verify=self.verify, compose_max_iters=self.compose_max_iters, target_platform=self.platform_config.name, + test_timeout_s=self.test_timeout_s, ) except BaseException as exc: # catch SystemExit and others # Return a structured failure so caller can decide on fallback @@ -716,6 +720,7 @@ def main(argv: list[str] | None = None) -> int: p.add_argument("--ka-workers", type=int, default=4) p.add_argument("--ka-rounds", type=int, default=10) p.add_argument("--no-ka-high-reasoning", action="store_true") + p.add_argument("--test-timeout-s", type=int, default=30) p.add_argument("--router-model", default="gpt-5") p.add_argument("--no-router-high-reasoning", action="store_true") p.add_argument("--router-temp", type=float, default=0.2) @@ -786,6 +791,7 @@ def main(argv: list[str] | None = None) -> int: ignore_router_config=args.ignore_router_config, use_router_cache=(not args.no_router_cache), no_cusolver=args.no_cusolver, + test_timeout_s=args.run_timeout_s, ) try: diff --git a/Fuser/dispatch_kernel_agent.py b/Fuser/dispatch_kernel_agent.py index f3bccff..c041b64 100644 --- a/Fuser/dispatch_kernel_agent.py +++ b/Fuser/dispatch_kernel_agent.py @@ -334,6 +334,7 @@ def run( target_platform: str = "cuda", max_iters: int = 10, no_cusolver: bool = False, + test_timeout_s: int = 30, ) -> Path: """Dispatch subgraphs to KernelAgent with optional parallelism. @@ -371,6 +372,7 @@ def _handle_one(idx_item: tuple[int, dict[str, Any]]) -> tuple[int, dict[str, An model_name=agent_model, target_platform=platform, no_cusolver=no_cusolver, + test_timeout_s=test_timeout_s, ) try: result = local_agent.generate_kernel( @@ -453,6 +455,12 @@ def main(argv: list[str] | None = None) -> int: default="2", help="Max concurrent subgraphs to dispatch (default: 2); use 'auto' to match subgraph count", ) + p.add_argument( + "test-timeout-s", + type=int, + default=30, + help="Timeout for each test (default: 30s)", + ) p.add_argument( "--target-platform", default="cuda", @@ -493,6 +501,7 @@ def main(argv: list[str] | None = None) -> int: jobs=jobs_val, target_platform=args.target_platform, no_cusolver=args.no_cusolver, + test_timeout_s=30, ) print(str(summary_path)) return 0 diff --git a/Fuser/pipeline.py b/Fuser/pipeline.py index 23d2c73..f444473 100644 --- a/Fuser/pipeline.py +++ b/Fuser/pipeline.py @@ -57,6 +57,7 @@ def run_pipeline( verify: bool = True, compose_max_iters: int = 5, target_platform: str = "cuda", + test_timeout_s: int = 30, ) -> dict: # Select default KernelAgent model if not provided: prefer GPT-5 for Level 2/3 if dispatch_model is None: @@ -112,6 +113,7 @@ def run_pipeline( jobs=jobs_val, target_platform=target_platform, max_iters=max_iters, + test_timeout_s=test_timeout_s, ) # Step 3: compose end-to-end @@ -173,6 +175,7 @@ def main(argv: list[str] | None = None) -> int: choices=get_platform_choices(), help="Target platform", ) + p.add_argument("--test-timeout-s", type=int, default=30) args = p.parse_args(argv) problem_path = Path(args.problem).resolve() @@ -195,6 +198,7 @@ def main(argv: list[str] | None = None) -> int: verify=args.verify, compose_max_iters=args.compose_max_iters, target_platform=args.target_platform, + test_timeout_s=args.run_timeout_s, ) print(json.dumps(res, indent=2)) return 0 diff --git a/triton_kernel_agent/agent.py b/triton_kernel_agent/agent.py index 376070a..626f23d 100644 --- a/triton_kernel_agent/agent.py +++ b/triton_kernel_agent/agent.py @@ -42,6 +42,7 @@ def __init__( preferred_provider: BaseProvider | None = None, target_platform: PlatformConfig | None = None, no_cusolver: bool = False, + test_timeout_s: int = 30, ): """ Initialize the Triton Kernel Agent. @@ -90,6 +91,7 @@ def __init__( target_platform if target_platform else get_platform("cuda") ) self.no_cusolver = no_cusolver + self.test_timeout_s = test_timeout_s # Setup main logger self._setup_logging() @@ -107,6 +109,7 @@ def __init__( high_reasoning_effort=self.high_reasoning_effort, target_platform=self._platform_config.name, no_cusolver=self.no_cusolver, + test_timeout_s=self.test_timeout_s, ) def _setup_logging(self): @@ -352,7 +355,7 @@ def _generate_kernel_seeds( messages = [{"role": "user", "content": prompt}] # Use provider's multiple response capability - max_completion_tokens = 20000 + max_completion_tokens = 40000 if self.provider.supports_multiple_completions(): # Provider supports native multiple completions diff --git a/triton_kernel_agent/manager.py b/triton_kernel_agent/manager.py index 2e67622..412eca5 100644 --- a/triton_kernel_agent/manager.py +++ b/triton_kernel_agent/manager.py @@ -39,6 +39,7 @@ def __init__( high_reasoning_effort: bool = True, target_platform: str = "cuda", no_cusolver: bool = False, + test_timeout_s: int = 30, ): """ Initialize the worker manager. @@ -53,6 +54,7 @@ def __init__( high_reasoning_effort: Whether to use high reasoning effort for OpenAI models target_platform: Target platform ('cuda' or 'xpu') no_cusolver: If True, disables cuSolver library usage + test_timeout_s: Timeout in seconds for test execution """ self.num_workers = num_workers self.max_rounds = max_rounds @@ -62,6 +64,7 @@ def __init__( self.high_reasoning_effort = high_reasoning_effort self.target_platform = target_platform self.no_cusolver = no_cusolver + self.test_timeout_s = test_timeout_s # Setup logging if log_dir is None: @@ -168,6 +171,7 @@ def run_verification( self.high_reasoning_effort, self.target_platform, self.no_cusolver, + self.test_timeout_s, ) process = mp.Process(target=worker_process, args=args) @@ -233,6 +237,7 @@ def worker_process( high_reasoning_effort: bool, target_platform: str, no_cusolver: bool = False, + test_timeout_s: int = 30, ): """ Worker process for kernel verification and refinement. @@ -253,6 +258,7 @@ def worker_process( high_reasoning_effort=high_reasoning_effort, target_platform=target_platform, no_cusolver=no_cusolver, + test_timeout_s=test_timeout_s, ) result = worker.run( diff --git a/triton_kernel_agent/worker.py b/triton_kernel_agent/worker.py index 3a7e711..c864c89 100644 --- a/triton_kernel_agent/worker.py +++ b/triton_kernel_agent/worker.py @@ -133,6 +133,7 @@ def __init__( high_reasoning_effort: bool = True, target_platform: str = "cuda", no_cusolver: bool = False, + test_timeout_s: int = 30, ): """ Initialize a verification worker. @@ -148,6 +149,7 @@ def __init__( high_reasoning_effort: Whether to use high reasoning effort for OpenAI models target_platform: Target platform default: cuda no_cusolver: If True, disables cuSolver library usage + test_timeout_s: Timeout in seconds for test execution """ self.worker_id = worker_id self.workdir = Path(workdir) @@ -158,6 +160,7 @@ def __init__( self.high_reasoning_effort = high_reasoning_effort self._platform_config = get_platform(target_platform) self.no_cusolver = no_cusolver + self.test_timeout_s = test_timeout_s # Setup files self.kernel_file = self.workdir / "kernel.py" @@ -288,7 +291,7 @@ def _run_test(self) -> tuple[bool, str, str]: cwd=str(self.workdir), capture_output=True, text=True, - timeout=30, # 30 second timeout + timeout=self.test_timeout_s, ) success = result.returncode == 0 @@ -305,7 +308,7 @@ def _run_test(self) -> tuple[bool, str, str]: except subprocess.TimeoutExpired: self.logger.error("Test timed out") - return False, "", "Test execution timed out after 30 seconds" + return False, "", f"Test execution timed out after {self.test_timeout_s} seconds" except Exception as e: self.logger.error(f"Test execution error: {e}") return False, "", str(e) From 8e3aa2e8e0860b2baa8df65f9ee3c99f4ffb0e8a Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Fri, 6 Feb 2026 12:53:56 -0800 Subject: [PATCH 2/3] remove stray line --- triton_kernel_agent/agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/triton_kernel_agent/agent.py b/triton_kernel_agent/agent.py index 626f23d..c90e103 100644 --- a/triton_kernel_agent/agent.py +++ b/triton_kernel_agent/agent.py @@ -355,7 +355,7 @@ def _generate_kernel_seeds( messages = [{"role": "user", "content": prompt}] # Use provider's multiple response capability - max_completion_tokens = 40000 + max_completion_tokens = 20000 if self.provider.supports_multiple_completions(): # Provider supports native multiple completions From 2d9eeb8a478483db5eb04a55cb19030832f3f28f Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Fri, 6 Feb 2026 13:07:51 -0800 Subject: [PATCH 3/3] lint --- triton_kernel_agent/worker.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/triton_kernel_agent/worker.py b/triton_kernel_agent/worker.py index c864c89..39185b2 100644 --- a/triton_kernel_agent/worker.py +++ b/triton_kernel_agent/worker.py @@ -308,7 +308,11 @@ def _run_test(self) -> tuple[bool, str, str]: except subprocess.TimeoutExpired: self.logger.error("Test timed out") - return False, "", f"Test execution timed out after {self.test_timeout_s} seconds" + return ( + False, + "", + f"Test execution timed out after {self.test_timeout_s} seconds", + ) except Exception as e: self.logger.error(f"Test execution error: {e}") return False, "", str(e)