diff --git a/triton_kernel_agent/templates/test_generation.j2 b/triton_kernel_agent/templates/test_generation.j2 index f26b5be..61c9b2e 100644 --- a/triton_kernel_agent/templates/test_generation.j2 +++ b/triton_kernel_agent/templates/test_generation.j2 @@ -77,6 +77,7 @@ CRITICAL TEST DATA REQUIREMENTS: - Only create variations if the problem description is vague about the input specifications - If original problem description specifies FP32, convert to BF16 problems instead of FP32 problems - AVOID FP32 for inputs and outputs - use BF16 instead when users specify FP32 problems + - EXCEPTION: If the user cast example inputs explicitly to FP32, test with both FP32 and BF16 - GroupNorm: compute mean/var in fp32 over (C/G)*H*W per (N,group), eps=1e-5; apply affine in fp32, cast at end; no extra bf16 quantization before pool/norm. @@ -127,10 +128,10 @@ def test_kernel(): # Create test data using EXACT specifications from problem description # If problem specifies shape/dtype, use those exact values # input_tensor = torch.randn(EXACT_SHAPE, dtype=EXACT_DTYPE, device=device) - + # Call kernel_function as a normal Python function # result = kernel_function(input_tensor, other_args...) - + # Example device check (avoid comparing to literal 'cuda') # if isinstance(result, torch.Tensor) and result.device != input_tensor.device: # return False @@ -154,7 +155,7 @@ def test_kernel(): # print(f"Max absolute difference: {torch.max(torch.abs(result - expected))}") # print(f"Relative error: {torch.max(torch.abs((result - expected) / (expected + 1e-8)))}") # return False - + return True # if successful except Exception as e: # Surface undefined helper issues from kernel.py clearly @@ -170,4 +171,4 @@ if __name__ == "__main__": sys.exit(0 if success else 1) ``` -Generate a complete test implementation: +Generate a complete test implementation: diff --git a/utils/providers/available_models.py b/utils/providers/available_models.py index 22b65ab..6936617 100644 --- a/utils/providers/available_models.py +++ b/utils/providers/available_models.py @@ -64,4 +64,9 @@ provider_classes=[RelayProvider], description="Claude 4.5 Opus (Released Nov 2025)", ), + ModelConfig( + name="gpt-5-2", + provider_classes=[RelayProvider], + description="GPT-5.2 flagship model (Dec 2025) - Note the name is different from the OpenAI model", + ), ]