feat(ws1): NativeEmbeddingOp pure-PyTorch ground-truth reference + numerical contract tests#169
Conversation
WS1 ground-truth token-embedding op for issue RL-Align#108 (Qwen3-8B input embedding table, vocab=151936 x hidden=4096, tie_word_embeddings=false): - NativeEmbeddingOp: out = weight[token_ids], a lossless row gather exposing the forward / forward_fp32 dual-path contract (fp32 ground truth + dtype-behavior path); pure function, no in-place mutation. - register PYTORCH_NATIVE_EMBEDDING in OpBackend and the cuda/rocm/cpu priority maps. - tests/test_embedding.py: bitwise correctness vs direct indexing, dtype paths, non-int64 id tolerance, Axis-A batch invariance (slice + padding), purity, sparse gradient flow to weight, registry dispatch, and a GPU-only real-shape smoke test (vocab=151936, boundary ids). - docs/operators/embedding.md + nav/index wiring.
📝 WalkthroughWalkthroughA new ChangesNativeEmbeddingOp
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
| *, | ||
| output_dtype: torch.dtype, | ||
| ) -> torch.Tensor: | ||
| out = F.embedding(token_ids.long(), weight.float()) |
There was a problem hiding this comment.
I’m worried about this weight.float() path: it upcasts the entire embedding table before gathering, so real fp16/bf16 Qwen3-size weights will allocate an extra multi-GB fp32 copy for a tiny lookup. Since this is the only registered embedding backend today, this could make the default fallback OOM in normal GPU use.
There was a problem hiding this comment.
@inaniloquentee Thanks for the advice!
The gather now runs in the weight's native dtype and only the gathered rows are upcast (F.embedding(token_ids.long(), weight).to(output_dtype)), so there's no longer a multi-GB fp32 copy of the full vocab table for a tiny lookup. Since a row gather is lossless (pure indexing, no arithmetic), this is bitwise-identical to the previous path — all 11 tests in tests/test_embedding.py still pass.
Gathering with weight.float() upcast the entire vocab table to fp32 before the lookup, allocating a multi-GB fp32 copy of the Qwen3-8B embedding table just for a tiny row gather and risking OOM on the default fallback path. A row gather is lossless (pure indexing, no arithmetic), so gather in the weight's native dtype and upcast only the gathered rows -- bitwise-identical to the previous path. All 11 tests in tests/test_embedding.py still pass.
Summary
Adds the pure-PyTorch ground-truth reference op for the token embedding — the
input layer of the WS1 batch-invariant forward chain — built on top of the numerical
contract defined in #108. Ships the op, its registry wiring, docs, and an 11-case test
suite that pins down both alignment axes (Axis-A bitwise batch invariance, Axis-B
per-dtype path), plus a GPU-only smoke test at the real Qwen3-8B table dims.
Refs #108
Terminology
This PR uses the WS1 alignment vocabulary from #108:
on how many tokens share the batch (batch size, slicing, padding). Asserted bitwise
(
torch.equal). This is what keeps train-time (large batch) and sample-time(small batch / dynamic padding) numerics identical so the policy ratio doesn't drift.
lossless row gather — no reduction, no fp32 accumulation — so the dtype output is
bitwise equal to direct indexing at every dtype. No tolerance window is needed.
Motivation / Context
#108 establishes the ground-truth harness and numerical contract for the WS1
batch-invariant forward chain. The first stage of the Qwen3-8B stack maps integer token
ids to their hidden-state rows:
This PR provides the deterministic fp32 reference path that downstream kernels
(Triton / CUDA / ROCm) will be validated against. For Qwen3-8B the table is the input
embedding
[vocab=151936, hidden=4096]and is independent from the lm_head weight(
tie_word_embeddings=false) — the two weights are not shared.Changes
rl_engine/kernels/ops/pytorch/linear/embedding.py—NativeEmbeddingOpforward()— native-dtype gather, cast the gathered rows back to the weight dtype (Axis-B path)forward_fp32()— native-dtype gather, upcast the result to fp32 (ground-truth / backward golden source)weightrl_engine/kernels/registry.py— registerPYTORCH_NATIVE_EMBEDDINGinOpBackendand add
embeddingdispatch to the cuda / rocm / cpu priority mapstests/test_embedding.py— 11 tests (details below)docs/operators/embedding.md+ nav / index wiringHow this satisfies the #108 contract
forward_fp32()gathers in fp32; tests use fixed-seedtorch.Generatorso outputs are reproducibletorch.equal); Axis-B dtype output is a lossless gather, asserted bitwise against direct indexing (a gather needs no tolerance window)vocab=151936, hidden=4096), exercising boundary ids0andvocab-1; skips when CUDA / GPU memory is unavailableTest Environment
Test Results
The 11 tests cover:
token_ids.shape + (hidden,)).long())token_idsnorweightmutated in place)weight(fp32 autograd = backward golden source), includingsparse-grad: rows never indexed stay exactly zero
embedding→NativeEmbeddingOpvocab=151936, hidden=4096, boundary ids)Checklist
Summary by CodeRabbit
Release Notes
New Features
Documentation
Tests