diff --git a/pufferlib/extensions/cuda/kernels.cu b/pufferlib/extensions/cuda/kernels.cu index 910a895fe..fef22e1f2 100644 --- a/pufferlib/extensions/cuda/kernels.cu +++ b/pufferlib/extensions/cuda/kernels.cu @@ -12,6 +12,10 @@ #include #include +#define WARP_SIZE 32 +#define PPO_THREADS 256 +#define FULL_MASK 0xffffffff + #define SEQ_SIZE 256 #define BLOCK_SIZE 256 inline int grid_size(int N) { @@ -1138,6 +1142,114 @@ void launch_logcumsumexp_backward( fprintf(stderr, "Backward kernel error: %s\n", cudaGetErrorString(err)); } + +template +__global__ void ppo_loss_forward_kernel_optimized( + float* __restrict__ loss, + double* __restrict__ saved_for_backward, + const T* __restrict__ logits, + const T* __restrict__ values_pred, + const int64_t* __restrict__ actions, + const T* __restrict__ old_logprobs, + const T* __restrict__ advantages, + const T* __restrict__ prio, + const T* __restrict__ values, + const T* __restrict__ returns, + const float* __restrict__ adv_mean, + const float* __restrict__ adv_std, + float clip_coef, + float vf_clip_coef, + float vf_coef, + float ent_coef, + int T_seq, + int A, + int N +) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int total_elements = N * T_seq; + if (idx >= total_elements) return; + + __shared__ float block_loss[PPO_THREADS]; + + int n = idx / T_seq; + int t = idx % T_seq; + int nt = n * T_seq + t; + int logits_offset = n * T_seq * A + t * A; + int act = actions[nt]; + + float max_logit = -INFINITY; + float sum = 0.0f; + float act_logit = 0.0f; + + for (int a = 0; a < A; a++) { + float l = float(logits[logits_offset + a]); + + // cache the action's logit + if (a == act) { + act_logit = l; + } + + // rescale + if (l > max_logit) { + sum *= __expf(max_logit - l); + max_logit = l; + } + sum += __expf(l - max_logit); + } + + float logsumexp = max_logit + __logf(sum); + + float entropy = 0.0f; + for (int a = 0; a < A; a++) { + float l = float(logits[logits_offset + a]); + float logp = l - logsumexp; + float p = __expf(logp); + entropy -= p * logp; + } + + float new_logp = act_logit - logsumexp; + float old_logp = float(old_logprobs[nt]); + float adv = float(advantages[nt]); + float w = float(prio[n]); + float adv_normalized = (adv - adv_mean[0]) / (adv_std[0] + 1e-8f); + + float logratio = new_logp - old_logp; + float ratio = __expf(logratio); + + float ratio_clipped = fmaxf(1.0f - clip_coef, fminf(1.0f + clip_coef, ratio)); + float wa = -w * adv_normalized; + float pg_loss1 = wa * ratio; + float pg_loss2 = wa * ratio_clipped; + float pg_loss = fmaxf(pg_loss1, pg_loss2); + + float val = float(values[nt]); + float ret = float(returns[nt]); + float val_pred = float(values_pred[nt]); + + float v_error = val_pred - val; + float v_clipped = val + fmaxf(-vf_clip_coef, fminf(vf_clip_coef, v_error)); + float v_loss_unclipped = (val_pred - ret) * (val_pred - ret); + float v_loss_clipped = (v_clipped - ret) * (v_clipped - ret); + float v_loss = 0.5f * fmaxf(v_loss_unclipped, v_loss_clipped); + + float thread_loss = (pg_loss + vf_coef * v_loss - ent_coef * entropy) / float(total_elements); + + int tid = threadIdx.x; + block_loss[tid] = thread_loss; + __syncthreads(); + + for (int stride = PPO_THREADS / 2; stride > 0; stride >>= 1) { + if (tid < stride) { + block_loss[tid] += block_loss[tid + stride]; + } + __syncthreads(); + } + + if (tid == 0) { + atomicAdd(loss, block_loss[0]); + } +} + template __global__ void ppo_loss_forward_kernel( float* __restrict__ loss, @@ -1258,6 +1370,126 @@ __global__ void ppo_loss_forward_kernel( } } +template +__global__ void ppo_loss_backward_kernel_optimized( + T* __restrict__ grad_logits, + T* __restrict__ grad_values_pred, + const float* __restrict__ grad_loss, + const T* __restrict__ logits, + const T* __restrict__ values_pred, + const int64_t* __restrict__ actions, + const T* __restrict__ old_logprobs, + const T* __restrict__ advantages, + const T* __restrict__ prio, + const T* __restrict__ values, + const T* __restrict__ returns, + const float* __restrict__ adv_mean, + const float* __restrict__ adv_std, + float clip_coef, + float vf_clip_coef, + float vf_coef, + float ent_coef, + int T_seq, + int A, + int N +) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int total_elements = N * T_seq; + if (idx >= total_elements) return; + + float inv_NT = 1.0f / float(total_elements); + int n = idx / T_seq; + int t = idx % T_seq; + int nt = n * T_seq + t; + int logits_offset = n * T_seq * A + t * A; + int act = actions[nt]; + + float old_logp = float(old_logprobs[nt]); + float adv = float(advantages[nt]); + float w = float(prio[n]); + float val = float(values[nt]); + float ret = float(returns[nt]); + float val_pred = float(values_pred[nt]); + + float max_logit = -INFINITY; + float sum = 0.0f; + float act_logit = 0.0f; + + for (int a = 0; a < A; a++) { + float l = float(logits[logits_offset + a]); + if (a == act) act_logit = l; + + if (l > max_logit) { + sum *= __expf(max_logit - l); + max_logit = l; + } + sum += __expf(l - max_logit); + } + float logsumexp = max_logit + __logf(sum); + + float entropy = 0.0f; + for (int a = 0; a < A; a++) { + float l = float(logits[logits_offset + a]); + float logp = l - logsumexp; + float p = __expf(logp); + entropy -= p * logp; + } + + // recompute values that were saved in forward + float new_logp = act_logit - logsumexp; + float ratio = __expf(new_logp - old_logp); + float v_error = val_pred - val; + float v_clipped = val + fmaxf(-vf_clip_coef, fminf(vf_clip_coef, v_error)); + + // nrmalize advantage + float adv_normalized = (adv - adv_mean[0]) / (adv_std[0] + 1e-8f); + + // loss gradient scaling + float dL = grad_loss[0] * inv_NT; + float d_pg_loss = dL; + float d_entropy_term = dL * (-ent_coef); + + // gradient wrt value function prediction + float v_loss_unclipped = (val_pred - ret) * (val_pred - ret); + float v_loss_clipped = (v_clipped - ret) * (v_clipped - ret); + bool use_clipped_vf = (v_loss_clipped > v_loss_unclipped); + + float d_val_pred = 0.0f; + if (use_clipped_vf) { + if (v_error >= -vf_clip_coef && v_error <= vf_clip_coef) { + d_val_pred = v_clipped - ret; + } + } else { + d_val_pred = val_pred - ret; + } + grad_values_pred[nt] = T(dL * vf_coef * d_val_pred); + + // policy loss gradient + float ratio_clipped = fmaxf(1.0f - clip_coef, fminf(1.0f + clip_coef, ratio)); + float pg_loss1 = -w * adv_normalized * ratio; + float pg_loss2 = -w * adv_normalized * ratio_clipped; + + float d_ratio = -w * adv_normalized * d_pg_loss; + if (pg_loss2 > pg_loss1) { + if (ratio <= (1.0f - clip_coef) || ratio >= (1.0f + clip_coef)) { + d_ratio = 0.0f; + } + } + float d_new_logp = d_ratio * ratio; + + for (int a = 0; a < A; a++) { + float l = float(logits[logits_offset + a]); + float logp = l - logsumexp; + float p = __expf(logp); + + float d_logit = (a == act) ? d_new_logp : 0.0f; + d_logit -= p * d_new_logp; + + d_logit += d_entropy_term * p * (entropy - logp); + grad_logits[logits_offset + a] = T(d_logit); + } +} + template __global__ void ppo_loss_backward_kernel( T* __restrict__ grad_logits, @@ -1405,6 +1637,59 @@ __global__ void ppo_loss_backward_kernel( } } +template +inline void launch_ppo_loss_forward_optimized( + float* loss_output, + double* saved_for_backward, + const T* logits, + const T* values_pred, + const int64_t* actions, + const T* old_logprobs, + const T* advantages, + const T* prio, + const T* values, + const T* returns, + const float* adv_mean, + const float* adv_std, + float clip_coef, + float vf_clip_coef, + float vf_coef, + float ent_coef, + int T_seq, + int A, + int N, + cudaStream_t stream +) { + int total = N * T_seq; + int grid = (total + PPO_THREADS - 1) / PPO_THREADS; + ppo_loss_forward_kernel_optimized<<>>( + loss_output, + saved_for_backward, + logits, + values_pred, + actions, + old_logprobs, + advantages, + prio, + values, + returns, + adv_mean, + adv_std, + clip_coef, + vf_clip_coef, + vf_coef, + ent_coef, + T_seq, + A, + N + ); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + fprintf(stderr, "PPO forward optimized kernel error: %s\n", cudaGetErrorString(err)); + } +} + template inline void launch_ppo_loss_forward( float* loss_output, @@ -1459,6 +1744,62 @@ inline void launch_ppo_loss_forward( } } +template +void launch_ppo_loss_backward_optimized( + T* grad_logits, + T* grad_values_pred, + const float* grad_loss, + const T* logits, + const T* values_pred, // added: need to read val_pred directly + const int64_t* actions, + const T* old_logprobs, + const T* advantages, + const T* prio, + const T* values, + const T* returns, + const float* adv_mean, + const float* adv_std, + float clip_coef, + float vf_clip_coef, + float vf_coef, + float ent_coef, + int T_seq, + int A, + int N, + cudaStream_t stream +) { + int total = N * T_seq; + int grid = (total + PPO_THREADS - 1) / PPO_THREADS; + + ppo_loss_backward_kernel_optimized<<>>( + grad_logits, + grad_values_pred, + grad_loss, + logits, + values_pred, + actions, + old_logprobs, + advantages, + prio, + values, + returns, + adv_mean, + adv_std, + clip_coef, + vf_clip_coef, + vf_coef, + ent_coef, + T_seq, + A, + N + ); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + fprintf(stderr, "PPO backward optimized kernel error: %s\n", cudaGetErrorString(err)); + } +} + template void launch_ppo_loss_backward( T* grad_logits, @@ -1649,3 +1990,148 @@ void launch_sample_logits( fprintf(stderr, "sample_logits kernel error: %s\n", cudaGetErrorString(err)); } } + +// ============================================================================ +// C-linkage wrappers for ctypes testing +// ============================================================================ + +extern "C" { + +void sync_device() { + cudaDeviceSynchronize(); +} + +const char* get_last_error() { + cudaError_t err = cudaGetLastError(); + return cudaGetErrorString(err); +} + +// PPO Forward - Original +void launch_ppo_loss_forward_original_f32( + float* loss_output, + double* saved_for_backward, + const float* logits, + const float* values_pred, + const int64_t* actions, + const float* old_logprobs, + const float* advantages, + const float* prio, + const float* values, + const float* returns, + const float* adv_mean, + const float* adv_std, + double clip_coef, + double vf_clip_coef, + double vf_coef, + double ent_coef, + int T_seq, + int A, + int N +) { + launch_ppo_loss_forward( + loss_output, saved_for_backward, + logits, values_pred, actions, old_logprobs, advantages, + prio, values, returns, adv_mean, adv_std, + clip_coef, vf_clip_coef, vf_coef, ent_coef, + T_seq, A, N, nullptr + ); +} + +// PPO Forward - Optimized +void launch_ppo_loss_forward_optimized_f32( + float* loss_output, + double* saved_for_backward, + const float* logits, + const float* values_pred, + const int64_t* actions, + const float* old_logprobs, + const float* advantages, + const float* prio, + const float* values, + const float* returns, + const float* adv_mean, + const float* adv_std, + double clip_coef, + double vf_clip_coef, + double vf_coef, + double ent_coef, + int T_seq, + int A, + int N +) { + launch_ppo_loss_forward_optimized( + loss_output, saved_for_backward, + logits, values_pred, actions, old_logprobs, advantages, + prio, values, returns, adv_mean, adv_std, + float(clip_coef), float(vf_clip_coef), float(vf_coef), float(ent_coef), + T_seq, A, N, nullptr + ); +} + +// PPO Backward - Original +void launch_ppo_loss_backward_original_f32( + float* grad_logits, + float* grad_values_pred, + const float* grad_loss, + const float* logits, + const int64_t* actions, + const float* old_logprobs, + const float* advantages, + const float* prio, + const float* values, + const float* returns, + const double* saved_for_backward, + const float* adv_mean, + const float* adv_std, + double clip_coef, + double vf_clip_coef, + double vf_coef, + double ent_coef, + int T_seq, + int A, + int N +) { + launch_ppo_loss_backward( + grad_logits, grad_values_pred, grad_loss, + logits, actions, old_logprobs, advantages, + prio, values, returns, saved_for_backward, + adv_mean, adv_std, + clip_coef, vf_clip_coef, vf_coef, ent_coef, + T_seq, A, N, nullptr + ); +} + +// PPO Backward - Optimized +void launch_ppo_loss_backward_optimized_f32( + float* grad_logits, + float* grad_values_pred, + const float* grad_loss, + const float* logits, + const float* values_pred, // added: need to read val_pred directly + const int64_t* actions, + const float* old_logprobs, + const float* advantages, + const float* prio, + const float* values, + const float* returns, + const float* adv_mean, + const float* adv_std, + double clip_coef, + double vf_clip_coef, + double vf_coef, + double ent_coef, + int T_seq, + int A, + int N +) { + launch_ppo_loss_backward_optimized( + grad_logits, grad_values_pred, grad_loss, + logits, values_pred, actions, old_logprobs, advantages, + prio, values, returns, + adv_mean, adv_std, + float(clip_coef), float(vf_clip_coef), float(vf_coef), float(ent_coef), + T_seq, A, N, nullptr + ); +} + +} // extern "C" diff --git a/pufferlib/extensions/cuda/test_ppo_kernel.py b/pufferlib/extensions/cuda/test_ppo_kernel.py new file mode 100644 index 000000000..cb703b050 --- /dev/null +++ b/pufferlib/extensions/cuda/test_ppo_kernel.py @@ -0,0 +1,754 @@ +#!/usr/bin/env python3 +""" +Test suite for PPO loss kernel optimization. + +Compile: + cd pufferlib/extensions/cuda + nvcc -O3 -arch=sm_86 -shared -Xcompiler -fPIC kernels.cu -o kernels.so + +Usage: + python test_ppo_kernel.py +""" + +import ctypes +import gc +import time +from pathlib import Path + +import torch + + +def load_extension(): + """Load the precompiled CUDA .so via ctypes.""" + so_file = Path(__file__).parent / "kernels.so" + + if not so_file.exists(): + raise FileNotFoundError( + f"Compiled library not found: {so_file}\n" + f"Compile it first with:\n" + f" cd {so_file.parent}\n" + f" nvcc -O3 -arch=sm_86 -shared -Xcompiler -fPIC kernels.cu -o kernels.so" + ) + + print(f"Loading {so_file}...") + lib = ctypes.CDLL(str(so_file)) + + # Forward original: launch_ppo_loss_forward_original_f32 + lib.launch_ppo_loss_forward_original_f32.argtypes = [ + ctypes.c_void_p, # loss_output (float*) + ctypes.c_void_p, # saved_for_backward (double*) + ctypes.c_void_p, # logits (float*) + ctypes.c_void_p, # values_pred (float*) + ctypes.c_void_p, # actions (int64_t*) + ctypes.c_void_p, # old_logprobs (float*) + ctypes.c_void_p, # advantages (float*) + ctypes.c_void_p, # prio (float*) + ctypes.c_void_p, # values (float*) + ctypes.c_void_p, # returns (float*) + ctypes.c_void_p, # adv_mean (float*) + ctypes.c_void_p, # adv_std (float*) + ctypes.c_double, # clip_coef + ctypes.c_double, # vf_clip_coef + ctypes.c_double, # vf_coef + ctypes.c_double, # ent_coef + ctypes.c_int, # T_seq + ctypes.c_int, # A + ctypes.c_int, # N + ] + lib.launch_ppo_loss_forward_original_f32.restype = None + + # Forward optimized: launch_ppo_loss_forward_optimized_f32 + lib.launch_ppo_loss_forward_optimized_f32.argtypes = [ + ctypes.c_void_p, # loss_output (float*) + ctypes.c_void_p, # saved_for_backward (double*) + ctypes.c_void_p, # logits (float*) + ctypes.c_void_p, # values_pred (float*) + ctypes.c_void_p, # actions (int64_t*) + ctypes.c_void_p, # old_logprobs (float*) + ctypes.c_void_p, # advantages (float*) + ctypes.c_void_p, # prio (float*) + ctypes.c_void_p, # values (float*) + ctypes.c_void_p, # returns (float*) + ctypes.c_void_p, # adv_mean (float*) + ctypes.c_void_p, # adv_std (float*) + ctypes.c_double, # clip_coef + ctypes.c_double, # vf_clip_coef + ctypes.c_double, # vf_coef + ctypes.c_double, # ent_coef + ctypes.c_int, # T_seq + ctypes.c_int, # A + ctypes.c_int, # N + ] + lib.launch_ppo_loss_forward_optimized_f32.restype = None + + # Backward original: launch_ppo_loss_backward_original_f32 + lib.launch_ppo_loss_backward_original_f32.argtypes = [ + ctypes.c_void_p, # grad_logits (float*) + ctypes.c_void_p, # grad_values_pred (float*) + ctypes.c_void_p, # grad_loss (float*) + ctypes.c_void_p, # logits (float*) + ctypes.c_void_p, # actions (int64_t*) + ctypes.c_void_p, # old_logprobs (float*) + ctypes.c_void_p, # advantages (float*) + ctypes.c_void_p, # prio (float*) + ctypes.c_void_p, # values (float*) + ctypes.c_void_p, # returns (float*) + ctypes.c_void_p, # saved_for_backward (double*) + ctypes.c_void_p, # adv_mean (float*) + ctypes.c_void_p, # adv_std (float*) + ctypes.c_double, # clip_coef + ctypes.c_double, # vf_clip_coef + ctypes.c_double, # vf_coef + ctypes.c_double, # ent_coef + ctypes.c_int, # T_seq + ctypes.c_int, # A + ctypes.c_int, # N + ] + lib.launch_ppo_loss_backward_original_f32.restype = None + + # Backward optimized: launch_ppo_loss_backward_optimized_f32 + # Note: new signature - values_pred added, saved_for_backward removed + lib.launch_ppo_loss_backward_optimized_f32.argtypes = [ + ctypes.c_void_p, # grad_logits (float*) + ctypes.c_void_p, # grad_values_pred (float*) + ctypes.c_void_p, # grad_loss (float*) + ctypes.c_void_p, # logits (float*) + ctypes.c_void_p, # values_pred (float*) + ctypes.c_void_p, # actions (int64_t*) + ctypes.c_void_p, # old_logprobs (float*) + ctypes.c_void_p, # advantages (float*) + ctypes.c_void_p, # prio (float*) + ctypes.c_void_p, # values (float*) + ctypes.c_void_p, # returns (float*) + ctypes.c_void_p, # adv_mean (float*) + ctypes.c_void_p, # adv_std (float*) + ctypes.c_double, # clip_coef + ctypes.c_double, # vf_clip_coef + ctypes.c_double, # vf_coef + ctypes.c_double, # ent_coef + ctypes.c_int, # T_seq + ctypes.c_int, # A + ctypes.c_int, # N + ] + lib.launch_ppo_loss_backward_optimized_f32.restype = None + + lib.sync_device.argtypes = [] + lib.sync_device.restype = None + + lib.get_last_error.argtypes = [] + lib.get_last_error.restype = ctypes.c_char_p + + print("Loaded successfully!\n") + return lib + + +# Test configurations based on production RL workloads +# Format: (N, T, A) where: +# N = batch size (number of sequences) +# T = sequence length (timesteps per sequence) +# A = action space size +TEST_CONFIGS = { + "tiny": [ + (32, 64, 4), # simple control + (32, 64, 6), # lunar lander style + (64, 32, 4), + ], + "small": [ + (128, 64, 6), # atari minimal + (128, 64, 18), # atari full + (256, 32, 15), # procgen + ], + "medium": [ + (512, 64, 4), + (512, 64, 18), + (1024, 32, 18), + ], + "large": [ + (1024, 64, 18), + (2048, 32, 18), + (2048, 64, 6), + ], + "extreme": [ + (512, 64, 64), # large action space + (256, 64, 128), # very large (nethack-ish) + ], +} + + +def create_test_tensors(N: int, T: int, A: int, device: str = "cuda"): + """ + Create test tensors matching actual PPO kernel inputs. + + Args: + N: Batch size (number of sequences) + T: Sequence length (timesteps per sequence) + A: Action space size + device: Device to create tensors on + + Returns: + dict of tensors + """ + # logits: (N, T, A) - raw policy outputs + logits = torch.randn(N, T, A, device=device, dtype=torch.float32) + + # values_pred: (N, T) - predicted values from critic + values_pred = torch.randn(N, T, device=device, dtype=torch.float32) + + # actions: (N, T) - actions taken (indices into A) + actions = torch.randint(0, A, (N, T), device=device, dtype=torch.int64) + + # old_logprobs: (N, T) - log probs from behavior policy + old_logprobs = torch.randn(N, T, device=device, dtype=torch.float32) - 2.0 # make negative + + # advantages: (N, T) - GAE advantages + advantages = torch.randn(N, T, device=device, dtype=torch.float32) + + # prio: (N,) - importance weights per sequence + prio = torch.ones(N, device=device, dtype=torch.float32) + + # values: (N, T) - old value predictions (for clipping) + values = torch.randn(N, T, device=device, dtype=torch.float32) + + # returns: (N, T) - computed returns + returns = torch.randn(N, T, device=device, dtype=torch.float32) + + # adv_mean, adv_std: scalars for normalization + adv_mean = torch.tensor([advantages.mean().item()], device=device, dtype=torch.float32) + adv_std = torch.tensor([advantages.std().item()], device=device, dtype=torch.float32) + + return { + "logits": logits, + "values_pred": values_pred, + "actions": actions, + "old_logprobs": old_logprobs, + "advantages": advantages, + "prio": prio, + "values": values, + "returns": returns, + "adv_mean": adv_mean, + "adv_std": adv_std, + } + + +def compare_outputs( + outputs_orig: list[torch.Tensor], + outputs_new: list[torch.Tensor], + names: list[str], + rtol: float = 1e-4, + atol: float = 1e-4, +) -> tuple[bool, dict]: + """ + Compare two sets of outputs for numerical equivalence. + + Returns: + (all_passed, details_dict) + """ + assert len(outputs_orig) == len(outputs_new) == len(names) + + results = {} + all_passed = True + + for orig, new, name in zip(outputs_orig, outputs_new, names): + # Check shapes match + if orig.shape != new.shape: + results[name] = { + "passed": False, + "error": f"Shape mismatch: {orig.shape} vs {new.shape}", + } + all_passed = False + continue + + # Check values + try: + torch.testing.assert_close(new, orig, rtol=rtol, atol=atol) + max_diff = (orig - new).abs().max().item() + mean_diff = (orig - new).abs().mean().item() + results[name] = { + "passed": True, + "max_diff": max_diff, + "mean_diff": mean_diff, + } + except AssertionError as e: + max_diff = (orig - new).abs().max().item() + mean_diff = (orig - new).abs().mean().item() + + results[name] = { + "passed": False, + "max_diff": max_diff, + "mean_diff": mean_diff, + "error": str(e)[:200], + } + all_passed = False + + return all_passed, results + + +def cleanup_gpu(): + """Force cleanup of GPU memory.""" + gc.collect() + torch.cuda.empty_cache() + + +def run_forward_kernel(lib, kernel_func, tensors, N, T, A, + clip_coef=0.2, vf_clip_coef=0.2, vf_coef=0.5, ent_coef=0.01): + """ + Run a PPO forward kernel and return output tensors. + + Returns: + (loss, saved_for_backward) + """ + device = tensors["logits"].device + + # Allocate output tensors + loss = torch.zeros(1, device=device, dtype=torch.float32) + saved_for_backward = torch.empty(N * T, 5, device=device, dtype=torch.float64) + + # Call kernel + kernel_func( + loss.data_ptr(), + saved_for_backward.data_ptr(), + tensors["logits"].data_ptr(), + tensors["values_pred"].data_ptr(), + tensors["actions"].data_ptr(), + tensors["old_logprobs"].data_ptr(), + tensors["advantages"].data_ptr(), + tensors["prio"].data_ptr(), + tensors["values"].data_ptr(), + tensors["returns"].data_ptr(), + tensors["adv_mean"].data_ptr(), + tensors["adv_std"].data_ptr(), + clip_coef, + vf_clip_coef, + vf_coef, + ent_coef, + T, A, N + ) + + # Sync and check for errors + lib.sync_device() + err = lib.get_last_error() + if err and err != b"no error": + raise RuntimeError(f"CUDA error: {err.decode()}") + + return loss, saved_for_backward + + +def run_backward_kernel_original(lib, tensors, saved_for_backward, N, T, A, + clip_coef=0.2, vf_clip_coef=0.2, vf_coef=0.5, ent_coef=0.01): + """ + Run the original PPO backward kernel and return gradient tensors. + + Returns: + (grad_logits, grad_values_pred) + """ + device = tensors["logits"].device + + # Allocate output tensors + grad_logits = torch.empty(N, T, A, device=device, dtype=torch.float32) + grad_values_pred = torch.empty(N, T, device=device, dtype=torch.float32) + + # grad_loss is typically 1.0 for .backward() + grad_loss = torch.ones(1, device=device, dtype=torch.float32) + + # Call original kernel (uses saved_for_backward) + lib.launch_ppo_loss_backward_original_f32( + grad_logits.data_ptr(), + grad_values_pred.data_ptr(), + grad_loss.data_ptr(), + tensors["logits"].data_ptr(), + tensors["actions"].data_ptr(), + tensors["old_logprobs"].data_ptr(), + tensors["advantages"].data_ptr(), + tensors["prio"].data_ptr(), + tensors["values"].data_ptr(), + tensors["returns"].data_ptr(), + saved_for_backward.data_ptr(), + tensors["adv_mean"].data_ptr(), + tensors["adv_std"].data_ptr(), + clip_coef, + vf_clip_coef, + vf_coef, + ent_coef, + T, A, N + ) + + # Sync and check for errors + lib.sync_device() + err = lib.get_last_error() + if err and err != b"no error": + raise RuntimeError(f"CUDA error: {err.decode()}") + + return grad_logits, grad_values_pred + + +def run_backward_kernel_optimized(lib, tensors, N, T, A, + clip_coef=0.2, vf_clip_coef=0.2, vf_coef=0.5, ent_coef=0.01): + """ + Run the optimized PPO backward kernel and return gradient tensors. + + Note: The optimized kernel recomputes all values instead of reading saved_for_backward, + so it takes values_pred directly instead of saved_for_backward. + + Returns: + (grad_logits, grad_values_pred) + """ + device = tensors["logits"].device + + # Allocate output tensors + grad_logits = torch.empty(N, T, A, device=device, dtype=torch.float32) + grad_values_pred = torch.empty(N, T, device=device, dtype=torch.float32) + + # grad_loss is typically 1.0 for .backward() + grad_loss = torch.ones(1, device=device, dtype=torch.float32) + + # Call optimized kernel (recomputes everything, takes values_pred directly) + lib.launch_ppo_loss_backward_optimized_f32( + grad_logits.data_ptr(), + grad_values_pred.data_ptr(), + grad_loss.data_ptr(), + tensors["logits"].data_ptr(), + tensors["values_pred"].data_ptr(), # takes values_pred directly + tensors["actions"].data_ptr(), + tensors["old_logprobs"].data_ptr(), + tensors["advantages"].data_ptr(), + tensors["prio"].data_ptr(), + tensors["values"].data_ptr(), + tensors["returns"].data_ptr(), + tensors["adv_mean"].data_ptr(), + tensors["adv_std"].data_ptr(), + clip_coef, + vf_clip_coef, + vf_coef, + ent_coef, + T, A, N + ) + + # Sync and check for errors + lib.sync_device() + err = lib.get_last_error() + if err and err != b"no error": + raise RuntimeError(f"CUDA error: {err.decode()}") + + return grad_logits, grad_values_pred + + +def run_forward_correctness_test( + lib, + N: int, + T: int, + A: int, + verbose: bool = False, +) -> tuple[bool, dict]: + """ + Run correctness test for forward pass. + + Returns: + (passed, details) + """ + try: + tensors = create_test_tensors(N, T, A) + + # Run original + loss_orig, saved_orig = run_forward_kernel( + lib, lib.launch_ppo_loss_forward_original_f32, tensors, N, T, A + ) + + # Run optimized + loss_opt, saved_opt = run_forward_kernel( + lib, lib.launch_ppo_loss_forward_optimized_f32, tensors, N, T, A + ) + + # Compare outputs - only compare loss (optimized kernel doesn't use saved_for_backward) + output_names = ["loss"] + passed, details = compare_outputs( + [loss_orig], + [loss_opt], + output_names, + rtol=1e-3, # looser tolerance for fast math (__expf, __logf) + atol=1e-3, + ) + + return passed, details + finally: + cleanup_gpu() + + +def run_backward_correctness_test( + lib, + N: int, + T: int, + A: int, + verbose: bool = False, +) -> tuple[bool, dict]: + """ + Run correctness test for backward pass. + + Returns: + (passed, details) + """ + try: + tensors = create_test_tensors(N, T, A) + + # Run forward to get saved_for_backward (use original for consistency) + loss, saved_for_backward = run_forward_kernel( + lib, lib.launch_ppo_loss_forward_original_f32, tensors, N, T, A + ) + + # Run original backward (uses saved_for_backward) + grad_logits_orig, grad_values_pred_orig = run_backward_kernel_original( + lib, tensors, saved_for_backward, N, T, A + ) + + # Run optimized backward (recomputes everything, doesn't need saved_for_backward) + grad_logits_opt, grad_values_pred_opt = run_backward_kernel_optimized( + lib, tensors, N, T, A + ) + + # Compare gradients + output_names = ["grad_logits", "grad_values_pred"] + passed, details = compare_outputs( + [grad_logits_orig, grad_values_pred_orig], + [grad_logits_opt, grad_values_pred_opt], + output_names + ) + + return passed, details + finally: + cleanup_gpu() + + +def run_forward_benchmark( + lib, + N: int, + T: int, + A: int, + warmup_iters: int = 10, + bench_iters: int = 100, +) -> dict: + """ + Benchmark both forward kernels and return timing results. + """ + try: + tensors = create_test_tensors(N, T, A) + + # Warmup + for _ in range(warmup_iters): + _ = run_forward_kernel(lib, lib.launch_ppo_loss_forward_original_f32, tensors, N, T, A) + _ = run_forward_kernel(lib, lib.launch_ppo_loss_forward_optimized_f32, tensors, N, T, A) + + lib.sync_device() + + # Benchmark original + start = time.perf_counter() + for _ in range(bench_iters): + _ = run_forward_kernel(lib, lib.launch_ppo_loss_forward_original_f32, tensors, N, T, A) + lib.sync_device() + orig_time = (time.perf_counter() - start) / bench_iters * 1000 # ms + + # Benchmark optimized + start = time.perf_counter() + for _ in range(bench_iters): + _ = run_forward_kernel(lib, lib.launch_ppo_loss_forward_optimized_f32, tensors, N, T, A) + lib.sync_device() + opt_time = (time.perf_counter() - start) / bench_iters * 1000 # ms + + speedup = orig_time / opt_time if opt_time > 0 else float('inf') + + return { + "original_ms": orig_time, + "optimized_ms": opt_time, + "speedup": speedup, + } + finally: + cleanup_gpu() + + +def run_backward_benchmark( + lib, + N: int, + T: int, + A: int, + warmup_iters: int = 10, + bench_iters: int = 100, +) -> dict: + """ + Benchmark both backward kernels and return timing results. + """ + try: + tensors = create_test_tensors(N, T, A) + + # Run forward to get saved_for_backward (needed by original kernel) + loss, saved_for_backward = run_forward_kernel( + lib, lib.launch_ppo_loss_forward_original_f32, tensors, N, T, A + ) + + # Warmup + for _ in range(warmup_iters): + _ = run_backward_kernel_original( + lib, tensors, saved_for_backward, N, T, A + ) + _ = run_backward_kernel_optimized( + lib, tensors, N, T, A + ) + + lib.sync_device() + + # Benchmark original backward + start = time.perf_counter() + for _ in range(bench_iters): + _ = run_backward_kernel_original( + lib, tensors, saved_for_backward, N, T, A + ) + lib.sync_device() + orig_time = (time.perf_counter() - start) / bench_iters * 1000 + + # Benchmark optimized backward + start = time.perf_counter() + for _ in range(bench_iters): + _ = run_backward_kernel_optimized( + lib, tensors, N, T, A + ) + lib.sync_device() + opt_time = (time.perf_counter() - start) / bench_iters * 1000 + + return { + "original_ms": orig_time, + "optimized_ms": opt_time, + "speedup": orig_time / opt_time if opt_time > 0 else float('inf'), + } + finally: + cleanup_gpu() + + +def main(): + if not torch.cuda.is_available(): + print("ERROR: CUDA not available") + return 1 + + print(f"CUDA device: {torch.cuda.get_device_name()}") + print(f"PyTorch version: {torch.__version__}") + print() + + try: + lib = load_extension() + except Exception as e: + print(f"ERROR: Failed to load extension: {e}") + return 1 + + configs = [] + for size_name, size_configs in TEST_CONFIGS.items(): + for cfg in size_configs: + configs.append((size_name, cfg)) + + print("=" * 70) + print("FORWARD CORRECTNESS TESTS") + print("=" * 70) + + all_passed = True + for size_name, (N, T, A) in configs: + config_str = f"N={N:4d}, T={T:3d}, A={A:3d}" + + try: + passed, details = run_forward_correctness_test(lib, N, T, A) + except Exception as e: + print(f"[{size_name:12s}] {config_str} EXCEPTION: {e}") + all_passed = False + continue + + if passed: + status = "PASS" + else: + status = "FAIL" + failed = [name for name, d in details.items() if not d["passed"]] + status += f" (failed: {', '.join(failed)})" + max_diffs = [f"{name}:{d['max_diff']:.2e}" for name, d in details.items()] + status += f"\n max_diffs: {', '.join(max_diffs)}" + all_passed = False + + print(f"[{size_name:12s}] {config_str} {status}") + + print() + + print("=" * 70) + print("FORWARD BENCHMARKS") + print("=" * 70) + print(f"{'Config':<30s} {'Original':>12s} {'Optimized':>12s} {'Speedup':>10s}") + print("-" * 70) + + for size_name, (N, T, A) in configs: + config_str = f"N={N}, T={T}, A={A}" + + try: + bench = run_forward_benchmark(lib, N, T, A) + print( + f"{config_str:<30s} " + f"{bench['original_ms']:>10.3f}ms " + f"{bench['optimized_ms']:>10.3f}ms " + f"{bench['speedup']:>9.2f}x" + ) + except Exception as e: + print(f"{config_str:<30s} ERROR: {e}") + + print() + + print("=" * 70) + print("BACKWARD CORRECTNESS TESTS") + print("=" * 70) + + for size_name, (N, T, A) in configs: + config_str = f"N={N:4d}, T={T:3d}, A={A:3d}" + + try: + passed, details = run_backward_correctness_test(lib, N, T, A) + except Exception as e: + print(f"[{size_name:12s}] {config_str} EXCEPTION: {e}") + all_passed = False + continue + + if passed: + status = "PASS" + else: + status = "FAIL" + failed = [name for name, d in details.items() if not d["passed"]] + status += f" (failed: {', '.join(failed)})" + max_diffs = [f"{name}:{d['max_diff']:.2e}" for name, d in details.items()] + status += f"\n max_diffs: {', '.join(max_diffs)}" + all_passed = False + + print(f"[{size_name:12s}] {config_str} {status}") + + print() + + print("=" * 70) + print("BACKWARD BENCHMARKS") + print("=" * 70) + print(f"{'Config':<30s} {'Original':>12s} {'Optimized':>12s} {'Speedup':>10s}") + print("-" * 70) + + for size_name, (N, T, A) in configs: + config_str = f"N={N}, T={T}, A={A}" + + try: + bench = run_backward_benchmark(lib, N, T, A) + print( + f"{config_str:<30s} " + f"{bench['original_ms']:>10.3f}ms " + f"{bench['optimized_ms']:>10.3f}ms " + f"{bench['speedup']:>9.2f}x" + ) + except Exception as e: + print(f"{config_str:<30s} ERROR: {e}") + + print() + + print("=" * 70) + if all_passed: + print("ALL TESTS PASSED!") + return 0 + else: + print("SOME TESTS FAILED") + return 1 + + +if __name__ == "__main__": + exit(main())