Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions Fuser/auto_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ def __init__(
ignore_router_config: bool = False,
use_router_cache: bool = True,
no_cusolver: bool = False,
test_timeout_s: int = 30,
) -> None:
self.ka_model = ka_model
self.ka_num_workers = ka_num_workers
Expand All @@ -372,6 +373,7 @@ def __init__(
self.ignore_router_config = ignore_router_config
self.use_router_cache = use_router_cache
self.no_cusolver = no_cusolver
self.test_timeout_s = test_timeout_s

def _solve_with_kernelagent(self, problem_code: str) -> RouteResult:
agent = TritonKernelAgent(
Expand All @@ -381,6 +383,7 @@ def _solve_with_kernelagent(self, problem_code: str) -> RouteResult:
high_reasoning_effort=self.ka_high_reasoning,
target_platform=self.platform_config,
no_cusolver=self.no_cusolver,
test_timeout_s=self.run_timeout_s,
)
try:
# Ensure exceptions in KernelAgent do not abort routing; return a structured failure
Expand Down Expand Up @@ -445,6 +448,7 @@ def _solve_with_fuser(self, problem_path: Path) -> RouteResult:
verify=self.verify,
compose_max_iters=self.compose_max_iters,
target_platform=self.platform_config.name,
test_timeout_s=self.test_timeout_s,
)
except BaseException as exc: # catch SystemExit and others
# Return a structured failure so caller can decide on fallback
Expand Down Expand Up @@ -716,6 +720,7 @@ def main(argv: list[str] | None = None) -> int:
p.add_argument("--ka-workers", type=int, default=4)
p.add_argument("--ka-rounds", type=int, default=10)
p.add_argument("--no-ka-high-reasoning", action="store_true")
p.add_argument("--test-timeout-s", type=int, default=30)
p.add_argument("--router-model", default="gpt-5")
p.add_argument("--no-router-high-reasoning", action="store_true")
p.add_argument("--router-temp", type=float, default=0.2)
Expand Down Expand Up @@ -786,6 +791,7 @@ def main(argv: list[str] | None = None) -> int:
ignore_router_config=args.ignore_router_config,
use_router_cache=(not args.no_router_cache),
no_cusolver=args.no_cusolver,
test_timeout_s=args.run_timeout_s,
)

try:
Expand Down
9 changes: 9 additions & 0 deletions Fuser/dispatch_kernel_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ def run(
target_platform: str = "cuda",
max_iters: int = 10,
no_cusolver: bool = False,
test_timeout_s: int = 30,
) -> Path:
"""Dispatch subgraphs to KernelAgent with optional parallelism.

Expand Down Expand Up @@ -371,6 +372,7 @@ def _handle_one(idx_item: tuple[int, dict[str, Any]]) -> tuple[int, dict[str, An
model_name=agent_model,
target_platform=platform,
no_cusolver=no_cusolver,
test_timeout_s=test_timeout_s,
)
try:
result = local_agent.generate_kernel(
Expand Down Expand Up @@ -453,6 +455,12 @@ def main(argv: list[str] | None = None) -> int:
default="2",
help="Max concurrent subgraphs to dispatch (default: 2); use 'auto' to match subgraph count",
)
p.add_argument(
"test-timeout-s",
type=int,
default=30,
help="Timeout for each test (default: 30s)",
)
p.add_argument(
"--target-platform",
default="cuda",
Expand Down Expand Up @@ -493,6 +501,7 @@ def main(argv: list[str] | None = None) -> int:
jobs=jobs_val,
target_platform=args.target_platform,
no_cusolver=args.no_cusolver,
test_timeout_s=30,
)
print(str(summary_path))
return 0
Expand Down
4 changes: 4 additions & 0 deletions Fuser/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def run_pipeline(
verify: bool = True,
compose_max_iters: int = 5,
target_platform: str = "cuda",
test_timeout_s: int = 30,
) -> dict:
# Select default KernelAgent model if not provided: prefer GPT-5 for Level 2/3
if dispatch_model is None:
Expand Down Expand Up @@ -112,6 +113,7 @@ def run_pipeline(
jobs=jobs_val,
target_platform=target_platform,
max_iters=max_iters,
test_timeout_s=test_timeout_s,
)

# Step 3: compose end-to-end
Expand Down Expand Up @@ -173,6 +175,7 @@ def main(argv: list[str] | None = None) -> int:
choices=get_platform_choices(),
help="Target platform",
)
p.add_argument("--test-timeout-s", type=int, default=30)
args = p.parse_args(argv)

problem_path = Path(args.problem).resolve()
Expand All @@ -195,6 +198,7 @@ def main(argv: list[str] | None = None) -> int:
verify=args.verify,
compose_max_iters=args.compose_max_iters,
target_platform=args.target_platform,
test_timeout_s=args.run_timeout_s,
)
print(json.dumps(res, indent=2))
return 0
Expand Down
3 changes: 3 additions & 0 deletions triton_kernel_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
preferred_provider: BaseProvider | None = None,
target_platform: PlatformConfig | None = None,
no_cusolver: bool = False,
test_timeout_s: int = 30,
):
"""
Initialize the Triton Kernel Agent.
Expand Down Expand Up @@ -90,6 +91,7 @@ def __init__(
target_platform if target_platform else get_platform("cuda")
)
self.no_cusolver = no_cusolver
self.test_timeout_s = test_timeout_s

