Minimal code to reproduce bi-level, alternating, and joint optimization of trees and sequences#1
Conversation
Consolidates ~2000 lines across 3 duplicated training scripts + modules into ~1000 lines with shared infrastructure in common/. Three optimization modes: - search_alt.py: alternating tree/seq optimization - search_bilevel.py: bilevel optimization with implicit differentiation - search_joint.py: joint optimization (single optimizer, both param sets) Shared code in common/: - setup.py: arg parsing, metadata, device config, GT generation, param init - search_loop.py: training loop, visualization, cost tracking - tree_func.py, sankoff.py, gt_tree_gen.py, vis_utils.py: cleaned modules Changes from original: - No wandb dependency (console + saved figures only) - Auto GPU detection - Prints both surrogate and hard cost side-by-side - Compatible with JAX 0.9+ / jaxopt 0.8.5 - Removed WIP code (diff sankoff), dead code, Colab paths Verified: reproduces paper results (e.g. 16 leaves, seed 43: 1393 = Sankoff optimal) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Old training scripts and modules moved to legacy/ for reference. New consolidated entry points (search_alt, search_bilevel, search_joint) and shared common/ module are now at root level. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Replace plotly heatmaps with matplotlib (no Chrome/kaleido needed) - Remove netgraph and animate_tree (unused) - Add loss curve plot (surrogate vs hard cost, total loss vs tree constraint) - Save history.json for cross-method comparison plots - Fix seq heatmap perf: 2.9s -> 117ms per save Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…e code - Delete compute_loss (unused legacy function with fix_seqs/fix_tree flags) - Remove metadata parameter from `compute_loss_optimized`, `compute_detailed_loss_optimized`, and `enforce_graph` - Remove dead gumbel_noise * 0.0 in update_tree - Update all callers (`search_alt`, `search_bilevel`, `search_joint`, `search_loop`)
Replace seq_params dict {"0": array, "1": array, ...} with {"s": array}
of shape (init_count, n_ancestors, seq_length, n_letters). This eliminates
the sequential loop in update_seq (single softmax + slice set), removes
generate_vmap_keys helper, and simplifies get_one_tree_and_seq.
~25% speedup at 32 leaves, scaling with tree size.
Note: naive single-call init changes kaiming_normal variance scaling
due to different fan-in/fan-out — fixed in next commit.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Init each ancestor slice separately with its original per-ancestor PRNG key and shape (init_count, seq_length, n_letters), then jnp.stack. This preserves the same fan-in/fan-out that kaiming_normal used before stacking, avoiding ~2.5x smaller init variance that degraded convergence. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Apart from requirement updates, the most notable change is not representing ancestor sequences as a series of 1D dictionaries, but rather as a matrix itself.
We used to loop over ancestors manually.
There were some unused args (
metadata) which was removed fromcompute_loss_optimized