Add @torch.inference_mode() to pipeline __call__ methods#170
Open
garrick99 wants to merge 1 commit intoLightricks:mainfrom
Open
Add @torch.inference_mode() to pipeline __call__ methods#170garrick99 wants to merge 1 commit intoLightricks:mainfrom
garrick99 wants to merge 1 commit intoLightricks:mainfrom
Conversation
…ricks#152) All 7 pipeline classes were missing the decorator on __call__, causing torch to retain autograd graphs when called from Python (not CLI). This leads to OOM — the text encoder's ~37 GB of activations aren't freed before the transformer loads. Only ti2vid_two_stages_hq.py already had the decorator. The main() functions in each file had inference_mode, but __call__ did not — so CLI usage worked but Python API usage OOMed. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Fixes #152 — Pipeline
__call__methods are not wrapped intorch.inference_mode(), causing OOM when called from Python code (not CLI).Problem
All 7 pipeline classes were missing
@torch.inference_mode()on their__call__method. Themain()functions had the decorator, so CLI usage worked fine. But when calling pipelines from Python (the API use case), torch retains autograd graphs — the text encoder's ~37 GB of activations aren't freed before the transformer loads, causing OOM.Only
ti2vid_two_stages_hq.pyalready had the decorator on__call__.Fix
Added
@torch.inference_mode()to__call__in all 7 pipeline classes:a2vid_two_stage.pydistilled.pyic_lora.pykeyframe_interpolation.pyretake.pyti2vid_one_stage.pyti2vid_two_stages.py7 files, 7 lines added. No behavioral change for CLI users (already covered by
main()'s decorator).🤖 Generated with Claude Code