# Setup main logger
self._setup_logging()
Expand All @@ -107,6 +109,7 @@ def __init__(
high_reasoning_effort=self.high_reasoning_effort,
target_platform=self._platform_config.name,
no_cusolver=self.no_cusolver,
test_timeout_s=self.test_timeout_s,
)

def _setup_logging(self):
Expand Down
6 changes: 6 additions & 0 deletions triton_kernel_agent/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
high_reasoning_effort: bool = True,
target_platform: str = "cuda",
no_cusolver: bool = False,
test_timeout_s: int = 30,
):
"""
Initialize the worker manager.
Expand All @@ -53,6 +54,7 @@ def __init__(
high_reasoning_effort: Whether to use high reasoning effort for OpenAI models
target_platform: Target platform ('cuda' or 'xpu')
no_cusolver: If True, disables cuSolver library usage
test_timeout_s: Timeout in seconds for test execution
"""
self.num_workers = num_workers
self.max_rounds = max_rounds
Expand All @@ -62,6 +64,7 @@ def __init__(
self.high_reasoning_effort = high_reasoning_effort
self.target_platform = target_platform
self.no_cusolver = no_cusolver
self.test_timeout_s = test_timeout_s

# Setup logging
if log_dir is None:
Expand Down Expand Up @@ -168,6 +171,7 @@ def run_verification(
self.high_reasoning_effort,
self.target_platform,
self.no_cusolver,
self.test_timeout_s,
)

process = mp.Process(target=worker_process, args=args)
Expand Down Expand Up @@ -233,6 +237,7 @@ def worker_process(
high_reasoning_effort: bool,
target_platform: str,
no_cusolver: bool = False,
test_timeout_s: int = 30,
):
"""
Worker process for kernel verification and refinement.
Expand All @@ -253,6 +258,7 @@ def worker_process(
high_reasoning_effort=high_reasoning_effort,
target_platform=target_platform,
no_cusolver=no_cusolver,
test_timeout_s=test_timeout_s,
)

result = worker.run(
Expand Down
11 changes: 9 additions & 2 deletions triton_kernel_agent/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def __init__(
high_reasoning_effort: bool = True,
target_platform: str = "cuda",
no_cusolver: bool = False,
test_timeout_s: int = 30,
):
"""
Initialize a verification worker.
Expand All @@ -148,6 +149,7 @@ def __init__(
high_reasoning_effort: Whether to use high reasoning effort for OpenAI models
target_platform: Target platform default: cuda
no_cusolver: If True, disables cuSolver library usage
test_timeout_s: Timeout in seconds for test execution
"""
self.worker_id = worker_id
self.workdir = Path(workdir)
Expand All @@ -158,6 +160,7 @@ def __init__(
self.high_reasoning_effort = high_reasoning_effort
self._platform_config = get_platform(target_platform)
self.no_cusolver = no_cusolver
self.test_timeout_s = test_timeout_s

# Setup files
self.kernel_file = self.workdir / "kernel.py"
Expand Down Expand Up @@ -288,7 +291,7 @@ def _run_test(self) -> tuple[bool, str, str]:
cwd=str(self.workdir),
capture_output=True,
text=True,
timeout=30, # 30 second timeout
timeout=self.test_timeout_s,
)

success = result.returncode == 0
Expand All @@ -305,7 +308,11 @@ def _run_test(self) -> tuple[bool, str, str]:

except subprocess.TimeoutExpired:
self.logger.error("Test timed out")
return False, "", "Test execution timed out after 30 seconds"
return (
False,
"",
f"Test execution timed out after {self.test_timeout_s} seconds",
)
except Exception as e:
self.logger.error(f"Test execution error: {e}")
return False, "", str(e)
Expand Down