From 1331c7d56dcf3c7a1200740e871565df1680de24 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Thu, 10 Nov 2022 22:58:12 +0530 Subject: [PATCH 1/9] add lit app --- app.py | 57 ++++++++++++++++++++++++++++++++++++++++++++++++ requirements.txt | 1 + 2 files changed, 58 insertions(+) create mode 100644 app.py diff --git a/app.py b/app.py new file mode 100644 index 0000000..ff568e4 --- /dev/null +++ b/app.py @@ -0,0 +1,57 @@ +import lightning as L +import os, wget, json +import gradio as gr +from lightning.app.components.serve import ServeGradio +from Waveformer import TARGETS, Waveformer + +class ModelDemo(ServeGradio): + inputs = [gr.Audio(label="Input audio"), gr.CheckboxGroup(choices=TARGETS, label="Input target selection(s)")] + outputs = gr.Audio(label="Output audio") + examples = [["data/Sample.wav"]] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def build_model(self): + import torch + + if not os.path.exists('default_config.json'): + config_url = 'https://targetsound.cs.washington.edu/files/default_config.json' + print("Downloading model configuration from %s:" % config_url) + wget.download(config_url) + + if not os.path.exists('default_ckpt.pt'): + ckpt_url = 'https://targetsound.cs.washington.edu/files/default_ckpt.pt' + print("\nDownloading the checkpoint from %s:" % ckpt_url) + wget.download(ckpt_url) + + # Instantiate model + with open('default_config.json') as f: + params = json.load(f) + model = Waveformer(**params['model_params']) + model.load_state_dict( + torch.load('default_ckpt.pt', map_location=torch.device('cpu'))['model_state_dict']) + model.eval() + return model + + def predict(self, audio, label_choices): + import torch, torchaudio + # Read input audio + fs, mixture = audio + mixture = torchaudio.functional.resample(torch.as_tensor(mixture, dtype=torch.float32), orig_freq=fs, new_freq=44100).numpy() + # if fs != 44100: + # raise ValueError("Sampling rate must be 44100, but got %d" % fs) + mixture = torch.from_numpy( + mixture).unsqueeze(0).unsqueeze(0).to(torch.float) / (2.0 ** 15) + + # Construct the query vector + query = torch.zeros(1, len(TARGETS)) + for t in label_choices: + query[0, TARGETS.index(t)] = 1. + + with torch.no_grad(): + output = (2.0 ** 15) * self.model(mixture, query) + + return fs, output.squeeze(0).squeeze(0).to(torch.short).numpy() + +app = L.LightningApp(ModelDemo()) diff --git a/requirements.txt b/requirements.txt index a472039..81d400c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,3 +16,4 @@ seaborn ipykernel scaper wget +gradio \ No newline at end of file From 463842ebc244bc4382285b9b698144e4a0501f09 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Thu, 10 Nov 2022 23:13:24 +0530 Subject: [PATCH 2/9] update --- app.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/app.py b/app.py index ff568e4..6f06caf 100644 --- a/app.py +++ b/app.py @@ -1,4 +1,5 @@ import lightning as L +import torch, torchaudio import os, wget, json import gradio as gr from lightning.app.components.serve import ServeGradio @@ -8,7 +9,7 @@ class ModelDemo(ServeGradio): inputs = [gr.Audio(label="Input audio"), gr.CheckboxGroup(choices=TARGETS, label="Input target selection(s)")] outputs = gr.Audio(label="Output audio") examples = [["data/Sample.wav"]] - + enable_queue: bool=False def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -35,12 +36,10 @@ def build_model(self): return model def predict(self, audio, label_choices): - import torch, torchaudio # Read input audio fs, mixture = audio mixture = torchaudio.functional.resample(torch.as_tensor(mixture, dtype=torch.float32), orig_freq=fs, new_freq=44100).numpy() - # if fs != 44100: - # raise ValueError("Sampling rate must be 44100, but got %d" % fs) + mixture = torch.from_numpy( mixture).unsqueeze(0).unsqueeze(0).to(torch.float) / (2.0 ** 15) From 043ef64716f5c2886e86566e9ff1b302ebf11eb3 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Thu, 10 Nov 2022 23:14:33 +0530 Subject: [PATCH 3/9] format --- app.py | 55 +++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 37 insertions(+), 18 deletions(-) diff --git a/app.py b/app.py index 6f06caf..74a95c7 100644 --- a/app.py +++ b/app.py @@ -1,56 +1,75 @@ -import lightning as L -import torch, torchaudio -import os, wget, json +import json +import os + import gradio as gr +import lightning as L +import torch +import torchaudio +import wget from lightning.app.components.serve import ServeGradio + from Waveformer import TARGETS, Waveformer + class ModelDemo(ServeGradio): - inputs = [gr.Audio(label="Input audio"), gr.CheckboxGroup(choices=TARGETS, label="Input target selection(s)")] + inputs = [ + gr.Audio(label="Input audio"), + gr.CheckboxGroup(choices=TARGETS, label="Input target selection(s)"), + ] outputs = gr.Audio(label="Output audio") examples = [["data/Sample.wav"]] - enable_queue: bool=False + enable_queue: bool = False + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def build_model(self): import torch - if not os.path.exists('default_config.json'): - config_url = 'https://targetsound.cs.washington.edu/files/default_config.json' + if not os.path.exists("default_config.json"): + config_url = ( + "https://targetsound.cs.washington.edu/files/default_config.json" + ) print("Downloading model configuration from %s:" % config_url) wget.download(config_url) - if not os.path.exists('default_ckpt.pt'): - ckpt_url = 'https://targetsound.cs.washington.edu/files/default_ckpt.pt' + if not os.path.exists("default_ckpt.pt"): + ckpt_url = "https://targetsound.cs.washington.edu/files/default_ckpt.pt" print("\nDownloading the checkpoint from %s:" % ckpt_url) wget.download(ckpt_url) # Instantiate model - with open('default_config.json') as f: + with open("default_config.json") as f: params = json.load(f) - model = Waveformer(**params['model_params']) + model = Waveformer(**params["model_params"]) model.load_state_dict( - torch.load('default_ckpt.pt', map_location=torch.device('cpu'))['model_state_dict']) + torch.load("default_ckpt.pt", map_location=torch.device("cpu"))[ + "model_state_dict" + ] + ) model.eval() return model - + def predict(self, audio, label_choices): # Read input audio fs, mixture = audio - mixture = torchaudio.functional.resample(torch.as_tensor(mixture, dtype=torch.float32), orig_freq=fs, new_freq=44100).numpy() + mixture = torchaudio.functional.resample( + torch.as_tensor(mixture, dtype=torch.float32), orig_freq=fs, new_freq=44100 + ).numpy() - mixture = torch.from_numpy( - mixture).unsqueeze(0).unsqueeze(0).to(torch.float) / (2.0 ** 15) + mixture = torch.from_numpy(mixture).unsqueeze(0).unsqueeze(0).to( + torch.float + ) / (2.0**15) # Construct the query vector query = torch.zeros(1, len(TARGETS)) for t in label_choices: - query[0, TARGETS.index(t)] = 1. + query[0, TARGETS.index(t)] = 1.0 with torch.no_grad(): - output = (2.0 ** 15) * self.model(mixture, query) + output = (2.0**15) * self.model(mixture, query) return fs, output.squeeze(0).squeeze(0).to(torch.short).numpy() + app = L.LightningApp(ModelDemo()) From 5e3073509f21e243bc198a5ec7d44c6b0286c568 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Thu, 10 Nov 2022 23:38:35 +0530 Subject: [PATCH 4/9] inference mode --- README.md | 1 + app.py | 8 +++++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index d640cd5..0b84979 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,7 @@ This repository provides code for the Waveformer architecture proposed in the paper. Waveformer is a low-latency target sound extraction model implementing streaming inference -- the model process a ~10 ms input audio chunk at each time step, while only looking at past chunks and no future chunks. On a Core i5 CPU using a single thread, real-time factors (RTFs) of different model configurations range from 0.66 to 0.94, with an end-to-end latency less than 20 ms. [![Gradio demo](https://img.shields.io/badge/arxiv-abs-green)](https://arxiv.org/abs/2211.02250) [![Gradio demo](https://img.shields.io/badge/arxiv-pdf-green)](https://arxiv.org/pdf/2211.02250) [![Gradio demo](https://img.shields.io/badge/Gradio-app-blue)](https://huggingface.co/spaces/uwx/waveformer) +[![App Gallery](https://bit.ly/3xTcccO)](https://01ghh2pnbdet9ex9sdqqsnpxwh.litng-ai-03.litng.ai/view) diff --git a/app.py b/app.py index 74a95c7..d4d2b68 100644 --- a/app.py +++ b/app.py @@ -8,7 +8,8 @@ import wget from lightning.app.components.serve import ServeGradio -from Waveformer import TARGETS, Waveformer +from Waveformer import TARGETS +from Waveformer import Waveformer as WaveformerModel class ModelDemo(ServeGradio): @@ -21,7 +22,7 @@ class ModelDemo(ServeGradio): enable_queue: bool = False def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + super().__init__(cloud_compute=L.CloudCompute("cpu-medium"), **kwargs) def build_model(self): import torch @@ -41,7 +42,7 @@ def build_model(self): # Instantiate model with open("default_config.json") as f: params = json.load(f) - model = Waveformer(**params["model_params"]) + model = WaveformerModel(**params["model_params"]) model.load_state_dict( torch.load("default_ckpt.pt", map_location=torch.device("cpu"))[ "model_state_dict" @@ -50,6 +51,7 @@ def build_model(self): model.eval() return model + @torch.inference_mode() def predict(self, audio, label_choices): # Read input audio fs, mixture = audio From f1aa3545f8874308b8ffea5f358a63e961d2d71e Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Fri, 11 Nov 2022 12:15:44 +0530 Subject: [PATCH 5/9] update --- app.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/app.py b/app.py index d4d2b68..9a7078a 100644 --- a/app.py +++ b/app.py @@ -15,7 +15,7 @@ class ModelDemo(ServeGradio): inputs = [ gr.Audio(label="Input audio"), - gr.CheckboxGroup(choices=TARGETS, label="Input target selection(s)"), + gr.Checkbox(choices=TARGETS, label="Extract target sound"), ] outputs = gr.Audio(label="Output audio") examples = [["data/Sample.wav"]] @@ -25,8 +25,6 @@ def __init__(self, *args, **kwargs): super().__init__(cloud_compute=L.CloudCompute("cpu-medium"), **kwargs) def build_model(self): - import torch - if not os.path.exists("default_config.json"): config_url = ( "https://targetsound.cs.washington.edu/files/default_config.json" @@ -43,10 +41,12 @@ def build_model(self): with open("default_config.json") as f: params = json.load(f) model = WaveformerModel(**params["model_params"]) + if torch.cuda.is_available(): + device = torch.device("cuda") + else: + device = torch.device("cpu") model.load_state_dict( - torch.load("default_ckpt.pt", map_location=torch.device("cpu"))[ - "model_state_dict" - ] + torch.load("default_ckpt.pt", map_location=device)["model_state_dict"] ) model.eval() return model @@ -55,9 +55,10 @@ def build_model(self): def predict(self, audio, label_choices): # Read input audio fs, mixture = audio - mixture = torchaudio.functional.resample( - torch.as_tensor(mixture, dtype=torch.float32), orig_freq=fs, new_freq=44100 - ).numpy() + if fs!=44100: + mixture = torchaudio.functional.resample( + torch.as_tensor(mixture, dtype=torch.float32), orig_freq=fs, new_freq=44100 + ).numpy() mixture = torch.from_numpy(mixture).unsqueeze(0).unsqueeze(0).to( torch.float @@ -68,7 +69,7 @@ def predict(self, audio, label_choices): for t in label_choices: query[0, TARGETS.index(t)] = 1.0 - with torch.no_grad(): + with torch.inference_mode(): output = (2.0**15) * self.model(mixture, query) return fs, output.squeeze(0).squeeze(0).to(torch.short).numpy() From 4b0fba53fb0c8726553d666413c43fac4496e397 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Fri, 11 Nov 2022 12:18:50 +0530 Subject: [PATCH 6/9] enable MPS device --- app.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/app.py b/app.py index 9a7078a..3631bf7 100644 --- a/app.py +++ b/app.py @@ -15,7 +15,7 @@ class ModelDemo(ServeGradio): inputs = [ gr.Audio(label="Input audio"), - gr.Checkbox(choices=TARGETS, label="Extract target sound"), + gr.CheckboxGroup(choices=TARGETS, label="Extract target sound"), ] outputs = gr.Audio(label="Output audio") examples = [["data/Sample.wav"]] @@ -43,8 +43,11 @@ def build_model(self): model = WaveformerModel(**params["model_params"]) if torch.cuda.is_available(): device = torch.device("cuda") + elif torch.backends.mps.is_available(): + device = torch.device("mps") else: device = torch.device("cpu") + print(f"loading model on {device}") model.load_state_dict( torch.load("default_ckpt.pt", map_location=device)["model_state_dict"] ) From cdf2d98392c37013a8b2c8ce1a2191140b5cea4a Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Fri, 11 Nov 2022 12:19:46 +0530 Subject: [PATCH 7/9] mps --- app.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/app.py b/app.py index 3631bf7..3f0c027 100644 --- a/app.py +++ b/app.py @@ -51,8 +51,7 @@ def build_model(self): model.load_state_dict( torch.load("default_ckpt.pt", map_location=device)["model_state_dict"] ) - model.eval() - return model + return model.to(device).eval() @torch.inference_mode() def predict(self, audio, label_choices): From 7d1fab86b21c0856e27eec668296bacb9e79b662 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Fri, 11 Nov 2022 12:34:36 +0530 Subject: [PATCH 8/9] remove mps --- app.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/app.py b/app.py index 3f0c027..6929c53 100644 --- a/app.py +++ b/app.py @@ -43,8 +43,6 @@ def build_model(self): model = WaveformerModel(**params["model_params"]) if torch.cuda.is_available(): device = torch.device("cuda") - elif torch.backends.mps.is_available(): - device = torch.device("mps") else: device = torch.device("cpu") print(f"loading model on {device}") From 4cf630459997043da1ab071a8092e753d4e30dfb Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Fri, 11 Nov 2022 12:38:57 +0530 Subject: [PATCH 9/9] support device --- app.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/app.py b/app.py index 6929c53..ab0edb0 100644 --- a/app.py +++ b/app.py @@ -23,6 +23,7 @@ class ModelDemo(ServeGradio): def __init__(self, *args, **kwargs): super().__init__(cloud_compute=L.CloudCompute("cpu-medium"), **kwargs) + self._device = None def build_model(self): if not os.path.exists("default_config.json"): @@ -45,11 +46,12 @@ def build_model(self): device = torch.device("cuda") else: device = torch.device("cpu") + self._device = device print(f"loading model on {device}") model.load_state_dict( - torch.load("default_ckpt.pt", map_location=device)["model_state_dict"] + torch.load("default_ckpt.pt", map_location=self._device)["model_state_dict"] ) - return model.to(device).eval() + return model.to(self._device).eval() @torch.inference_mode() def predict(self, audio, label_choices): @@ -65,14 +67,14 @@ def predict(self, audio, label_choices): ) / (2.0**15) # Construct the query vector - query = torch.zeros(1, len(TARGETS)) + query = torch.zeros(1, len(TARGETS)).to(self._device) for t in label_choices: query[0, TARGETS.index(t)] = 1.0 with torch.inference_mode(): - output = (2.0**15) * self.model(mixture, query) + output = (2.0**15) * self.model(mixture.to(self._device), query) - return fs, output.squeeze(0).squeeze(0).to(torch.short).numpy() + return fs, output.squeeze(0).squeeze(0).to(torch.short).cpu().numpy() app = L.LightningApp(ModelDemo())