Skip to content

Commit e1bdcc7

Browse files
authored
[Model Card] standardize T2I Sdxl Lora model card (#6944)
* standardize model card template t2i-lora-sdxl * type annotations
1 parent 84905ca commit e1bdcc7

File tree

1 file changed

+26
-27
lines changed

1 file changed

+26
-27
lines changed

examples/text_to_image/train_text_to_image_lora_sdxl.py

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
convert_unet_state_dict_to_peft,
5959
is_wandb_available,
6060
)
61+
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
6162
from diffusers.utils.import_utils import is_xformers_available
6263
from diffusers.utils.torch_utils import is_compiled_module
6364

@@ -70,33 +71,20 @@
7071

7172
def save_model_card(
7273
repo_id: str,
73-
images=None,
74-
base_model=str,
75-
dataset_name=str,
76-
train_text_encoder=False,
77-
repo_folder=None,
78-
vae_path=None,
74+
images: list = None,
75+
base_model: str = None,
76+
dataset_name: str = None,
77+
train_text_encoder: bool = False,
78+
repo_folder: str = None,
79+
vae_path: str = None,
7980
):
8081
img_str = ""
81-
for i, image in enumerate(images):
82-
image.save(os.path.join(repo_folder, f"image_{i}.png"))
83-
img_str += f"![img_{i}](./image_{i}.png)\n"
84-
85-
yaml = f"""
86-
---
87-
license: creativeml-openrail-m
88-
base_model: {base_model}
89-
dataset: {dataset_name}
90-
tags:
91-
- stable-diffusion-xl
92-
- stable-diffusion-xl-diffusers
93-
- text-to-image
94-
- diffusers
95-
- lora
96-
inference: true
97-
---
98-
"""
99-
model_card = f"""
82+
if images is not None:
83+
for i, image in enumerate(images):
84+
image.save(os.path.join(repo_folder, f"image_{i}.png"))
85+
img_str += f"![img_{i}](./image_{i}.png)\n"
86+
87+
model_description = f"""
10088
# LoRA text2image fine-tuning - {repo_id}
10189
10290
These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
@@ -106,8 +94,19 @@ def save_model_card(
10694
10795
Special VAE used for training: {vae_path}.
10896
"""
109-
with open(os.path.join(repo_folder, "README.md"), "w") as f:
110-
f.write(yaml + model_card)
97+
model_card = load_or_create_model_card(
98+
repo_id_or_path=repo_id,
99+
from_training=True,
100+
license="creativeml-openrail-m",
101+
base_model=base_model,
102+
model_description=model_description,
103+
inference=True,
104+
)
105+
106+
tags = ["stable-diffusion-xl", "stable-diffusion-xl-diffusers", "text-to-image", "diffusers", "lora"]
107+
model_card = populate_model_card(model_card, tags=tags)
108+
109+
model_card.save(os.path.join(repo_folder, "README.md"))
111110

112111

113112
def import_model_class_from_model_name_or_path(

0 commit comments

Comments
 (0)