feat(ws1): Add PyTorch matmul reference operator#168
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
📝 WalkthroughWalkthroughAdds ChangesNativeMatmulOp Feature
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary
Adds the PyTorch reference GEMM/Matmul operator for Issue #108.
This implements
NativeMatmulOpas the fp32 ground-truth baseline for dense projectionmatmuls. The operator follows the frozen #108 interface contract:
forward_fp32(a, b)casts inputs to fp32 and callstorch.matmulonceop_class = "reduction"kernel_registry.get_op("matmul")Also adds operator documentation under
docs/operators/matmul.md.Implementation
rl_engine/kernels/ops/pytorch/linear/matmul.pyrl_engine/kernels/ops/pytorch/linear/__init__.pyPYTORCH_NATIVE_MATMULinrl_engine/kernels/registry.pytests/test_matmul.pydocs/operators/README.mdSummary by CodeRabbit
New Features
matmuloperator with a native PyTorch reference implementation.matmuldispatch on CPU, CUDA, and ROCm.Documentation
matmuloperator documentation page.matmul.Tests
matmultests covering correctness, dtype behavior (including FP32 casting), supported shapes, batch invariance, and registry integration.