Skip to content

Commit 5889c40

Browse files
committed
Remove compoments_loader.py
Signed-off-by: Rafal Leszko <rafal.leszko@gmail.com>
1 parent 2a91193 commit 5889c40

File tree

2 files changed

+102
-109
lines changed

2 files changed

+102
-109
lines changed

pipelines/streamdiffusionv2/components_loader.py

Lines changed: 0 additions & 108 deletions
This file was deleted.

pipelines/streamdiffusionv2/pipeline.py

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import logging
2+
import os
3+
import time
24

35
import torch
46
from diffusers.modular_pipelines import PipelineState
@@ -7,8 +9,10 @@
79
from ..blending import PromptBlender, handle_transition_prepare
810
from ..interface import Pipeline, Requirements
911
from ..process import postprocess_chunk, preprocess_chunk
10-
from .components_loader import load_stream_component, ComponentProvider
1112
from .modular_blocks import StreamDiffusionV2Blocks
13+
from .vendor.causvid.models.wan.causal_stream_inference import (
14+
CausalStreamInferencePipeline,
15+
)
1216

1317
# https://github.com/daydreamlive/scope/blob/0cf1766186be3802bf97ce550c2c978439f22068/pipelines/streamdiffusionv2/vendor/causvid/models/wan/causal_model.py#L306
1418
MAX_ROPE_FREQ_TABLE_SEQ_LEN = 1024
@@ -23,6 +27,103 @@
2327
logger = logging.getLogger(__name__)
2428

2529

30+
class ComponentProvider:
31+
"""Simple wrapper to provide component access from ComponentsManager to blocks."""
32+
33+
def __init__(self, components_manager: ComponentsManager, component_name: str, collection: str = "streamdiffusionv2"):
34+
"""
35+
Initialize the component provider.
36+
37+
Args:
38+
components_manager: The ComponentsManager instance
39+
component_name: Name of the component to provide
40+
collection: Collection name for retrieving the component
41+
"""
42+
self.components_manager = components_manager
43+
self.component_name = component_name
44+
self.collection = collection
45+
# Cache the component to avoid repeated lookups
46+
self._component = None
47+
48+
@property
49+
def stream(self):
50+
"""Provide access to the stream component."""
51+
if self._component is None:
52+
self._component = self.components_manager.get_one(
53+
name=self.component_name, collection=self.collection
54+
)
55+
return self._component
56+
57+
58+
def load_stream_component(
59+
config,
60+
device,
61+
dtype,
62+
model_dir,
63+
components_manager: ComponentsManager,
64+
collection: str = "streamdiffusionv2",
65+
) -> ComponentProvider:
66+
"""
67+
Load the CausalStreamInferencePipeline and add it to ComponentsManager.
68+
69+
Args:
70+
config: Configuration dictionary for the pipeline
71+
device: Device to run the pipeline on
72+
dtype: Data type for the pipeline
73+
model_dir: Directory containing the model files
74+
components_manager: ComponentsManager instance to add component to
75+
collection: Collection name for organizing components
76+
77+
Returns:
78+
ComponentProvider: A provider that gives access to the stream component
79+
"""
80+
# Check if component already exists in ComponentsManager
81+
try:
82+
existing = components_manager.get_one(name="stream", collection=collection)
83+
# Component exists, create provider for it
84+
print(f"Reusing existing stream component from collection '{collection}'")
85+
return ComponentProvider(components_manager, "stream", collection)
86+
except Exception:
87+
# Component doesn't exist, create and add it
88+
pass
89+
90+
# Create and initialize the stream pipeline
91+
stream = CausalStreamInferencePipeline(config, device).to(
92+
device=device, dtype=dtype
93+
)
94+
95+
# Load the generator state dict
96+
start = time.time()
97+
model_path = os.path.join(model_dir, "StreamDiffusionV2/model.pt")
98+
if not os.path.exists(model_path):
99+
raise FileNotFoundError(
100+
f"Model file not found at {model_path}. "
101+
"Please ensure StreamDiffusionV2/model.pt exists in the model directory."
102+
)
103+
104+
state_dict_data = torch.load(model_path, map_location="cpu")
105+
106+
# Handle both dict with "generator" key and direct state dict
107+
if isinstance(state_dict_data, dict) and "generator" in state_dict_data:
108+
state_dict = state_dict_data["generator"]
109+
else:
110+
state_dict = state_dict_data
111+
112+
stream.generator.load_state_dict(state_dict, strict=True)
113+
print(f"Loaded diffusion state dict in {time.time() - start:.3f}s")
114+
115+
# Add component to ComponentsManager
116+
component_id = components_manager.add(
117+
"stream",
118+
stream,
119+
collection=collection,
120+
)
121+
print(f"Added stream component to ComponentsManager with ID: {component_id}")
122+
123+
# Create and return provider
124+
return ComponentProvider(components_manager, "stream", collection)
125+
126+
26127
class StreamDiffusionV2Pipeline(Pipeline):
27128
def __init__(
28129
self,

0 commit comments

Comments
 (0)