From 3f5f17e0f76af51353799436a323dac807aaf19c Mon Sep 17 00:00:00 2001 From: Naman Omar <142042135+namanomar@users.noreply.github.com> Date: Mon, 29 Jun 2026 02:45:51 +0530 Subject: [PATCH] Parallelize multi-chunk inference for ~3-4x speedup on multi-sentence text generate() previously ran each text chunk's phonemization and ONNX inference strictly sequentially. Phonemization (espeak) keeps shared internal state and isn't safe to call concurrently, but ONNX Runtime sessions support concurrent run() calls. Splitting these two steps lets inference across chunks run in a thread pool while keeping phonemization sequential, since it's the inference step that dominates latency (>99.9% of per-chunk time per profiling). Benchmarked on a 5-sentence paragraph (kitten-tts-mini-0.8, 12-core CPU): ~37s sequential -> ~8-9s parallel. Single-chunk text is unaffected (no executor overhead, same as before). Verified no NaN/Inf, no clipping, and stable output across repeated runs, multiple voices, and longer multi-chunk inputs (up to 15 chunks). --- kittentts/onnx_model.py | 43 +++++++++++++++++++++++++++++------------ 1 file changed, 31 insertions(+), 12 deletions(-) diff --git a/kittentts/onnx_model.py b/kittentts/onnx_model.py index b368e4f..8c2c211 100644 --- a/kittentts/onnx_model.py +++ b/kittentts/onnx_model.py @@ -1,4 +1,5 @@ import os +from concurrent.futures import ThreadPoolExecutor import espeakng_loader from phonemizer.backend.espeak.wrapper import EspeakWrapper EspeakWrapper.set_library(espeakng_loader.get_library_path()) @@ -118,11 +119,22 @@ def normalize_text(self, text: str, locale: str = "en-US", return_spans: bool = return normalize_text(text, locale=locale, return_spans=return_spans) def generate(self, text: str, voice: str = "expr-voice-5-m", speed: float = 1.0, clean_text: bool=True) -> np.ndarray: - out_chunks = [] if clean_text: text = self.preprocessor(text) - for text_chunk in chunk_text(text): - out_chunks.append(self.generate_single_chunk(text_chunk, voice, speed)) + text_chunks = chunk_text(text) + + if len(text_chunks) <= 1: + out_chunks = [self.generate_single_chunk(c, voice, speed) for c in text_chunks] + else: + # Phonemization (espeak) keeps internal state and isn't safe to call + # concurrently, so prepare every chunk's inputs sequentially first. + # ONNX Runtime sessions support concurrent run() calls, so the actual + # (and far more expensive) inference step can run in parallel across chunks. + prepared_inputs = [self._prepare_inputs(c, voice, speed) for c in text_chunks] + max_workers = min(len(prepared_inputs), os.cpu_count() or 4) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + out_chunks = list(executor.map(self._run_inference, prepared_inputs)) + return np.concatenate(out_chunks, axis=-1) def generate_stream(self, text: str, voice: str = "expr-voice-5-m", speed: float = 1.0, clean_text: bool = True): @@ -136,25 +148,32 @@ def generate_stream(self, text: str, voice: str = "expr-voice-5-m", speed: float for text_chunk in chunk_text(text): yield self.generate_single_chunk(text_chunk, voice, speed) + def _run_inference(self, onnx_inputs: dict) -> np.ndarray: + """Run the ONNX session on already-prepared inputs and trim the output. + + Safe to call concurrently from multiple threads for different inputs + on the same session (unlike phonemization, which is not thread-safe). + """ + outputs = self.session.run(None, onnx_inputs) + + # Trim audio + audio = outputs[0][..., :-5000] + + return audio + def generate_single_chunk(self, text: str, voice: str = "expr-voice-5-m", speed: float = 1.0) -> np.ndarray: """Synthesize speech from text. - + Args: text: Input text to synthesize voice: Voice to use for synthesis speed: Speech speed (1.0 = normal) - + Returns: Audio data as numpy array """ onnx_inputs = self._prepare_inputs(text, voice, speed) - - outputs = self.session.run(None, onnx_inputs) - - # Trim audio - audio = outputs[0][..., :-5000] - - return audio + return self._run_inference(onnx_inputs) def generate_to_file(self, text: str, output_path: str, voice: str = "expr-voice-5-m", speed: float = 1.0, sample_rate: int = 24000, clean_text: bool=True) -> None: