Skip to content

Commit 8772496

Browse files
authored
[Model Card] standardize T2I model card (#6939)
* standardize model card * fix base_model
1 parent 35fd84b commit 8772496

File tree

1 file changed

+16
-18
lines changed

1 file changed

+16
-18
lines changed

examples/text_to_image/train_text_to_image.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from diffusers.optimization import get_scheduler
4646
from diffusers.training_utils import EMAModel, compute_snr
4747
from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid
48+
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
4849
from diffusers.utils.import_utils import is_xformers_available
4950
from diffusers.utils.torch_utils import is_compiled_module
5051

@@ -75,21 +76,7 @@ def save_model_card(
7576
image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png"))
7677
img_str += "![val_imgs_grid](./val_imgs_grid.png)\n"
7778

78-
yaml = f"""
79-
---
80-
license: creativeml-openrail-m
81-
base_model: {args.pretrained_model_name_or_path}
82-
datasets:
83-
- {args.dataset_name}
84-
tags:
85-
- stable-diffusion
86-
- stable-diffusion-diffusers
87-
- text-to-image
88-
- diffusers
89-
inference: true
90-
---
91-
"""
92-
model_card = f"""
79+
model_description = f"""
9380
# Text-to-image finetuning - {repo_id}
9481
9582
This pipeline was finetuned from **{args.pretrained_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n
@@ -132,10 +119,21 @@ def save_model_card(
132119
More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).
133120
"""
134121

135-
model_card += wandb_info
122+
model_description += wandb_info
136123

137-
with open(os.path.join(repo_folder, "README.md"), "w") as f:
138-
f.write(yaml + model_card)
124+
model_card = load_or_create_model_card(
125+
repo_id_or_path=repo_id,
126+
from_training=True,
127+
license="creativeml-openrail-m",
128+
base_model=args.pretrained_model_name_or_path,
129+
model_description=model_description,
130+
inference=True,
131+
)
132+
133+
tags = ["stable-diffusion", "stable-diffusion-diffusers", "text-to-image", "diffusers"]
134+
model_card = populate_model_card(model_card, tags=tags)
135+
136+
model_card.save(os.path.join(repo_folder, "README.md"))
139137

140138

141139
def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch):

0 commit comments

Comments
 (0)