Skip to content

feat(ws1): NativeEmbeddingOp pure-PyTorch ground-truth reference + numerical contract tests#169

Open
maxiaosong1124 wants to merge 2 commits into
RL-Align:mainfrom
maxiaosong1124:feat/ws1-embedding-pytorch-op
Open

feat(ws1): NativeEmbeddingOp pure-PyTorch ground-truth reference + numerical contract tests#169
maxiaosong1124 wants to merge 2 commits into
RL-Align:mainfrom
maxiaosong1124:feat/ws1-embedding-pytorch-op

Conversation

@maxiaosong1124

@maxiaosong1124 maxiaosong1124 commented Jun 22, 2026

Copy link
Copy Markdown
Collaborator

Summary

Adds the pure-PyTorch ground-truth reference op for the token embedding — the
input layer of the WS1 batch-invariant forward chain — built on top of the numerical
contract defined in #108. Ships the op, its registry wiring, docs, and an 11-case test
suite that pins down both alignment axes (Axis-A bitwise batch invariance, Axis-B
per-dtype path), plus a GPU-only smoke test at the real Qwen3-8B table dims.

Refs #108

Terminology

This PR uses the WS1 alignment vocabulary from #108:

  • Axis-A — batch invariance (reproducibility). A token's output row must not depend
    on how many tokens share the batch (batch size, slicing, padding). Asserted bitwise
    (torch.equal). This is what keeps train-time (large batch) and sample-time
    (small batch / dynamic padding) numerics identical so the policy ratio doesn't drift.
  • Axis-B — accuracy. The low-precision (bf16 / fp16) forward path. Embedding is a
    lossless row gather — no reduction, no fp32 accumulation — so the dtype output is
    bitwise equal to direct indexing at every dtype. No tolerance window is needed.

Motivation / Context

#108 establishes the ground-truth harness and numerical contract for the WS1
batch-invariant forward chain. The first stage of the Qwen3-8B stack maps integer token
ids to their hidden-state rows:

hidden = embedding_table[token_ids]

This PR provides the deterministic fp32 reference path that downstream kernels
(Triton / CUDA / ROCm) will be validated against. For Qwen3-8B the table is the input
embedding [vocab=151936, hidden=4096] and is independent from the lm_head weight
(tie_word_embeddings=false) — the two weights are not shared.

Changes

  • rl_engine/kernels/ops/pytorch/linear/embedding.pyNativeEmbeddingOp
    • forward() — native-dtype gather, cast the gathered rows back to the weight dtype (Axis-B path)
    • forward_fp32() — native-dtype gather, upcast the result to fp32 (ground-truth / backward golden source)
    • Formula: out = weight[token_ids] (via F.embedding(token_ids.long(), weight))
    • Pure function — inputs never mutated in place; output dtype follows weight
  • rl_engine/kernels/registry.py — register PYTORCH_NATIVE_EMBEDDING in OpBackend
    and add embedding dispatch to the cuda / rocm / cpu priority maps
  • tests/test_embedding.py — 11 tests (details below)
  • docs/operators/embedding.md + nav / index wiring

How this satisfies the #108 contract

#108 requirement How it's met here
Deterministic reference path forward_fp32() gathers in fp32; tests use fixed-seed torch.Generator so outputs are reproducible
Per-dtype tolerance policy (bitwise vs tight-tolerance) Axis-A asserted bitwise (torch.equal); Axis-B dtype output is a lossless gather, asserted bitwise against direct indexing (a gather needs no tolerance window)
Batch-config sweep / validation helper Batch-invariance checks compute on the full batch, then assert sliced/padded rows are bitwise identical to their full-batch counterparts
Realistic shapes covered GPU-only smoke test at the real Qwen3-8B table dims (vocab=151936, hidden=4096), exercising boundary ids 0 and vocab-1; skips when CUDA / GPU memory is unavailable

Test Environment

OS Ubuntu (kernel 5.15.0-124-generic)
Python 3.12.3
PyTorch 2.8.0+cu128
CUDA / cuDNN 12.8 / 9.10.02 (driver 580.76.05)
GPU NVIDIA vGPU-32GB

Test Results

python -m pytest tests/test_embedding.py -v
cf77909025a5752054fb23be11a2ff4e

The 11 tests cover:

  • correctness vs direct indexing across fp32 / bf16 / fp16, asserted bitwise
  • output shape (token_ids.shape + (hidden,))
  • non-int64 id tolerance (cast via .long())
  • Axis-A batch invariance — slice + padding variants, asserted bitwise
  • purity (neither token_ids nor weight mutated in place)
  • gradient flow to weight (fp32 autograd = backward golden source), including
    sparse-grad: rows never indexed stay exactly zero
  • registry dispatch resolves embeddingNativeEmbeddingOp
  • GPU-only real-shape smoke test (Qwen3-8B vocab=151936, hidden=4096, boundary ids)

