Short sequence prefix-invariant evo2 implementation#1580
Conversation
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
|
Important Review skippedAuto reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Plus Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
farhadrgh
left a comment
There was a problem hiding this comment.
Two bugs, two questions, and the last section flags that this PR regresses subq-ops inference support that already landed in #1565.
Bugs
1. hyena_utils.fftconv_func fix is incomplete, bidirectional path still broken.
The fix lands only inside the else: # causal branch:
if use_subquadratic_ops:
y = fft_causal_conv1d(u, k.squeeze(0))
else:
fft_size = max(fft_size, 2 * k.shape[-1]) # <-- only here
k_f = torch.fft.rfft(k, n=fft_size) / fft_sizeThe if bidirectional: branch immediately above still does torch.fft.rfft(k, n=fft_size) with the original fft_size = 2 * seqlen. Same truncation bug if anyone runs the bidirectional path with seqlen < K. Suggest hoisting the max(...) line to right after fft_size = 2 * seqlen so both branches benefit.
2. The short-filter causal_conv1d subq path was reverted, but the xfail only covers the fused B2B path.
Two separate code paths got removed in this PR, but the xfail (test_b2b_causal_conv1d_module_matches_sequential_reference) only documents one:
engine.parallel_firlost itsif use_subquadratic_ops: _subq_causal_conv1d(...)arm in the< 128branch.ParallelCausalDepthwiseConv1d.forwardnow always usescausal_conv1d_fninstead of dispatching to subq whenuse_subquadratic_ops=True.
Neither of those is the fused B2B kernel, they're plain depthwise short-filter convolutions. Is the issue actually with subq's causal_conv1d under causal-conv1d 1.6+, or did these get caught in the same revert? If it's the latter, worth keeping them — they're the easier speedup with no fusion semantics to verify.
Questions
3. @torch.compile removed from ImplicitModalFilter.filter does the comment refer to a specific reproducer? A pointer in the comment would help future readers, and if the bad-interaction scope is narrow we may be able to keep @torch.compile with dynamic=False or wrap the offending call site in torch.compiler.disable instead of dropping it altogether.
4. hyena_block.py variable-arity get_cpu_offload_context call, clean fix for the 6-vs-7-arg drift, but len(inspect.signature(...).parameters) is a brittle proxy (it counts a *args parameter as 1, which would silently break the slice). Worth a # tied to MCore <= 0.x note so future readers know to revisit if MCore changes the signature again.
Regression of #1565 (already on main)
This PR removes the two inference subq-ops code paths that landed in #1565 (merged 2026-04-30):
engine.parallel_firshort branch: theif use_subquadratic_ops: _subq_causal_conv1d(...)arm from #1565 is removed (item 2 above).HyenaMixer.forwardprefill: #1565 added_populate_b2b_inference_stateand gated the fused b2b kernel onuse_subquadratic_ops. This PR forces the gate off viaself.use_fused_b2b_causal_conv1d = False(hardcoded), so the fused path can never fire even when the user passes--use-subquadratic-ops. This also disables the original training andpredict_evo2b2b path that predates #1565.
Net effect for infer_evo2 --use-subquadratic-ops after this PR lands:
- The flag still routes long-filter FFT convs through subq-ops (
_subq_fft_causal_conv1d), so the existingtest_subquadratic_ops_matches_baselinecorrectness test will still pass. - But the short-filter and fused-B2B prefill paths are gone, so the measured ~15% prefill speedup at 8K prompt on the 1B model (single A6000) goes back to zero. Users get the CLI flag without the performance it was added for.
I get why this is happening, the xfail in test_hyena_utils.py shows the fused B2B kernel doesn't match the reference under causal-conv1d 1.6+. That's a real kernel-side bug. But two things:
(a) The fix for the fused-B2B mismatch shouldn't take out the short-filter causal_conv1d path too. They're independent (see item 2 above). If the subq short-filter kernel is also broken under 1.6+, a passing/failing test would clarify; if it isn't broken, please keep that path.
(b) Disabling the fused B2B path is reasonable as a temporary measure, but hardcoding the flag to False makes the regression permanent until someone re-edits the file. Please make it a real config attribute so it can be flipped back on once subquadratic-ops ships the 1.6+ fix, without another PR. Suggested:
self.use_fused_b2b_causal_conv1d = getattr(
transformer_config, "use_fused_b2b_causal_conv1d", False
)That way #1565's runtime behavior is recoverable via config, and we don't lose the speedup permanently. (And anyone hitting a predict_evo2 perf regression after this lands can re-enable it for the training/predict path independently.)
…nd fail loudly if the CUDA_ERROR_UNSUPPORTED_PTX_VERSION error comes up Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St John <jstjohn@nvidia.com>
…to jstjohn/prefix_invariance_evo2
Signed-off-by: John St John <jstjohn@nvidia.com>
Signed-off-by: John St John <jstjohn@nvidia.com>
|
/ok to test 121c57e |
@jstjohn, there was an error processing your request: See the following link for more information: https://docs.gha-runners.nvidia.com/cpr/e/2/ |
|
/ok to test |
@jstjohn, there was an error processing your request: See the following link for more information: https://docs.gha-runners.nvidia.com/cpr/e/1/ |
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In @.devcontainer/postCreateCommand.sh:
- Around line 4-7: The current install silently ignores failures because of the
"|| true" after the curl/install pipeline; update the block in
.devcontainer/postCreateCommand.sh that runs "curl -fsSL
https://chatgpt.com/codex/install.sh | sh || true" to instead capture the
installer exit status, remove the "|| true", and if the install fails (non-zero
exit) or "command -v codex" still does not find the binary, emit a clear warning
to stderr (e.g., echo to >&2) describing the failure and that codex is not
available; then re-check "command -v codex" after the install attempt and log
the warning if missing so the failure is visible.
In `@bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/predict.py`:
- Around line 1287-1292: The call to register_allowed_target_prefix("bionemo.")
is only in main(), but predict() calls instantiate(run_config["model"]) and may
run before the allowlist is registered; add a small helper (e.g.,
ensure_bionemo_allowed()) that wraps the try/except import of
megatron.bridge.utils.instantiate_utils and calls
register_allowed_target_prefix("bionemo.") if available, then invoke that helper
at the very start of predict() (in addition to leaving the existing call in
main()) so imports of predict() or direct predict() calls always register the
prefix before instantiate() is used.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: 3644a2b5-83a5-4215-bebe-b26858ce3cd1
📒 Files selected for processing (26)
.devcontainer/Dockerfile.devcontainer/devcontainer.json.devcontainer/initializeCommand.sh.devcontainer/postCreateCommand.sh.devcontainer/start.shbionemo-recipes/recipes/evo2_megatron/.ci_build.shbionemo-recipes/recipes/evo2_megatron/build_requirements.txtbionemo-recipes/recipes/evo2_megatron/pyproject.tomlbionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_provider.pybionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/engine.pybionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/fft_utils.pybionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_block.pybionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_layer.pybionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_mixer.pybionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_utils.pybionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/subquadratic_safety.pybionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.pybionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/predict.pybionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/train.pybionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/conftest.pybionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_engine.pybionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_mixer_kernel.pybionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_utils.pybionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_infer.pybionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_predict.pybionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_model_providers.py
💤 Files with no reviewable changes (1)
- bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/conftest.py
|
/ok to test 9d2f598 |
Description
Changes:
Summary by CodeRabbit
Chores
Bug Fixes
Tests