Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
Binary file added .DS_Store
Binary file not shown.
131 changes: 129 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,135 @@ Abstract: Recent advancements in world models have revolutionized dynamic enviro
## TODO List

- **[2025-05-30]** ✅ Release [paper](https://arxiv.org/abs/2505.22421)
- [ ] Release inference code
- [ ] Release model checkpoints
- **[2026-02-20]** ✅ Release inference code
- **[2026-02-20]** ✅ Release model checkpoints

## Getting Started

<summary><b>Environment Requirement</b></summary>

We recommend first use `conda` to create virtual environment, and install needed libraries.

```bash
conda create -n geodrive python=3.10 -y
conda activate geodrive
pip install -r requirements.txt
```

Then, install custom diffusers with:

```bash
cd ./diffusers
pip install -e .
```

Next, install required ffmpeg:

```bash
conda install -c conda-forge ffmpeg -y
```

<summary><b>Data Preparation</b></summary>


We use processed data from NuScenes. Please follow instruction from another folder 'monst3r' to prepare the data.


## 🏃🏼 Running Scripts

<summary><b>Training</b></summary>

Train the GeoDrive using the script:

```bash
export MODEL_PATH="THUDM/CogVideoX-5b-I2V"
export CACHE_PATH="~/.cache"
export CONDITION_DATASET_PATH="PATH/TO/CONDITION_TRAIN"
export VIDEO_DATASET_PATH="PATH/TO/NUSCENES"
export VAL_CONDITION_DATASET_PATH=""PATH/TO/CONDITION_VAL"
export PROJECT_NAME="GeoDrive"
export RUNS_NAME="GeoDrive"
export OUTPUT_PATH="./${PROJECT_NAME}/${RUNS_NAME}"
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export TOKENIZERS_PARALLELISM=false
export ACCELERATE_LAUNCH_WAIT_TIMEOUT=300
export DS_SKIP_CUDA_CHECK=1

export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7}

torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr=10086 --master_port=$MASTER_PORT \
train.py \
--pretrained_model_name_or_path $MODEL_PATH \
--cache_dir $CACHE_PATH \
--meta_file_path PATH/TO/META_TRAIN \
--val_meta_file_path PATH/TO/META_VAL \
--video_data_root $VIDEO_DATASET_PATH \
--condition_data_root $CONDITION_DATASET_PATH \
--val_condition_data_root $VAL_CONDITION_DATASET_PATH \
--dataloader_num_workers 1 \
--num_validation_videos 1 \
--validation_epochs 2 \
--seed 42 \
--mixed_precision bf16 \
--output_dir $OUTPUT_PATH \
--height 480 \
--width 720 \
--fps 10 \
--video_reshape_mode "center" \
--branch_layer_num 2 \
--train_batch_size 1 \
--num_train_epochs 20 \
--checkpointing_steps 2000 \
--validating_steps 1000 \
--gradient_accumulation_steps 1 \
--learning_rate 1e-5 \
--lr_scheduler cosine_with_restarts \
--lr_warmup_steps 1000 \
--lr_num_cycles 1 \
--enable_slicing \
--enable_tiling \
--noised_image_dropout 0.05 \
--gradient_checkpointing \
--optimizer AdamW \
--adam_beta1 0.9 \
--adam_beta2 0.95 \
--max_grad_norm 1.0 \
--allow_tf32 \
--report_to wandb \
--tracker_name $PROJECT_NAME \
--runs_name $RUNS_NAME \
--mix_train_ratio 0 \
--first_frame_gt

```

<summary><b>Inference</b></summary>

You can run inference with the script:

```bash
cd infer

pretrained_model_name_or_path="THUDM/CogVideoX-5b-I2V"
checkpoint_path="PATH/TO/CKPT"
meta_file_path="PATH/TO/META"
condition_data_root="PATH/TO/CONDITION"
video_data_root="PATH/TO/NUSCENES"

python run_validation.py \
--checkpoint_path $checkpoint_path \
--pretrained_model_name_or_path $pretrained_model_name_or_path \
--meta_file_path $meta_file_path \
--condition_data_root $condition_data_root \
--video_data_root $video_data_root \
--output_dir /PATH/TO/OUTPUT \
--height 480 \
--width 720 \
--max_num_frames 49 \
--mixed_precision bf16 \
--target_frames 25 \

```


## Citation
Expand Down
178 changes: 178 additions & 0 deletions diffusers/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# Initially taken from GitHub's Python gitignore file

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# tests and logs
tests/fixtures/cached_*_text.txt
logs/
lightning_logs/
lang_code_data/

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a Python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# vscode
.vs
.vscode

# Pycharm
.idea

# TF code
tensorflow_code

# Models
proc_data

# examples
runs
/runs_old
/wandb
/examples/runs
/examples/**/*.args
/examples/rag/sweep

# data
/data
serialization_dir

# emacs
*.*~
debug.env

# vim
.*.swp

# ctags
tags

# pre-commit
.pre-commit*

# .lock
*.lock

# DS_Store (MacOS)
.DS_Store

# RL pipelines may produce mp4 outputs
*.mp4

# dependencies
/transformers

# ruff
.ruff_cache

# wandb
wandb
52 changes: 52 additions & 0 deletions diffusers/CITATION.cff
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
cff-version: 1.2.0
title: 'Diffusers: State-of-the-art diffusion models'
message: >-
If you use this software, please cite it using the
metadata from this file.
type: software
authors:
- given-names: Patrick
family-names: von Platen
- given-names: Suraj
family-names: Patil
- given-names: Anton
family-names: Lozhkov
- given-names: Pedro
family-names: Cuenca
- given-names: Nathan
family-names: Lambert
- given-names: Kashif
family-names: Rasul
- given-names: Mishig
family-names: Davaadorj
- given-names: Dhruv
family-names: Nair
- given-names: Sayak
family-names: Paul
- given-names: Steven
family-names: Liu
- given-names: William
family-names: Berman
- given-names: Yiyi
family-names: Xu
- given-names: Thomas
family-names: Wolf
repository-code: 'https://github.com/huggingface/diffusers'
abstract: >-
Diffusers provides pretrained diffusion models across
multiple modalities, such as vision and audio, and serves
as a modular toolbox for inference and training of
diffusion models.
keywords:
- deep-learning
- pytorch
- image-generation
- hacktoberfest
- diffusion
- text2image
- image2image
- score-based-generative-modeling
- stable-diffusion
- stable-diffusion-diffusers
license: Apache-2.0
version: 0.12.1
Loading