5858 convert_unet_state_dict_to_peft ,
5959 is_wandb_available ,
6060)
61+ from diffusers .utils .hub_utils import load_or_create_model_card , populate_model_card
6162from diffusers .utils .import_utils import is_xformers_available
6263from diffusers .utils .torch_utils import is_compiled_module
6364
7071
7172def save_model_card (
7273 repo_id : str ,
73- images = None ,
74- base_model = str ,
75- dataset_name = str ,
76- train_text_encoder = False ,
77- repo_folder = None ,
78- vae_path = None ,
74+ images : list = None ,
75+ base_model : str = None ,
76+ dataset_name : str = None ,
77+ train_text_encoder : bool = False ,
78+ repo_folder : str = None ,
79+ vae_path : str = None ,
7980):
8081 img_str = ""
81- for i , image in enumerate (images ):
82- image .save (os .path .join (repo_folder , f"image_{ i } .png" ))
83- img_str += f"\n "
84-
85- yaml = f"""
86- ---
87- license: creativeml-openrail-m
88- base_model: { base_model }
89- dataset: { dataset_name }
90- tags:
91- - stable-diffusion-xl
92- - stable-diffusion-xl-diffusers
93- - text-to-image
94- - diffusers
95- - lora
96- inference: true
97- ---
98- """
99- model_card = f"""
82+ if images is not None :
83+ for i , image in enumerate (images ):
84+ image .save (os .path .join (repo_folder , f"image_{ i } .png" ))
85+ img_str += f"\n "
86+
87+ model_description = f"""
10088# LoRA text2image fine-tuning - { repo_id }
10189
10290These are LoRA adaption weights for { base_model } . The weights were fine-tuned on the { dataset_name } dataset. You can find some example images in the following. \n
@@ -106,8 +94,19 @@ def save_model_card(
10694
10795Special VAE used for training: { vae_path } .
10896"""
109- with open (os .path .join (repo_folder , "README.md" ), "w" ) as f :
110- f .write (yaml + model_card )
97+ model_card = load_or_create_model_card (
98+ repo_id_or_path = repo_id ,
99+ from_training = True ,
100+ license = "creativeml-openrail-m" ,
101+ base_model = base_model ,
102+ model_description = model_description ,
103+ inference = True ,
104+ )
105+
106+ tags = ["stable-diffusion-xl" , "stable-diffusion-xl-diffusers" , "text-to-image" , "diffusers" , "lora" ]
107+ model_card = populate_model_card (model_card , tags = tags )
108+
109+ model_card .save (os .path .join (repo_folder , "README.md" ))
111110
112111
113112def import_model_class_from_model_name_or_path (
0 commit comments