diff --git a/benchmarks/benchmark_rl_kernels.py b/benchmarks/benchmark_rl_kernels.py index 2147adb..103f838 100644 --- a/benchmarks/benchmark_rl_kernels.py +++ b/benchmarks/benchmark_rl_kernels.py @@ -137,6 +137,8 @@ def _selected_logprob_row(config: BenchmarkConfig) -> dict[str, Any]: "online_fp32": "FusedLogp.online_fp32", "online_indexed_out": "FusedLogp.online_indexed_out", "online_indexed_fp32": "FusedLogp.online_indexed_fp32", + "deterministic_fp32": "DeterministicLogp.apply_fp32", + "deterministic_indexed_fp32": "DeterministicLogp.indexed_fp32", } candidate_name = candidate_names[config.candidate] @@ -200,7 +202,21 @@ def _selected_logprob_row(config: BenchmarkConfig) -> dict[str, Any]: indexed_op = kernel_registry.get_op("logp_indexed") online_op = kernel_registry.get_op("logp_online") online_indexed_op = kernel_registry.get_op("logp_online_indexed") - candidate_backend_name = "FusedLogpGenericOp" + deterministic_op = kernel_registry.get_op("logp_deterministic") + deterministic_indexed_op = kernel_registry.get_op("logp_deterministic_indexed") + candidate_backend_names = { + "apply": "FusedLogpGenericOp", + "out": "FusedLogpGenericOp", + "fp32": "FusedLogpGenericOp", + "indexed_out": "FusedLogpGenericOp", + "indexed_fp32": "FusedLogpGenericOp", + "online_out": "FusedLogpGenericOp", + "online_fp32": "FusedLogpGenericOp", + "online_indexed_out": "FusedLogpGenericOp", + "online_indexed_fp32": "FusedLogpGenericOp", + "deterministic_fp32": "DeterministicLogpCUDAOp", + "deterministic_indexed_fp32": "DeterministicLogpCUDAOp", + } required_backends = { "apply": dense_op, "out": dense_op, @@ -211,8 +227,11 @@ def _selected_logprob_row(config: BenchmarkConfig) -> dict[str, Any]: "online_fp32": online_op, "online_indexed_out": online_indexed_op, "online_indexed_fp32": online_indexed_op, + "deterministic_fp32": deterministic_op, + "deterministic_indexed_fp32": deterministic_indexed_op, } selected_backend = required_backends[config.candidate] + candidate_backend_name = candidate_backend_names[config.candidate] if selected_backend.__class__.__name__ != candidate_backend_name: raise RuntimeError(f"{candidate_backend_name} backend is unavailable") @@ -276,6 +295,17 @@ def _selected_logprob_row(config: BenchmarkConfig) -> dict[str, Any]: batch.valid_indices, ) notes = "online log-sum-exp valid-index float32 output" + elif config.candidate == "deterministic_fp32": + run_candidate = partial(deterministic_op.apply_fp32, logits, batch.token_ids) + notes = "batch-invariant deterministic float32 output" + elif config.candidate == "deterministic_indexed_fp32": + run_candidate = partial( + deterministic_indexed_op.indexed_fp32, + logits, + batch.token_ids, + batch.valid_indices, + ) + notes = "batch-invariant deterministic valid-index float32 output" else: raise ValueError(f"unsupported candidate: {config.candidate}") @@ -378,6 +408,8 @@ def build_arg_parser() -> argparse.ArgumentParser: "online_fp32", "online_indexed_out", "online_indexed_fp32", + "deterministic_fp32", + "deterministic_indexed_fp32", ], ) parser.add_argument("--smoke", action="store_true", help="Run a small local-development shape") diff --git a/csrc/deterministic_logp_kernel.cu b/csrc/deterministic_logp_kernel.cu new file mode 100644 index 0000000..48688d5 --- /dev/null +++ b/csrc/deterministic_logp_kernel.cu @@ -0,0 +1,348 @@ +#include +#include +#include +#include +#include +#include + +namespace { + +constexpr int kDeterministicLogpSmallBlockSize = 128; +constexpr int kDeterministicLogpMediumBlockSize = 256; +constexpr int kDeterministicLogpLargeBlockSize = 512; +constexpr int kDeterministicLogpSmallVocabLimit = 128; +constexpr int kDeterministicLogpMediumVocabLimit = 4096; +constexpr int kDeterministicLogpWarpSize = 32; +constexpr float kDeterministicLogpNegInf = -3.4028234663852886e38F; + +template +struct DeterministicLogpBlockTraits { + static_assert( + BlockSize == kDeterministicLogpSmallBlockSize || + BlockSize == kDeterministicLogpMediumBlockSize || + BlockSize == kDeterministicLogpLargeBlockSize, + "deterministic logp reduction topology requires a supported fixed block size"); + static_assert(BlockSize % kDeterministicLogpWarpSize == 0, "block size must be warp-aligned"); + static constexpr int WarpCount = BlockSize / kDeterministicLogpWarpSize; +}; + +template +__device__ __forceinline__ float deterministicBlockReduceMax(float val) { + constexpr int WarpCount = DeterministicLogpBlockTraits::WarpCount; + __shared__ float shared[WarpCount]; + + int lane = threadIdx.x & (kDeterministicLogpWarpSize - 1); + int wid = threadIdx.x / kDeterministicLogpWarpSize; + +#pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + val = fmaxf(val, __shfl_down_sync(0xffffffff, val, offset)); + } + + if (lane == 0) { + shared[wid] = val; + } + __syncthreads(); + + const bool has_warp_value = threadIdx.x < WarpCount; + const int shared_idx = has_warp_value ? threadIdx.x : 0; + val = has_warp_value ? shared[shared_idx] : kDeterministicLogpNegInf; + if (wid == 0) { +#pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + val = fmaxf(val, __shfl_down_sync(0xffffffff, val, offset)); + } + } + return val; +} + +template +__device__ __forceinline__ float deterministicBlockReduceSum(float val) { + constexpr int WarpCount = DeterministicLogpBlockTraits::WarpCount; + __shared__ float shared[WarpCount]; + + int lane = threadIdx.x & (kDeterministicLogpWarpSize - 1); + int wid = threadIdx.x / kDeterministicLogpWarpSize; + +#pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + val += __shfl_down_sync(0xffffffff, val, offset); + } + + if (lane == 0) { + shared[wid] = val; + } + __syncthreads(); + + const bool has_warp_value = threadIdx.x < WarpCount; + const int shared_idx = has_warp_value ? threadIdx.x : 0; + val = has_warp_value ? shared[shared_idx] : 0.0f; + if (wid == 0) { +#pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + val += __shfl_down_sync(0xffffffff, val, offset); + } + } + return val; +} + +template +__global__ void __launch_bounds__(BlockSize) deterministic_logp_forward_kernel( + const input_t* __restrict__ logits, + const int64_t* __restrict__ token_ids, + output_t* __restrict__ output, + const int64_t* __restrict__ row_indices, + int64_t total_rows, + int vocab_size) { + int64_t row = row_indices == nullptr ? blockIdx.x : row_indices[blockIdx.x]; + if (row < 0 || row >= total_rows) { + return; + } + + const input_t* row_logits = logits + row * vocab_size; + + float local_max = kDeterministicLogpNegInf; + for (int col = threadIdx.x; col < vocab_size; col += BlockSize) { + local_max = fmaxf(local_max, static_cast(row_logits[col])); + } + + float max_val = deterministicBlockReduceMax(local_max); + + __shared__ float row_max; + if (threadIdx.x == 0) { + row_max = max_val; + } + __syncthreads(); + + float local_sum = 0.0f; + for (int col = threadIdx.x; col < vocab_size; col += BlockSize) { + local_sum += expf(static_cast(row_logits[col]) - row_max); + } + + float sum_val = deterministicBlockReduceSum(local_sum); + + __shared__ float row_sum; + if (threadIdx.x == 0) { + row_sum = sum_val; + } + __syncthreads(); + + // Indexed mode may launch duplicate row ids. The writes are idempotent: + // every duplicate writer computes and stores the same deterministic value. + if (threadIdx.x == 0) { + int64_t target_id = token_ids[row]; + if (target_id >= 0 && target_id < vocab_size) { + float target_logit = static_cast(row_logits[target_id]); + output[row] = static_cast(target_logit - row_max - logf(row_sum)); + } else { + output[row] = static_cast(0.0f); + } + } +} + +void check_deterministic_logp_inputs( + const torch::Tensor& logits, + const torch::Tensor& token_ids, + const torch::Tensor& output) { + TORCH_CHECK(logits.is_cuda(), "logits must be a CUDA tensor"); + TORCH_CHECK(token_ids.is_cuda(), "token_ids must be a CUDA tensor"); + TORCH_CHECK(output.is_cuda(), "output must be a CUDA tensor"); + TORCH_CHECK( + logits.device() == token_ids.device(), + "logits and token_ids must be on the same CUDA device"); + TORCH_CHECK( + logits.device() == output.device(), + "logits and output must be on the same CUDA device"); + TORCH_CHECK(logits.dim() == 2, "logits must be a 2D tensor"); + TORCH_CHECK(token_ids.dim() == 1, "token_ids must be a 1D tensor"); + TORCH_CHECK(output.dim() == 1, "output must be a 1D tensor"); + TORCH_CHECK(token_ids.scalar_type() == at::ScalarType::Long, "token_ids must be int64"); + TORCH_CHECK( + token_ids.numel() == logits.size(0), + "token_ids length must match logits rows"); + TORCH_CHECK(output.numel() == logits.size(0), "output length must match logits rows"); + TORCH_CHECK(output.is_contiguous(), "output must be contiguous"); + TORCH_CHECK(logits.size(1) > 0, "logits vocab dimension must be non-empty"); + TORCH_CHECK( + logits.size(0) <= std::numeric_limits::max(), + "logits row count exceeds CUDA grid-x limit"); + TORCH_CHECK( + logits.size(1) <= std::numeric_limits::max(), + "logits vocab dimension exceeds int32 kernel limit"); + TORCH_CHECK( + output.scalar_type() == at::ScalarType::Float || + output.scalar_type() == at::ScalarType::Double || + output.scalar_type() == at::ScalarType::Half || + output.scalar_type() == at::ScalarType::BFloat16, + "output dtype must be float64, float32, float16, or bfloat16"); +} + +void check_deterministic_logp_indices( + const torch::Tensor& logits, + const torch::Tensor& row_indices) { + TORCH_CHECK(row_indices.is_cuda(), "row_indices must be a CUDA tensor"); + TORCH_CHECK( + logits.device() == row_indices.device(), + "logits and row_indices must be on the same CUDA device"); + TORCH_CHECK(row_indices.dim() == 1, "row_indices must be a 1D tensor"); + TORCH_CHECK(row_indices.scalar_type() == at::ScalarType::Long, "row_indices must be int64"); + TORCH_CHECK( + row_indices.numel() <= std::numeric_limits::max(), + "row_indices length exceeds CUDA grid-x limit"); +} + +void launch_deterministic_logp_kernel( + const torch::Tensor& logits, + const torch::Tensor& token_ids, + const torch::Tensor& output, + const int64_t* row_indices_ptr, + int64_t launch_rows, + int64_t total_rows, + int64_t vocab_size) { + if (launch_rows == 0) { + return; + } + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + logits.scalar_type(), + "deterministic_logp_kernel", + ([&] { + using input_t = scalar_t; + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + output.scalar_type(), + "deterministic_logp_output_kernel", + ([&] { + using output_t = scalar_t; + const int vocab_size_i32 = static_cast(vocab_size); + const int launch_rows_i32 = static_cast(launch_rows); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if (vocab_size <= kDeterministicLogpSmallVocabLimit) { + deterministic_logp_forward_kernel< + input_t, + output_t, + kDeterministicLogpSmallBlockSize><<< + launch_rows_i32, + kDeterministicLogpSmallBlockSize, + 0, + stream>>>( + logits.data_ptr(), + token_ids.data_ptr(), + output.data_ptr(), + row_indices_ptr, + total_rows, + vocab_size_i32); + } else if (vocab_size <= kDeterministicLogpMediumVocabLimit) { + deterministic_logp_forward_kernel< + input_t, + output_t, + kDeterministicLogpMediumBlockSize><<< + launch_rows_i32, + kDeterministicLogpMediumBlockSize, + 0, + stream>>>( + logits.data_ptr(), + token_ids.data_ptr(), + output.data_ptr(), + row_indices_ptr, + total_rows, + vocab_size_i32); + } else { + deterministic_logp_forward_kernel< + input_t, + output_t, + kDeterministicLogpLargeBlockSize><<< + launch_rows_i32, + kDeterministicLogpLargeBlockSize, + 0, + stream>>>( + logits.data_ptr(), + token_ids.data_ptr(), + output.data_ptr(), + row_indices_ptr, + total_rows, + vocab_size_i32); + } + })); + })); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +} // namespace + +torch::Tensor deterministic_logp_forward_out( + torch::Tensor logits, + torch::Tensor token_ids, + torch::Tensor output) { + check_deterministic_logp_inputs(logits, token_ids, output); + + auto logits_contig = logits.contiguous(); + auto token_ids_contig = token_ids.contiguous(); + + int64_t total_rows = logits_contig.size(0); + int64_t vocab_size = logits_contig.size(1); + launch_deterministic_logp_kernel( + logits_contig, + token_ids_contig, + output, + nullptr, + total_rows, + total_rows, + vocab_size); + + return output; +} + +torch::Tensor deterministic_logp_forward_indexed_out( + torch::Tensor logits, + torch::Tensor token_ids, + torch::Tensor row_indices, + torch::Tensor output) { + check_deterministic_logp_inputs(logits, token_ids, output); + check_deterministic_logp_indices(logits, row_indices); + + auto logits_contig = logits.contiguous(); + auto token_ids_contig = token_ids.contiguous(); + auto row_indices_contig = row_indices.contiguous(); + + int64_t total_rows = logits_contig.size(0); + int64_t vocab_size = logits_contig.size(1); + int64_t valid_rows = row_indices_contig.numel(); + + launch_deterministic_logp_kernel( + logits_contig, + token_ids_contig, + output, + row_indices_contig.data_ptr(), + valid_rows, + total_rows, + vocab_size); + + return output; +} + +torch::Tensor deterministic_logp_forward(torch::Tensor logits, torch::Tensor token_ids) { + TORCH_CHECK(logits.dim() == 2, "logits must be a 2D tensor"); + auto output = torch::empty({logits.size(0)}, logits.options()); + return deterministic_logp_forward_out(logits, token_ids, output); +} + +torch::Tensor deterministic_logp_forward_fp32(torch::Tensor logits, torch::Tensor token_ids) { + TORCH_CHECK(logits.dim() == 2, "logits must be a 2D tensor"); + auto output = torch::empty({logits.size(0)}, logits.options().dtype(at::ScalarType::Float)); + return deterministic_logp_forward_out(logits, token_ids, output); +} + +torch::Tensor deterministic_logp_forward_indexed_fp32( + torch::Tensor logits, + torch::Tensor token_ids, + torch::Tensor row_indices) { + TORCH_CHECK(logits.dim() == 2, "logits must be a 2D tensor"); + auto output = torch::zeros({logits.size(0)}, logits.options().dtype(at::ScalarType::Float)); + return deterministic_logp_forward_indexed_out(logits, token_ids, row_indices, output); +} diff --git a/csrc/ops.cpp b/csrc/ops.cpp index c241bf6..897f580 100644 --- a/csrc/ops.cpp +++ b/csrc/ops.cpp @@ -20,6 +20,11 @@ torch::Tensor fused_logp_forward_online_out(torch::Tensor logits, torch::Tensor torch::Tensor fused_logp_forward_online_fp32(torch::Tensor logits, torch::Tensor token_ids); torch::Tensor fused_logp_forward_online_indexed_out(torch::Tensor logits, torch::Tensor token_ids, torch::Tensor row_indices, torch::Tensor output); torch::Tensor fused_logp_forward_online_indexed_fp32(torch::Tensor logits, torch::Tensor token_ids, torch::Tensor row_indices); +torch::Tensor deterministic_logp_forward(torch::Tensor logits, torch::Tensor token_ids); +torch::Tensor deterministic_logp_forward_out(torch::Tensor logits, torch::Tensor token_ids, torch::Tensor output); +torch::Tensor deterministic_logp_forward_fp32(torch::Tensor logits, torch::Tensor token_ids); +torch::Tensor deterministic_logp_forward_indexed_out(torch::Tensor logits, torch::Tensor token_ids, torch::Tensor row_indices, torch::Tensor output); +torch::Tensor deterministic_logp_forward_indexed_fp32(torch::Tensor logits, torch::Tensor token_ids, torch::Tensor row_indices); // Prefix-Shared Attention Declarations & Wrappers @@ -86,6 +91,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fused_logp_forward_online_fp32", &fused_logp_forward_online_fp32, "Fused logp online fp32"); m.def("fused_logp_forward_online_indexed_out", &fused_logp_forward_online_indexed_out, "Fused logp online indexed out"); m.def("fused_logp_forward_online_indexed_fp32", &fused_logp_forward_online_indexed_fp32, "Fused logp online indexed fp32"); + m.def("deterministic_logp", &deterministic_logp_forward, "Batch-invariant deterministic logp"); + m.def("deterministic_logp_forward_out", &deterministic_logp_forward_out, "Batch-invariant deterministic logp out"); + m.def("deterministic_logp_forward_fp32", &deterministic_logp_forward_fp32, "Batch-invariant deterministic logp fp32"); + m.def("deterministic_logp_forward_indexed_out", &deterministic_logp_forward_indexed_out, "Batch-invariant deterministic logp indexed out"); + m.def("deterministic_logp_forward_indexed_fp32", &deterministic_logp_forward_indexed_fp32, "Batch-invariant deterministic logp indexed fp32"); // registry Prefix-Shared Attention m.def("prefix_shared_attention", &prefix_shared_attention, "Prefix-Shared Fused Attention for GRPO"); diff --git a/envs.py b/envs.py new file mode 100644 index 0000000..abd12b1 --- /dev/null +++ b/envs.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +"""Environment variable parsing helpers for build scripts. + +This module is intentionally import-safe for setup.py: keep it free of torch or +other heavy runtime imports. +""" + +import os + + +def env_flag(name: str, default: bool = False) -> bool: + value = os.environ.get(name) + if value is None: + return default + normalized = value.strip().lower() + if normalized in {"1", "true", "yes", "on"}: + return True + if normalized in {"0", "false", "no", "off"}: + return False + raise ValueError(f"{name} must be a boolean flag, got {value!r}") + + +KERNEL_ALIGN_USE_FAST_MATH = "KERNEL_ALIGN_USE_FAST_MATH" +KERNEL_ALIGN_NCU_LINEINFO = "KERNEL_ALIGN_NCU_LINEINFO" +KERNEL_ALIGN_ALLOW_UNSUPPORTED_MSVC = "KERNEL_ALIGN_ALLOW_UNSUPPORTED_MSVC" +KERNEL_ALIGN_FORCE_SM90 = "KERNEL_ALIGN_FORCE_SM90" diff --git a/examples/grpo_single_gpu.py b/examples/grpo_single_gpu.py index d263389..0929478 100644 --- a/examples/grpo_single_gpu.py +++ b/examples/grpo_single_gpu.py @@ -21,6 +21,7 @@ if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) +from rl_engine.kernels.registry import resolve_logp_op_type # noqa: E402 from rl_engine.testing import ( # noqa: E402 active_token_count, compute_policy_ratio, @@ -71,6 +72,19 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--clip-eps", type=float, default=0.2) parser.add_argument("--beta", type=float, default=0.01) parser.add_argument("--seed", type=int, default=1234) + parser.add_argument( + "--logp-backend", + default="auto", + help=( + "Selected-token logp backend. Use 'deterministic' or " + "'batch_invariant' for the CUDA deterministic path." + ), + ) + parser.add_argument( + "--require-batch-invariant-logp", + action="store_true", + help="Require a batch-invariant deterministic logp backend.", + ) parser.add_argument( "--require-fused-logp", action="store_true", @@ -87,7 +101,15 @@ def select_device(requested: str) -> torch.device: return torch.device(requested) -def resolve_logp_op(device: torch.device) -> Any: +def resolve_logp_op( + device: torch.device, + logp_backend: str, + require_batch_invariant_logp: bool, +) -> Any: + op_type = resolve_logp_op_type( + logp_backend, + require_batch_invariant=require_batch_invariant_logp, + ) if device.type == "cpu": from rl_engine.kernels.ops.pytorch.loss.logp import NativeLogpOp @@ -95,11 +117,11 @@ def resolve_logp_op(device: torch.device) -> Any: from rl_engine.kernels.registry import kernel_registry - return kernel_registry.get_op("logp") + return kernel_registry.get_op(op_type) -def is_fused_logp_backend(backend_name: str) -> bool: - return backend_name.startswith("FusedLogp") +def is_fused_logp_backend(logp_op: Any) -> bool: + return bool(getattr(logp_op, "is_fused_logp", False)) def make_group_advantages( @@ -166,9 +188,13 @@ def run_training(args: argparse.Namespace) -> list[StepMetrics]: torch.manual_seed(args.seed) device = select_device(args.device) - logp_op = resolve_logp_op(device) + logp_op = resolve_logp_op( + device, + args.logp_backend, + args.require_batch_invariant_logp, + ) backend_name = logp_op.__class__.__name__ - if args.require_fused_logp and not is_fused_logp_backend(backend_name): + if args.require_fused_logp and not is_fused_logp_backend(logp_op): raise RuntimeError( "--require-fused-logp was set, but kernel dispatch selected " f"{backend_name}. Build the CUDA extension with `pip install -e .` " diff --git a/rl_engine/_C.pyi b/rl_engine/_C.pyi index fdd6982..b52295a 100644 --- a/rl_engine/_C.pyi +++ b/rl_engine/_C.pyi @@ -41,3 +41,24 @@ def fused_logp_forward_online_indexed_fp32( token_ids: torch.Tensor, row_indices: torch.Tensor, ) -> torch.Tensor: ... +def deterministic_logp(logits: torch.Tensor, token_ids: torch.Tensor) -> torch.Tensor: ... +def deterministic_logp_forward_out( + logits: torch.Tensor, + token_ids: torch.Tensor, + output: torch.Tensor, +) -> torch.Tensor: ... +def deterministic_logp_forward_fp32( + logits: torch.Tensor, + token_ids: torch.Tensor, +) -> torch.Tensor: ... +def deterministic_logp_forward_indexed_out( + logits: torch.Tensor, + token_ids: torch.Tensor, + row_indices: torch.Tensor, + output: torch.Tensor, +) -> torch.Tensor: ... +def deterministic_logp_forward_indexed_fp32( + logits: torch.Tensor, + token_ids: torch.Tensor, + row_indices: torch.Tensor, +) -> torch.Tensor: ... diff --git a/rl_engine/executors/deepspeed_trainer.py b/rl_engine/executors/deepspeed_trainer.py index 44a1134..72916aa 100644 --- a/rl_engine/executors/deepspeed_trainer.py +++ b/rl_engine/executors/deepspeed_trainer.py @@ -48,6 +48,7 @@ class DeepSpeedTrainingConfig(TorchRLTrainingConfig): initialize_kwargs: Mapping[str, Any] = field(default_factory=dict) def __post_init__(self) -> None: + super().__post_init__() if self.zero_stage < 0: raise ValueError("zero_stage must be >= 0") diff --git a/rl_engine/executors/rollout.py b/rl_engine/executors/rollout.py index 0e12e7f..1832ae1 100644 --- a/rl_engine/executors/rollout.py +++ b/rl_engine/executors/rollout.py @@ -15,7 +15,7 @@ make_weight_bridge, ) from rl_engine.executors.vllm_sampler import VLLMSamplerConfig, VLLMSharedPrefixSampler -from rl_engine.kernels.registry import kernel_registry +from rl_engine.kernels.registry import kernel_registry, resolve_logp_op_type from rl_engine.utils.logger import logger @@ -40,6 +40,10 @@ def __init__( self.weight_install_adapter = weight_install_adapter self.active_weight_version: Optional[int] = None self.active_weight_update_id: Optional[str] = None + self.logp_op_type = resolve_logp_op_type( + self.config.get("logp_backend"), + require_batch_invariant=bool(self.config.get("require_batch_invariant_logp", False)), + ) self.logp_op = None self.attn_op = None self.sampler_config: Optional[VLLMSamplerConfig] = None @@ -125,11 +129,11 @@ def _prepare_kernels(self): """ if not self.logp_op: # Retrieves the best implementation based on hardware. - self.logp_op = kernel_registry.get_op("logp") + self.logp_op = kernel_registry.get_op(self.logp_op_type) self.attn_op = kernel_registry.get_op("attn") logger.info( - f"Active Kernels -> Logp: {type(self.logp_op).__name__}," + f"Active Kernels -> Logp({self.logp_op_type}): {type(self.logp_op).__name__}," f" Attn: {type(self.attn_op).__name__}" ) diff --git a/rl_engine/executors/training_contract.py b/rl_engine/executors/training_contract.py index 14019c9..5f6309a 100644 --- a/rl_engine/executors/training_contract.py +++ b/rl_engine/executors/training_contract.py @@ -9,6 +9,7 @@ import torch +from rl_engine.kernels.registry import resolve_logp_op_type from rl_engine.testing import SyntheticRLKernelBatch, make_synthetic_rl_kernel_batch @@ -64,6 +65,14 @@ class TorchRLTrainingConfig: dtype: torch.dtype = torch.float32 seed: int = 0 min_completion_len: int = 1 + logp_backend: str = "auto" + require_batch_invariant_logp: bool = False + + def __post_init__(self) -> None: + resolve_logp_op_type( + self.logp_backend, + require_batch_invariant=self.require_batch_invariant_logp, + ) class RolloutBatchMixin: diff --git a/rl_engine/kernels/ops/cuda/loss/logp.py b/rl_engine/kernels/ops/cuda/loss/logp.py index 1742b9d..24a2832 100644 --- a/rl_engine/kernels/ops/cuda/loss/logp.py +++ b/rl_engine/kernels/ops/cuda/loss/logp.py @@ -10,6 +10,8 @@ class FusedLogpSM90Op: """TMA-accelerated Fused LogP for SM90+ cards.""" + is_fused_logp = True + def __init__(self): if not _EXT_AVAILABLE or not hasattr(_C, "fused_logp_sm90"): raise RuntimeError( @@ -29,6 +31,8 @@ def __call__(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: class FusedLogpGenericOp: """Generic custom CUDA fallback Fused LogP with RL variants.""" + is_fused_logp = True + def __init__(self): if not _EXT_AVAILABLE or not hasattr(_C, "fused_logp"): raise RuntimeError("Base custom kernel 'fused_logp' is unavailable.") @@ -45,8 +49,8 @@ def _prepare_inputs( token_ids: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Size]: orig_shape = logits.shape[:-1] - logits_2d = logits.view(-1, logits.size(-1)) - token_ids_1d = token_ids.view(-1).to(device=logits.device, dtype=torch.long).contiguous() + logits_2d = logits.reshape(-1, logits.size(-1)) + token_ids_1d = token_ids.reshape(-1).to(device=logits.device, dtype=torch.long).contiguous() return logits_2d, token_ids_1d, orig_shape def _prepare_output(self, output: torch.Tensor, orig_shape: torch.Size) -> torch.Tensor: @@ -58,7 +62,7 @@ def _prepare_output(self, output: torch.Tensor, orig_shape: torch.Size) -> torch return output.view(-1) def _prepare_indices(self, row_indices: torch.Tensor, logits: torch.Tensor) -> torch.Tensor: - return row_indices.view(-1).to(device=logits.device, dtype=torch.long).contiguous() + return row_indices.reshape(-1).to(device=logits.device, dtype=torch.long).contiguous() def apply(self, logits: torch.Tensor, token_ids: torch.Tensor) -> torch.Tensor: logits_2d, token_ids_1d, orig_shape = self._prepare_inputs(logits, token_ids) @@ -140,3 +144,86 @@ def online_indexed_fp32( logits_2d, token_ids_1d, row_indices_1d ) return results.view(orig_shape) + + +class DeterministicLogpCUDAOp(FusedLogpGenericOp): + """Batch-invariant deterministic CUDA LogP. + + The default call path returns float32 output so tests and downstream KL + code observe the fixed reduction result before any lower-precision cast. + """ + + is_batch_invariant = True + + def __init__(self): + if not _EXT_AVAILABLE or not hasattr(_C, "deterministic_logp_forward_fp32"): + raise RuntimeError("Deterministic CUDA logp kernel is unavailable.") + self._backend = _C + self.op = self._backend.deterministic_logp_forward_fp32 + logger.info("Successfully linked to precompiled _C.deterministic_logp kernel.") + + def __call__(self, logits: torch.Tensor, token_ids: torch.Tensor) -> torch.Tensor: + return self.apply_fp32(logits, token_ids) + + def apply(self, logits: torch.Tensor, token_ids: torch.Tensor) -> torch.Tensor: + return self.apply_fp32(logits, token_ids) + + def apply_fp32(self, logits: torch.Tensor, token_ids: torch.Tensor) -> torch.Tensor: + logits_2d, token_ids_1d, orig_shape = self._prepare_inputs(logits, token_ids) + results = self._backend.deterministic_logp_forward_fp32(logits_2d, token_ids_1d) + return results.view(orig_shape) + + def out( + self, logits: torch.Tensor, token_ids: torch.Tensor, output: torch.Tensor + ) -> torch.Tensor: + logits_2d, token_ids_1d, orig_shape = self._prepare_inputs(logits, token_ids) + output_1d = self._prepare_output(output, orig_shape) + results = self._backend.deterministic_logp_forward_out(logits_2d, token_ids_1d, output_1d) + return results.view(orig_shape) + + def indexed_out( + self, + logits: torch.Tensor, + token_ids: torch.Tensor, + row_indices: torch.Tensor, + output: torch.Tensor, + ) -> torch.Tensor: + logits_2d, token_ids_1d, orig_shape = self._prepare_inputs(logits, token_ids) + row_indices_1d = self._prepare_indices(row_indices, logits) + output_1d = self._prepare_output(output, orig_shape) + results = self._backend.deterministic_logp_forward_indexed_out( + logits_2d, token_ids_1d, row_indices_1d, output_1d + ) + return results.view(orig_shape) + + def indexed_fp32( + self, logits: torch.Tensor, token_ids: torch.Tensor, row_indices: torch.Tensor + ) -> torch.Tensor: + logits_2d, token_ids_1d, orig_shape = self._prepare_inputs(logits, token_ids) + row_indices_1d = self._prepare_indices(row_indices, logits) + results = self._backend.deterministic_logp_forward_indexed_fp32( + logits_2d, token_ids_1d, row_indices_1d + ) + return results.view(orig_shape) + + def online_out( + self, logits: torch.Tensor, token_ids: torch.Tensor, output: torch.Tensor + ) -> torch.Tensor: + return self.out(logits, token_ids, output) + + def online_fp32(self, logits: torch.Tensor, token_ids: torch.Tensor) -> torch.Tensor: + return self.apply_fp32(logits, token_ids) + + def online_indexed_out( + self, + logits: torch.Tensor, + token_ids: torch.Tensor, + row_indices: torch.Tensor, + output: torch.Tensor, + ) -> torch.Tensor: + return self.indexed_out(logits, token_ids, row_indices, output) + + def online_indexed_fp32( + self, logits: torch.Tensor, token_ids: torch.Tensor, row_indices: torch.Tensor + ) -> torch.Tensor: + return self.indexed_fp32(logits, token_ids, row_indices) diff --git a/rl_engine/kernels/registry.py b/rl_engine/kernels/registry.py index 7d85d26..3e21cde 100644 --- a/rl_engine/kernels/registry.py +++ b/rl_engine/kernels/registry.py @@ -28,6 +28,7 @@ class OpBackend(Enum, metaclass=_KernelEnumMeta): # TMA-accelerated LogP for SM90+ (Warp Specialization) CUDA_FUSED_LOGP_SM90 = "rl_engine.kernels.ops.cuda.loss.logp.FusedLogpSM90Op" CUDA_FUSED_LOGP_GENERIC = "rl_engine.kernels.ops.cuda.loss.logp.FusedLogpGenericOp" + CUDA_DETERMINISTIC_LOGP = "rl_engine.kernels.ops.cuda.loss.logp.DeterministicLogpCUDAOp" # AMD ROCm optimized stack ROCM_AITER = "rl_engine.kernels.ops.rocm.aiter.AiterOp" @@ -42,6 +43,53 @@ class OpBackend(Enum, metaclass=_KernelEnumMeta): PYTORCH_NATIVE = "rl_engine.kernels.ops.pytorch.loss.logp.NativeLogpOp" +def resolve_logp_op_type( + logp_backend: Optional[str] = None, + *, + require_batch_invariant: bool = False, +) -> str: + """Normalize user-facing logp backend names to KernelRegistry op types.""" + + normalized = (logp_backend or "auto").strip().lower().replace("-", "_") + aliases = { + "auto": "logp", + "default": "logp", + "fused": "logp", + "fused_logp": "logp", + "generic": "logp", + "logp": "logp", + "indexed": "logp_indexed", + "logp_indexed": "logp_indexed", + "online": "logp_online", + "logp_online": "logp_online", + "online_indexed": "logp_online_indexed", + "logp_online_indexed": "logp_online_indexed", + "deterministic": "logp_deterministic", + "deterministic_cuda": "logp_deterministic", + "batch_invariant": "logp_deterministic", + "batch_invariant_deterministic": "logp_deterministic", + "logp_deterministic": "logp_deterministic", + "deterministic_indexed": "logp_deterministic_indexed", + "deterministic_cuda_indexed": "logp_deterministic_indexed", + "batch_invariant_indexed": "logp_deterministic_indexed", + "logp_deterministic_indexed": "logp_deterministic_indexed", + } + if normalized not in aliases: + valid = ", ".join(sorted(aliases)) + raise ValueError(f"unsupported logp backend {logp_backend!r}; valid values: {valid}") + + op_type = aliases[normalized] + if require_batch_invariant: + if normalized in {"auto", "default", "logp"}: + return "logp_deterministic" + if not op_type.startswith("logp_deterministic"): + raise ValueError( + "require_batch_invariant_logp=True requires a deterministic logp backend; " + f"got {logp_backend!r}" + ) + return op_type + + class KernelRegistry: """ Central dispatcher for high-performance kernels. @@ -72,17 +120,29 @@ def __init__(self): OpBackend.CUDA_FUSED_LOGP_GENERIC, OpBackend.PYTORCH_NATIVE, ], + "logp_deterministic": [ + OpBackend.CUDA_DETERMINISTIC_LOGP, + OpBackend.PYTORCH_NATIVE, + ], + "logp_deterministic_indexed": [ + OpBackend.CUDA_DETERMINISTIC_LOGP, + OpBackend.PYTORCH_NATIVE, + ], "attn": [OpBackend.FLASH_ATTN, OpBackend.TRITON_GENERIC, OpBackend.PYTORCH_NATIVE], "grpo_loss": [OpBackend.TRITON_GRPO_LOSS, OpBackend.PYTORCH_GRPO_LOSS], # Default dispatch logic for new operators }, "rocm": { "logp": [OpBackend.ROCM_AITER, OpBackend.TRITON_GENERIC, OpBackend.PYTORCH_NATIVE], + "logp_deterministic": [OpBackend.PYTORCH_NATIVE], + "logp_deterministic_indexed": [OpBackend.PYTORCH_NATIVE], "attn": [OpBackend.TRITON_GENERIC, OpBackend.PYTORCH_NATIVE], "grpo_loss": [OpBackend.TRITON_GRPO_LOSS, OpBackend.PYTORCH_GRPO_LOSS], }, "cpu": { "logp": [OpBackend.PYTORCH_NATIVE], + "logp_deterministic": [OpBackend.PYTORCH_NATIVE], + "logp_deterministic_indexed": [OpBackend.PYTORCH_NATIVE], "attn": [OpBackend.PYTORCH_NATIVE], "grpo_loss": [OpBackend.PYTORCH_GRPO_LOSS], }, diff --git a/setup.py b/setup.py index 8e60813..650536e 100644 --- a/setup.py +++ b/setup.py @@ -1,11 +1,26 @@ # SPDX-License-Identifier: Apache-2.0 # Copyright (c) 2026 RL-Kernel Contributors +import importlib.util import os +from pathlib import Path from setuptools import find_packages, setup +def _load_envs_module(): + envs_path = Path(__file__).with_name("envs.py") + spec = importlib.util.spec_from_file_location("_rl_kernel_envs", envs_path) + if spec is None or spec.loader is None: + raise RuntimeError(f"failed to load environment helpers from {envs_path}") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +envs = _load_envs_module() + + def _load_torch_extension_tools(): try: import torch @@ -57,11 +72,14 @@ def get_extensions(): cuda_sources = [ "csrc/ops.cpp", "csrc/fused_logp_kernel.cu", + "csrc/deterministic_logp_kernel.cu", "csrc/cuda/attention/prefix_shared_attention.cu", ] cc_major, cc_minor = torch.cuda.get_device_capability() - nvcc_flags = ["-O3", "--use_fast_math", "-Xfatbin", "-compress-all"] + nvcc_flags = ["-O3", "-Xfatbin", "-compress-all"] + if envs.env_flag(envs.KERNEL_ALIGN_USE_FAST_MATH): + nvcc_flags.append("--use_fast_math") nvcc_flags.extend( _cuda_define_from_env( "FUSED_LOGP_TWOPASS_BLOCK_SIZE", @@ -104,14 +122,17 @@ def get_extensions(): "FUSED_LOGP_ONLINE_MIN_BLOCKS_PER_SM", ) ) - if os.environ.get("KERNEL_ALIGN_NCU_LINEINFO") == "1": + if envs.env_flag(envs.KERNEL_ALIGN_NCU_LINEINFO): nvcc_flags.append("-lineinfo") + if os.name == "nt" and envs.env_flag(envs.KERNEL_ALIGN_ALLOW_UNSUPPORTED_MSVC): + nvcc_flags.append("-allow-unsupported-compiler") + nvcc_flags.append("-D_ALLOW_COMPILER_AND_STL_VERSION_MISMATCH") cxx_flags = ["-O3", "-std=c++17", "-DKERNEL_ALIGN_WITH_CUDA"] extra_link_args = [] tma_src = "csrc/cuda/fused_logp_sm90.cu" - enable_sm90 = os.environ.get("KERNEL_ALIGN_FORCE_SM90") == "1" + enable_sm90 = envs.env_flag(envs.KERNEL_ALIGN_FORCE_SM90) if enable_sm90 and os.path.exists(tma_src): tma_arch = f"{cc_major}{cc_minor}a" cuda_sources.append(tma_src) diff --git a/tests/test_deepspeed_training_worker.py b/tests/test_deepspeed_training_worker.py index 49b54e3..e111f0d 100644 --- a/tests/test_deepspeed_training_worker.py +++ b/tests/test_deepspeed_training_worker.py @@ -123,6 +123,26 @@ def fail_import(name, package=None): deepspeed_trainer.DeepSpeedTrainingWorker() +def test_training_config_accepts_batch_invariant_logp_requirement(): + from rl_engine.executors.training_contract import TorchRLTrainingConfig + + TorchRLTrainingConfig(require_batch_invariant_logp=True) + TorchRLTrainingConfig( + logp_backend="deterministic", + require_batch_invariant_logp=True, + ) + + +def test_training_config_rejects_non_deterministic_batch_invariant_backend(): + from rl_engine.executors.training_contract import TorchRLTrainingConfig + + with pytest.raises(ValueError, match="requires a deterministic logp backend"): + TorchRLTrainingConfig( + logp_backend="online", + require_batch_invariant_logp=True, + ) + + def test_deepspeed_loader_preserves_explicit_cuda_home(monkeypatch): import torch.utils.cpp_extension as cpp_extension diff --git a/tests/test_deterministic_logp.py b/tests/test_deterministic_logp.py new file mode 100644 index 0000000..7427408 --- /dev/null +++ b/tests/test_deterministic_logp.py @@ -0,0 +1,569 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +from pathlib import Path + +import pytest +import torch + +from rl_engine.kernels.registry import kernel_registry + +CUDA_SHAPE_CASES = ( + pytest.param(1, 1, 1, id="single-token-vocab"), + pytest.param(1, 3, 2, id="tiny-vocab"), + pytest.param(2, 5, 31, id="below-warp"), + pytest.param(2, 7, 32, id="one-warp"), + pytest.param(3, 4, 33, id="above-warp"), + pytest.param(3, 3, 127, id="below-small-bucket"), + pytest.param(3, 3, 128, id="small-bucket-boundary"), + pytest.param(3, 3, 129, id="medium-bucket-start"), + pytest.param(4, 3, 255, id="below-block"), + pytest.param(4, 5, 256, id="one-block"), + pytest.param(4, 5, 257, id="above-block"), + pytest.param(2, 6, 1024, id="multi-block-stride"), + pytest.param(2, 3, 4095, id="below-medium-boundary"), + pytest.param(2, 3, 4096, id="medium-bucket-boundary"), + pytest.param(2, 3, 4097, id="large-bucket-start"), + pytest.param(2, 3, 4099, id="large-prime-vocab"), + pytest.param(1, 2, 8192, id="large-power-two-vocab"), +) + +CUDA_BUCKET_BOUNDARY_VOCABS = (128, 129, 4096, 4097) + + +def _tensor_bytes(tensor: torch.Tensor) -> bytes: + return tensor.detach().cpu().contiguous().numpy().tobytes() + + +def _assert_bitwise_equal(actual: torch.Tensor, expected: torch.Tensor) -> None: + assert actual.shape == expected.shape + assert actual.dtype == expected.dtype + assert _tensor_bytes(actual) == _tensor_bytes(expected) + + +def _deterministic_cuda_op(): + try: + op = kernel_registry.get_op("logp_deterministic") + except RuntimeError as exc: + pytest.skip(f"deterministic logp backend is unavailable: {exc}") + if op.__class__.__name__ != "DeterministicLogpCUDAOp": + pytest.skip("deterministic CUDA logp extension is not compiled") + return op + + +def _skip_if_cuda_dtype_unavailable(dtype: torch.dtype) -> None: + if dtype is torch.bfloat16 and not torch.cuda.is_bf16_supported(): + pytest.skip("CUDA device does not support bfloat16") + + +def _dtype_tolerance(dtype: torch.dtype) -> float: + if dtype is torch.float16: + return 2e-3 + if dtype is torch.bfloat16: + return 2e-2 + if dtype is torch.float64: + return 1e-5 + return 1e-4 + + +def _reference_selected_logp(logits: torch.Tensor, token_ids: torch.Tensor) -> torch.Tensor: + ref = torch.log_softmax(logits.float(), dim=-1) + return torch.gather(ref, dim=-1, index=token_ids.long().unsqueeze(-1)).squeeze(-1) + + +def _assert_close_to_reference( + actual: torch.Tensor, + logits: torch.Tensor, + token_ids: torch.Tensor, + *, + output_dtype: torch.dtype = torch.float32, +) -> None: + expected = _reference_selected_logp(logits, token_ids).to(output_dtype) + tolerance = _dtype_tolerance(output_dtype) + assert torch.allclose( + actual.float(), + expected.float(), + atol=tolerance, + rtol=tolerance, + ) + + +def _make_target(device: torch.device, dtype: torch.dtype, seq_len: int, vocab_size: int): + generator = torch.Generator(device=device).manual_seed(1234) + logits = torch.randn( + seq_len, + vocab_size, + device=device, + dtype=dtype, + generator=generator, + ) + token_ids = torch.randint( + 0, + vocab_size, + (seq_len,), + device=device, + dtype=torch.long, + generator=generator, + ) + return logits, token_ids + + +def _pack_target( + target_logits: torch.Tensor, + target_ids: torch.Tensor, + *, + batch_size: int, + position: int, + seed: int, +): + generator = torch.Generator(device=target_logits.device).manual_seed(seed) + seq_len, vocab_size = target_logits.shape + logits = torch.randn( + batch_size, + seq_len, + vocab_size, + device=target_logits.device, + dtype=target_logits.dtype, + generator=generator, + ) + token_ids = torch.randint( + 0, + vocab_size, + (batch_size, seq_len), + device=target_logits.device, + dtype=torch.long, + generator=generator, + ) + logits[position].copy_(target_logits) + token_ids[position].copy_(target_ids) + return logits, token_ids + + +def test_deterministic_logp_source_locks_reduction_contract(): + source = Path(__file__).resolve().parents[1] / "csrc" / "deterministic_logp_kernel.cu" + text = source.read_text(encoding="utf-8") + + assert "kDeterministicLogpSmallBlockSize = 128" in text + assert "kDeterministicLogpMediumBlockSize = 256" in text + assert "kDeterministicLogpLargeBlockSize = 512" in text + assert "kDeterministicLogpSmallVocabLimit = 128" in text + assert "kDeterministicLogpMediumVocabLimit = 4096" in text + assert "vocab_size <= kDeterministicLogpSmallVocabLimit" in text + assert "vocab_size <= kDeterministicLogpMediumVocabLimit" in text + assert "shared[lane]" not in text + assert "shared[shared_idx]" in text + assert "duplicate row ids" in text + assert "writes are idempotent" in text + assert "atomicAdd" not in text + assert "cub::BlockReduce" not in text + assert "select_deterministic" not in text + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") +@pytest.mark.parametrize("batch_size,seq_len,vocab_size", CUDA_SHAPE_CASES) +@pytest.mark.parametrize( + "dtype", + (torch.float16, torch.bfloat16, torch.float32), + ids=("fp16", "bf16", "fp32"), +) +def test_deterministic_logp_shape_dtype_matrix_cuda( + dtype: torch.dtype, + batch_size: int, + seq_len: int, + vocab_size: int, +): + _skip_if_cuda_dtype_unavailable(dtype) + op = _deterministic_cuda_op() + device = torch.device("cuda") + generator = torch.Generator(device=device).manual_seed( + 1000 + batch_size * 101 + seq_len * 17 + vocab_size + ) + logits = torch.randn( + batch_size, + seq_len, + vocab_size, + device=device, + dtype=dtype, + generator=generator, + ) + token_ids = torch.randint( + 0, + vocab_size, + (batch_size, seq_len), + device=device, + dtype=torch.long, + generator=generator, + ) + + actual = op.apply_fp32(logits, token_ids) + + assert actual.shape == token_ids.shape + assert actual.dtype == torch.float32 + _assert_close_to_reference(actual, logits, token_ids) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") +def test_deterministic_logp_repeatability_cuda(): + op = _deterministic_cuda_op() + device = torch.device("cuda") + generator = torch.Generator(device=device).manual_seed(2026) + logits = torch.randn(6, 1021, device=device, dtype=torch.float16, generator=generator) + token_ids = torch.randint(0, logits.size(-1), (6,), device=device, dtype=torch.long) + + baseline = op.apply_fp32(logits, token_ids) + for _ in range(20): + actual = op.apply_fp32(logits, token_ids) + torch.cuda.synchronize() + _assert_bitwise_equal(actual, baseline) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") +@pytest.mark.parametrize( + "output_dtype", + (torch.float16, torch.bfloat16, torch.float32, torch.float64), + ids=("fp16", "bf16", "fp32", "fp64"), +) +def test_deterministic_logp_out_dtype_matrix_reuses_storage_cuda(output_dtype: torch.dtype): + _skip_if_cuda_dtype_unavailable(output_dtype) + op = _deterministic_cuda_op() + device = torch.device("cuda") + generator = torch.Generator(device=device).manual_seed(303) + logits = torch.randn(3, 4, 257, device=device, dtype=torch.float16, generator=generator) + token_ids = torch.randint(0, logits.size(-1), (3, 4), device=device, dtype=torch.long) + output = torch.full(token_ids.shape, 123.0, device=device, dtype=output_dtype) + + actual = op.out(logits, token_ids, output) + + assert actual.data_ptr() == output.data_ptr() + assert actual.dtype == output_dtype + _assert_close_to_reference(actual, logits, token_ids, output_dtype=output_dtype) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") +def test_deterministic_logp_non_contiguous_inputs_cuda(): + op = _deterministic_cuda_op() + device = torch.device("cuda") + generator = torch.Generator(device=device).manual_seed(404) + batch_size, seq_len, vocab_size = 3, 5, 129 + base_logits = torch.randn( + batch_size, + seq_len, + vocab_size * 2, + device=device, + dtype=torch.float16, + generator=generator, + ) + logits = base_logits[..., ::2] + base_token_ids = torch.randint( + 0, + vocab_size, + (batch_size, seq_len * 2), + device=device, + dtype=torch.long, + generator=generator, + ) + token_ids = base_token_ids[:, ::2] + + assert not logits.is_contiguous() + assert not token_ids.is_contiguous() + + actual = op.apply_fp32(logits, token_ids) + + _assert_close_to_reference(actual, logits, token_ids) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") +def test_deterministic_logp_batch_size_invariance_cuda(): + op = _deterministic_cuda_op() + device = torch.device("cuda") + target_logits, target_ids = _make_target( + device, + torch.float16, + seq_len=7, + vocab_size=4099, + ) + baseline = op.apply_fp32(target_logits.unsqueeze(0), target_ids.unsqueeze(0))[0] + + for seed, batch_size, position in ( + (11, 1, 0), + (12, 2, 1), + (13, 4, 2), + (14, 8, 5), + (15, 16, 11), + ): + logits, token_ids = _pack_target( + target_logits, + target_ids, + batch_size=batch_size, + position=position, + seed=seed, + ) + actual = op.apply_fp32(logits, token_ids)[position] + _assert_bitwise_equal(actual, baseline) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") +def test_deterministic_logp_batch_position_invariance_cuda(): + op = _deterministic_cuda_op() + device = torch.device("cuda") + target_logits, target_ids = _make_target( + device, + torch.float16, + seq_len=5, + vocab_size=2053, + ) + baseline = op.apply_fp32(target_logits.unsqueeze(0), target_ids.unsqueeze(0))[0] + + for position in range(8): + logits, token_ids = _pack_target( + target_logits, + target_ids, + batch_size=8, + position=position, + seed=100 + position, + ) + actual = op.apply_fp32(logits, token_ids)[position] + _assert_bitwise_equal(actual, baseline) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") +@pytest.mark.parametrize("vocab_size", CUDA_BUCKET_BOUNDARY_VOCABS) +def test_deterministic_logp_bucket_boundaries_are_batch_and_indexed_invariant_cuda( + vocab_size: int, +): + op = _deterministic_cuda_op() + device = torch.device("cuda") + target_logits, target_ids = _make_target( + device, + torch.float16, + seq_len=6, + vocab_size=vocab_size, + ) + baseline = op.apply_fp32(target_logits.unsqueeze(0), target_ids.unsqueeze(0))[0] + + for seed, batch_size, position in ( + (210 + vocab_size, 1, 0), + (220 + vocab_size, 2, 1), + (230 + vocab_size, 8, 3), + (240 + vocab_size, 16, 9), + ): + logits, token_ids = _pack_target( + target_logits, + target_ids, + batch_size=batch_size, + position=position, + seed=seed, + ) + dense = op.apply_fp32(logits, token_ids) + _assert_bitwise_equal(dense[position], baseline) + + row_start = position * target_ids.numel() + row_indices = torch.arange( + row_start, + row_start + target_ids.numel(), + device=device, + dtype=torch.long, + ) + indexed = op.indexed_fp32(logits, token_ids, row_indices) + indexed_flat = indexed.reshape(-1) + _assert_bitwise_equal(indexed_flat[row_indices], baseline) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") +def test_deterministic_logp_ignores_batch_noise_bitwise_cuda(): + op = _deterministic_cuda_op() + device = torch.device("cuda") + target_logits, target_ids = _make_target( + device, + torch.float16, + seq_len=11, + vocab_size=769, + ) + baseline = op.apply_fp32(target_logits.unsqueeze(0), target_ids.unsqueeze(0))[0] + + for seed in range(20, 30): + logits, token_ids = _pack_target( + target_logits, + target_ids, + batch_size=32, + position=seed % 32, + seed=seed, + ) + logits.add_(torch.randn_like(logits) * 0.01) + token_ids.random_(0, logits.size(-1)) + logits[seed % 32].copy_(target_logits) + token_ids[seed % 32].copy_(target_ids) + + actual = op.apply_fp32(logits, token_ids)[seed % 32] + _assert_bitwise_equal(actual, baseline) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") +def test_deterministic_logp_indexed_matches_dense_bits_cuda(): + op = _deterministic_cuda_op() + device = torch.device("cuda") + generator = torch.Generator(device=device).manual_seed(707) + logits = torch.randn(4, 5, 1031, device=device, dtype=torch.float16, generator=generator) + token_ids = torch.randint(0, logits.size(-1), (4, 5), device=device, dtype=torch.long) + dense = op.apply_fp32(logits, token_ids) + dense_flat = dense.reshape(-1) + target_row = 7 + target_baseline = None + + index_sets = ( + torch.tensor([target_row], device=device, dtype=torch.long), + torch.tensor([0, 3, target_row, 11, 19], device=device, dtype=torch.long), + torch.arange(dense_flat.numel(), device=device, dtype=torch.long), + ) + + for row_indices in index_sets: + indexed = op.indexed_fp32(logits, token_ids, row_indices) + indexed_flat = indexed.reshape(-1) + + _assert_bitwise_equal(indexed_flat[row_indices], dense_flat[row_indices]) + + active_mask = torch.zeros(dense_flat.numel(), device=device, dtype=torch.bool) + active_mask[row_indices] = True + assert torch.equal(indexed_flat[~active_mask], torch.zeros_like(indexed_flat[~active_mask])) + + current_target = indexed_flat[target_row : target_row + 1] + if target_baseline is None: + target_baseline = current_target.clone() + else: + _assert_bitwise_equal(current_target, target_baseline) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") +def test_deterministic_logp_indexed_out_preserves_inactive_rows_cuda(): + op = _deterministic_cuda_op() + device = torch.device("cuda") + generator = torch.Generator(device=device).manual_seed(909) + logits = torch.randn(3, 4, 263, device=device, dtype=torch.float16, generator=generator) + token_ids = torch.randint(0, logits.size(-1), (3, 4), device=device, dtype=torch.long) + dense = op.apply_fp32(logits, token_ids).reshape(-1) + sentinel = torch.tensor(123.0, device=device, dtype=torch.float32) + + output = torch.full(token_ids.shape, sentinel.item(), device=device, dtype=torch.float32) + empty_indices = torch.empty(0, device=device, dtype=torch.long) + empty_result = op.indexed_out(logits, token_ids, empty_indices, output) + + assert empty_result.data_ptr() == output.data_ptr() + assert torch.equal(empty_result, torch.full_like(empty_result, sentinel.item())) + + row_indices = torch.tensor([7, 0, 11, 7, 4], device=device, dtype=torch.long) + output.fill_(sentinel.item()) + indexed = op.indexed_out(logits, token_ids, row_indices, output).reshape(-1) + active = torch.unique(row_indices) + inactive_mask = torch.ones_like(indexed, dtype=torch.bool) + inactive_mask[active] = False + + _assert_bitwise_equal(indexed[active], dense[active]) + assert torch.equal( + indexed[inactive_mask], + torch.full_like(indexed[inactive_mask], sentinel.item()), + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") +def test_deterministic_logp_indexed_fp32_empty_indices_zero_fills_cuda(): + op = _deterministic_cuda_op() + device = torch.device("cuda") + logits = torch.randn(2, 3, 17, device=device, dtype=torch.float16) + token_ids = torch.randint(0, logits.size(-1), (2, 3), device=device, dtype=torch.long) + row_indices = torch.empty(0, device=device, dtype=torch.long) + + actual = op.indexed_fp32(logits, token_ids, row_indices) + + assert actual.dtype == torch.float32 + assert torch.equal(actual, torch.zeros_like(actual)) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") +def test_deterministic_logp_invalid_token_ids_zero_fill_cuda(): + op = _deterministic_cuda_op() + device = torch.device("cuda") + vocab_size = 17 + logits = torch.randn(1, 5, vocab_size, device=device, dtype=torch.float16) + token_ids = torch.tensor([[-100, vocab_size, 0, vocab_size - 1, -1]], device=device) + + actual = op.apply_fp32(logits, token_ids) + valid = (token_ids >= 0) & (token_ids < vocab_size) + safe_token_ids = token_ids.clamp(0, vocab_size - 1) + expected = _reference_selected_logp(logits, safe_token_ids) + + assert torch.equal(actual[~valid], torch.zeros_like(actual[~valid])) + assert torch.allclose(actual[valid], expected[valid], atol=2e-3, rtol=2e-3) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") +def test_deterministic_logp_extreme_logits_are_stable_cuda(): + op = _deterministic_cuda_op() + device = torch.device("cuda") + vocab_size = 4099 + rows = 5 + logits = torch.empty(rows, vocab_size, device=device, dtype=torch.float32) + + logits[0].fill_(0.0) + logits[1].fill_(80.0) + logits[2].fill_(-80.0) + logits[3] = torch.linspace(-80.0, 80.0, vocab_size, device=device) + logits[4] = torch.linspace(80.0, -80.0, vocab_size, device=device) + token_ids = torch.tensor([0, vocab_size - 1, vocab_size // 2, vocab_size - 1, 0], device=device) + + actual = op.apply_fp32(logits, token_ids) + expected = _reference_selected_logp(logits, token_ids) + + assert torch.isfinite(actual).all() + assert torch.allclose(actual, expected, atol=1e-4, rtol=1e-4) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") +def test_deterministic_logp_rejects_bad_shapes_and_output_dtype_cuda(): + op = _deterministic_cuda_op() + device = torch.device("cuda") + logits = torch.randn(2, 3, 17, device=device, dtype=torch.float16) + token_ids = torch.randint(0, logits.size(-1), (2, 3), device=device, dtype=torch.long) + + with pytest.raises(RuntimeError, match="token_ids length must match logits rows"): + op.apply_fp32(logits, token_ids[:, :2]) + + with pytest.raises(ValueError, match="output shape"): + op.out(logits, token_ids, torch.empty(2, 2, device=device, dtype=torch.float32)) + + with pytest.raises(RuntimeError, match="output dtype"): + op.out(logits, token_ids, torch.empty(token_ids.shape, device=device, dtype=torch.int32)) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") +def test_deterministic_logp_out_of_range_indices_do_not_overwrite_output_cuda(): + op = _deterministic_cuda_op() + device = torch.device("cuda") + logits = torch.randn(2, 3, 29, device=device, dtype=torch.float16) + token_ids = torch.randint(0, logits.size(-1), (2, 3), device=device, dtype=torch.long) + dense = op.apply_fp32(logits, token_ids).reshape(-1) + output = torch.full(token_ids.shape, -77.0, device=device, dtype=torch.float32) + row_indices = torch.tensor([-1, 2, 999], device=device, dtype=torch.long) + + actual = op.indexed_out(logits, token_ids, row_indices, output).reshape(-1) + + assert actual[2].item() == pytest.approx(dense[2].item()) + inactive = torch.ones_like(actual, dtype=torch.bool) + inactive[2] = False + assert torch.equal(actual[inactive], torch.full_like(actual[inactive], -77.0)) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") +@pytest.mark.parametrize("dtype", (torch.float16, torch.float32)) +def test_deterministic_logp_matches_reference_tolerance_cuda(dtype: torch.dtype): + op = _deterministic_cuda_op() + device = torch.device("cuda") + generator = torch.Generator(device=device).manual_seed(808) + logits = torch.randn(3, 4, 257, device=device, dtype=dtype, generator=generator) + token_ids = torch.randint(0, logits.size(-1), (3, 4), device=device, dtype=torch.long) + + actual = op.apply_fp32(logits, token_ids) + ref = torch.log_softmax(logits.float(), dim=-1) + ref = torch.gather(ref, dim=-1, index=token_ids.unsqueeze(-1)).squeeze(-1) + + tolerance = 2e-3 if dtype is torch.float16 else 1e-4 + assert torch.allclose(actual, ref, atol=tolerance, rtol=tolerance) diff --git a/tests/test_envs.py b/tests/test_envs.py new file mode 100644 index 0000000..6d7c9eb --- /dev/null +++ b/tests/test_envs.py @@ -0,0 +1,34 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +import pytest + +import envs + + +@pytest.mark.parametrize("value", ("1", "true", "TRUE", "yes", "on")) +def test_env_flag_truthy_values(monkeypatch, value): + monkeypatch.setenv("RL_KERNEL_TEST_FLAG", value) + + assert envs.env_flag("RL_KERNEL_TEST_FLAG") + + +@pytest.mark.parametrize("value", ("0", "false", "FALSE", "no", "off")) +def test_env_flag_falsey_values(monkeypatch, value): + monkeypatch.setenv("RL_KERNEL_TEST_FLAG", value) + + assert not envs.env_flag("RL_KERNEL_TEST_FLAG", default=True) + + +def test_env_flag_uses_default_for_missing_value(monkeypatch): + monkeypatch.delenv("RL_KERNEL_TEST_FLAG", raising=False) + + assert envs.env_flag("RL_KERNEL_TEST_FLAG", default=True) + assert not envs.env_flag("RL_KERNEL_TEST_FLAG", default=False) + + +def test_env_flag_rejects_ambiguous_values(monkeypatch): + monkeypatch.setenv("RL_KERNEL_TEST_FLAG", "maybe") + + with pytest.raises(ValueError, match="RL_KERNEL_TEST_FLAG"): + envs.env_flag("RL_KERNEL_TEST_FLAG") diff --git a/tests/test_grpo_single_gpu_example.py b/tests/test_grpo_single_gpu_example.py index e13e658..8ba2ca0 100644 --- a/tests/test_grpo_single_gpu_example.py +++ b/tests/test_grpo_single_gpu_example.py @@ -6,6 +6,8 @@ import sys from pathlib import Path +from examples.grpo_single_gpu import is_fused_logp_backend + REPO_ROOT = Path(__file__).resolve().parents[1] @@ -79,3 +81,18 @@ def test_grpo_single_gpu_example_require_fused_rejects_cpu_fallback(): assert result.returncode != 0 assert "--require-fused-logp was set" in result.stderr + + +def test_grpo_single_gpu_fused_backend_detection_uses_capability_flag(): + class DummyBackend: + is_fused_logp = True + + class RenamedBackend: + is_fused_logp = True + + class PlainBackend: + pass + + assert is_fused_logp_backend(DummyBackend()) + assert is_fused_logp_backend(RenamedBackend()) + assert not is_fused_logp_backend(PlainBackend()) diff --git a/tests/test_vllm_rollout_sampler.py b/tests/test_vllm_rollout_sampler.py index ab3cc02..810fb66 100644 --- a/tests/test_vllm_rollout_sampler.py +++ b/tests/test_vllm_rollout_sampler.py @@ -287,6 +287,23 @@ def test_rollout_executor_defers_vllm_sampler_config_validation(): executor.generate_candidates(["prompt-a"]) +def test_rollout_executor_accepts_deterministic_logp_backend_alias(): + executor = RolloutExecutor({"backend": "not-vllm", "logp_backend": "deterministic"}) + + assert executor.logp_op_type == "logp_deterministic" + + +def test_rollout_executor_rejects_non_deterministic_required_logp_backend(): + with pytest.raises(ValueError, match="requires a deterministic logp backend"): + RolloutExecutor( + { + "backend": "not-vllm", + "logp_backend": "online", + "require_batch_invariant_logp": True, + } + ) + + def test_rollout_executor_defaults_to_cuda_vmm_manifest_bridge(): executor = RolloutExecutor()