-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_ae.py
More file actions
67 lines (52 loc) · 1.84 KB
/
train_ae.py
File metadata and controls
67 lines (52 loc) · 1.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from __future__ import annotations
import os
import shutil
from pathlib import Path
from backend.app.training.runner import train_run
EPOCHS = 10
BATCH_SIZE = 64
LEARNING_RATE = 1e-3
EMBEDDING_DIM = 64
CLASS_WEIGHT = 1.0
def train() -> None:
os.makedirs("outputs", exist_ok=True)
run_dir = Path("outputs")
config = {
"model_type": "ae",
"epochs": EPOCHS,
"batch_size": BATCH_SIZE,
"learning_rate": LEARNING_RATE,
"capacity": 16,
"embedding_dim": EMBEDDING_DIM,
"class_weight": CLASS_WEIGHT,
"checkpoint_interval_epochs": 1,
"save_final_checkpoint": True,
"final_checkpoint_name": "final.pth",
"latest_recon_name": "recon_latest.png",
"dataset_name": "mnist",
}
print("Starting AE training...")
result = train_run(
config=config,
run_dir=run_dir,
emit=_print_event,
is_cancel_requested=lambda: False,
)
if result.get("status") != "completed":
raise RuntimeError(f"AE training did not complete: {result}")
# Preserve legacy output filenames
final_ckpt = run_dir / "checkpoints" / "final.pth"
shutil.copyfile(final_ckpt, run_dir / "autoencoder.pth")
latest_recon = run_dir / "samples" / "recon_latest.png"
shutil.copyfile(latest_recon, run_dir / "reconstruction_results.png")
print("Model weights saved to outputs/autoencoder.pth")
print("Visual comparison saved as outputs/reconstruction_results.png")
def _print_event(event_type: str, payload: dict) -> None:
if event_type == "train.metrics":
epoch = payload["epoch"]
total = payload["total_loss"]
recon = payload["recon_loss"]
cls = payload["class_loss"]
print(f"Epoch [{epoch}] Total: {total:.4f} Recon: {recon:.4f} Class: {cls:.4f}")
if __name__ == "__main__":
train()