diff --git a/acestep/pipeline_ace_step.py b/acestep/pipeline_ace_step.py index 919d491f..bfa704dc 100644 --- a/acestep/pipeline_ace_step.py +++ b/acestep/pipeline_ace_step.py @@ -1466,7 +1466,11 @@ def __call__( save_path: str = None, batch_size: int = 1, debug: bool = False, + actual_seeds: list = None, + **kwargs, ): + if actual_seeds is not None: + manual_seeds = actual_seeds # decent middle ground fix start_time = time.time() diff --git a/acestep/ui/components.py b/acestep/ui/components.py index fc56b298..febc8e9d 100644 --- a/acestep/ui/components.py +++ b/acestep/ui/components.py @@ -72,22 +72,23 @@ def update_tags_from_preset(preset_name): def create_output_ui(task_name="Text2Music"): # For many consumer-grade GPU devices, only one batch can be run - output_audio1 = gr.Audio(type="filepath", label=f"{task_name} Generated Audio 1") + output_audio1 = gr.Audio(type="filepath", label=f"{task_name} Generated Audio 1", interactive=False) # output_audio2 = gr.Audio(type="filepath", label="Generated Audio 2") with gr.Accordion(f"{task_name} Parameters", open=False): input_params_json = gr.JSON(label=f"{task_name} Parameters") # outputs = [output_audio1, output_audio2] - outputs = [output_audio1] + outputs = output_audio1 return outputs, input_params_json -def dump_func(*args): +def dump_func(*args, **kwargs): + print() print(args) - return [] + print(kwargs) + return [None, {}] def create_text2music_ui( - gr, text2music_process_func, sample_data_func=None, load_data_func=None, @@ -102,7 +103,7 @@ def create_text2music_ui( os.makedirs(output_file_dir, exist_ok=True) json_files = [f for f in os.listdir(output_file_dir) if f.endswith('.json')] json_files.sort(reverse=True, key=lambda x: int(x.split('_')[1])) - output_files = gr.Dropdown(choices=json_files, label="Select previous generated input params", scale=9, interactive=True) + output_files = gr.Dropdown(choices=json_files, label="Select previous generated input params", scale=9, ) load_bnt = gr.Button("Load", variant="primary", scale=1) with gr.Row(): @@ -115,11 +116,10 @@ def create_text2music_ui( step=0.00001, value=-1, label="Audio Duration", - interactive=True, info="-1 means random duration (30 ~ 240).", scale=9, ) - format = gr.Dropdown(choices=["mp3", "ogg", "flac", "wav"], value="wav", label="Format") + format = gr.Dropdown(choices=["mp3", "ogg", "flac", "wav"], value="wav", label="Format") # FIXME: recommend rename to _format sample_bnt = gr.Button("Sample", variant="secondary", scale=1) # audio2audio @@ -143,7 +143,6 @@ def create_text2music_ui( value=0.5, elem_id="ref_audio_strength", visible=False, - interactive=True, ) def toggle_ref_audio_visibility(is_checked): @@ -198,7 +197,6 @@ def toggle_ref_audio_visibility(is_checked): step=1, value=60, label="Infer Steps", - interactive=True, ) guidance_scale = gr.Slider( minimum=0.0, @@ -206,7 +204,6 @@ def toggle_ref_audio_visibility(is_checked): step=0.1, value=15.0, label="Guidance Scale", - interactive=True, info="When guidance_scale_lyric > 1 and guidance_scale_text > 1, the guidance scale will not be applied.", ) guidance_scale_text = gr.Slider( @@ -215,7 +212,6 @@ def toggle_ref_audio_visibility(is_checked): step=0.1, value=0.0, label="Guidance Scale Text", - interactive=True, info="Guidance scale for text condition. It can only apply to cfg. set guidance_scale_text=5.0, guidance_scale_lyric=1.5 for start", ) guidance_scale_lyric = gr.Slider( @@ -224,7 +220,6 @@ def toggle_ref_audio_visibility(is_checked): step=0.1, value=0.0, label="Guidance Scale Lyric", - interactive=True, ) manual_seeds = gr.Textbox( @@ -271,7 +266,6 @@ def toggle_ref_audio_visibility(is_checked): step=0.1, value=10.0, label="Granularity Scale", - interactive=True, info="Granularity scale for the generation. Higher values can reduce artifacts", ) @@ -281,7 +275,6 @@ def toggle_ref_audio_visibility(is_checked): step=0.01, value=0.5, label="Guidance Interval", - interactive=True, info="Guidance interval for the generation. 0.5 means only apply guidance in the middle steps (0.25 * infer_steps to 0.75 * infer_steps)", ) guidance_interval_decay = gr.Slider( @@ -290,7 +283,6 @@ def toggle_ref_audio_visibility(is_checked): step=0.01, value=0.0, label="Guidance Interval Decay", - interactive=True, info="Guidance interval decay for the generation. Guidance scale will decay from guidance_scale to min_guidance_scale in the interval. 0.0 means no decay.", ) min_guidance_scale = gr.Slider( @@ -299,7 +291,6 @@ def toggle_ref_audio_visibility(is_checked): step=0.1, value=3.0, label="Min Guidance Scale", - interactive=True, info="Min guidance scale for guidance interval decay's end scale", ) oss_steps = gr.Textbox( @@ -313,698 +304,356 @@ def toggle_ref_audio_visibility(is_checked): with gr.Column(): outputs, input_params_json = create_output_ui() - with gr.Tab("retake"): - retake_variance = gr.Slider( - minimum=0.0, maximum=1.0, step=0.01, value=0.2, label="variance" - ) - retake_seeds = gr.Textbox( - label="retake seeds (default None)", placeholder="", value=None - ) - retake_bnt = gr.Button("Retake", variant="primary") - retake_outputs, retake_input_params_json = create_output_ui("Retake") - - def retake_process_func(json_data, retake_variance, retake_seeds): - return text2music_process_func( - json_data["format"], - json_data["audio_duration"], - json_data["prompt"], - json_data["lyrics"], - json_data["infer_step"], - json_data["guidance_scale"], - json_data["scheduler_type"], - json_data["cfg_type"], - json_data["omega_scale"], - ", ".join(map(str, json_data["actual_seeds"])), - json_data["guidance_interval"], - json_data["guidance_interval_decay"], - json_data["min_guidance_scale"], - json_data["use_erg_tag"], - json_data["use_erg_lyric"], - json_data["use_erg_diffusion"], - ", ".join(map(str, json_data["oss_steps"])), - ( - json_data["guidance_scale_text"] - if "guidance_scale_text" in json_data - else 0.0 - ), - ( - json_data["guidance_scale_lyric"] - if "guidance_scale_lyric" in json_data - else 0.0 - ), - retake_seeds=retake_seeds, - retake_variance=retake_variance, - task="retake", - lora_name_or_path="none" if "lora_name_or_path" not in json_data else json_data["lora_name_or_path"], - lora_weight=1 if "lora_weight" not in json_data else json_data["lora_weight"] - ) - - retake_bnt.click( - fn=retake_process_func, - inputs=[ - input_params_json, - retake_variance, - retake_seeds, - ], - outputs=retake_outputs + [retake_input_params_json], - ) - with gr.Tab("repainting"): - retake_variance = gr.Slider( - minimum=0.0, maximum=1.0, step=0.01, value=0.2, label="variance" - ) - retake_seeds = gr.Textbox( - label="repaint seeds (default None)", placeholder="", value=None - ) - repaint_start = gr.Slider( - minimum=0.0, - maximum=240.0, - step=0.01, - value=0.0, - label="Repaint Start Time", - interactive=True, - ) - repaint_end = gr.Slider( - minimum=0.0, - maximum=240.0, - step=0.01, - value=30.0, - label="Repaint End Time", - interactive=True, - ) - repaint_source = gr.Radio( - ["text2music", "last_repaint", "upload"], - value="text2music", - label="Repaint Source", - elem_id="repaint_source", - ) - repaint_source_audio_upload = gr.Audio( - label="Upload Audio", - type="filepath", - visible=False, - elem_id="repaint_source_audio_upload", - show_download_button=True, - ) - repaint_source.change( - fn=lambda x: gr.update( - visible=x == "upload", elem_id="repaint_source_audio_upload" - ), - inputs=[repaint_source], - outputs=[repaint_source_audio_upload], - ) + for i in range(2): + gr.Markdown('# ' + chr(0x200e)) # ToDo: find better way to space these out - repaint_bnt = gr.Button("Repaint", variant="primary") - repaint_outputs, repaint_input_params_json = create_output_ui("Repaint") - - def repaint_process_func( - text2music_json_data, - repaint_json_data, - retake_variance, - retake_seeds, - repaint_start, - repaint_end, - repaint_source, - repaint_source_audio_upload, - prompt, - lyrics, - infer_step, - guidance_scale, - scheduler_type, - cfg_type, - omega_scale, - manual_seeds, - guidance_interval, - guidance_interval_decay, - min_guidance_scale, - use_erg_tag, - use_erg_lyric, - use_erg_diffusion, - oss_steps, - guidance_scale_text, - guidance_scale_lyric, - ): - if repaint_source == "upload": - src_audio_path = repaint_source_audio_upload - audio_duration = librosa.get_duration(filename=src_audio_path) - json_data = {"audio_duration": audio_duration} - elif repaint_source == "text2music": - json_data = text2music_json_data - src_audio_path = json_data["audio_path"] - elif repaint_source == "last_repaint": - json_data = repaint_json_data - src_audio_path = json_data["audio_path"] - - return text2music_process_func( - format.value, - json_data["audio_duration"], - prompt, - lyrics, - infer_step, - guidance_scale, - scheduler_type, - cfg_type, - omega_scale, - manual_seeds, - guidance_interval, - guidance_interval_decay, - min_guidance_scale, - use_erg_tag, - use_erg_lyric, - use_erg_diffusion, - oss_steps, - guidance_scale_text, - guidance_scale_lyric, - retake_seeds=retake_seeds, - retake_variance=retake_variance, - task="repaint", - repaint_start=repaint_start, - repaint_end=repaint_end, - src_audio_path=src_audio_path, - lora_name_or_path="none" if "lora_name_or_path" not in json_data else json_data["lora_name_or_path"], - lora_weight=1 if "lora_weight" not in json_data else json_data["lora_weight"] + retake_variance = gr.Slider( + minimum=0.0, maximum=1.0, step=0.01, value=0.2, label="variance" + ) + retake_seeds = gr.Textbox( + label="retake seeds (default None)", placeholder="", value=None + ) + with gr.Tabs() as text2music_tabs: + def toggle_variance_interact(event_data: gr.SelectData): + not_edit = event_data.value != 'edit' + return [ + gr.update(interactive=not_edit), + gr.update(interactive=not_edit), + ] + text2music_tabs.select( + fn=toggle_variance_interact, + outputs=[retake_seeds, retake_variance] + ) + with gr.Tab("retake"): + retake_bnt = gr.Button("Retake", variant="primary") + retake_outputs, retake_input_params_json = create_output_ui("Retake") + + with gr.Tab("repainting"): + repaint_start = gr.Slider( + minimum=0.0, + maximum=240.0, + step=0.01, + value=0.0, + label="Repaint Start Time", + ) + repaint_end = gr.Slider( + minimum=0.0, + maximum=240.0, + step=0.01, + value=30.0, + label="Repaint End Time", + ) + repaint_source = gr.Radio( + ["text2music", "last_repaint", "upload"], + value="text2music", + label="Repaint Source", + elem_id="repaint_source", ) - repaint_bnt.click( - fn=repaint_process_func, - inputs=[ - input_params_json, - repaint_input_params_json, - retake_variance, - retake_seeds, - repaint_start, - repaint_end, - repaint_source, - repaint_source_audio_upload, - prompt, - lyrics, - infer_step, - guidance_scale, - scheduler_type, - cfg_type, - omega_scale, - manual_seeds, - guidance_interval, - guidance_interval_decay, - min_guidance_scale, - use_erg_tag, - use_erg_lyric, - use_erg_diffusion, - oss_steps, - guidance_scale_text, - guidance_scale_lyric, - ], - outputs=repaint_outputs + [repaint_input_params_json], - ) - with gr.Tab("edit"): - edit_prompt = gr.Textbox(lines=2, label="Edit Tags", max_lines=4) - edit_lyrics = gr.Textbox(lines=9, label="Edit Lyrics", max_lines=13) - retake_seeds = gr.Textbox( - label="edit seeds (default None)", placeholder="", value=None - ) - - edit_type = gr.Radio( - ["only_lyrics", "remix"], - value="only_lyrics", - label="Edit Type", - elem_id="edit_type", - info="`only_lyrics` will keep the whole song the same except lyrics difference. Make your diffrence smaller, e.g. one lyrc line change.\nremix can change the song melody and genre", - ) - edit_n_min = gr.Slider( - minimum=0.0, - maximum=1.0, - step=0.01, - value=0.6, - label="edit_n_min", - interactive=True, - ) - edit_n_max = gr.Slider( - minimum=0.0, - maximum=1.0, - step=0.01, - value=1.0, - label="edit_n_max", - interactive=True, - ) + repaint_source_audio_upload = gr.Audio( + label="Upload Audio", + type="filepath", + visible=False, + elem_id="repaint_source_audio_upload", + show_download_button=True, + ) + repaint_source.change( + fn=lambda x: gr.update( + visible=x == "upload", elem_id="repaint_source_audio_upload" + ), + inputs=[repaint_source], + outputs=[repaint_source_audio_upload], + ) - def edit_type_change_func(edit_type): - if edit_type == "only_lyrics": - n_min = 0.6 - n_max = 1.0 - elif edit_type == "remix": - n_min = 0.2 - n_max = 0.4 - return n_min, n_max - - edit_type.change( - edit_type_change_func, - inputs=[edit_type], - outputs=[edit_n_min, edit_n_max], - ) + repaint_bnt = gr.Button("Repaint", variant="primary") + repaint_outputs, repaint_input_params_json = create_output_ui("Repaint") - edit_source = gr.Radio( - ["text2music", "last_edit", "upload"], - value="text2music", - label="Edit Source", - elem_id="edit_source", - ) - edit_source_audio_upload = gr.Audio( - label="Upload Audio", - type="filepath", - visible=False, - elem_id="edit_source_audio_upload", - show_download_button=True, - ) - edit_source.change( - fn=lambda x: gr.update( - visible=x == "upload", elem_id="edit_source_audio_upload" - ), - inputs=[edit_source], - outputs=[edit_source_audio_upload], - ) + with gr.Tab("edit"): + edit_target_prompt = gr.Textbox(lines=2, label="Edit Tags", max_lines=4) + edit_target_lyrics = gr.Textbox(lines=9, label="Edit Lyrics", max_lines=13) - edit_bnt = gr.Button("Edit", variant="primary") - edit_outputs, edit_input_params_json = create_output_ui("Edit") - - def edit_process_func( - text2music_json_data, - edit_input_params_json, - edit_source, - edit_source_audio_upload, - prompt, - lyrics, - edit_prompt, - edit_lyrics, - edit_n_min, - edit_n_max, - infer_step, - guidance_scale, - scheduler_type, - cfg_type, - omega_scale, - manual_seeds, - guidance_interval, - guidance_interval_decay, - min_guidance_scale, - use_erg_tag, - use_erg_lyric, - use_erg_diffusion, - oss_steps, - guidance_scale_text, - guidance_scale_lyric, - retake_seeds, - ): - if edit_source == "upload": - src_audio_path = edit_source_audio_upload - audio_duration = librosa.get_duration(filename=src_audio_path) - json_data = {"audio_duration": audio_duration} - elif edit_source == "text2music": - json_data = text2music_json_data - src_audio_path = json_data["audio_path"] - elif edit_source == "last_edit": - json_data = edit_input_params_json - src_audio_path = json_data["audio_path"] - - if not edit_prompt: - edit_prompt = prompt - if not edit_lyrics: - edit_lyrics = lyrics - - return text2music_process_func( - format.value, - json_data["audio_duration"], - prompt, - lyrics, - infer_step, - guidance_scale, - scheduler_type, - cfg_type, - omega_scale, - manual_seeds, - guidance_interval, - guidance_interval_decay, - min_guidance_scale, - use_erg_tag, - use_erg_lyric, - use_erg_diffusion, - oss_steps, - guidance_scale_text, - guidance_scale_lyric, - task="edit", - src_audio_path=src_audio_path, - edit_target_prompt=edit_prompt, - edit_target_lyrics=edit_lyrics, - edit_n_min=edit_n_min, - edit_n_max=edit_n_max, - retake_seeds=retake_seeds, - lora_name_or_path="none" if "lora_name_or_path" not in json_data else json_data["lora_name_or_path"], - lora_weight=1 if "lora_weight" not in json_data else json_data["lora_weight"] + edit_type = gr.Radio( + ["only_lyrics", "remix"], + value="only_lyrics", + label="Edit Type", + elem_id="edit_type", + info="`only_lyrics` will keep the whole song the same except lyrics difference. Make your diffrence smaller, e.g. one lyrc line change.\nremix can change the song melody and genre", + ) + edit_n_min = gr.Slider( + minimum=0.0, + maximum=1.0, + step=0.01, + value=0.6, + label="edit_n_min", + ) + edit_n_max = gr.Slider( + minimum=0.0, + maximum=1.0, + step=0.01, + value=1.0, + label="edit_n_max", ) - edit_bnt.click( - fn=edit_process_func, - inputs=[ - input_params_json, - edit_input_params_json, - edit_source, - edit_source_audio_upload, - prompt, - lyrics, - edit_prompt, - edit_lyrics, - edit_n_min, - edit_n_max, - infer_step, - guidance_scale, - scheduler_type, - cfg_type, - omega_scale, - manual_seeds, - guidance_interval, - guidance_interval_decay, - min_guidance_scale, - use_erg_tag, - use_erg_lyric, - use_erg_diffusion, - oss_steps, - guidance_scale_text, - guidance_scale_lyric, - retake_seeds, - ], - outputs=edit_outputs + [edit_input_params_json], - ) - with gr.Tab("extend"): - extend_seeds = gr.Textbox( - label="extend seeds (default None)", placeholder="", value=None - ) - left_extend_length = gr.Slider( - minimum=0.0, - maximum=240.0, - step=0.01, - value=0.0, - label="Left Extend Length", - interactive=True, - ) - right_extend_length = gr.Slider( - minimum=0.0, - maximum=240.0, - step=0.01, - value=30.0, - label="Right Extend Length", - interactive=True, - ) - extend_source = gr.Radio( - ["text2music", "last_extend", "upload"], - value="text2music", - label="Extend Source", - elem_id="extend_source", - ) - - extend_source_audio_upload = gr.Audio( - label="Upload Audio", - type="filepath", - visible=False, - elem_id="extend_source_audio_upload", - show_download_button=True, - ) - extend_source.change( - fn=lambda x: gr.update( - visible=x == "upload", elem_id="extend_source_audio_upload" - ), - inputs=[extend_source], - outputs=[extend_source_audio_upload], - ) + def edit_type_change_func(edit_type): + if edit_type == "only_lyrics": + n_min = 0.6 + n_max = 1.0 + elif edit_type == "remix": + n_min = 0.2 + n_max = 0.4 + return n_min, n_max + + edit_type.change( + edit_type_change_func, + inputs=[edit_type], + outputs=[edit_n_min, edit_n_max], + ) - extend_bnt = gr.Button("Extend", variant="primary") - extend_outputs, extend_input_params_json = create_output_ui("Extend") - - def extend_process_func( - text2music_json_data, - extend_input_params_json, - extend_seeds, - left_extend_length, - right_extend_length, - extend_source, - extend_source_audio_upload, - prompt, - lyrics, - infer_step, - guidance_scale, - scheduler_type, - cfg_type, - omega_scale, - manual_seeds, - guidance_interval, - guidance_interval_decay, - min_guidance_scale, - use_erg_tag, - use_erg_lyric, - use_erg_diffusion, - oss_steps, - guidance_scale_text, - guidance_scale_lyric, - ): - if extend_source == "upload": - src_audio_path = extend_source_audio_upload - # get audio duration - audio_duration = librosa.get_duration(filename=src_audio_path) - json_data = {"audio_duration": audio_duration} - elif extend_source == "text2music": - json_data = text2music_json_data - src_audio_path = json_data["audio_path"] - elif extend_source == "last_extend": - json_data = extend_input_params_json - src_audio_path = json_data["audio_path"] - - repaint_start = -left_extend_length - repaint_end = json_data["audio_duration"] + right_extend_length - return text2music_process_func( - format.value, - json_data["audio_duration"], - prompt, - lyrics, - infer_step, - guidance_scale, - scheduler_type, - cfg_type, - omega_scale, - manual_seeds, - guidance_interval, - guidance_interval_decay, - min_guidance_scale, - use_erg_tag, - use_erg_lyric, - use_erg_diffusion, - oss_steps, - guidance_scale_text, - guidance_scale_lyric, - retake_seeds=extend_seeds, - retake_variance=1.0, - task="extend", - repaint_start=repaint_start, - repaint_end=repaint_end, - src_audio_path=src_audio_path, - lora_name_or_path=( - "none" - if "lora_name_or_path" not in json_data - else json_data["lora_name_or_path"] - ), - lora_weight=( - 1 - if "lora_weight" not in json_data - else json_data["lora_weight"] + edit_source = gr.Radio( + ["text2music", "last_edit", "upload"], + value="text2music", + label="Edit Source", + elem_id="edit_source", + ) + edit_source_audio_upload = gr.Audio( + label="Upload Audio", + type="filepath", + visible=False, + elem_id="edit_source_audio_upload", + show_download_button=True, + ) + edit_source.change( + fn=lambda x: gr.update( + visible=x == "upload", elem_id="edit_source_audio_upload" ), + inputs=[edit_source], + outputs=[edit_source_audio_upload], ) - extend_bnt.click( - fn=extend_process_func, - inputs=[ - input_params_json, - extend_input_params_json, - extend_seeds, - left_extend_length, - right_extend_length, - extend_source, - extend_source_audio_upload, - prompt, - lyrics, - infer_step, - guidance_scale, - scheduler_type, - cfg_type, - omega_scale, - manual_seeds, - guidance_interval, - guidance_interval_decay, - min_guidance_scale, - use_erg_tag, - use_erg_lyric, - use_erg_diffusion, - oss_steps, - guidance_scale_text, - guidance_scale_lyric, - ], - outputs=extend_outputs + [extend_input_params_json], - ) + edit_bnt = gr.Button("Edit", variant="primary") + edit_outputs, edit_input_params_json = create_output_ui("Edit") - def json2output(json_data): - return ( - json_data["audio_duration"], - json_data["prompt"], - json_data["lyrics"], - json_data["infer_step"], - json_data["guidance_scale"], - json_data["scheduler_type"], - json_data["cfg_type"], - json_data["omega_scale"], - ", ".join(map(str, json_data["actual_seeds"])), - json_data["guidance_interval"], - json_data["guidance_interval_decay"], - json_data["min_guidance_scale"], - json_data["use_erg_tag"], - json_data["use_erg_lyric"], - json_data["use_erg_diffusion"], - ", ".join(map(str, json_data["oss_steps"])), - ( - json_data["guidance_scale_text"] - if "guidance_scale_text" in json_data - else 0.0 - ), - ( - json_data["guidance_scale_lyric"] - if "guidance_scale_lyric" in json_data - else 0.0 - ), - ( - json_data["audio2audio_enable"] - if "audio2audio_enable" in json_data - else False - ), - ( - json_data["ref_audio_strength"] - if "ref_audio_strength" in json_data - else 0.5 - ), - ( - json_data["ref_audio_input"] - if "ref_audio_input" in json_data - else None - ), - ) + with gr.Tab("extend"): + extend_seeds = gr.Textbox( + label="extend seeds (default None)", placeholder="", value=None + ) + left_extend_length = gr.Slider( + minimum=0.0, + maximum=240.0, + step=0.01, + value=0.0, + label="Left Extend Length", + ) + right_extend_length = gr.Slider( + minimum=0.0, + maximum=240.0, + step=0.01, + value=30.0, + label="Right Extend Length", + ) + extend_source = gr.Radio( + ["text2music", "last_extend", "upload"], + value="text2music", + label="Extend Source", + elem_id="extend_source", + ) - def sample_data(lora_name_or_path_): - json_data = sample_data_func(lora_name_or_path_) - return json2output(json_data) - - sample_bnt.click( - sample_data, - inputs=[lora_name_or_path], - outputs=[ - audio_duration, - prompt, - lyrics, - infer_step, - guidance_scale, - scheduler_type, - cfg_type, - omega_scale, - manual_seeds, - guidance_interval, - guidance_interval_decay, - min_guidance_scale, - use_erg_tag, - use_erg_lyric, - use_erg_diffusion, - oss_steps, - guidance_scale_text, - guidance_scale_lyric, - audio2audio_enable, - ref_audio_strength, - ref_audio_input, - ], - ) + extend_source_audio_upload = gr.Audio( + label="Upload Audio", + type="filepath", + visible=False, + elem_id="extend_source_audio_upload", + show_download_button=True, + ) + extend_source.change( + fn=lambda x: gr.update( + visible=x == "upload", elem_id="extend_source_audio_upload" + ), + inputs=[extend_source], + outputs=[extend_source_audio_upload], + ) - def load_data(json_file): - if isinstance(output_file_dir, str): - json_file = os.path.join(output_file_dir, json_file) - json_data = load_data_func(json_file) - return json2output(json_data) - - load_bnt.click( - fn=load_data, - inputs=[output_files], - outputs=[ - audio_duration, - prompt, - lyrics, - infer_step, - guidance_scale, - scheduler_type, - cfg_type, - omega_scale, - manual_seeds, - guidance_interval, - guidance_interval_decay, - min_guidance_scale, - use_erg_tag, - use_erg_lyric, - use_erg_diffusion, - oss_steps, - guidance_scale_text, - guidance_scale_lyric, - audio2audio_enable, - ref_audio_strength, - ref_audio_input, - ], - ) + extend_bnt = gr.Button("Extend", variant="primary") + extend_outputs, extend_input_params_json = create_output_ui("Extend") + + all_gradio_component_names = [] + all_gradio_components = [] + l = locals() + for k in l: + var = l[k] + if isinstance(var, gr.components.Component): + all_gradio_components.append(var) + all_gradio_component_names.append(k) + + def get_kwargs(args): + assert len(args) == len(all_gradio_component_names), 'name list length != argument list length' + kwargs = {} + for i in range(len(args)): + kwargs[all_gradio_component_names[i]] = args[i] + return kwargs + + text2music_outputs = [ + audio_duration, + prompt, + lyrics, + infer_step, + guidance_scale, + scheduler_type, + cfg_type, + omega_scale, + manual_seeds, + guidance_interval, + guidance_interval_decay, + min_guidance_scale, + use_erg_tag, + use_erg_lyric, + use_erg_diffusion, + oss_steps, + guidance_scale_text, + guidance_scale_lyric, + audio2audio_enable, + ref_audio_strength, + ref_audio_input, + + lora_name_or_path, + lora_weight, + #retake_variance, + #retake_seeds, + ] + l = locals() + lk = list(l.keys()) + lv = list(l.values()) + text2music_output_names = [ lk[lv.index(thing)] for thing in text2music_outputs ] + + def _get_args(kwargs): + args = [] + for name, component in zip(text2music_output_names, text2music_outputs): + v = kwargs.get(name, None) + if isinstance(component, gr.components.Textbox) and isinstance(v, list): + v = str(v)[1:-1] # remove brackets + args.append(v) + return args + + def select_input(kwargs, edit_source, upload_path, last_json): + params = {} + if kwargs[edit_source] == "upload": + params['src_audio_path'] = kwargs[upload_path] + params['audio_duration'] = librosa.get_duration(path=kwargs[upload_path]) + elif kwargs[edit_source] == "text2music": + params.update(kwargs['input_params_json']) + params['src_audio_path'] = params['audio_path'] + else: + params.update(kwargs[last_json]) + params['src_audio_path'] = params['audio_path'] + + kwargs['src_audio_path'] = params['src_audio_path'] + kwargs['audio_duration'] = params['audio_duration'] + return kwargs + + def sample_data(*args): + kwargs = get_kwargs(args) + json_data = sample_data_func(kwargs['lora_name_or_path']) + json_data['manual_seeds'] = json_data['actual_seeds'] + return _get_args(json_data) + + def load_data(*args): + kwargs = get_kwargs(args) + if isinstance(output_file_dir, str): + json_file = os.path.join(output_file_dir, kwargs['output_files']) + json_data = load_data_func(json_file) + json_data['manual_seeds'] = json_data['actual_seeds'] + return _get_args(json_data) + + def generate_process_function(*args): + kwargs = get_kwargs(args) + kwargs['task'] = 'text2music' + return text2music_process_func(**kwargs) + + def retake_process_func(*args): + kwargs = get_kwargs(args) + kwargs['manual_seeds'] = kwargs['input_params_json']['actual_seeds'] + kwargs['task'] = 'retake' + return text2music_process_func(**kwargs) + + def edit_process_func(*args): + kwargs = get_kwargs(args) + kwargs = select_input(kwargs, 'edit_source', 'edit_source_audio_upload', 'edit_input_params_json') + + if not kwargs.get('edit_target_prompt', ''): + kwargs['edit_target_prompt'] = kwargs['prompt'] + if not kwargs.get('edit_target_lyrics', ''): + kwargs['edit_target_lyrics'] = kwargs['lyrics'] + + kwargs['task'] = 'edit' + return text2music_process_func(**kwargs) + + def extend_process_func(*args): + kwargs = get_kwargs(args) + kwargs = select_input(kwargs, 'extend_source', 'extend_source_audio_upload', 'extend_input_params_json') + + kwargs['repaint_start'] = -kwargs['left_extend_length'] + kwargs['repaint_end'] = kwargs['audio_duration'] + kwargs['right_extend_length'] + kwargs['audio_duration'] += kwargs['right_extend_length'] + kwargs['left_extend_length'] + kwargs['task'] = 'extend' + return text2music_process_func(**kwargs) + + def repaint_process_func(*args): + kwargs = get_kwargs(args) + kwargs = select_input(kwargs, 'repaint_source', 'repaint_source_audio_upload', 'repaint_input_params_json') + + kwargs['task'] = 'repaint' + return text2music_process_func(**kwargs) text2music_bnt.click( - fn=text2music_process_func, - inputs=[ - format, - audio_duration, - prompt, - lyrics, - infer_step, - guidance_scale, - scheduler_type, - cfg_type, - omega_scale, - manual_seeds, - guidance_interval, - guidance_interval_decay, - min_guidance_scale, - use_erg_tag, - use_erg_lyric, - use_erg_diffusion, - oss_steps, - guidance_scale_text, - guidance_scale_lyric, - audio2audio_enable, - ref_audio_strength, - ref_audio_input, - lora_name_or_path, - lora_weight - ], - outputs=outputs + [input_params_json], + generate_process_function, + inputs=all_gradio_components, + outputs=[outputs, input_params_json] + ) + retake_bnt.click( + retake_process_func, + inputs=all_gradio_components, + outputs=[retake_outputs, retake_input_params_json] + ) + repaint_bnt.click( + repaint_process_func, + inputs=all_gradio_components, + outputs=[repaint_outputs, repaint_input_params_json] + ) + edit_bnt.click( + edit_process_func, + inputs=all_gradio_components, + outputs=[edit_outputs, edit_input_params_json] + ) + extend_bnt.click( + extend_process_func, + inputs=all_gradio_components, + outputs=[extend_outputs, extend_input_params_json] + ) + sample_bnt.click( + sample_data, + inputs=all_gradio_components, + outputs=text2music_outputs + ) + load_bnt.click( + load_data, + inputs=all_gradio_components, + outputs=text2music_outputs ) - def create_main_demo_ui( text2music_process_func=dump_func, sample_data_func=dump_func, load_data_func=dump_func, ): - with gr.Blocks( - title="ACE-Step Model 1.0 DEMO", - ) as demo: + with gr.Blocks(title="ACE-Step Model 1.0 DEMO") as demo: gr.Markdown( """ -

