Skip to content

Commit ba8e68f

Browse files
feat: add VAE download infrastructure with metadata registry
feat: add VAE download infrastructure with metadata registry Adds complete infrastructure for downloading and managing VAE models independently of pipeline downloads. This change has no effect on the current code. When the ability to select different VAE types is added, the system will check if the selected VAE is downloaded and prompt for download if missing. Backend: - Add /api/v1/vae/status endpoint to check VAE download status - Add /api/v1/vae/download endpoint to trigger VAE downloads - Add VAEMetadata dataclass and VAE_METADATA registry as single source of truth for VAE filenames and download sources - Add vae_file_exists() and get_vae_file_path() using metadata registry - Add download_vae() and download_downloadable_vaes() functions Frontend: - Add checkVaeStatus() and downloadVae() API functions - Extend DownloadDialog to show VAE-specific download prompts - Add VAE status checking before stream start in StreamPage - Add separate VAE download flow with polling for completion Prepares codebase for upcoming PRs adding additional VAE types. Signed-off-by: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com>
1 parent 77e02fd commit ba8e68f

File tree

7 files changed

+328
-6
lines changed

7 files changed

+328
-6
lines changed

frontend/src/components/DownloadDialog.tsx

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,18 @@ interface DownloadDialogProps {
1616
pipelineId: PipelineId;
1717
onClose: () => void;
1818
onDownload: () => void;
19+
vaeNeedsDownload?: {
20+
vaeType: string;
21+
modelName: string;
22+
} | null;
1923
}
2024

