Skip to content

Minimal code to reproduce bi-level, alternating, and joint optimization of trees and sequences#1

Open
ramithuh wants to merge 9 commits intomainfrom
minimal
Open

Minimal code to reproduce bi-level, alternating, and joint optimization of trees and sequences#1
ramithuh wants to merge 9 commits intomainfrom
minimal

Conversation

@ramithuh
Copy link
Owner

@ramithuh ramithuh commented Mar 5, 2026

Apart from requirement updates, the most notable change is not representing ancestor sequences as a series of 1D dictionaries, but rather as a matrix itself.

We used to loop over ancestors manually.
There were some unused args (metadata) which was removed from compute_loss_optimized

ramithuh and others added 9 commits March 4, 2026 14:31
Consolidates ~2000 lines across 3 duplicated training scripts + modules
into ~1000 lines with shared infrastructure in common/.

Three optimization modes:
- search_alt.py: alternating tree/seq optimization
- search_bilevel.py: bilevel optimization with implicit differentiation
- search_joint.py: joint optimization (single optimizer, both param sets)

Shared code in common/:
- setup.py: arg parsing, metadata, device config, GT generation, param init
- search_loop.py: training loop, visualization, cost tracking
- tree_func.py, sankoff.py, gt_tree_gen.py, vis_utils.py: cleaned modules

Changes from original:
- No wandb dependency (console + saved figures only)
- Auto GPU detection
- Prints both surrogate and hard cost side-by-side
- Compatible with JAX 0.9+ / jaxopt 0.8.5
- Removed WIP code (diff sankoff), dead code, Colab paths

Verified: reproduces paper results (e.g. 16 leaves, seed 43: 1393 = Sankoff optimal)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Old training scripts and modules moved to legacy/ for reference.
New consolidated entry points (search_alt, search_bilevel, search_joint)
and shared common/ module are now at root level.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Replace plotly heatmaps with matplotlib (no Chrome/kaleido needed)
- Remove netgraph and animate_tree (unused)
- Add loss curve plot (surrogate vs hard cost, total loss vs tree constraint)
- Save history.json for cross-method comparison plots
- Fix seq heatmap perf: 2.9s -> 117ms per save

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…e code

- Delete compute_loss (unused legacy function with fix_seqs/fix_tree flags)
- Remove metadata parameter from `compute_loss_optimized`, `compute_detailed_loss_optimized`, and `enforce_graph`
- Remove dead gumbel_noise * 0.0 in update_tree
- Update all callers (`search_alt`, `search_bilevel`, `search_joint`, `search_loop`)
Replace seq_params dict {"0": array, "1": array, ...} with {"s": array}
of shape (init_count, n_ancestors, seq_length, n_letters). This eliminates
the sequential loop in update_seq (single softmax + slice set), removes
generate_vmap_keys helper, and simplifies get_one_tree_and_seq.

~25% speedup at 32 leaves, scaling with tree size.

Note: naive single-call init changes kaiming_normal variance scaling
due to different fan-in/fan-out — fixed in next commit.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Init each ancestor slice separately with its original per-ancestor PRNG
key and shape (init_count, seq_length, n_letters), then jnp.stack. This
preserves the same fan-in/fan-out that kaiming_normal used before
stacking, avoiding ~2.5x smaller init variance that degraded convergence.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant