|
45 | 45 | from diffusers.optimization import get_scheduler |
46 | 46 | from diffusers.training_utils import EMAModel, compute_snr |
47 | 47 | from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid |
| 48 | +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card |
48 | 49 | from diffusers.utils.import_utils import is_xformers_available |
49 | 50 | from diffusers.utils.torch_utils import is_compiled_module |
50 | 51 |
|
@@ -75,21 +76,7 @@ def save_model_card( |
75 | 76 | image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png")) |
76 | 77 | img_str += "\n" |
77 | 78 |
|
78 | | - yaml = f""" |
79 | | ---- |
80 | | -license: creativeml-openrail-m |
81 | | -base_model: {args.pretrained_model_name_or_path} |
82 | | -datasets: |
83 | | -- {args.dataset_name} |
84 | | -tags: |
85 | | -- stable-diffusion |
86 | | -- stable-diffusion-diffusers |
87 | | -- text-to-image |
88 | | -- diffusers |
89 | | -inference: true |
90 | | ---- |
91 | | - """ |
92 | | - model_card = f""" |
| 79 | + model_description = f""" |
93 | 80 | # Text-to-image finetuning - {repo_id} |
94 | 81 |
|
95 | 82 | This pipeline was finetuned from **{args.pretrained_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n |
@@ -132,10 +119,21 @@ def save_model_card( |
132 | 119 | More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}). |
133 | 120 | """ |
134 | 121 |
|
135 | | - model_card += wandb_info |
| 122 | + model_description += wandb_info |
136 | 123 |
|
137 | | - with open(os.path.join(repo_folder, "README.md"), "w") as f: |
138 | | - f.write(yaml + model_card) |
| 124 | + model_card = load_or_create_model_card( |
| 125 | + repo_id_or_path=repo_id, |
| 126 | + from_training=True, |
| 127 | + license="creativeml-openrail-m", |
| 128 | + base_model=args.pretrained_model_name_or_path, |
| 129 | + model_description=model_description, |
| 130 | + inference=True, |
| 131 | + ) |
| 132 | + |
| 133 | + tags = ["stable-diffusion", "stable-diffusion-diffusers", "text-to-image", "diffusers"] |
| 134 | + model_card = populate_model_card(model_card, tags=tags) |
| 135 | + |
| 136 | + model_card.save(os.path.join(repo_folder, "README.md")) |
139 | 137 |
|
140 | 138 |
|
141 | 139 | def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch): |
|
0 commit comments