Skip to content

Commit 06a042c

Browse files
authored
[Model Card] standardize T2I Lora model card (#6940)
standardize model card t2i-lora
1 parent 8772496 commit 06a042c

File tree

1 file changed

+22
-16
lines changed

1 file changed

+22
-16
lines changed

examples/text_to_image/train_text_to_image_lora.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from diffusers.optimization import get_scheduler
4646
from diffusers.training_utils import cast_training_params, compute_snr
4747
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
48+
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
4849
from diffusers.utils.import_utils import is_xformers_available
4950
from diffusers.utils.torch_utils import is_compiled_module
5051

@@ -61,26 +62,31 @@ def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str,
6162
image.save(os.path.join(repo_folder, f"image_{i}.png"))
6263
img_str += f"![img_{i}](./image_{i}.png)\n"
6364

64-
yaml = f"""
65-
---
66-
license: creativeml-openrail-m
67-
base_model: {base_model}
68-
tags:
69-
- stable-diffusion
70-
- stable-diffusion-diffusers
71-
- text-to-image
72-
- diffusers
73-
- lora
74-
inference: true
75-
---
76-
"""
77-
model_card = f"""
65+
model_description = f"""
7866
# LoRA text2image fine-tuning - {repo_id}
7967
These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
8068
{img_str}
8169
"""
82-
with open(os.path.join(repo_folder, "README.md"), "w") as f:
83-
f.write(yaml + model_card)
70+
71+
model_card = load_or_create_model_card(
72+
repo_id_or_path=repo_id,
73+
from_training=True,
74+
license="creativeml-openrail-m",
75+
base_model=base_model,
76+
model_description=model_description,
77+
inference=True,
78+
)
79+
80+
tags = [
81+
"stable-diffusion",
82+
"stable-diffusion-diffusers",
83+
"text-to-image",
84+
"diffusers",
85+
"lora",
86+
]
87+
model_card = populate_model_card(model_card, tags=tags)
88+
89+
model_card.save(os.path.join(repo_folder, "README.md"))
8490

8591

8692
def parse_args():

0 commit comments

Comments
 (0)