ACE-Step: A Step Towards Music Generation Foundation Model

- """ +

ACE-Step: A Step Towards Music Generation Foundation Model

+ """ ) with gr.Tab("text2music"): create_text2music_ui( - gr=gr, text2music_process_func=text2music_process_func, sample_data_func=sample_data_func, load_data_func=load_data_func, diff --git a/infer-api.py b/infer-api.py index 591fb3eb..b0f7cd6a 100644 --- a/infer-api.py +++ b/infer-api.py @@ -13,7 +13,7 @@ class ACEStepInput(BaseModel): bf16: bool = True torch_compile: bool = False device_id: int = 0 - output_path: Optional[str] = None + save_path: Optional[str] = None audio_duration: float prompt: str lyrics: str @@ -22,7 +22,7 @@ class ACEStepInput(BaseModel): scheduler_type: str cfg_type: str omega_scale: float - actual_seeds: List[int] + manual_seeds: List[int] guidance_interval: float guidance_interval_decay: float min_guidance_scale: float @@ -58,35 +58,12 @@ async def generate_audio(input_data: ACEStepInput): ) # Prepare parameters - params = ( - input_data.audio_duration, - input_data.prompt, - input_data.lyrics, - input_data.infer_step, - input_data.guidance_scale, - input_data.scheduler_type, - input_data.cfg_type, - input_data.omega_scale, - ", ".join(map(str, input_data.actual_seeds)), - input_data.guidance_interval, - input_data.guidance_interval_decay, - input_data.min_guidance_scale, - input_data.use_erg_tag, - input_data.use_erg_lyric, - input_data.use_erg_diffusion, - ", ".join(map(str, input_data.oss_steps)), - input_data.guidance_scale_text, - input_data.guidance_scale_lyric, - ) # Generate output path if not provided - output_path = input_data.output_path or f"output_{uuid.uuid4().hex}.wav" + input_data.save_path = input_data.save_path or f"output_{uuid.uuid4().hex}.wav" # Run pipeline - model_demo( - *params, - save_path=output_path - ) + model_demo(**input_data) return ACEStepOutput( status="success", diff --git a/infer.py b/infer.py index 6ee2ca0c..aebb3ceb 100644 --- a/infer.py +++ b/infer.py @@ -4,34 +4,6 @@ from acestep.pipeline_ace_step import ACEStepPipeline from acestep.data_sampler import DataSampler - -def sample_data(json_data): - return ( - json_data["audio_duration"], - json_data["prompt"], - json_data["lyrics"], - json_data["infer_step"], - json_data["guidance_scale"], - json_data["scheduler_type"], - json_data["cfg_type"], - json_data["omega_scale"], - ", ".join(map(str, json_data["actual_seeds"])), - json_data["guidance_interval"], - json_data["guidance_interval_decay"], - json_data["min_guidance_scale"], - json_data["use_erg_tag"], - json_data["use_erg_lyric"], - json_data["use_erg_diffusion"], - ", ".join(map(str, json_data["oss_steps"])), - json_data["guidance_scale_text"] if "guidance_scale_text" in json_data else 0.0, - ( - json_data["guidance_scale_lyric"] - if "guidance_scale_lyric" in json_data - else 0.0 - ), - ) - - @click.command() @click.option( "--checkpoint_path", type=str, default="", help="Path to the checkpoint directory" @@ -63,51 +35,10 @@ def main(checkpoint_path, bf16, torch_compile, cpu_offload, overlapped_decode, d data_sampler = DataSampler() json_data = data_sampler.sample() - json_data = sample_data(json_data) - print(json_data) - ( - audio_duration, - prompt, - lyrics, - infer_step, - guidance_scale, - scheduler_type, - cfg_type, - omega_scale, - manual_seeds, - guidance_interval, - guidance_interval_decay, - min_guidance_scale, - use_erg_tag, - use_erg_lyric, - use_erg_diffusion, - oss_steps, - guidance_scale_text, - guidance_scale_lyric, - ) = json_data + print(json_data) - model_demo( - audio_duration=audio_duration, - prompt=prompt, - lyrics=lyrics, - infer_step=infer_step, - guidance_scale=guidance_scale, - scheduler_type=scheduler_type, - cfg_type=cfg_type, - omega_scale=omega_scale, - manual_seeds=manual_seeds, - guidance_interval=guidance_interval, - guidance_interval_decay=guidance_interval_decay, - min_guidance_scale=min_guidance_scale, - use_erg_tag=use_erg_tag, - use_erg_lyric=use_erg_lyric, - use_erg_diffusion=use_erg_diffusion, - oss_steps=oss_steps, - guidance_scale_text=guidance_scale_text, - guidance_scale_lyric=guidance_scale_lyric, - save_path=output_path, - ) + model_demo(save_path=output_path, **json_data) if __name__ == "__main__":