2125
export function DownloadDialog({
2226
open,
2327
pipelineId,
2428
onClose,
2529
onDownload,
30+
vaeNeedsDownload,
2631
}: DownloadDialogProps) {
2732
const pipelineInfo = PIPELINES[pipelineId];
2833
if (!pipelineInfo) return null;
@@ -31,13 +36,22 @@ export function DownloadDialog({
3136
<Dialog open={open} onOpenChange={isOpen => !isOpen && onClose()}>
3237
<DialogContent className="sm:max-w-md">
3338
<DialogHeader>
34-
<DialogTitle>Download Models</DialogTitle>
39+
<DialogTitle>
40+
{vaeNeedsDownload ? "Download VAE Model" : "Download Models"}
41+
</DialogTitle>
3542
<DialogDescription className="mt-3">
36-
This pipeline requires model weights to be downloaded.
43+
{vaeNeedsDownload ? (
44+
<>
45+
The selected VAE model ({vaeNeedsDownload.vaeType}) is missing
46+
and needs to be downloaded.
47+
</>
48+
) : (
49+
<>This pipeline requires model weights to be downloaded.</>
50+
)}
3751
</DialogDescription>
3852
</DialogHeader>
3953

40-
{pipelineInfo.estimatedVram && (
54+
{!vaeNeedsDownload && pipelineInfo.estimatedVram && (
4155
<p className="text-sm text-muted-foreground mb-3">
4256
<span className="font-semibold">
4357
Estimated GPU VRAM Requirement:

frontend/src/lib/api.ts

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,3 +307,47 @@ export const getPipelineSchemas =
307307
const result = await response.json();
308308
return result;
309309
};
310+
311+
export const checkVaeStatus = async (
312+
vaeType: string,
313+
modelName: string = "Wan2.1-T2V-1.3B"
314+
): Promise<{ downloaded: boolean }> => {
315+
const response = await fetch(
316+
`/api/v1/vae/status?vae_type=${encodeURIComponent(vaeType)}&model_name=${encodeURIComponent(modelName)}`,
317+
{
318+
method: "GET",
319+
headers: { "Content-Type": "application/json" },
320+
}
321+
);
322+
323+
if (!response.ok) {
324+
const errorText = await response.text();
325+
throw new Error(
326+
`VAE status check failed: ${response.status} ${response.statusText}: ${errorText}`
327+
);
328+
}
329+
330+
const result = await response.json();
331+
return result;
332+
};
333+
334+
export const downloadVae = async (
335+
vaeType: string,
336+
modelName: string = "Wan2.1-T2V-1.3B"
337+
): Promise<{ message: string }> => {
338+
const response = await fetch("/api/v1/vae/download", {
339+
method: "POST",
340+
headers: { "Content-Type": "application/json" },
341+
body: JSON.stringify({ vae_type: vaeType, model_name: modelName }),
342+
});
343+
344+
if (!response.ok) {
345+
const errorText = await response.text();
346+
throw new Error(
347+
`VAE download failed: ${response.status} ${response.statusText}: ${errorText}`
348+
);
349+
}
350+
351+
const result = await response.json();
352+
return result;
353+
};

frontend/src/pages/StreamPage.tsx

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,12 @@ import type {
2424
LoraMergeStrategy,
2525
} from "../types";
2626
import type { PromptItem, PromptTransition } from "../lib/api";
27-
import { checkModelStatus, downloadPipelineModels } from "../lib/api";
27+
import {
28+
checkModelStatus,
29+
checkVaeStatus,
30+
downloadPipelineModels,
31+
downloadVae,
32+
} from "../lib/api";
2833
import { sendLoRAScaleUpdates } from "../utils/loraHelpers";
2934

3035
// Delay before resetting video reinitialization flag (ms)
@@ -96,6 +101,10 @@ export function StreamPage() {
96101
const [pipelineNeedsModels, setPipelineNeedsModels] = useState<string | null>(
97102
null
98103
);
104+
const [vaeNeedsDownload, setVaeNeedsDownload] = useState<{
105+
vaeType: string;
106+
modelName: string;
107+
} | null>(null);
99108

100109
// Ref to access timeline functions
101110
const timelineRef = useRef<{
@@ -274,6 +283,15 @@ export function StreamPage() {
274283
};
275284

276285
const handleDownloadModels = async () => {
286+
// Check if we need to download VAE first
287+
if (vaeNeedsDownload) {
288+
await handleDownloadVae();
289+
} else if (pipelineNeedsModels) {
290+
await handleDownloadPipelineModels();
291+
}
292+
};
293+
294+
const handleDownloadPipelineModels = async () => {
277295
if (!pipelineNeedsModels) return;
278296

279297
setIsDownloading(true);
@@ -354,9 +372,74 @@ export function StreamPage() {
354372
}
355373
};
356374

375+
const handleDownloadVae = async () => {
376+
if (!vaeNeedsDownload) return;
377+
378+
setIsDownloading(true);
379+
setShowDownloadDialog(false);
380+
381+
try {
382+
await downloadVae(vaeNeedsDownload.vaeType, vaeNeedsDownload.modelName);
383+
384+
// Start polling to check when download is complete
385+
const checkDownloadComplete = async () => {
386+
try {
387+
const status = await checkVaeStatus(
388+
vaeNeedsDownload.vaeType,
389+
vaeNeedsDownload.modelName
390+
);
391+
if (status.downloaded) {
392+
setIsDownloading(false);
393+
setVaeNeedsDownload(null);
394+
395+
// After VAE download, check if pipeline models are also needed
396+
const pipelineIdToUse = pipelineNeedsModels || settings.pipelineId;
397+
const pipelineInfo = PIPELINES[pipelineIdToUse];
398+
if (pipelineInfo?.requiresModels) {
399+
try {
400+
const pipelineStatus = await checkModelStatus(pipelineIdToUse);
401+
if (!pipelineStatus.downloaded) {
402+
// Still need pipeline models, show dialog for that
403+
setPipelineNeedsModels(pipelineIdToUse);
404+
setShowDownloadDialog(true);
405+
return;
406+
}
407+
} catch (error) {
408+
console.error("Error checking model status:", error);
409+
}
410+
}
411+
412+
// All downloads complete, start the stream
413+
setTimeout(async () => {
414+
const started = await handleStartStream();
415+
if (started && timelinePlayPauseRef.current) {
416+
setTimeout(() => {
417+
timelinePlayPauseRef.current?.();
418+
}, 2000);
419+
}
420+
}, 100);
421+
} else {
422+
// Check again in 2 seconds
423+
setTimeout(checkDownloadComplete, 2000);
424+
}
425+
} catch (error) {
426+
console.error("Error checking VAE download status:", error);
427+
setIsDownloading(false);
428+
}
429+
};
430+
431+
// Start checking for completion
432+
setTimeout(checkDownloadComplete, 5000);
433+
} catch (error) {
434+
console.error("Error downloading VAE:", error);
435+
setIsDownloading(false);
436+
}
437+
};
438+
357439
const handleDialogClose = () => {
358440
setShowDownloadDialog(false);
359441
setPipelineNeedsModels(null);
442+
setVaeNeedsDownload(null);
360443

361444
// When user cancels, no stream or timeline has started yet, so nothing to clean up
362445
// Just close the dialog and return early without any state changes
@@ -569,6 +652,24 @@ export function StreamPage() {
569652
}
570653
}
571654

655+
// Check if VAE is needed but not downloaded
656+
// Default to "wan" VAE type (backend will handle VAE selection)
657+
// NOTE: support for other vae types will be added later. const vaeType = settings.vaeType ?? "wan";
658+
const vaeType = "wan";
659+
try {
660+
const vaeStatus = await checkVaeStatus(vaeType);
661+
if (!vaeStatus.downloaded) {
662+
// Show download dialog for VAE (use pipeline ID for dialog, but track VAE separately)
663+
setVaeNeedsDownload({ vaeType, modelName: "Wan2.1-T2V-1.3B" });
664+
setPipelineNeedsModels(pipelineIdToUse);
665+
setShowDownloadDialog(true);
666+
return false; // Stream did not start
667+
}
668+
} catch (error) {
669+
console.error("Error checking VAE status:", error);
670+
// Continue anyway if check fails
671+
}
672+
572673
// Always load pipeline with current parameters - backend will handle the rest
573674
console.log(`Loading ${pipelineIdToUse} pipeline...`);
574675

@@ -959,6 +1060,7 @@ export function StreamPage() {
9591060
pipelineId={pipelineNeedsModels as PipelineId}
9601061
onClose={handleDialogClose}
9611062
onDownload={handleDownloadModels}
1063+
vaeNeedsDownload={vaeNeedsDownload}
9621064
/>
9631065
)}
9641066
</div>

src/scope/core/pipelines/wan2_1/vae/__init__.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,39 @@
1616
vae = create_vae(model_dir="wan_models", vae_path="/path/to/custom_vae.pth")
1717
"""
1818

19+
from dataclasses import dataclass
20+
1921
from .wan import WanVAEWrapper
2022

23+
24+
@dataclass(frozen=True)
25+
class VAEMetadata:
26+
"""Metadata for a VAE type (filenames, download sources)."""
27+
28+
filename: str
29+
download_repo: str | None = None # None = bundled with main model repo
30+
download_file: str | None = None # None = no separate download needed
31+
32+
33+
# Single source of truth for VAE metadata
34+
VAE_METADATA: dict[str, VAEMetadata] = {
35+
"wan": VAEMetadata(
36+
filename="Wan2.1_VAE.pth",
37+
download_repo="Wan-AI/Wan2.1-T2V-1.3B",
38+
download_file="Wan2.1_VAE.pth",
39+
),
40+
"lightvae": VAEMetadata(
41+
filename="lightvaew2_1.pth",
42+
download_repo="lightx2v/Autoencoders",
43+
download_file="lightvaew2_1.pth",
44+
),
45+
"tae": VAEMetadata(
46+
filename="taew2_1.pth",
47+
download_repo="lightx2v/Autoencoders",
48+
download_file="taew2_1.pth",
49+
),
50+
}
51+
2152
# Registry mapping type names to VAE classes
2253
# UI dropdowns will use these keys
2354
VAE_REGISTRY: dict[str, type] = {
@@ -69,8 +100,10 @@ def list_vae_types() -> list[str]:
69100

70101
__all__ = [
71102
"WanVAEWrapper",
72-
"create_vae",
73-
"list_vae_types",
103+
"VAEMetadata",
104+
"VAE_METADATA",
74105
"VAE_REGISTRY",
75106
"DEFAULT_VAE_TYPE",
107+
"create_vae",
108+
"list_vae_types",
76109
]

src/scope/server/app.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,15 @@ class DownloadModelsRequest(BaseModel):
367367
pipeline_id: str
368368

369369

370+
class VaeStatusResponse(BaseModel):
371+
downloaded: bool
372+
373+
374+
class DownloadVaeRequest(BaseModel):
375+
vae_type: str
376+
model_name: str = "Wan2.1-T2V-1.3B"
377+
378+
370379
class LoRAFileInfo(BaseModel):
371380
"""Metadata for an available LoRA file on disk."""
372381

@@ -452,6 +461,46 @@ def download_in_background():
452461
raise HTTPException(status_code=500, detail=str(e)) from e
453462

454463

464+
@app.get("/api/v1/vae/status", response_model=VaeStatusResponse)
465+
async def get_vae_status(vae_type: str, model_name: str = "Wan2.1-T2V-1.3B"):
466+
"""Check if a VAE file is downloaded."""
467+
try:
468+
from .models_config import vae_file_exists
469+
470+
downloaded = vae_file_exists(vae_type, model_name)
471+
return VaeStatusResponse(downloaded=downloaded)
472+
except Exception as e:
473+
logger.error(f"Error checking VAE status: {e}")
474+
raise HTTPException(status_code=500, detail=str(e)) from e
475+
476+
477+
@app.post("/api/v1/vae/download")
478+
async def download_vae(request: DownloadVaeRequest):
479+
"""Download a specific VAE file."""
480+
try:
481+
if not request.vae_type:
482+
raise HTTPException(status_code=400, detail="vae_type is required")
483+
484+
# Download in a background thread to avoid blocking
485+
import threading
486+
487+
from .download_models import download_vae as download_vae_func
488+
489+
def download_in_background():
490+
download_vae_func(request.vae_type, request.model_name)
491+
492+
thread = threading.Thread(target=download_in_background)
493+
thread.daemon = True
494+
thread.start()
495+
496+
return {
497+
"message": f"VAE download started for {request.vae_type} (model: {request.model_name})"
498+
}
499+
except Exception as e:
500+
logger.error(f"Error starting VAE download: {e}")
501+
raise HTTPException(status_code=500, detail=str(e)) from e
502+
503+
455504
@app.get("/api/v1/hardware/info", response_model=HardwareInfoResponse)
456505
async def get_hardware_info():
457506
"""Get hardware information including available VRAM."""

0 commit comments

Comments
 (0)