diff --git a/kittentts/onnx_model.py b/kittentts/onnx_model.py index b368e4f..8c2c211 100644 --- a/kittentts/onnx_model.py +++ b/kittentts/onnx_model.py @@ -1,4 +1,5 @@ import os +from concurrent.futures import ThreadPoolExecutor import espeakng_loader from phonemizer.backend.espeak.wrapper import EspeakWrapper EspeakWrapper.set_library(espeakng_loader.get_library_path()) @@ -118,11 +119,22 @@ def normalize_text(self, text: str, locale: str = "en-US", return_spans: bool = return normalize_text(text, locale=locale, return_spans=return_spans) def generate(self, text: str, voice: str = "expr-voice-5-m", speed: float = 1.0, clean_text: bool=True) -> np.ndarray: - out_chunks = [] if clean_text: text = self.preprocessor(text) - for text_chunk in chunk_text(text): - out_chunks.append(self.generate_single_chunk(text_chunk, voice, speed)) + text_chunks = chunk_text(text) + + if len(text_chunks) <= 1: + out_chunks = [self.generate_single_chunk(c, voice, speed) for c in text_chunks] + else: + # Phonemization (espeak) keeps internal state and isn't safe to call + # concurrently, so prepare every chunk's inputs sequentially first. + # ONNX Runtime sessions support concurrent run() calls, so the actual + # (and far more expensive) inference step can run in parallel across chunks. + prepared_inputs = [self._prepare_inputs(c, voice, speed) for c in text_chunks] + max_workers = min(len(prepared_inputs), os.cpu_count() or 4) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + out_chunks = list(executor.map(self._run_inference, prepared_inputs)) + return np.concatenate(out_chunks, axis=-1) def generate_stream(self, text: str, voice: str = "expr-voice-5-m", speed: float = 1.0, clean_text: bool = True): @@ -136,25 +148,32 @@ def generate_stream(self, text: str, voice: str = "expr-voice-5-m", speed: float for text_chunk in chunk_text(text): yield self.generate_single_chunk(text_chunk, voice, speed) + def _run_inference(self, onnx_inputs: dict) -> np.ndarray: + """Run the ONNX session on already-prepared inputs and trim the output. + + Safe to call concurrently from multiple threads for different inputs + on the same session (unlike phonemization, which is not thread-safe). + """ + outputs = self.session.run(None, onnx_inputs) + + # Trim audio + audio = outputs[0][..., :-5000] + + return audio + def generate_single_chunk(self, text: str, voice: str = "expr-voice-5-m", speed: float = 1.0) -> np.ndarray: """Synthesize speech from text. - + Args: text: Input text to synthesize voice: Voice to use for synthesis speed: Speech speed (1.0 = normal) - + Returns: Audio data as numpy array """ onnx_inputs = self._prepare_inputs(text, voice, speed) - - outputs = self.session.run(None, onnx_inputs) - - # Trim audio - audio = outputs[0][..., :-5000] - - return audio + return self._run_inference(onnx_inputs) def generate_to_file(self, text: str, output_path: str, voice: str = "expr-voice-5-m", speed: float = 1.0, sample_rate: int = 24000, clean_text: bool=True) -> None: