Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 31 additions & 12 deletions kittentts/onnx_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from concurrent.futures import ThreadPoolExecutor
import espeakng_loader
from phonemizer.backend.espeak.wrapper import EspeakWrapper
EspeakWrapper.set_library(espeakng_loader.get_library_path())
Expand Down Expand Up @@ -118,11 +119,22 @@ def normalize_text(self, text: str, locale: str = "en-US", return_spans: bool =
return normalize_text(text, locale=locale, return_spans=return_spans)

def generate(self, text: str, voice: str = "expr-voice-5-m", speed: float = 1.0, clean_text: bool=True) -> np.ndarray:
out_chunks = []
if clean_text:
text = self.preprocessor(text)
for text_chunk in chunk_text(text):
out_chunks.append(self.generate_single_chunk(text_chunk, voice, speed))
text_chunks = chunk_text(text)

if len(text_chunks) <= 1:
out_chunks = [self.generate_single_chunk(c, voice, speed) for c in text_chunks]
else:
# Phonemization (espeak) keeps internal state and isn't safe to call
# concurrently, so prepare every chunk's inputs sequentially first.
# ONNX Runtime sessions support concurrent run() calls, so the actual
# (and far more expensive) inference step can run in parallel across chunks.
prepared_inputs = [self._prepare_inputs(c, voice, speed) for c in text_chunks]
max_workers = min(len(prepared_inputs), os.cpu_count() or 4)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
out_chunks = list(executor.map(self._run_inference, prepared_inputs))

return np.concatenate(out_chunks, axis=-1)

def generate_stream(self, text: str, voice: str = "expr-voice-5-m", speed: float = 1.0, clean_text: bool = True):
Expand All @@ -136,25 +148,32 @@ def generate_stream(self, text: str, voice: str = "expr-voice-5-m", speed: float
for text_chunk in chunk_text(text):
yield self.generate_single_chunk(text_chunk, voice, speed)

def _run_inference(self, onnx_inputs: dict) -> np.ndarray:
"""Run the ONNX session on already-prepared inputs and trim the output.

Safe to call concurrently from multiple threads for different inputs
on the same session (unlike phonemization, which is not thread-safe).
"""
outputs = self.session.run(None, onnx_inputs)

# Trim audio
audio = outputs[0][..., :-5000]

return audio

def generate_single_chunk(self, text: str, voice: str = "expr-voice-5-m", speed: float = 1.0) -> np.ndarray:
"""Synthesize speech from text.

Args:
text: Input text to synthesize
voice: Voice to use for synthesis
speed: Speech speed (1.0 = normal)

Returns:
Audio data as numpy array
"""
onnx_inputs = self._prepare_inputs(text, voice, speed)

outputs = self.session.run(None, onnx_inputs)

# Trim audio
audio = outputs[0][..., :-5000]

return audio
return self._run_inference(onnx_inputs)

def generate_to_file(self, text: str, output_path: str, voice: str = "expr-voice-5-m",
speed: float = 1.0, sample_rate: int = 24000, clean_text: bool=True) -> None:
Expand Down