Skip to content

Fixing up model wrapping and tracking of metrics code, learning scheduler#707

Open
mtauraso wants to merge 2 commits intomainfrom
mtauraso/fixup-model-vs-wrapped-model
Open

Fixing up model wrapping and tracking of metrics code, learning scheduler#707
mtauraso wants to merge 2 commits intomainfrom
mtauraso/fixup-model-vs-wrapped-model

Conversation

@mtauraso
Copy link
Collaborator

This is a follow-up on #706 which was a targeted fix for the bug reported in dirac slack: https://uw-dirac.slack.com/archives/C08F5FLEY5A/p1771350453909469

This is the more complete/thorough fix.

To avoid putting data/calling functions on the wrong model, the create_* methods now have two local variables "model" and "wrapped_model" In the case where there is no wrapping done by idist.auto_model() these are the same. I've tried to update the relevant accesses, but I want both @drewoldag and @SamSandwich07 to sign off before I merge.

I've also removed local variables of scheduler and optimizer in favor of model.scheduler and model.optimizer which are always on the inner model.

This has also exposed a larger issue that our CI doesn't have real GPUs, and we would have caught the original bug, and perhaps several other issues where we memoize data onto the model if we had CI with even an old and cheap GPU.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes a critical bug in model wrapping for distributed training in Hyrax. The issue stems from PyTorch Ignite's idist.auto_model() which wraps models for distributed execution, but the previous code was accessing optimizer, scheduler, and storing state on the wrapped model instead of the unwrapped model. This caused failures in GPU/distributed scenarios.

Changes:

  • Introduced explicit wrapped_model and model variables in all engine creation functions to distinguish between wrapped (for execution) and unwrapped (for state access) models
  • Changed optimizer and scheduler access from local variables to direct attribute access on the unwrapped model (model.optimizer, model.scheduler)
  • Ensured all model state modifications (metrics, learning rate history) are performed on the unwrapped model

@codecov
Copy link

codecov bot commented Feb 17, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 64.15%. Comparing base (4c93171) to head (5eacc3e).

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #707      +/-   ##
==========================================
- Coverage   64.17%   64.15%   -0.02%     
==========================================
  Files          61       61              
  Lines        5892     5890       -2     
==========================================
- Hits         3781     3779       -2     
  Misses       2111     2111              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@github-actions
Copy link

Before [4c93171] After [f8cca71] Ratio Benchmark (Parameter)
failed failed n/a vector_db_benchmarks.VectorDBInsertBenchmarks.time_load_vector_db(16384, 'qdrant')
6.71±0.01s 7.13±0.09s 1.06 vector_db_benchmarks.VectorDBInsertBenchmarks.time_load_vector_db(256, 'qdrant')
9.31±0.04ms 9.64±0.09ms 1.04 vector_db_benchmarks.VectorDBSearchBenchmarks.time_search_by_vector_many_shards(64, 'chromadb')
12.3±0.01s 12.5±0.1s 1.02 vector_db_benchmarks.VectorDBInsertBenchmarks.time_load_vector_db(2048, 'qdrant')
1.94±0.01s 1.95±0.01s 1.01 benchmarks.time_database_connection_help
200±0.9ms 202±1ms 1.01 benchmarks.time_import
38.1±0.1ms 38.5±0.3ms 1.01 benchmarks.time_nb_obj_construct
1.92±0.01s 1.93±0.01s 1.01 benchmarks.time_prepare_help
1.93±0.01s 1.95±0.01s 1.01 benchmarks.time_train_help
3.91G 3.94G 1.01 vector_db_benchmarks.VectorDBInsertBenchmarks.peakmem_load_vector_db(16384, 'qdrant')

Click here to view all benchmarks.

Copy link
Collaborator

@drewoldag drewoldag left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants

Comments