Checklist

  • Pure-PyTorch reference, no custom extension required
  • Covered at the real Qwen3-8B table dims (vocab=151936, hidden=4096)
  • Axis-A bitwise batch invariance enforced
  • Axis-B lossless-gather dtype path tested (bitwise, no tolerance window)
  • Registered in OpBackend + cuda/rocm/cpu priority maps
  • All 11 tests pass locally

Summary by CodeRabbit

Release Notes

  • New Features

    • Added a Token Embedding operator for mapping integer token IDs to embedding vectors, with device support (CPU/CUDA/ROCm) using a PyTorch reference backend.
    • Supports dtype-aware forward paths with lossless, bit-exact gather semantics.
  • Documentation

    • Added documentation and navigation entries for the new embedding operator.
  • Tests

    • Added a comprehensive test suite validating correctness across float32/bfloat16/float16, batch invariance, non-mutation, and gradient behavior.

WS1 ground-truth token-embedding op for issue RL-Align#108 (Qwen3-8B input
embedding table, vocab=151936 x hidden=4096, tie_word_embeddings=false):
- NativeEmbeddingOp: out = weight[token_ids], a lossless row gather
  exposing the forward / forward_fp32 dual-path contract (fp32 ground
  truth + dtype-behavior path); pure function, no in-place mutation.
- register PYTORCH_NATIVE_EMBEDDING in OpBackend and the cuda/rocm/cpu
  priority maps.
- tests/test_embedding.py: bitwise correctness vs direct indexing, dtype
  paths, non-int64 id tolerance, Axis-A batch invariance (slice +
  padding), purity, sparse gradient flow to weight, registry dispatch,
  and a GPU-only real-shape smoke test (vocab=151936, boundary ids).
- docs/operators/embedding.md + nav/index wiring.
@coderabbitai

coderabbitai Bot commented Jun 22, 2026

Copy link
Copy Markdown

Review Change Stack

📝 Walkthrough

Walkthrough

A new NativeEmbeddingOp class is added as a pure-PyTorch fp32-gather reference embedding kernel. It is registered in KernelRegistry under the "embedding" operator key for cuda, rocm, and cpu via a new OpBackend.PYTORCH_NATIVE_EMBEDDING enum entry. A 179-line test suite and a full operator documentation page are also added.

Changes

NativeEmbeddingOp

Layer / File(s) Summary
NativeEmbeddingOp implementation and registry wiring
rl_engine/kernels/ops/pytorch/linear/__init__.py, rl_engine/kernels/ops/pytorch/linear/embedding.py, rl_engine/kernels/registry.py
Adds NativeEmbeddingOp with forward (casts output to weight.dtype) and forward_fp32 (keeps fp32 output) paths sharing a _embedding helper. Adds OpBackend.PYTORCH_NATIVE_EMBEDDING enum entry and extends KernelRegistry._priority_map to route "embedding" requests on cuda, rocm, and cpu to that backend.
Test suite
tests/test_embedding.py
Covers bitwise correctness vs direct indexing across fp32/bf16/fp16, output shape, non-int64 id casting, batch and padding invariance, input immutability, gradient flow with zero-grad on unused rows, registry dispatch, and a CUDA-only Qwen3-8B shape smoke test gated by free GPU memory.
Operator documentation and navigation
docs/.nav.yml, docs/operators/README.md, docs/operators/embedding.md
Adds the embedding.md operator reference page documenting the dual-path contract, tensor shapes/dtypes, backend dispatch semantics, accuracy and batch-invariance requirements, test coverage summary, related files, and known limitations. Registers the new page in nav and operator README.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Poem

🐇 A table of tokens, so neatly arrayed,
Each id a row — in perfect cascade!
fp32 golden, bit-exact and true,
The registry knows just what to do.
Hop, gather, cast — the embedding's born,
With tests and docs to greet the morn! 🌟

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 16.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main contribution—a pure-PyTorch ground-truth embedding operator with numerical contract tests—which is the core of the changeset.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

*,
output_dtype: torch.dtype,
) -> torch.Tensor:
out = F.embedding(token_ids.long(), weight.float())

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m worried about this weight.float() path: it upcasts the entire embedding table before gathering, so real fp16/bf16 Qwen3-size weights will allocate an extra multi-GB fp32 copy for a tiny lookup. Since this is the only registered embedding backend today, this could make the default fallback OOM in normal GPU use.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@inaniloquentee Thanks for the advice!
The gather now runs in the weight's native dtype and only the gathered rows are upcast (F.embedding(token_ids.long(), weight).to(output_dtype)), so there's no longer a multi-GB fp32 copy of the full vocab table for a tiny lookup. Since a row gather is lossless (pure indexing, no arithmetic), this is bitwise-identical to the previous path — all 11 tests in tests/test_embedding.py still pass.

Gathering with weight.float() upcast the entire vocab table to fp32 before
the lookup, allocating a multi-GB fp32 copy of the Qwen3-8B embedding table
just for a tiny row gather and risking OOM on the default fallback path.

A row gather is lossless (pure indexing, no arithmetic), so gather in the
weight's native dtype and upcast only the gathered rows -- bitwise-identical
to the previous path. All 11 tests in tests/test_embedding.py still pass.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants