diff --git a/backend/backend.proto b/backend/backend.proto index 2a575426e4d3..01c5b63a7b5c 100644 --- a/backend/backend.proto +++ b/backend/backend.proto @@ -18,6 +18,18 @@ service Backend { rpc GenerateVideo(GenerateVideoRequest) returns (Result) {} rpc AudioTranscription(TranscriptRequest) returns (TranscriptResult) {} rpc AudioTranscriptionStream(TranscriptRequest) returns (stream TranscriptStreamResponse) {} + // AudioTranscriptionLive is the bidirectional live-microphone ASR RPC. The + // first message MUST carry a Config; subsequent messages carry Audio frames + // (mono float PCM at config.sample_rate, 16 kHz default). After a + // successful open the backend replies with a single ready ack + // (TranscriptLiveResponse{ready:true}); backends or models without + // cache-aware streaming support return UNIMPLEMENTED instead. Newly + // finalized text streams back as deltas; eou=true marks the model's + // end-of-utterance token. One stream spans many utterances (the decoder + // resets itself after each EOU). Closing the send side finalizes: the + // backend flushes the decoder tail and emits a terminal message carrying + // final_result. A second Config mid-stream resets the decode session. + rpc AudioTranscriptionLive(stream TranscriptLiveRequest) returns (stream TranscriptLiveResponse) {} rpc TTS(TTSRequest) returns (Result) {} rpc TTSStream(TTSRequest) returns (stream Reply) {} rpc SoundGeneration(SoundGenerationRequest) returns (Result) {} @@ -479,6 +491,10 @@ message TranscriptResult { string text = 2; string language = 3; float duration = 4; + // True when the decode ended on the model's end-of-utterance special token + // (/, emitted by cache-aware streaming models such as + // parakeet_realtime_eou_120m-v1). The marker itself is stripped from text. + bool eou = 5; } message TranscriptStreamResponse { @@ -486,6 +502,34 @@ message TranscriptStreamResponse { TranscriptResult final_result = 2; } +// === AudioTranscriptionLive messages ===================================== + +message TranscriptLiveRequest { + oneof payload { + TranscriptLiveConfig config = 1; + TranscriptLiveAudio audio = 2; + } +} + +message TranscriptLiveConfig { + string language = 1; // "" => model default + int32 sample_rate = 2; // 0 => 16000; backends may reject others + map params = 3; // backend-specific tuning +} + +message TranscriptLiveAudio { + repeated float pcm = 1; // mono PCM in [-1,1] at config.sample_rate +} + +message TranscriptLiveResponse { + bool ready = 1; // open ack: sent once, before any delta + string delta = 2; // newly-finalized text since previous response + bool eou = 3; // fired during this feed (the user yielded the turn) + repeated TranscriptWord words = 4; // words finalized by this feed (stream-relative ns) + TranscriptResult final_result = 5; // terminal message only, after the send side closes + bool eob = 6; // fired: a backchannel ("uh-huh") ended — NOT a turn boundary +} + message TranscriptWord { int64 start = 1; int64 end = 2; diff --git a/backend/go/parakeet-cpp/Makefile b/backend/go/parakeet-cpp/Makefile index 9a781d634480..ea2c80243f53 100644 --- a/backend/go/parakeet-cpp/Makefile +++ b/backend/go/parakeet-cpp/Makefile @@ -15,6 +15,10 @@ # That's what the L0 smoke test uses. The default target below does the # proper clone-at-pin + cmake build so CI doesn't need a side-checkout. +# ABI v5: incremental StreamingMel (live feeds no longer recompute the full mel +# per call, which fell behind real time and delayed by seconds on long +# turns) plus the / split (eou_out bitmask + JSON "eob" field) so +# backchannels are not mistaken for turn boundaries. PARAKEET_VERSION?=db755a78d39f789bb7d4e3935158a9e8105dbe36 PARAKEET_REPO?=https://github.com/mudler/parakeet.cpp diff --git a/backend/go/parakeet-cpp/goparakeetcpp.go b/backend/go/parakeet-cpp/goparakeetcpp.go index e87409255465..6be8e0b81069 100644 --- a/backend/go/parakeet-cpp/goparakeetcpp.go +++ b/backend/go/parakeet-cpp/goparakeetcpp.go @@ -103,12 +103,13 @@ type transcriptJSON struct { // {"text":"...","eou":0,"eob":0,"frame_sec":0.080000, // "words":[{"w":"...","start":0.480,"end":0.640,"conf":0.9100}, ...]} // -// "text" is the newly-finalized text since the last call; "eou" is 1 when an -// (end of utterance) fired this feed and "eob" is 1 when an -// (backchannel) fired. ABI v4 conflated the two into "eou"; v5 split them, so -// we read both and treat either as an utterance boundary for segmentation. -// "words" are the words finalized this call with absolute (stream-relative) -// start/end seconds. +// "text" is the newly-finalized text since the last call. Under ABI v5 "eou" +// is 1 iff an fired this feed (the user yielded the turn) and "eob" 1 +// iff an fired (a backchannel like "uh-huh" ended — NOT a turn +// boundary). A v4 library has no "eob" field and its "eou" conflates both +// tokens: Eob stays 0 and Eou keeps the old any-event meaning. "words" are +// the words finalized this call with absolute (stream-relative) start/end +// seconds. type streamFeedJSON struct { Text string `json:"text"` Eou int `json:"eou"` @@ -364,7 +365,7 @@ var segmentSeparators = []rune{'.', '?', '!'} // the caller requested word granularity; token ids populate each segment's // Tokens by time-window membership. Shared by the batched and direct paths. func transcriptResultFromDoc(doc transcriptJSON, opts *pb.TranscriptRequest, gapFrames int) pb.TranscriptResult { - text := strings.TrimSpace(doc.Text) + text, eou := stripEouMarker(strings.TrimSpace(doc.Text)) // Frame-unit gap threshold -> seconds (NeMo segment_gap_threshold). 0 = off. gapSeconds := 0.0 @@ -383,6 +384,7 @@ func transcriptResultFromDoc(doc transcriptJSON, opts *pb.TranscriptRequest, gap return pb.TranscriptResult{ Text: text, Segments: []*pb.TranscriptSegment{{Id: 0, Text: text}}, + Eou: eou, } } @@ -409,7 +411,25 @@ func transcriptResultFromDoc(doc transcriptJSON, opts *pb.TranscriptRequest, gap } segments = append(segments, seg) } - return pb.TranscriptResult{Text: text, Segments: segments} + return pb.TranscriptResult{Text: text, Segments: segments, Eou: eou} +} + +// stripEouMarker removes a trailing literal / from offline-decode +// text and reports whether the decode ended on an end-of-UTTERANCE token. The +// realtime EOU model's offline decode keeps the special token in the +// detokenized text (the streaming path strips it and surfaces it as flags +// instead); user-visible transcripts must never carry either marker, but only +// may confirm the semantic_vad retranscribe cross-check — a decode +// ending on means the last thing heard was a backchannel, not the user +// yielding the turn. +func stripEouMarker(text string) (string, bool) { + if strings.HasSuffix(text, "") { + return strings.TrimSpace(strings.TrimSuffix(text, "")), true + } + if strings.HasSuffix(text, "") { + return strings.TrimSpace(strings.TrimSuffix(text, "")), false + } + return text, false } // splitWordsIntoSegments groups words into segments exactly as NeMo's @@ -487,9 +507,7 @@ type streamSegmenter struct { func (s *streamSegmenter) add(doc streamFeedJSON) { s.cur = append(s.cur, doc.Words...) - // Close the segment on either turn signal: (end of utterance) or - // (backchannel). ABI v4 reported both via "eou"; v5 split them, so we - // OR them here to keep the v4 segmentation boundaries. + // Both and reset the decoder, so both close a segment. if doc.Eou != 0 || doc.Eob != 0 { s.flush() } @@ -535,6 +553,107 @@ func secondsToNanos(sec float64) int64 { return int64(sec * 1e9) } +// Per-C-call engine serialization for the streaming paths. +// +// Every individual C call (begin / feed / finalize / free) takes engineMu and +// re-checks ctxPtr under the lock; the lock is NEVER held across a stream's +// lifetime. This is safe because each parakeet.cpp call builds its own ggml +// graph and all streaming caches live in the session object, not the ctx — +// the only ctx-shared mutable state is last_error, which is why it is read +// under the same lock as the failing call. Holding the lock per call (rather +// than per stream, as this file previously did) keeps a long-lived live +// session from starving batched unary transcription and vice versa. +// +// A stream must not outlive its ctx (C-API contract). Free() takes engineMu +// and zeroes ctxPtr, so a racing per-call helper returns ModelNotLoaded +// instead of feeding a freed engine; streamFree of an orphaned session only +// runs the session destructor, which does not touch the ctx. + +// streamBegin opens a cache-aware streaming session. A 0 stream with nil +// error means the loaded model is not a streaming model. +func (p *ParakeetCpp) streamBegin(lang string) (uintptr, error) { + p.engineMu.Lock() + defer p.engineMu.Unlock() + if p.ctxPtr == 0 { + return 0, grpcerrors.ModelNotLoaded("parakeet-cpp") + } + if CppStreamBeginLang != nil { + return CppStreamBeginLang(p.ctxPtr, lang), nil + } + return CppStreamBegin(p.ctxPtr), nil +} + +func (p *ParakeetCpp) streamFree(stream uintptr) { + if stream == 0 { + return + } + p.engineMu.Lock() + defer p.engineMu.Unlock() + CppStreamFree(stream) +} + +// streamFeedText runs one text-mode feed (or the finalize flush when +// finalize is true) under engineMu, returning the newly-finalized delta and +// whether an / fired during the call. +func (p *ParakeetCpp) streamFeedText(stream uintptr, pcm []float32, finalize bool) (delta string, eou, eob bool, err error) { + p.engineMu.Lock() + defer p.engineMu.Unlock() + if p.ctxPtr == 0 { + return "", false, false, grpcerrors.ModelNotLoaded("parakeet-cpp") + } + var ret uintptr + var events int32 + if finalize { + ret = CppStreamFinalize(stream) + } else { + ret = CppStreamFeed(stream, pcm, int32(len(pcm)), unsafe.Pointer(&events)) + } + if ret == 0 { + // last_error is ctx-shared: read it under the same lock as the call. + msg := CppLastError(p.ctxPtr) + if msg == "" { + msg = "unknown error" + } + return "", false, false, fmt.Errorf("parakeet-cpp: stream feed/finalize failed: %s", msg) + } + delta = goStringFromCPtr(ret) + CppFreeString(ret) + // ABI v5: eou_out is a bitmask (bit 0 = , bit 1 = ). A v4 + // library sets 0/1 for either token, which the bit-0 test reads as the + // old conflated eou — the EOB distinction simply isn't available there. + return delta, events&1 != 0, events&2 != 0, nil +} + +// streamFeedDoc runs one ABI v4 JSON feed (or finalize) under engineMu and +// returns the parsed {text,eou,frame_sec,words} document. +func (p *ParakeetCpp) streamFeedDoc(stream uintptr, pcm []float32, finalize bool) (streamFeedJSON, error) { + p.engineMu.Lock() + defer p.engineMu.Unlock() + if p.ctxPtr == 0 { + return streamFeedJSON{}, grpcerrors.ModelNotLoaded("parakeet-cpp") + } + var ret uintptr + if finalize { + ret = CppStreamFinalizeJSON(stream) + } else { + ret = CppStreamFeedJSON(stream, pcm, int32(len(pcm))) + } + if ret == 0 { + msg := CppLastError(p.ctxPtr) + if msg == "" { + msg = "unknown error" + } + return streamFeedJSON{}, fmt.Errorf("parakeet-cpp: stream feed/finalize failed: %s", msg) + } + raw := goStringFromCPtr(ret) + CppFreeString(ret) + var doc streamFeedJSON + if err := json.Unmarshal([]byte(raw), &doc); err != nil { + return streamFeedJSON{}, fmt.Errorf("parakeet-cpp: decode stream json: %w", err) + } + return doc, nil +} + // AudioTranscriptionStream drives the cache-aware streaming RNN-T over the // audio at opts.Dst: it decodes the file to 16 kHz mono PCM, feeds it in // chunks to parakeet_capi_stream_feed, and emits each newly-finalized text @@ -560,11 +679,9 @@ func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.Tra return status.Error(codes.Canceled, "transcription cancelled") } - var stream uintptr - if CppStreamBeginLang != nil { - stream = CppStreamBeginLang(p.ctxPtr, opts.GetLanguage()) - } else { - stream = CppStreamBegin(p.ctxPtr) + stream, err := p.streamBegin(opts.GetLanguage()) + if err != nil { + return err } if stream == 0 { // Not a cache-aware streaming model: run a normal offline @@ -579,25 +696,16 @@ func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.Tra results <- &pb.TranscriptStreamResponse{FinalResult: &res} return nil } - defer CppStreamFree(stream) - // The C engine is a single shared context: a streaming session and a batched - // unary dispatch must never touch it at once, so hold engineMu for the whole - // stream. This lock is intentionally taken AFTER the non-streaming fallback - // above returns: that fallback goes through AudioTranscription -> the batcher - // -> runBatch, which itself acquires engineMu, so locking here first would - // deadlock. Do not hoist this lock above the fallback. - p.engineMu.Lock() - defer p.engineMu.Unlock() + defer p.streamFree(stream) data, duration, err := decodeWavMono16k(opts.Dst) if err != nil { return err } - // ABI v4: when the streaming JSON entry points are present, drive them so the - // per-utterance segments carry per-word start/end timestamps. Falls through to - // the text-only loop below against an older libparakeet.so. Runs under the - // engineMu already held above. + // ABI v4: when the streaming JSON entry points are present, drive them so + // the per-utterance segments carry per-word start/end timestamps. Falls + // through to the text-only loop below against an older libparakeet.so. if CppStreamFeedJSON != nil { return p.streamJSON(ctx, stream, data, duration, results) } @@ -607,6 +715,7 @@ func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.Tra segText strings.Builder segments []*pb.TranscriptSegment segID int32 + finalEou bool ) flushSegment := func() { @@ -619,25 +728,28 @@ func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.Tra segID++ } - // emitDelta consumes the malloc'd char* returned by feed/finalize: frees - // it, accumulates the text, and sends a delta when non-empty. A 0 return - // is an error (vs the "" empty-but-non-NULL no-new-text case). - emitDelta := func(ret uintptr) error { - if ret == 0 { - msg := CppLastError(p.ctxPtr) - if msg == "" { - msg = "unknown error" - } - return fmt.Errorf("parakeet-cpp: stream feed/finalize failed: %s", msg) + feed := func(chunk []float32, finalize bool) error { + delta, eou, eob, err := p.streamFeedText(stream, chunk, finalize) + if err != nil { + return err } - delta := goStringFromCPtr(ret) - CppFreeString(ret) - if delta == "" { - return nil + if delta != "" { + full.WriteString(delta) + segText.WriteString(delta) + results <- &pb.TranscriptStreamResponse{Delta: delta} + } + // finalEou tracks whether the decode ENDED on an utterance boundary: + // an re-arms it; trailing text or a backchannel clears + // it. Both tokens reset the decoder, so both close a segment. + switch { + case eou: + finalEou = true + case eob || delta != "": + finalEou = false + } + if eou || eob { + flushSegment() } - full.WriteString(delta) - segText.WriteString(delta) - results <- &pb.TranscriptStreamResponse{Delta: delta} return nil } @@ -646,20 +758,13 @@ func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.Tra return status.Error(codes.Canceled, "transcription cancelled") } end := min(off+streamChunkSamples, len(data)) - chunk := data[off:end] - - var eou int32 - ret := CppStreamFeed(stream, chunk, int32(len(chunk)), unsafe.Pointer(&eou)) - if err := emitDelta(ret); err != nil { + if err := feed(data[off:end], false); err != nil { return err } - if eou != 0 { - flushSegment() - } } // Flush the streaming tail (final encoder chunk). - if err := emitDelta(CppStreamFinalize(stream)); err != nil { + if err := feed(nil, true); err != nil { return err } flushSegment() @@ -673,6 +778,7 @@ func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.Tra Text: text, Segments: segments, Duration: duration, + Eou: finalEou, }, } return nil @@ -682,35 +788,33 @@ func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.Tra // feed/finalize returns a {text,eou,eob,frame_sec,words} document. The // newly-finalized text is emitted as a delta (unchanged streaming contract) // while words are accumulated into per-utterance segments (closed on or -// ) so the closing FinalResult carries timestamped segments. Runs under -// engineMu (already held by the caller). +// ) so the closing FinalResult carries timestamped segments. Each C call +// takes engineMu individually (via streamFeedDoc); the lock is not held across +// the stream's lifetime. func (p *ParakeetCpp) streamJSON(ctx context.Context, stream uintptr, data []float32, duration float32, results chan *pb.TranscriptStreamResponse) error { var ( - full strings.Builder - seg streamSegmenter + full strings.Builder + seg streamSegmenter + finalEou bool ) - // consume frees the malloc'd char* (a 0 return is an error), parses the JSON, - // emits the delta, and routes words through the segmenter. - consume := func(ret uintptr) error { - if ret == 0 { - msg := CppLastError(p.ctxPtr) - if msg == "" { - msg = "unknown error" - } - return fmt.Errorf("parakeet-cpp: stream feed/finalize failed: %s", msg) - } - raw := goStringFromCPtr(ret) - CppFreeString(ret) - var doc streamFeedJSON - if err := json.Unmarshal([]byte(raw), &doc); err != nil { - return fmt.Errorf("parakeet-cpp: decode stream json: %w", err) + feed := func(chunk []float32, finalize bool) error { + doc, err := p.streamFeedDoc(stream, chunk, finalize) + if err != nil { + return err } if doc.Text != "" { full.WriteString(doc.Text) results <- &pb.TranscriptStreamResponse{Delta: doc.Text} } seg.add(doc) + // finalEou tracks whether the decode ENDED on an utterance boundary: + // an re-arms it; trailing output or a backchannel clears it. + if doc.Eou != 0 { + finalEou = true + } else if doc.Eob != 0 || doc.Text != "" || len(doc.Words) > 0 { + finalEou = false + } return nil } @@ -719,12 +823,11 @@ func (p *ParakeetCpp) streamJSON(ctx context.Context, stream uintptr, data []flo return status.Error(codes.Canceled, "transcription cancelled") } end := min(off+streamChunkSamples, len(data)) - chunk := data[off:end] - if err := consume(CppStreamFeedJSON(stream, chunk, int32(len(chunk)))); err != nil { + if err := feed(data[off:end], false); err != nil { return err } } - if err := consume(CppStreamFinalizeJSON(stream)); err != nil { + if err := feed(nil, true); err != nil { return err } seg.flush() // close any trailing utterance that never saw an EOU @@ -739,6 +842,7 @@ func (p *ParakeetCpp) streamJSON(ctx context.Context, stream uintptr, data []flo Text: text, Segments: segments, Duration: duration, + Eou: finalEou, }, } return nil @@ -803,6 +907,10 @@ func (p *ParakeetCpp) Free() error { close(p.batStop) p.batStop = nil } + // engineMu so an in-flight streaming call (which locks per C call and + // re-checks ctxPtr under the lock) can never feed into a freed ctx. + p.engineMu.Lock() + defer p.engineMu.Unlock() if p.ctxPtr != 0 { CppFree(p.ctxPtr) p.ctxPtr = 0 diff --git a/backend/go/parakeet-cpp/live.go b/backend/go/parakeet-cpp/live.go new file mode 100644 index 000000000000..0aca9ee4b24f --- /dev/null +++ b/backend/go/parakeet-cpp/live.go @@ -0,0 +1,226 @@ +package main + +import ( + "strings" + "time" + + "github.com/mudler/LocalAI/pkg/grpc/grpcerrors" + pb "github.com/mudler/LocalAI/pkg/grpc/proto" + "github.com/mudler/xlog" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// liveSampleRate is the only PCM rate the parakeet C streaming API accepts. +const liveSampleRate = 16000 + +// AudioTranscriptionLive drives one cache-aware streaming session over audio +// fed incrementally by the caller (the realtime API's semantic_vad turn +// detection). Contract: +// +// - the first request must carry a Config; a Config mid-stream resets the +// decode session (free + begin) and drops accumulated transcript state; +// - a Ready ack is sent right after a successful stream_begin so callers +// can degrade synchronously when the model has no streaming support +// (LiveTranscriptionUnsupported, codes.Unimplemented); +// - every feed that produced output is forwarded as {delta, eou, words}; +// the / flag is the model's own utterance boundary and the +// decoder auto-resets after it, so one session spans many utterances; +// - closing the send side finalizes: the held-back tail chunk is flushed +// (the last ~2 encoder frames of words only appear here) and a terminal +// FinalResult carries the full transcript, segments, and whether the +// decode ended on an utterance boundary. +// +// Engine access is serialized per C call (streamBegin/streamFeed*/streamFree +// take engineMu internally), never for the session lifetime — unary +// transcription keeps flowing between feeds. +func (p *ParakeetCpp) AudioTranscriptionLive(in <-chan *pb.TranscriptLiveRequest, out chan<- *pb.TranscriptLiveResponse) error { + defer close(out) + + if p.ctxPtr == 0 { + return grpcerrors.ModelNotLoaded("parakeet-cpp") + } + + first, ok := <-in + if !ok { + return nil // caller closed without sending anything + } + cfg := first.GetConfig() + if cfg == nil { + return status.Error(codes.InvalidArgument, "parakeet-cpp: first live message must carry a config") + } + if err := validateLiveConfig(cfg); err != nil { + return err + } + + stream, err := p.streamBegin(cfg.GetLanguage()) + if err != nil { + return err + } + if stream == 0 { + return grpcerrors.LiveTranscriptionUnsupported("parakeet-cpp", + "loaded model is not a cache-aware streaming model") + } + // stream is reassigned on a mid-stream Config reset; free whatever is + // current when the RPC unwinds. + defer func() { p.streamFree(stream) }() + + out <- &pb.TranscriptLiveResponse{Ready: true} + + var ( + full strings.Builder + seg streamSegmenter + finalEou bool + fedSecs float64 + + // behindSec accumulates how far decode wall time has fallen behind + // the audio it was fed. A live caller feeds in real time, so a + // persistent positive backlog means every downstream signal — + // including the the turn detector waits on — arrives that many + // seconds late. Warned once per session; reset by a Config reset. + behindSec float64 + behindWarned bool + ) + + // process forwards one feed document and folds it into the final-result + // accumulators. finalEou tracks whether the decode ENDED on an utterance + // boundary: an re-arms it; later output or a backchannel + // clears it. + process := func(doc streamFeedJSON) { + if doc.Text != "" { + full.WriteString(doc.Text) + } + seg.add(doc) + if doc.Eou != 0 { + finalEou = true + } else if doc.Eob != 0 || doc.Text != "" || len(doc.Words) > 0 { + finalEou = false + } + if doc.Text != "" || doc.Eou != 0 || doc.Eob != 0 || len(doc.Words) > 0 { + out <- &pb.TranscriptLiveResponse{ + Delta: doc.Text, + Eou: doc.Eou != 0, + Eob: doc.Eob != 0, + Words: liveWordsToProto(doc.Words), + } + } + } + + feed := func(pcm []float32, finalize bool) error { + if CppStreamFeedJSON != nil { + doc, err := p.streamFeedDoc(stream, pcm, finalize) + if err != nil { + return err + } + process(doc) + return nil + } + delta, eou, eob, err := p.streamFeedText(stream, pcm, finalize) + if err != nil { + return err + } + doc := streamFeedJSON{Text: delta} + if eou { + doc.Eou = 1 + } + if eob { + doc.Eob = 1 + } + process(doc) + return nil + } + + for req := range in { + switch payload := req.GetPayload().(type) { + case *pb.TranscriptLiveRequest_Config: + if err := validateLiveConfig(payload.Config); err != nil { + return err + } + // Reset: a fresh decode session, dropping accumulated state. + p.streamFree(stream) + stream, err = p.streamBegin(payload.Config.GetLanguage()) + if err != nil { + return err + } + if stream == 0 { + return grpcerrors.LiveTranscriptionUnsupported("parakeet-cpp", + "loaded model is not a cache-aware streaming model") + } + full.Reset() + seg = streamSegmenter{} + finalEou = false + fedSecs = 0 + case *pb.TranscriptLiveRequest_Audio: + pcm := payload.Audio.GetPcm() + audioSec := float64(len(pcm)) / liveSampleRate + fedSecs += audioSec + start := time.Now() + // Slice large feeds so each engineMu hold stays short and unary + // requests interleave fairly; the C session buffers internally. + for off := 0; off < len(pcm); off += streamChunkSamples { + end := min(off+streamChunkSamples, len(pcm)) + if err := feed(pcm[off:end], false); err != nil { + return err + } + } + wallSec := time.Since(start).Seconds() + behindSec += wallSec - audioSec + if behindSec < 0 { + behindSec = 0 + } + xlog.Debug("parakeet-cpp: live feed", + "audio_ms", int(audioSec*1000), "wall_ms", int(wallSec*1000), + "behind_ms", int(behindSec*1000), "fed_s", fedSecs) + if behindSec > 1 && !behindWarned { + behindWarned = true + xlog.Warn("parakeet-cpp: live decode is falling behind real time; "+ + "end-of-utterance signals will arrive late", + "behind_s", behindSec, "fed_s", fedSecs) + } + } + } + + // Send side closed: flush the streaming tail and emit the final result. + if err := feed(nil, true); err != nil { + return err + } + seg.flush() // close a trailing utterance that never saw an EOU + + text := strings.TrimSpace(full.String()) + segments := seg.segments() + if len(segments) == 0 && text != "" { + segments = append(segments, &pb.TranscriptSegment{Id: 0, Text: text}) + } + out <- &pb.TranscriptLiveResponse{ + FinalResult: &pb.TranscriptResult{ + Text: text, + Segments: segments, + Duration: float32(fedSecs), + Eou: finalEou, + }, + } + return nil +} + +func validateLiveConfig(cfg *pb.TranscriptLiveConfig) error { + if sr := cfg.GetSampleRate(); sr != 0 && sr != liveSampleRate { + return status.Errorf(codes.InvalidArgument, + "parakeet-cpp: unsupported live sample_rate %d (only %d)", sr, liveSampleRate) + } + return nil +} + +func liveWordsToProto(words []transcriptWord) []*pb.TranscriptWord { + if len(words) == 0 { + return nil + } + out := make([]*pb.TranscriptWord, len(words)) + for i, w := range words { + out[i] = &pb.TranscriptWord{ + Start: secondsToNanos(w.Start), + End: secondsToNanos(w.End), + Text: w.W, + } + } + return out +} diff --git a/backend/go/parakeet-cpp/live_test.go b/backend/go/parakeet-cpp/live_test.go new file mode 100644 index 000000000000..1d51d87a3ca7 --- /dev/null +++ b/backend/go/parakeet-cpp/live_test.go @@ -0,0 +1,420 @@ +package main + +import ( + "sync" + "time" + "unsafe" + + "github.com/mudler/LocalAI/pkg/grpc/grpcerrors" + pb "github.com/mudler/LocalAI/pkg/grpc/proto" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// The live-RPC specs drive AudioTranscriptionLive entirely against stubbed +// Cpp* package vars (the same seam batcher_test.go uses), so they run +// without libparakeet.so. + +// liveCstrPool hands out NUL-terminated C-style strings backed by Go memory +// and keeps them alive for the duration of a spec (goStringFromCPtr reads +// through the raw pointer; Go's GC must not collect the backing array while +// a stub's return value is in flight). +type liveCstrPool struct { + mu sync.Mutex + bufs [][]byte +} + +func (p *liveCstrPool) cstr(s string) uintptr { + p.mu.Lock() + defer p.mu.Unlock() + b := append([]byte(s), 0) + p.bufs = append(p.bufs, b) + return uintptr(unsafe.Pointer(&b[0])) +} + +// liveStubs swaps every C entry point the live path touches and returns a +// restore func for AfterEach. +func liveStubs() (restore func()) { + savedBegin, savedBeginLang := CppStreamBegin, CppStreamBeginLang + savedFeed, savedFeedJSON := CppStreamFeed, CppStreamFeedJSON + savedFinalize, savedFinalizeJSON := CppStreamFinalize, CppStreamFinalizeJSON + savedFree, savedLastError := CppStreamFree, CppLastError + savedFreeString := CppFreeString + return func() { + CppStreamBegin, CppStreamBeginLang = savedBegin, savedBeginLang + CppStreamFeed, CppStreamFeedJSON = savedFeed, savedFeedJSON + CppStreamFinalize, CppStreamFinalizeJSON = savedFinalize, savedFinalizeJSON + CppStreamFree, CppLastError = savedFree, savedLastError + CppFreeString = savedFreeString + } +} + +// runLive starts the RPC on its own goroutine and returns the request +// channel plus a collector for everything the backend emitted. +func runLive(p *ParakeetCpp) (chan *pb.TranscriptLiveRequest, chan *pb.TranscriptLiveResponse, chan error) { + in := make(chan *pb.TranscriptLiveRequest) + out := make(chan *pb.TranscriptLiveResponse, 32) + errCh := make(chan error, 1) + go func() { errCh <- p.AudioTranscriptionLive(in, out) }() + return in, out, errCh +} + +func liveConfig(lang string) *pb.TranscriptLiveRequest { + return &pb.TranscriptLiveRequest{ + Payload: &pb.TranscriptLiveRequest_Config{Config: &pb.TranscriptLiveConfig{Language: lang}}, + } +} + +func liveAudio(pcm []float32) *pb.TranscriptLiveRequest { + return &pb.TranscriptLiveRequest{ + Payload: &pb.TranscriptLiveRequest_Audio{Audio: &pb.TranscriptLiveAudio{Pcm: pcm}}, + } +} + +func collectLive(out chan *pb.TranscriptLiveResponse) []*pb.TranscriptLiveResponse { + var got []*pb.TranscriptLiveResponse + for r := range out { + got = append(got, r) + } + return got +} + +var _ = Describe("AudioTranscriptionLive (stubbed C API)", func() { + var ( + pool *liveCstrPool + restore func() + p *ParakeetCpp + ) + + BeforeEach(func() { + pool = &liveCstrPool{} + restore = liveStubs() + p = &ParakeetCpp{ctxPtr: 1} + + CppStreamBeginLang = nil + CppStreamBegin = func(ctx uintptr) uintptr { return 7 } + CppStreamFree = func(s uintptr) {} + CppFreeString = func(s uintptr) {} + CppLastError = func(ctx uintptr) string { return "stub error" } + CppStreamFeed = nil + CppStreamFeedJSON = nil + CppStreamFinalize = nil + CppStreamFinalizeJSON = nil + }) + + AfterEach(func() { restore() }) + + It("rejects a stream whose first message is not a config", func() { + in, out, errCh := runLive(p) + in <- liveAudio([]float32{0.1}) + close(in) + + err := <-errCh + Expect(status.Code(err)).To(Equal(codes.InvalidArgument)) + Expect(collectLive(out)).To(BeEmpty()) + }) + + It("rejects a non-16k sample rate", func() { + in, _, errCh := runLive(p) + in <- &pb.TranscriptLiveRequest{ + Payload: &pb.TranscriptLiveRequest_Config{Config: &pb.TranscriptLiveConfig{SampleRate: 8000}}, + } + close(in) + Expect(status.Code(<-errCh)).To(Equal(codes.InvalidArgument)) + }) + + It("returns the typed Unimplemented signal for non-streaming models, before any ack", func() { + CppStreamBegin = func(ctx uintptr) uintptr { return 0 } + + in, out, errCh := runLive(p) + in <- liveConfig("") + close(in) + + err := <-errCh + Expect(grpcerrors.IsLiveTranscriptionUnsupported(err)).To(BeTrue()) + Expect(collectLive(out)).To(BeEmpty()) + }) + + It("streams deltas, eou flags and words on the JSON path and finalizes on close", func() { + var freed []uintptr + CppStreamFree = func(s uintptr) { freed = append(freed, s) } + feeds := 0 + CppStreamFeedJSON = func(s uintptr, pcm []float32, n int32) uintptr { + feeds++ + switch feeds { + case 1: + return pool.cstr(`{"text":"hello ","eou":0,"frame_sec":0.08,` + + `"words":[{"w":"hello","start":0.1,"end":0.4,"conf":0.9}]}`) + default: + return pool.cstr(`{"text":"world","eou":1,"frame_sec":0.08,` + + `"words":[{"w":"world","start":0.5,"end":0.8,"conf":0.9}]}`) + } + } + CppStreamFinalizeJSON = func(s uintptr) uintptr { + return pool.cstr(`{"text":"","eou":0,"frame_sec":0.08,"words":[]}`) + } + + in, out, errCh := runLive(p) + in <- liveConfig("en") + in <- liveAudio(make([]float32, 100)) + in <- liveAudio(make([]float32, 200)) + close(in) + Expect(<-errCh).NotTo(HaveOccurred()) + + got := collectLive(out) + Expect(got).To(HaveLen(4)) // ready, two deltas, final + + Expect(got[0].Ready).To(BeTrue()) + + Expect(got[1].Delta).To(Equal("hello ")) + Expect(got[1].Eou).To(BeFalse()) + Expect(got[1].Words).To(HaveLen(1)) + Expect(got[1].Words[0].Text).To(Equal("hello")) + + Expect(got[2].Delta).To(Equal("world")) + Expect(got[2].Eou).To(BeTrue()) + + final := got[3].FinalResult + Expect(final).NotTo(BeNil()) + Expect(final.Text).To(Equal("hello world")) + Expect(final.Eou).To(BeTrue(), "decode ended on the EOU boundary") + Expect(final.Segments).To(HaveLen(1)) + Expect(final.Segments[0].Text).To(Equal("hello world")) + Expect(final.Duration).To(BeNumerically("~", 300.0/16000.0, 1e-6)) + + Expect(freed).To(Equal([]uintptr{7})) + }) + + It("falls back to the text feed (eou out-param) when the JSON entry points are absent", func() { + feeds := 0 + CppStreamFeed = func(s uintptr, pcm []float32, n int32, eouOut unsafe.Pointer) uintptr { + feeds++ + if feeds == 2 { + *(*int32)(eouOut) = 1 + return pool.cstr("done") + } + return pool.cstr("first ") + } + CppStreamFinalize = func(s uintptr) uintptr { return pool.cstr("") } + + in, out, errCh := runLive(p) + in <- liveConfig("") + in <- liveAudio(make([]float32, 10)) + in <- liveAudio(make([]float32, 10)) + close(in) + Expect(<-errCh).NotTo(HaveOccurred()) + + got := collectLive(out) + Expect(got).To(HaveLen(4)) + Expect(got[1].Delta).To(Equal("first ")) + Expect(got[1].Eou).To(BeFalse()) + Expect(got[2].Delta).To(Equal("done")) + Expect(got[2].Eou).To(BeTrue()) + Expect(got[3].FinalResult.Text).To(Equal("first done")) + Expect(got[3].FinalResult.Eou).To(BeTrue()) + }) + + It("forwards as eob — a backchannel, never an eou (ABI v5 JSON)", func() { + feeds := 0 + CppStreamFeedJSON = func(s uintptr, pcm []float32, n int32) uintptr { + feeds++ + if feeds == 1 { + return pool.cstr(`{"text":"uh-huh","eou":0,"eob":1,"frame_sec":0.08,` + + `"words":[{"w":"uh-huh","start":0.1,"end":0.3,"conf":0.9}]}`) + } + return pool.cstr(`{"text":"the turn","eou":1,"eob":0,"frame_sec":0.08,` + + `"words":[{"w":"the","start":0.5,"end":0.6,"conf":0.9},{"w":"turn","start":0.6,"end":0.8,"conf":0.9}]}`) + } + CppStreamFinalizeJSON = func(s uintptr) uintptr { + return pool.cstr(`{"text":"","eou":0,"eob":0,"frame_sec":0.08,"words":[]}`) + } + + in, out, errCh := runLive(p) + in <- liveConfig("") + in <- liveAudio(make([]float32, 10)) + in <- liveAudio(make([]float32, 10)) + close(in) + Expect(<-errCh).NotTo(HaveOccurred()) + + got := collectLive(out) + Expect(got).To(HaveLen(4)) + Expect(got[1].Eob).To(BeTrue()) + Expect(got[1].Eou).To(BeFalse(), "a backchannel must not masquerade as a turn boundary") + Expect(got[2].Eou).To(BeTrue()) + final := got[3].FinalResult + Expect(final.Eou).To(BeTrue(), "decode ended on the utterance boundary, not the backchannel") + Expect(final.Segments).To(HaveLen(2), "both and reset the decoder and close a segment") + }) + + It("maps the v5 eou_out bitmask on the text path (bit0 , bit1 )", func() { + feeds := 0 + CppStreamFeed = func(s uintptr, pcm []float32, n int32, eouOut unsafe.Pointer) uintptr { + feeds++ + if feeds == 1 { + *(*int32)(eouOut) = 2 // only + return pool.cstr("uh-huh") + } + *(*int32)(eouOut) = 1 // + return pool.cstr(" done") + } + CppStreamFinalize = func(s uintptr) uintptr { return pool.cstr("") } + + in, out, errCh := runLive(p) + in <- liveConfig("") + in <- liveAudio(make([]float32, 10)) + in <- liveAudio(make([]float32, 10)) + close(in) + Expect(<-errCh).NotTo(HaveOccurred()) + + got := collectLive(out) + Expect(got).To(HaveLen(4)) + Expect(got[1].Eob).To(BeTrue()) + Expect(got[1].Eou).To(BeFalse()) + Expect(got[2].Eou).To(BeTrue()) + Expect(got[2].Eob).To(BeFalse()) + Expect(got[3].FinalResult.Eou).To(BeTrue()) + }) + + It("clears the final eou flag when trailing text arrives after the last EOU", func() { + feeds := 0 + CppStreamFeedJSON = func(s uintptr, pcm []float32, n int32) uintptr { + feeds++ + if feeds == 1 { + return pool.cstr(`{"text":"turn one","eou":1,"frame_sec":0.08,"words":[]}`) + } + return pool.cstr(`{"text":" and more","eou":0,"frame_sec":0.08,"words":[]}`) + } + CppStreamFinalizeJSON = func(s uintptr) uintptr { + return pool.cstr(`{"text":"","eou":0,"frame_sec":0.08,"words":[]}`) + } + + in, out, errCh := runLive(p) + in <- liveConfig("") + in <- liveAudio(make([]float32, 10)) + in <- liveAudio(make([]float32, 10)) + close(in) + Expect(<-errCh).NotTo(HaveOccurred()) + + got := collectLive(out) + final := got[len(got)-1].FinalResult + Expect(final.Text).To(Equal("turn one and more")) + Expect(final.Eou).To(BeFalse(), "trailing speech without an EOU is an open utterance") + }) + + It("resets the decode session on a mid-stream config", func() { + var begun, freed int + CppStreamBegin = func(ctx uintptr) uintptr { begun++; return uintptr(10 + begun) } + CppStreamFree = func(s uintptr) { freed++ } + CppStreamFeedJSON = func(s uintptr, pcm []float32, n int32) uintptr { + return pool.cstr(`{"text":"x","eou":0,"frame_sec":0.08,"words":[]}`) + } + CppStreamFinalizeJSON = func(s uintptr) uintptr { + return pool.cstr(`{"text":"","eou":0,"frame_sec":0.08,"words":[]}`) + } + + in, out, errCh := runLive(p) + in <- liveConfig("") + in <- liveAudio(make([]float32, 10)) + in <- liveConfig("") // reset + in <- liveAudio(make([]float32, 10)) + close(in) + Expect(<-errCh).NotTo(HaveOccurred()) + + got := collectLive(out) + final := got[len(got)-1].FinalResult + Expect(final.Text).To(Equal("x"), "pre-reset transcript dropped") + Expect(begun).To(Equal(2)) + Expect(freed).To(Equal(2), "old session freed on reset, new one on unwind") + }) + + It("does not hold engineMu between feeds (unary work interleaves with a live session)", func() { + CppStreamFeedJSON = func(s uintptr, pcm []float32, n int32) uintptr { + return pool.cstr(`{"text":"","eou":0,"frame_sec":0.08,"words":[]}`) + } + CppStreamFinalizeJSON = func(s uintptr) uintptr { + return pool.cstr(`{"text":"","eou":0,"frame_sec":0.08,"words":[]}`) + } + + in, out, errCh := runLive(p) + in <- liveConfig("") + in <- liveAudio(make([]float32, 10)) + + // The session is open and idle between feeds: the engine lock must be + // acquirable, which is what lets batched unary transcription proceed + // mid-session. Under stream-lifetime locking this probe would block + // until the stream ended and the Eventually would time out. + locked := make(chan struct{}) + go func() { + p.engineMu.Lock() + p.engineMu.Unlock() //nolint:staticcheck // probe: acquire-release proves availability + close(locked) + }() + Eventually(locked, time.Second).Should(BeClosed()) + + close(in) + Expect(<-errCh).NotTo(HaveOccurred()) + collectLive(out) + }) + + It("errors out and reads last_error under the lock when a feed fails", func() { + CppStreamFeedJSON = func(s uintptr, pcm []float32, n int32) uintptr { return 0 } + + in, out, errCh := runLive(p) + in <- liveConfig("") + in <- liveAudio(make([]float32, 10)) + + err := <-errCh + Expect(err).To(MatchError(ContainSubstring("stub error"))) + got := collectLive(out) + Expect(got).To(HaveLen(1)) // just the ready ack + close(in) + }) +}) + +var _ = Describe("stripEouMarker", func() { + It("strips a trailing and reports it", func() { + text, eou := stripEouMarker("it is certainly very like the old portrait") + Expect(text).To(Equal("it is certainly very like the old portrait")) + Expect(eou).To(BeTrue()) + }) + + It("strips a trailing WITHOUT reporting an utterance end", func() { + // A decode ending on a backchannel must not confirm the + // retranscribe gate — the user was acknowledging, not yielding. + text, eou := stripEouMarker("uh-huh") + Expect(text).To(Equal("uh-huh")) + Expect(eou).To(BeFalse()) + }) + + It("leaves marker-free text alone", func() { + text, eou := stripEouMarker("plain transcript") + Expect(text).To(Equal("plain transcript")) + Expect(eou).To(BeFalse()) + }) + + It("does not strip a marker in the middle of the text", func() { + text, eou := stripEouMarker("ab") + Expect(text).To(Equal("ab")) + Expect(eou).To(BeFalse()) + }) +}) + +var _ = Describe("transcriptResultFromDoc EOU handling", func() { + It("strips the offline marker from text and sets the result flag", func() { + doc := transcriptJSON{Text: "the old portrait"} + res := transcriptResultFromDoc(doc, &pb.TranscriptRequest{}, 0) + Expect(res.Text).To(Equal("the old portrait")) + Expect(res.Eou).To(BeTrue()) + Expect(res.Segments).To(HaveLen(1)) + Expect(res.Segments[0].Text).To(Equal("the old portrait")) + }) + + It("reports eou=false for marker-free decodes", func() { + doc := transcriptJSON{Text: "no marker here"} + res := transcriptResultFromDoc(doc, &pb.TranscriptRequest{}, 0) + Expect(res.Text).To(Equal("no marker here")) + Expect(res.Eou).To(BeFalse()) + }) +}) diff --git a/core/application/application.go b/core/application/application.go index 9bbf26bb8bd7..56320773bc8a 100644 --- a/core/application/application.go +++ b/core/application/application.go @@ -103,6 +103,11 @@ func newApplication(appConfig *config.ApplicationConfig) *Application { mcpTools.CloseMCPSessions(modelName) }) + // Record a model_load backend trace for every real backend load, so the + // Traces UI shows which backend runtime served each model and how long + // the load took. Load failures are traced by the modality wrappers. + ml.SetLoadObserver(corebackend.ModelLoadTraceObserver(appConfig)) + app := &Application{ backendLoader: config.NewModelConfigLoader(appConfig.SystemState.Model.ModelsPath), modelLoader: ml, diff --git a/core/backend/model_load_trace_test.go b/core/backend/model_load_trace_test.go new file mode 100644 index 000000000000..1cce5da2637a --- /dev/null +++ b/core/backend/model_load_trace_test.go @@ -0,0 +1,72 @@ +package backend_test + +import ( + "errors" + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/mudler/LocalAI/core/backend" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/trace" + "github.com/mudler/LocalAI/pkg/model" +) + +// ModelLoadTraceObserver is what makes successful loads visible on the +// Traces page: one model_load row per real backend load, carrying the +// resolved backend runtime. Failures must NOT be recorded here — the +// modality wrappers own those — and the observer must respect the runtime +// tracing toggle. +var _ = Describe("ModelLoadTraceObserver", func() { + var appConfig *config.ApplicationConfig + + successEvent := model.BackendLoadEvent{ + ModelID: "parakeet-cpp-realtime_eou_120m-v1", + ModelName: "realtime_eou_120m.gguf", + Backend: "parakeet-cpp", + BackendURI: "/backends/intel-sycl-f16-parakeet-cpp-development/run.sh", + Duration: 1500 * time.Millisecond, + } + + BeforeEach(func() { + appConfig = &config.ApplicationConfig{ + EnableTracing: true, + TracingMaxItems: 64, + } + trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes) + trace.ClearBackendTraces() + }) + + It("records a model_load trace with the backend runtime on success", func() { + backend.ModelLoadTraceObserver(appConfig)(successEvent) + + Eventually(trace.GetBackendTraces).Should(HaveLen(1)) + got := trace.GetBackendTraces()[0] + Expect(got.Type).To(Equal(trace.BackendTraceModelLoad)) + Expect(got.Summary).To(Equal("Model loaded")) + Expect(got.ModelName).To(Equal("parakeet-cpp-realtime_eou_120m-v1")) + Expect(got.Backend).To(Equal("parakeet-cpp")) + Expect(got.Duration).To(Equal(1500 * time.Millisecond)) + Expect(got.Data["backend_runtime"]).To(Equal("/backends/intel-sycl-f16-parakeet-cpp-development/run.sh")) + Expect(got.Data["model_file"]).To(Equal("realtime_eou_120m.gguf")) + Expect(got.Error).To(BeEmpty()) + }) + + It("skips failed loads — the modality wrappers trace those with request context", func() { + failed := successEvent + failed.Err = errors.New("grpc service not ready") + + backend.ModelLoadTraceObserver(appConfig)(failed) + + Consistently(trace.GetBackendTraces, "100ms", "20ms").Should(BeEmpty()) + }) + + It("records nothing when tracing is disabled", func() { + appConfig.EnableTracing = false + + backend.ModelLoadTraceObserver(appConfig)(successEvent) + + Consistently(trace.GetBackendTraces, "100ms", "20ms").Should(BeEmpty()) + }) +}) diff --git a/core/backend/options.go b/core/backend/options.go index 528c10e525a6..9ae22dd2231f 100644 --- a/core/backend/options.go +++ b/core/backend/options.go @@ -19,6 +19,39 @@ import ( "github.com/mudler/xlog" ) +// ModelLoadTraceObserver returns the ModelLoader load observer that records +// a model_load backend trace for every successful real load (backend process +// spawn + LoadModel RPC; cache hits never reach the observer). Failures are +// deliberately skipped here: the modality wrappers already record them via +// recordModelLoadFailure with request context, and the backend auto-discovery +// scan probes several backends before one succeeds — tracing every probe +// failure would bury the buffer in noise. +// +// The traced data includes the resolved backend runtime (the installed +// backend's launcher path, which names the variant directory) — that is what +// identifies WHICH build served the load. A stale installed backend is +// invisible in the model config but obvious here. +func ModelLoadTraceObserver(appConfig *config.ApplicationConfig) func(model.BackendLoadEvent) { + return func(ev model.BackendLoadEvent) { + if ev.Err != nil || !appConfig.EnableTracing { + return + } + trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes) + trace.RecordBackendTrace(trace.BackendTrace{ + Timestamp: time.Now(), + Duration: ev.Duration, + Type: trace.BackendTraceModelLoad, + ModelName: ev.ModelID, + Backend: ev.Backend, + Summary: "Model loaded", + Data: map[string]any{ + "model_file": ev.ModelName, + "backend_runtime": ev.BackendURI, + }, + }) + } +} + // recordModelLoadFailure records a backend trace when model loading fails. func recordModelLoadFailure(appConfig *config.ApplicationConfig, modelName, backend string, err error, data map[string]any) { if !appConfig.EnableTracing { diff --git a/core/backend/transcript.go b/core/backend/transcript.go index e6da923cc60c..211269160750 100644 --- a/core/backend/transcript.go +++ b/core/backend/transcript.go @@ -181,6 +181,7 @@ func transcriptResultFromProto(r *proto.TranscriptResult) *schema.TranscriptionR Text: r.Text, Language: r.Language, Duration: float64(r.Duration), + Eou: r.Eou, } for _, s := range r.Segments { diff --git a/core/backend/transcript_live.go b/core/backend/transcript_live.go new file mode 100644 index 000000000000..f43138dfe1d6 --- /dev/null +++ b/core/backend/transcript_live.go @@ -0,0 +1,300 @@ +package backend + +import ( + "context" + "errors" + "fmt" + "io" + "maps" + "sync" + "time" + + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/core/trace" + grpcPkg "github.com/mudler/LocalAI/pkg/grpc" + "github.com/mudler/LocalAI/pkg/grpc/proto" + "github.com/mudler/LocalAI/pkg/model" + "github.com/mudler/LocalAI/pkg/sound" + "github.com/mudler/xlog" +) + +// LiveTranscriptionEvent is one streamed event from a live (bidirectional) +// transcription session. Delta/Eou/Eob/Words arrive as the user speaks; Final +// is set exactly once, on the terminal event after Close flushes the decode +// tail. Eou means the model judged the user yielded the turn; Eob means a +// backchannel ("uh-huh") ended — callers must NOT treat Eob as a turn +// boundary. +type LiveTranscriptionEvent struct { + Delta string + Eou bool + Eob bool + Words []schema.TranscriptionWord + Final *schema.TranscriptionResult +} + +// LiveTranscriptionSession is a handle on an open live transcription stream. +// Feed pushes 16 kHz mono float PCM; Close signals end-of-audio, waits for +// the backend's terminal Final event to be delivered, and releases the +// stream. +type LiveTranscriptionSession interface { + Feed(pcm []float32) error + Close() error +} + +// liveCloseDrainTimeout bounds how long Close waits for the backend to flush +// the decode tail before force-cancelling the stream. Finalize is one short +// engine call; seconds here means the backend is wedged. +const liveCloseDrainTimeout = 10 * time.Second + +type liveTranscriptionSession struct { + stream grpcPkg.AudioTranscriptionLiveClient + cancel context.CancelFunc + recvDone chan struct{} + recvErr error // written by the recv goroutine before recvDone closes + closeOnce sync.Once + closeErr error + trace *liveTraceState // nil when tracing was disabled at open +} + +func (s *liveTranscriptionSession) Feed(pcm []float32) error { + s.trace.addPCM(pcm) + return s.stream.Send(&proto.TranscriptLiveRequest{ + Payload: &proto.TranscriptLiveRequest_Audio{Audio: &proto.TranscriptLiveAudio{Pcm: pcm}}, + }) +} + +func (s *liveTranscriptionSession) Close() error { + s.closeOnce.Do(func() { + err := s.stream.CloseSend() + select { + case <-s.recvDone: + case <-time.After(liveCloseDrainTimeout): + xlog.Warn("live transcription: backend did not finalize in time; cancelling stream") + s.cancel() + <-s.recvDone + } + s.cancel() + if err == nil { + err = s.recvErr + } + s.closeErr = err + s.trace.record(err) + }) + return s.closeErr +} + +// liveSampleRate is the PCM rate of a live transcription session, fixed by +// the session config sent in ModelTranscriptionLive. +const liveSampleRate = 16000 + +// liveTraceState accumulates what the per-turn backend trace needs while a +// live session runs: a bounded copy of the fed PCM for the audio snippet, +// the decode outputs, and timing. One trace is recorded at Close — the live +// path never touches the unary transcription wrapper, so without this a +// streaming-only pipeline produced no transcription traces at all. Feed and +// the recv goroutine run concurrently; mu guards the accumulators. +type liveTraceState struct { + appConfig *config.ApplicationConfig + modelName string + backend string + language string + started time.Time + + mu sync.Mutex + pcm []byte // first trace.MaxSnippetSeconds of fed audio, int16 LE + fedSamples int // ALL samples fed, beyond the snippet cap + deltaEvents int + eouEvents int + eobEvents int + finalText string + finalEou bool +} + +func newLiveTraceState(modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, language string) *liveTraceState { + if !appConfig.EnableTracing { + return nil + } + trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes) + return &liveTraceState{ + appConfig: appConfig, + modelName: modelConfig.Name, + backend: modelConfig.Backend, + language: language, + started: time.Now(), + } +} + +func (ts *liveTraceState) addPCM(pcm []float32) { + if ts == nil { + return + } + ts.mu.Lock() + defer ts.mu.Unlock() + ts.fedSamples += len(pcm) + maxBytes := trace.MaxSnippetSeconds * liveSampleRate * 2 + if room := (maxBytes - len(ts.pcm)) / 2; room > 0 { + if len(pcm) > room { + pcm = pcm[:room] + } + ts.pcm = append(ts.pcm, sound.Float32sToInt16LEBytes(pcm)...) + } +} + +func (ts *liveTraceState) observe(ev LiveTranscriptionEvent) { + if ts == nil { + return + } + ts.mu.Lock() + defer ts.mu.Unlock() + if ev.Delta != "" { + ts.deltaEvents++ + } + if ev.Eou { + ts.eouEvents++ + } + if ev.Eob { + ts.eobEvents++ + } + if ev.Final != nil { + ts.finalText = ev.Final.Text + ts.finalEou = ev.Final.Eou + } +} + +func (ts *liveTraceState) record(closeErr error) { + if ts == nil || !ts.appConfig.EnableTracing { + return + } + ts.mu.Lock() + data := map[string]any{ + "source": "live_stream", + "language": ts.language, + "result_text": ts.finalText, + "eou": ts.finalEou, + "eou_events": ts.eouEvents, + "eob_events": ts.eobEvents, + "delta_events": ts.deltaEvents, + } + if snippet := trace.AudioSnippetFromPCM(ts.pcm, liveSampleRate, ts.fedSamples*2, ts.appConfig.TracingMaxBodyBytes); snippet != nil { + maps.Copy(data, snippet) + } + summary := "live -> " + ts.finalText + ts.mu.Unlock() + + bt := trace.BackendTrace{ + Timestamp: ts.started, + Duration: time.Since(ts.started), + Type: trace.BackendTraceTranscription, + ModelName: ts.modelName, + Backend: ts.backend, + Summary: trace.TruncateString(summary, 200), + Data: data, + } + if closeErr != nil { + bt.Error = closeErr.Error() + } + trace.RecordBackendTrace(bt) +} + +// ModelTranscriptionLive loads the transcription backend, opens the +// bidirectional AudioTranscriptionLive RPC, sends the session config, and +// BLOCKS until the backend's ready ack. A grpcerrors. +// IsLiveTranscriptionUnsupported error means the backend (or the loaded +// model) cannot do live transcription and the caller should degrade to the +// unary/file path. After a successful return, onEvent is invoked from a +// background goroutine — in order, one event at a time — for every response +// the backend streams, ending with the Final event triggered by Close. +func ModelTranscriptionLive(ctx context.Context, language string, + ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, + onEvent func(LiveTranscriptionEvent)) (LiveTranscriptionSession, error) { + + transcriptionModel, err := loadTranscriptionModel(ctx, ml, modelConfig, appConfig) + if err != nil { + return nil, err + } + + // The derived cancel out-lives this call inside the session: Close uses + // it to unwind the stream (and, in embed mode, the server-side recv + // pump, which only stops on send-close or context cancellation). + streamCtx, cancel := context.WithCancel(ctx) + stream, err := transcriptionModel.AudioTranscriptionLive(streamCtx) + if err != nil { + cancel() + return nil, err + } + + fail := func(err error) (LiveTranscriptionSession, error) { + _ = stream.CloseSend() + cancel() + return nil, err + } + + if err := stream.Send(&proto.TranscriptLiveRequest{ + Payload: &proto.TranscriptLiveRequest_Config{Config: &proto.TranscriptLiveConfig{ + Language: language, + SampleRate: liveSampleRate, + }}, + }); err != nil { + return fail(err) + } + + // Ready-ack contract: the backend answers a successful open with a + // {ready:true} response before any transcript data; unsupported + // backends surface Unimplemented here instead. + ack, err := stream.Recv() + if err != nil { + return fail(err) + } + if !ack.GetReady() { + return fail(fmt.Errorf("live transcription: backend %q broke the ready-ack contract (first response carried data)", modelConfig.Backend)) + } + + s := &liveTranscriptionSession{ + stream: stream, + cancel: cancel, + recvDone: make(chan struct{}), + trace: newLiveTraceState(modelConfig, appConfig, language), + } + + go func() { + defer close(s.recvDone) + for { + resp, err := stream.Recv() + if err != nil { + if !errors.Is(err, io.EOF) && streamCtx.Err() == nil { + xlog.Warn("live transcription stream ended unexpectedly", "error", err) + s.recvErr = err + } + return + } + ev := liveEventFromProto(resp) + if ev.Delta == "" && !ev.Eou && !ev.Eob && len(ev.Words) == 0 && ev.Final == nil { + continue // duplicate ready ack / keep-alive: nothing to deliver + } + s.trace.observe(ev) + onEvent(ev) + } + }() + + return s, nil +} + +func liveEventFromProto(r *proto.TranscriptLiveResponse) LiveTranscriptionEvent { + ev := LiveTranscriptionEvent{ + Delta: r.GetDelta(), + Eou: r.GetEou(), + Eob: r.GetEob(), + } + for _, w := range r.GetWords() { + ev.Words = append(ev.Words, schema.TranscriptionWord{ + Start: time.Duration(w.Start), + End: time.Duration(w.End), + Text: w.Text, + }) + } + if r.GetFinalResult() != nil { + ev.Final = transcriptResultFromProto(r.GetFinalResult()) + } + return ev +} diff --git a/core/backend/transcript_live_internal_test.go b/core/backend/transcript_live_internal_test.go new file mode 100644 index 000000000000..f6330aee4c03 --- /dev/null +++ b/core/backend/transcript_live_internal_test.go @@ -0,0 +1,160 @@ +package backend + +import ( + "errors" + "time" + + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/core/trace" + "github.com/mudler/LocalAI/pkg/grpc/proto" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("liveEventFromProto", func() { + It("maps deltas, eou flags and words (ns -> duration)", func() { + ev := liveEventFromProto(&proto.TranscriptLiveResponse{ + Delta: "hello ", + Eou: true, + Words: []*proto.TranscriptWord{ + {Start: int64(100 * time.Millisecond), End: int64(400 * time.Millisecond), Text: "hello"}, + }, + }) + Expect(ev.Delta).To(Equal("hello ")) + Expect(ev.Eou).To(BeTrue()) + Expect(ev.Words).To(HaveLen(1)) + Expect(ev.Words[0].Text).To(Equal("hello")) + Expect(ev.Words[0].Start).To(Equal(100 * time.Millisecond)) + Expect(ev.Words[0].End).To(Equal(400 * time.Millisecond)) + Expect(ev.Final).To(BeNil()) + }) + + It("maps the terminal final result including the eou flag", func() { + ev := liveEventFromProto(&proto.TranscriptLiveResponse{ + FinalResult: &proto.TranscriptResult{ + Text: "hello world", + Duration: 1.5, + Eou: true, + Segments: []*proto.TranscriptSegment{{Id: 0, Text: "hello world"}}, + }, + }) + Expect(ev.Final).NotTo(BeNil()) + Expect(ev.Final.Text).To(Equal("hello world")) + Expect(ev.Final.Duration).To(BeNumerically("~", 1.5, 1e-6)) + Expect(ev.Final.Eou).To(BeTrue()) + Expect(ev.Final.Segments).To(HaveLen(1)) + }) + + It("yields an empty event for a bare ready ack (filtered by the recv loop)", func() { + ev := liveEventFromProto(&proto.TranscriptLiveResponse{Ready: true}) + Expect(ev.Delta).To(BeEmpty()) + Expect(ev.Eou).To(BeFalse()) + Expect(ev.Words).To(BeEmpty()) + Expect(ev.Final).To(BeNil()) + }) + + It("maps the eob backchannel flag separately from eou", func() { + ev := liveEventFromProto(&proto.TranscriptLiveResponse{Delta: "uh-huh", Eob: true}) + Expect(ev.Eob).To(BeTrue()) + Expect(ev.Eou).To(BeFalse()) + }) +}) + +// liveTraceState is what makes streaming-only pipelines visible on the +// Traces page: without it a semantic_vad session with retranscribe off +// produced no transcription trace at all. One trace per session (= one per +// realtime turn), recorded at Close. +var _ = Describe("liveTraceState", func() { + var appConfig *config.ApplicationConfig + + BeforeEach(func() { + appConfig = &config.ApplicationConfig{ + EnableTracing: true, + TracingMaxItems: 64, + } + trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes) + trace.ClearBackendTraces() + }) + + modelCfg := func() config.ModelConfig { + cfg := config.ModelConfig{Backend: "parakeet-cpp"} + cfg.Name = "parakeet-live" + return cfg + } + + It("is disabled (nil) when tracing is off, and nil receivers are no-ops", func() { + appConfig.EnableTracing = false + ts := newLiveTraceState(modelCfg(), appConfig, "en") + Expect(ts).To(BeNil()) + + // The session calls these unconditionally; nil must be safe. + ts.addPCM([]float32{0.5}) + ts.observe(LiveTranscriptionEvent{Eou: true}) + ts.record(nil) + Consistently(trace.GetBackendTraces, "100ms", "20ms").Should(BeEmpty()) + }) + + It("records one transcription trace with text, eou flag and audio snippet at Close", func() { + ts := newLiveTraceState(modelCfg(), appConfig, "en") + Expect(ts).NotTo(BeNil()) + + // One second of a loud-ish constant tone so the snippet has signal. + pcm := make([]float32, liveSampleRate) + for i := range pcm { + pcm[i] = 0.25 + } + ts.addPCM(pcm) + ts.observe(LiveTranscriptionEvent{Delta: "hello "}) + ts.observe(LiveTranscriptionEvent{Delta: "world", Eou: true}) + ts.observe(LiveTranscriptionEvent{Final: &schema.TranscriptionResult{Text: "hello world", Eou: true}}) + + ts.record(nil) + + Eventually(trace.GetBackendTraces).Should(HaveLen(1)) + got := trace.GetBackendTraces()[0] + Expect(got.Type).To(Equal(trace.BackendTraceTranscription)) + Expect(got.ModelName).To(Equal("parakeet-live")) + Expect(got.Backend).To(Equal("parakeet-cpp")) + Expect(got.Summary).To(ContainSubstring("hello world")) + Expect(got.Data["source"]).To(Equal("live_stream")) + Expect(got.Data["result_text"]).To(Equal("hello world")) + Expect(got.Data["eou"]).To(Equal(true)) + Expect(got.Data["eou_events"]).To(Equal(1)) + Expect(got.Data["delta_events"]).To(Equal(2)) + Expect(got.Data["audio_duration_s"]).To(BeNumerically("~", 1.0, 0.01)) + Expect(got.Data["audio_wav_base64"]).NotTo(BeEmpty()) + Expect(got.Error).To(BeEmpty()) + }) + + It("caps the stored snippet but keeps counting the full fed duration", func() { + ts := newLiveTraceState(modelCfg(), appConfig, "") + + // Feed past the snippet cap in two chunks (cap + one extra second). + ts.addPCM(make([]float32, trace.MaxSnippetSeconds*liveSampleRate)) + ts.addPCM(make([]float32, liveSampleRate)) + + Expect(len(ts.pcm)).To(Equal(trace.MaxSnippetSeconds * liveSampleRate * 2)) + Expect(ts.fedSamples).To(Equal((trace.MaxSnippetSeconds + 1) * liveSampleRate)) + + ts.record(nil) + Eventually(trace.GetBackendTraces).Should(HaveLen(1)) + got := trace.GetBackendTraces()[0] + Expect(got.Data["audio_duration_s"]).To(BeNumerically("~", float64(trace.MaxSnippetSeconds+1), 0.01)) + Expect(got.Data["audio_snippet_s"]).To(BeNumerically("~", float64(trace.MaxSnippetSeconds), 0.01)) + }) + + It("clamps out-of-range float samples instead of wrapping", func() { + ts := newLiveTraceState(modelCfg(), appConfig, "") + ts.addPCM([]float32{2.0, -2.0}) + Expect(ts.pcm).To(Equal([]byte{0xff, 0x7f, 0x00, 0x80})) // 32767, -32768 + }) + + It("stamps the close error on the trace", func() { + ts := newLiveTraceState(modelCfg(), appConfig, "") + ts.record(errors.New("stream torn down")) + + Eventually(trace.GetBackendTraces).Should(HaveLen(1)) + Expect(trace.GetBackendTraces()[0].Error).To(Equal("stream torn down")) + }) +}) diff --git a/core/config/meta/registry.go b/core/config/meta/registry.go index 3476076e1442..b8200cd4115b 100644 --- a/core/config/meta/registry.go +++ b/core/config/meta/registry.go @@ -567,6 +567,38 @@ func DefaultRegistry() map[string]FieldMetaOverride { Advanced: true, Order: 83, }, + "pipeline.turn_detection.type": { + Section: "pipeline", + Label: "Turn Detection", + Description: "Default turn-detection mode for realtime sessions on this pipeline. server_vad commits after a fixed silence window; semantic_vad lets the transcription model's end-of-utterance token drive a dynamic window (fast commit after the token, long eagerness fallback without it). semantic_vad requires a streaming-EOU transcription model (e.g. parakeet-cpp-realtime_eou_120m-v1) and degrades to silence-only otherwise. Clients can override per session via session.update.", + Component: "select", + Options: []FieldOption{ + {Value: "", Label: "Default (server_vad)"}, + {Value: "server_vad", Label: "server_vad (silence-based)"}, + {Value: "semantic_vad", Label: "semantic_vad (end-of-utterance token)"}, + }, + Order: 87, + }, + "pipeline.turn_detection.eagerness": { + Section: "pipeline", + Label: "Eagerness", + Description: "semantic_vad fallback silence window used when no end-of-utterance token was seen: low waits 8s, medium/auto 4s, high 2s.", + Component: "select", + Options: []FieldOption{ + {Value: "", Label: "Default (auto)"}, + {Value: "low", Label: "low (8s)"}, + {Value: "medium", Label: "medium (4s)"}, + {Value: "high", Label: "high (2s)"}, + }, + Order: 88, + }, + "pipeline.turn_detection.retranscribe": { + Section: "pipeline", + Label: "Retranscribe on Commit", + Description: "Cross-check every semantic_vad commit with an offline decode of the buffered turn: commit only proceeds when the batch decode also ends in the end-of-utterance token, and its transcript is used. Logs a streamed-vs-batch comparison — useful to gauge streaming/batch alignment — at the cost of one extra decode per turn.", + Component: "toggle", + Order: 89, + }, // --- Functions --- "function.grammar.parallel_calls": { diff --git a/core/config/model_config.go b/core/config/model_config.go index 8886ddfd5a5d..8e4538d92b03 100644 --- a/core/config/model_config.go +++ b/core/config/model_config.go @@ -650,6 +650,12 @@ type Pipeline struct { // VoiceRecognition gates the pipeline behind speaker verification. Nil // (block absent) means no gate, preserving existing behavior. VoiceRecognition *PipelineVoiceRecognition `yaml:"voice_recognition,omitempty" json:"voice_recognition,omitempty"` + + // TurnDetection sets the server-side default turn-detection mode for + // realtime sessions on this pipeline, so clients need no session.update + // to benefit. A client session.update still overrides type and eagerness + // per session; retranscribe is server-side only. Unset keeps server_vad. + TurnDetection PipelineTurnDetection `yaml:"turn_detection,omitempty" json:"turn_detection,omitempty"` } // PipelineCompaction configures summarize-then-drop for a realtime pipeline. @@ -934,6 +940,38 @@ func (v PipelineVoiceRecognition) Validate(registryAvailable bool) error { return nil } +// @Description PipelineTurnDetection sets realtime turn-detection defaults. +type PipelineTurnDetection struct { + // Type selects the default turn_detection mode for sessions on this + // pipeline: "server_vad" (silence-based) or "semantic_vad" (the + // transcription model's end-of-utterance token drives a dynamic silence + // window; needs a streaming-EOU transcription model such as + // parakeet_realtime_eou_120m-v1, degrades to silence-only otherwise). + Type string `yaml:"type,omitempty" json:"type,omitempty"` + // Eagerness is the semantic_vad fallback when no end-of-utterance token + // was seen: low waits 8s of silence, medium/auto 4s, high 2s. + Eagerness string `yaml:"eagerness,omitempty" json:"eagerness,omitempty"` + // Retranscribe (semantic_vad only) cross-checks every EOU-triggered + // commit with an offline decode of the buffered turn: the commit only + // proceeds when the batch decode also ends in the end-of-utterance token, + // and its transcript is the one used. The streamed and batch transcripts + // are compared in the logs — a diagnostic for streaming/batch alignment + // at the cost of one extra decode per turn. + Retranscribe *bool `yaml:"retranscribe,omitempty" json:"retranscribe,omitempty"` +} + +// TurnDetectionSemantic reports whether this pipeline defaults sessions to +// semantic (EOU-driven) turn detection. +func (p Pipeline) TurnDetectionSemantic() bool { + return strings.EqualFold(strings.TrimSpace(p.TurnDetection.Type), "semantic_vad") +} + +// TurnDetectionRetranscribe reports whether semantic_vad commits should be +// cross-checked (and transcribed) by an offline decode of the buffered turn. +func (p Pipeline) TurnDetectionRetranscribe() bool { + return p.TurnDetection.Retranscribe != nil && *p.TurnDetection.Retranscribe +} + // @Description File configuration for model downloads type File struct { Filename string `yaml:"filename,omitempty" json:"filename,omitempty"` diff --git a/core/config/pipeline_turn_detection_test.go b/core/config/pipeline_turn_detection_test.go new file mode 100644 index 000000000000..d2b11a115c65 --- /dev/null +++ b/core/config/pipeline_turn_detection_test.go @@ -0,0 +1,61 @@ +package config + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "gopkg.in/yaml.v3" +) + +// pipeline.turn_detection sets the server-side default turn-detection mode +// for realtime sessions. Unset keeps server_vad, so existing configs are +// unaffected; retranscribe is opt-in. +var _ = Describe("Pipeline turn_detection config", func() { + It("defaults to non-semantic with retranscribe off when unset", func() { + var p Pipeline + Expect(p.TurnDetectionSemantic()).To(BeFalse()) + Expect(p.TurnDetectionRetranscribe()).To(BeFalse()) + }) + + It("parses the nested turn_detection block from YAML", func() { + var c ModelConfig + err := yaml.Unmarshal([]byte(` +name: gpt-realtime +pipeline: + transcription: parakeet-cpp-realtime_eou_120m-v1 + turn_detection: + type: semantic_vad + eagerness: high + retranscribe: true +`), &c) + Expect(err).ToNot(HaveOccurred()) + Expect(c.Pipeline.TurnDetectionSemantic()).To(BeTrue()) + Expect(c.Pipeline.TurnDetection.Eagerness).To(Equal("high")) + Expect(c.Pipeline.TurnDetectionRetranscribe()).To(BeTrue()) + }) + + It("treats server_vad and unknown types as non-semantic", func() { + var p Pipeline + p.TurnDetection.Type = "server_vad" + Expect(p.TurnDetectionSemantic()).To(BeFalse()) + p.TurnDetection.Type = "something_else" + Expect(p.TurnDetectionSemantic()).To(BeFalse()) + }) + + It("matches semantic_vad case-insensitively with surrounding space", func() { + var p Pipeline + p.TurnDetection.Type = " Semantic_VAD " + Expect(p.TurnDetectionSemantic()).To(BeTrue()) + }) + + It("treats an explicit retranscribe false as off", func() { + var c ModelConfig + err := yaml.Unmarshal([]byte(` +pipeline: + turn_detection: + type: semantic_vad + retranscribe: false +`), &c) + Expect(err).ToNot(HaveOccurred()) + Expect(c.Pipeline.TurnDetectionRetranscribe()).To(BeFalse()) + }) +}) diff --git a/core/http/endpoints/openai/realtime.go b/core/http/endpoints/openai/realtime.go index d4d6a0ac40de..835c5b708859 100644 --- a/core/http/endpoints/openai/realtime.go +++ b/core/http/endpoints/openai/realtime.go @@ -288,6 +288,12 @@ type Model interface { // sound-event tags. topK caps the number of returned tags (0 = backend // default), threshold drops tags below the given score (0 = keep all). SoundDetection(ctx context.Context, audio string, topK int, threshold float32) (*schema.SoundClassificationResult, error) + // TranscribeLive opens a live (bidirectional) transcription session on the + // pipeline's transcription backend, used by semantic_vad turn detection; + // onEvent fires from a background goroutine for every delta/EOU/final + // event. Backends without live support fail with an error satisfying + // grpcerrors.IsLiveTranscriptionUnsupported. + TranscribeLive(ctx context.Context, language string, onEvent func(backend.LiveTranscriptionEvent)) (backend.LiveTranscriptionSession, error) PredictConfig() *config.ModelConfig } @@ -513,14 +519,10 @@ func runRealtimeSession(application *application.Application, t Transport, model // input_audio_buffer.commit. There is no transcription stage in that case. soundOnly := cfg.Pipeline.SoundDetection != "" && cfg.Pipeline.Transcription == "" && cfg.Pipeline.LLM == "" - turnDetection := &types.TurnDetectionUnion{ - ServerVad: &types.ServerVad{ - Threshold: 0.5, - PrefixPaddingMs: 300, - SilenceDurationMs: 500, - CreateResponse: true, - }, - } + // defaultTurnDetection seeds server_vad by default, or semantic_vad when the + // pipeline opts in (turn_detection.type: semantic_vad); clients can still + // override per session via session.update. + turnDetection := defaultTurnDetection(cfg) inputAudioTranscription := &types.AudioTranscription{Model: sttModel} if soundOnly { turnDetection = nil // turn_detection none: no VAD @@ -655,7 +657,7 @@ func runRealtimeSession(application *application.Application, t Transport, model vadServerStarted := false toggleVAD := func() { - if session.TurnDetection != nil && session.TurnDetection.ServerVad != nil && !vadServerStarted { + if turnDetectionActive(session.TurnDetection) && !vadServerStarted { xlog.Debug("Starting VAD goroutine...") done = make(chan struct{}) wg.Go(func() { @@ -663,7 +665,7 @@ func runRealtimeSession(application *application.Application, t Transport, model handleVAD(session, conversation, t, done) }) vadServerStarted = true - } else if (session.TurnDetection == nil || session.TurnDetection.ServerVad == nil) && vadServerStarted { + } else if !turnDetectionActive(session.TurnDetection) && vadServerStarted { xlog.Debug("Stopping VAD goroutine...") close(done) vadServerStarted = false @@ -811,11 +813,11 @@ func runRealtimeSession(application *application.Application, t Transport, model xlog.Debug("recv", "message", string(msg)) sessionLock.Lock() - isServerVAD := session.TurnDetection != nil && session.TurnDetection.ServerVad != nil + autoTurnDetection := turnDetectionActive(session.TurnDetection) sessionLock.Unlock() // TODO: At the least need to check locking and timer state in the VAD Go routine before allowing this - if isServerVAD { + if autoTurnDetection { sendNotImplemented(t, "input_audio_buffer.commit in conjunction with VAD") continue } @@ -1285,8 +1287,38 @@ func decodeOpusLoop(session *Session, opusBackend grpc.Backend, done chan struct } } +// noSpeechHoldbackSec is how much of the tail of an inspected, segment-free +// buffer survives the periodic no-speech clear. It must cover the VAD's +// onset-detection latency: a word can already be underway in the newest part +// of the window without silero having crossed its threshold yet, and clearing +// it cuts the start of the utterance the next tick will detect. +const noSpeechHoldbackSec = 0.5 + +// dropInspectedPrefix removes the head of the audio buffer that a VAD tick +// inspected (the first inspected bytes), keeping the newest holdbackBytes of +// that window plus everything appended while the tick ran — audio the VAD +// never saw. When something is dropped the result is a fresh copy, never a +// sub-slice, so later appends can't scribble on memory shared with the old +// backing array; when nothing is dropped buf is returned unchanged. +func dropInspectedPrefix(buf []byte, inspected, holdbackBytes int) []byte { + cut := inspected - holdbackBytes + if cut <= 0 { + return buf + } + if cut > len(buf) { + cut = len(buf) + } + return append([]byte(nil), buf[cut:]...) +} + // handleVAD is a goroutine that listens for audio data from the client, -// runs VAD on the audio data, and commits utterances to the conversation +// runs VAD on the audio data, and commits utterances to the conversation. +// +// With turn_detection.type == "semantic_vad" (sv != nil below) the silero +// loop is augmented by a live transcription stream: the buffer's new audio +// is fed to the transcription model every tick and its end-of-utterance +// token switches the commit threshold between a short post-EOU window and +// the long eagerness fallback. The server_vad path is untouched. func handleVAD(session *Session, conv *Conversation, t Transport, done chan struct{}) { vadContext, cancel := context.WithCancel(context.Background()) go func() { @@ -1299,6 +1331,9 @@ func handleVAD(session *Session, conv *Conversation, t Transport, done chan stru silenceThreshold = float64(session.TurnDetection.ServerVad.SilenceDurationMs) / 1000 } + lts := newLiveTurnState(session, t) + defer lts.discardTurn() + speechStarted := false startTime := time.Now() @@ -1310,6 +1345,23 @@ func handleVAD(session *Session, conv *Conversation, t Transport, done chan stru case <-done: return case <-ticker.C: + // Semantic mode is re-read each tick: session.update can switch + // turn-detection modes (and the retranscribe gate) mid-session. + sessionLock.Lock() + var sv *types.RealtimeSessionSemanticVad + if session.TurnDetection != nil { + sv = session.TurnDetection.SemanticVad + } + retranscribe := sv != nil && session.ModelConfig != nil && + session.ModelConfig.Pipeline.TurnDetectionRetranscribe() + sessionLock.Unlock() + + // session.update switched semantic -> server mid-turn: drop the + // orphaned live stream. + if sv == nil && lts.open() { + lts.discardTurn() + } + session.AudioBufferLock.Lock() allAudio := make([]byte, len(session.InputAudioBuffer)) copy(allAudio, session.InputAudioBuffer) @@ -1323,6 +1375,13 @@ func handleVAD(session *Session, conv *Conversation, t Transport, done chan stru // Resample from InputSampleRate to 16kHz aints = sound.ResampleInt16(aints, session.InputSampleRate, localSampleRate) + audioLength := float64(len(aints)) / localSampleRate + + if sv != nil && lts.open() { + lts.feedNewAudio(aints) + lts.drainEvents(audioLength) + } + segments, err := runVAD(vadContext, session, aints) if err != nil { if err.Error() == "unexpected speech end" { @@ -1334,19 +1393,36 @@ func handleVAD(session *Session, conv *Conversation, t Transport, done chan stru continue } - audioLength := float64(len(aints)) / localSampleRate - - // TODO: When resetting the buffer we should retain a small postfix + // NOTE: the no-speech clear and the min-buffer gate above stay on + // the short silenceThreshold even in semantic mode — the eagerness + // fallback applies only to the end-of-speech commit decision, or a + // low eagerness would delay speech_started/barge-in by seconds. if len(segments) == 0 && audioLength > silenceThreshold { + // "No segments" is not "no speech": silero (threshold 0.5) + // crosses up to a few hundred ms into a soft word onset, so + // the newest audio in the inspected window may be the start + // of a word the next tick will recognize — and more audio + // arrived while this tick ran. Keep both; drop only the + // older, confirmed-silent head, or utterance onsets get cut. + holdback := int(noSpeechHoldbackSec*float64(session.InputSampleRate)) * 2 session.AudioBufferLock.Lock() - session.InputAudioBuffer = nil + session.InputAudioBuffer = dropInspectedPrefix(session.InputAudioBuffer, len(allAudio), holdback) session.AudioBufferLock.Unlock() + if sv != nil { + lts.discardTurn() + } continue } else if len(segments) == 0 { continue } + // Speech began: start the turn's live stream and feed it the + // buffered prefix (including this tick's audio). + if sv != nil && !lts.open() && lts.openTurn(vadContext) { + lts.feedNewAudio(aints) + } + if !speechStarted { // Barge-in: cancel any in-flight response so we stop // sending audio and don't keep the interrupted reply in history. @@ -1361,16 +1437,70 @@ func handleVAD(session *Session, conv *Conversation, t Transport, done chan stru speechStarted = true } + if sv != nil { + // Drain again: events produced by THIS tick's feed have + // usually arrived by the time runVAD returns, and leaving + // them for the next tick adds 300ms to every EOU-triggered + // commit. + lts.drainEvents(audioLength) + } + // Segment still in progress when audio ended segEndTime := segments[len(segments)-1].End if segEndTime == 0 { continue } - if float32(audioLength)-segEndTime > float32(silenceThreshold) { + threshold := silenceThreshold + eouPending := false + if sv != nil { + eouPending = lts.eouPending(segments) + threshold = lts.thresholdSec(eouPending, sv) + } + + if float32(audioLength)-segEndTime > float32(threshold) { + if sv != nil { + trigger, eouLag := lts.commitTrigger(eouPending, float64(segEndTime)) + xlog.Info("semantic_vad: committing turn", + "trigger", trigger, + "speech_end_s", segEndTime, + "eou_lag_s", eouLag, + "silence_s", audioLength-float64(segEndTime), + "audio_s", audioLength) + } + // Retranscribe gate (semantic mode, EOU-triggered commits + // only): cross-check the streamed EOU with an offline decode + // of the buffered turn before committing. Runs synchronously + // on the tick — the engine would serialize a concurrent feed + // against it anyway. Timeout-triggered commits skip the gate. + var gated *schema.TranscriptionResult + if retranscribe && eouPending { + batch, gerr := transcribeUtterance(vadContext, sound.Int16toBytesLE(aints), session) + switch { + case gerr != nil: + xlog.Warn("semantic_vad: retranscribe gate failed; committing via the file path", "error", gerr) + case !batch.Eou: + xlog.Info("semantic_vad: batch decode did not confirm the streamed EOU; continuing to listen", + "streamed", lts.previewText(), "batch", batch.Text) + // The batch decode rejected the streamed EOU as a false + // positive: consume the recorded EOU so the next tick + // falls back to the eagerness window instead of + // re-triggering on the same token. + lts.eouAtSec = 0 + continue + default: + xlog.Info("semantic_vad: batch decode confirmed the streamed EOU", + "streamed", lts.previewText(), "batch", batch.Text) + gated = batch + } + } + xlog.Debug("Detected end of speech segment") session.AudioBufferLock.Lock() - session.InputAudioBuffer = nil + // Keep audio appended while this tick ran — it belongs to + // the next turn (in any mode: nil-ing it dropped the onset + // of an utterance started right after a commit). + session.InputAudioBuffer = dropInspectedPrefix(session.InputAudioBuffer, len(allAudio), 0) session.AudioBufferLock.Unlock() sendEvent(t, types.InputAudioBufferSpeechStoppedEvent{ @@ -1381,20 +1511,39 @@ func handleVAD(session *Session, conv *Conversation, t Transport, done chan stru }) speechStarted = false + // The committed item id must match the id the live caption + // deltas were streamed under, so the client's completed + // event replaces the partial text instead of duplicating it. + turnItemID := lts.itemID + if turnItemID == "" { + turnItemID = generateItemID() + } + sendEvent(t, types.InputAudioBufferCommittedEvent{ ServerEventBase: types.ServerEventBase{ EventID: "event_TODO", }, - ItemID: generateItemID(), + ItemID: turnItemID, PreviousItemID: "TODO", }) + // Finalize the turn's live stream (flushes the decode tail). + // In retranscribe mode the batch decode is the authoritative + // transcript, so the streamed one is dropped. + var live *liveUtterance + if sv != nil { + ut := lts.finishTurn(audioLength) + if !retranscribe { + live = ut + } + } + abytes := sound.Int16toBytesLE(aints) // TODO: Remove prefix silence that is is over TurnDetectionParams.PrefixPaddingMs respCtx, respDone := session.startResponse(vadContext) go func() { defer close(respDone) - commitUtterance(respCtx, abytes, session, conv, t) + commitUtteranceWithTranscript(respCtx, abytes, live, gated, turnItemID, session, conv, t) }() } } @@ -1402,6 +1551,19 @@ func handleVAD(session *Session, conv *Conversation, t Transport, done chan stru } func commitUtterance(ctx context.Context, utt []byte, session *Session, conv *Conversation, t Transport) { + commitUtteranceWithTranscript(ctx, utt, nil, nil, "", session, conv, t) +} + +// commitUtteranceWithTranscript commits one user turn. live carries the +// transcript semantic_vad's live stream already produced (its caption deltas +// were streamed to the client during the turn, so only the completed event +// is emitted here); gated carries the retranscribe gate's batch decode (the +// authoritative transcript in that mode). With neither — server_vad, manual +// commits, semantic degrade, or a live stream that heard nothing — the audio +// is written to a temp WAV and transcribed via the file path as before. +// itemID is the turn's conversation item id ("" mints a fresh one); it must +// match the id any live deltas were sent under. +func commitUtteranceWithTranscript(ctx context.Context, utt []byte, live *liveUtterance, gated *schema.TranscriptionResult, itemID string, session *Session, conv *Conversation, t Transport) { if len(utt) == 0 { return } @@ -1466,14 +1628,37 @@ func commitUtterance(ctx context.Context, utt []byte, session *Session, conv *Co } // TODO: If we have a real any-to-any model then transcription is optional + + // The turn's live captions (semantic_vad) already streamed under this + // itemID; the completed event below reuses it so the client replaces the + // partial text. server_vad / manual commits arrive with no itemID, so mint + // one here. + if itemID == "" { + itemID = generateItemID() + } + var transcript string switch { + case gated != nil: + // semantic_vad retranscribe gate: the batch decode is authoritative. + transcript = gated.Text + if err := emitPrecomputedTranscription(t, itemID, nil, transcript); err != nil { + sendError(t, "transcription_failed", err.Error(), "", "event_TODO") + return + } + case live != nil && live.Text != "": + // The caption deltas already streamed during the turn under this + // itemID; the completed event replaces the partial text client-side. + transcript = live.Text + if err := emitPrecomputedTranscription(t, itemID, nil, transcript); err != nil { + sendError(t, "transcription_failed", err.Error(), "", "event_TODO") + return + } case session.InputAudioTranscription != nil: // emitTranscription streams transcript deltas when // pipeline.streaming.transcription is set, otherwise emits a single // completed event; either way it returns the final transcript text. - var err error - transcript, err = emitTranscription(ctx, t, session, generateItemID(), f.Name()) + transcript, err = emitTranscription(ctx, t, session, itemID, f.Name()) if err != nil { // Drain the gate goroutine before returning so its in-flight read of // the temp WAV finishes before the deferred os.Remove fires. @@ -1642,6 +1827,56 @@ func writeWindowWAV(pcm []byte, sampleRate int) (string, error) { return f.Name(), nil } +// writeUtteranceWAV persists raw 16 kHz mono PCM to a temp WAV for the +// file-based transcription paths. The caller must invoke cleanup. +func writeUtteranceWAV(utt []byte) (string, func(), error) { + f, err := os.CreateTemp("", "realtime-audio-chunk-*.wav") + if err != nil { + return "", nil, err + } + cleanup := func() { + _ = f.Close() + _ = os.Remove(f.Name()) + } + xlog.Debug("Writing to file", "file", f.Name()) + + hdr := laudio.NewWAVHeader(uint32(len(utt))) + if err := hdr.Write(f); err != nil { + cleanup() + return "", nil, err + } + if _, err := f.Write(utt); err != nil { + cleanup() + return "", nil, err + } + _ = f.Sync() + return f.Name(), cleanup, nil +} + +// transcribeUtterance runs one offline (unary) decode of the buffered turn — +// the semantic_vad retranscribe gate. The result's Eou flag reports whether +// the batch decode also ended on the end-of-utterance token. +func transcribeUtterance(ctx context.Context, utt []byte, session *Session) (*schema.TranscriptionResult, error) { + path, cleanup, err := writeUtteranceWAV(utt) + if err != nil { + return nil, err + } + defer cleanup() + + language, prompt := "", "" + if cfg := session.InputAudioTranscription; cfg != nil { + language, prompt = cfg.Language, cfg.Prompt + } + tr, err := session.ModelInterface.Transcribe(ctx, path, language, false, false, prompt) + if err != nil { + return nil, err + } + if tr == nil { + return nil, fmt.Errorf("transcribe result is nil") + } + return tr, nil +} + func runVAD(ctx context.Context, session *Session, adata []int16) ([]schema.VADSegment, error) { soundIntBuffer := &audio.IntBuffer{ Format: &audio.Format{SampleRate: localSampleRate, NumChannels: 1}, diff --git a/core/http/endpoints/openai/realtime_doubles_test.go b/core/http/endpoints/openai/realtime_doubles_test.go index 10e608c17dbd..6dc1c6ca5796 100644 --- a/core/http/endpoints/openai/realtime_doubles_test.go +++ b/core/http/endpoints/openai/realtime_doubles_test.go @@ -74,6 +74,16 @@ type fakeModel struct { transcribeDeltas []string transcribeFinal *schema.TranscriptionResult + transcribeErr error + + // TranscribeLive scripting: liveErr makes the open fail (degrade path); + // liveEvents are delivered to onEvent synchronously at open; + // liveCloseEvents are delivered during Close (the finalize flush). + liveErr error + liveEvents []backend.LiveTranscriptionEvent + liveCloseEvents []backend.LiveTranscriptionEvent + liveOpened int + liveSession *fakeLiveSession // soundDetectionResult/soundDetectionErr drive the SoundDetection double so // the sound-event path can be exercised deterministically. @@ -97,7 +107,7 @@ func (m *fakeModel) VAD(context.Context, *schema.VADRequest) (*schema.VADRespons } func (m *fakeModel) Transcribe(context.Context, string, string, bool, bool, string) (*schema.TranscriptionResult, error) { - return m.transcribeFinal, nil + return m.transcribeFinal, m.transcribeErr } func (m *fakeModel) SoundDetection(context.Context, string, int, float32) (*schema.SoundClassificationResult, error) { @@ -150,4 +160,43 @@ func (m *fakeModel) TranscribeStream(_ context.Context, _, _ string, _, _ bool, return m.transcribeFinal, nil } +func (m *fakeModel) TranscribeLive(_ context.Context, _ string, onEvent func(backend.LiveTranscriptionEvent)) (backend.LiveTranscriptionSession, error) { + if m.liveErr != nil { + return nil, m.liveErr + } + m.liveOpened++ + for _, ev := range m.liveEvents { + onEvent(ev) + } + m.liveSession = &fakeLiveSession{onEvent: onEvent, closeEvents: m.liveCloseEvents} + return m.liveSession, nil +} + func (m *fakeModel) PredictConfig() *config.ModelConfig { return m.cfg } + +// fakeLiveSession records what semantic_vad fed and closed; closeEvents are +// replayed through onEvent during Close, mimicking the backend's finalize +// flush (trailing delta + Final) landing before Close returns. +type fakeLiveSession struct { + onEvent func(backend.LiveTranscriptionEvent) + closeEvents []backend.LiveTranscriptionEvent + fed [][]float32 + feedErr error + closed int +} + +func (s *fakeLiveSession) Feed(pcm []float32) error { + if s.feedErr != nil { + return s.feedErr + } + s.fed = append(s.fed, append([]float32(nil), pcm...)) + return nil +} + +func (s *fakeLiveSession) Close() error { + s.closed++ + for _, ev := range s.closeEvents { + s.onEvent(ev) + } + return nil +} diff --git a/core/http/endpoints/openai/realtime_model.go b/core/http/endpoints/openai/realtime_model.go index 6843a521d9dc..935fc79abd8a 100644 --- a/core/http/endpoints/openai/realtime_model.go +++ b/core/http/endpoints/openai/realtime_model.go @@ -102,6 +102,10 @@ func (m *transcriptOnlyModel) TranscribeStream(ctx context.Context, audio, langu return transcribeStream(ctx, m.modelLoader, *m.TranscriptionConfig, m.appConfig, audio, language, translate, diarize, prompt, onDelta) } +func (m *transcriptOnlyModel) TranscribeLive(ctx context.Context, language string, onEvent func(backend.LiveTranscriptionEvent)) (backend.LiveTranscriptionSession, error) { + return backend.ModelTranscriptionLive(ctx, language, m.modelLoader, *m.TranscriptionConfig, m.appConfig, onEvent) +} + func (m *transcriptOnlyModel) PredictConfig() *config.ModelConfig { return nil } @@ -348,6 +352,10 @@ func (m *wrappedModel) TranscribeStream(ctx context.Context, audio, language str return transcribeStream(ctx, m.modelLoader, *m.TranscriptionConfig, m.appConfig, audio, language, translate, diarize, prompt, onDelta) } +func (m *wrappedModel) TranscribeLive(ctx context.Context, language string, onEvent func(backend.LiveTranscriptionEvent)) (backend.LiveTranscriptionSession, error) { + return backend.ModelTranscriptionLive(ctx, language, m.modelLoader, *m.TranscriptionConfig, m.appConfig, onEvent) +} + func (m *wrappedModel) PredictConfig() *config.ModelConfig { return m.LLMConfig } diff --git a/core/http/endpoints/openai/realtime_semantic_vad.go b/core/http/endpoints/openai/realtime_semantic_vad.go new file mode 100644 index 000000000000..105bd416ecac --- /dev/null +++ b/core/http/endpoints/openai/realtime_semantic_vad.go @@ -0,0 +1,345 @@ +package openai + +import ( + "context" + "strings" + + "github.com/mudler/LocalAI/core/backend" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/http/endpoints/openai/types" + "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/xlog" +) + +// Semantic (EOU-driven) turn detection. +// +// With turn_detection.type == "semantic_vad", the transcription model is fed +// the microphone audio live while the user speaks and its end-of-utterance +// token turns the silence window dynamic: an immediate commit once the +// token fires (the model judged the user finished and expects a reply), the +// much longer eagerness fallback when it does not (mid-thought pause). The +// silero VAD stays in charge of speech_started/barge-in and the actual +// silence measurement, so a spurious EOU mid-speech cannot cut the user off +// — the commit still requires real silence. + +const ( + // semanticEouSilenceSec is the extra silence required to commit once the + // end-of-utterance token has fired. Zero: the token already trails the + // audio by the encoder chunk schedule plus a VAD tick (~0.3-0.9s), and + // the commit check only runs after silero closes the speech segment — + // which itself takes real silence — so any window on top is pure added + // response delay. + semanticEouSilenceSec = 0.0 + + // liveEventsBuffer sizes the recv-callback → VAD-tick handoff channel. + // Events arrive at a few per second and the ticker drains every 300ms; + // a full channel means the loop is wedged, and dropping (with a warning) + // beats blocking the backend's recv goroutine. + liveEventsBuffer = 64 +) + +// eagernessMaxSilenceSec maps the OpenAI semantic_vad eagerness to the +// fallback silence window used when no end-of-utterance token was seen: +// low waits longest, high responds fastest, auto/empty equals medium — +// the same 8s/4s/2s max timeouts OpenAI documents. +func eagernessMaxSilenceSec(eagerness string) float64 { + switch strings.ToLower(strings.TrimSpace(eagerness)) { + case "low": + return 8 + case "high": + return 2 + default: // "medium", "auto", "" + return 4 + } +} + +// liveUtterance is one committed turn's transcript as produced by the live +// stream. Its delta events were already streamed to the client as they +// arrived (keyed by the turn's item id), so only the final text travels here. +type liveUtterance struct { + Text string +} + +// liveTurnState is handleVAD's per-session live-ASR companion for +// semantic_vad. One live stream is opened per user turn (begun when the VAD +// first reports speech, finalized at commit) — the underlying decode session +// grows with fed audio, so per-turn streams keep it bounded. All fields are +// owned by the handleVAD goroutine; the backend's recv callback only writes +// into the buffered events channel. +type liveTurnState struct { + session *Session + transport Transport // live caption deltas are sent here as they drain + events chan backend.LiveTranscriptionEvent + + live backend.LiveTranscriptionSession // nil between turns + unavailable bool // sticky: backend can't do live ASR, degrade for the session + + fed16k int // 16k samples of the current buffer already fed + // eouAtSec is the audio time of the most recent EOU this turn (0 = none). + // It is a recorded fact: set when an EOU drains and never toggled off + // mid-turn. Whether it still governs the trailing silence is derived + // purely by eouPending() from this plus the live VAD segments. + eouAtSec float64 + parts []string // deltas accumulated for the current turn + finalText string // authoritative full-turn text from the Final event + itemID string // the turn's conversation item id, allocated at openTurn + deltasSent bool // at least one caption delta reached the client this turn +} + +func newLiveTurnState(session *Session, transport Transport) *liveTurnState { + return &liveTurnState{ + session: session, + transport: transport, + events: make(chan backend.LiveTranscriptionEvent, liveEventsBuffer), + } +} + +func (l *liveTurnState) open() bool { return l.live != nil } + +// openTurn starts the turn's live stream. A failure (most commonly the +// backend's typed "live transcription unsupported" signal) degrades the +// whole session to silence-only detection — warned once, then sticky. +func (l *liveTurnState) openTurn(ctx context.Context) bool { + if l.live != nil { + return true + } + if l.unavailable { + return false + } + language := "" + if l.session.InputAudioTranscription != nil { + language = l.session.InputAudioTranscription.Language + } + live, err := l.session.ModelInterface.TranscribeLive(ctx, language, func(ev backend.LiveTranscriptionEvent) { + select { + case l.events <- ev: + default: + xlog.Warn("semantic_vad: live transcription event dropped (event channel full)") + } + }) + if err != nil { + l.unavailable = true + xlog.Warn("semantic_vad: live transcription unavailable; degrading to silence-only turn detection", + "error", err) + return false + } + l.resetTurn() + l.live = live + // The item id is allocated when the turn STARTS so caption deltas can + // stream to the client while the user is still speaking; the committed + // event and the final transcript reuse it, replacing the partial text. + l.itemID = generateItemID() + return true +} + +// feedNewAudio pushes the not-yet-fed tail of the resampled buffer to the +// live stream. The final sample is held back: ResampleInt16 is prefix-stable +// except for its last output sample, so excluding it keeps successive +// whole-buffer resamples bit-identical over the fed range. +func (l *liveTurnState) feedNewAudio(aints16k []int16) { + if l.live == nil { + return + } + end := len(aints16k) - 1 + if end <= l.fed16k { + return + } + if err := l.live.Feed(int16sToFloat32(aints16k[l.fed16k:end])); err != nil { + xlog.Warn("semantic_vad: live feed failed; degrading to silence-only turn detection", "error", err) + l.discardTurn() + l.unavailable = true + return + } + l.fed16k = end +} + +// drainEvents folds everything the live stream produced since the last tick +// into the turn state. audioSec (the current buffer length in seconds) marks +// WHEN an EOU was observed, so later VAD segments can distinguish speech +// that resumed after it. +func (l *liveTurnState) drainEvents(audioSec float64) { + for { + select { + case ev := <-l.events: + if ev.Delta != "" { + l.parts = append(l.parts, ev.Delta) + // Live captions: forward the delta immediately under the + // turn's item id — the browser shows text while the user + // is still speaking; the completed event at commit + // replaces it with the authoritative transcript. + if l.transport != nil && l.itemID != "" { + sendEvent(l.transport, types.ConversationItemInputAudioTranscriptionDeltaEvent{ + ServerEventBase: types.ServerEventBase{EventID: "event_TODO"}, + ItemID: l.itemID, + ContentIndex: 0, + Delta: ev.Delta, + }) + l.deltasSent = true + } + } + if ev.Eou { + // Record the position; do not flip a flag. Whether this EOU + // still applies to the trailing silence is decided later by + // eouPending(), purely from this and the live VAD segments. + l.eouAtSec = audioSec + xlog.Debug("semantic_vad: EOU token observed", "audio_s", audioSec) + } + if ev.Eob { + // A backchannel ended ("uh-huh") — the user is still + // listening, not yielding the turn. Deliberately NOT a + // commit trigger. + xlog.Debug("semantic_vad: EOB (backchannel) observed", "audio_s", audioSec) + } + if ev.Final != nil && strings.TrimSpace(ev.Final.Text) != "" { + l.finalText = ev.Final.Text + } + default: + return + } + } +} + +// eouPending reports whether the recorded EOU still applies to the current +// trailing silence. It is a pure function of the recorded EOU position and the +// VAD's live view — there is no stored boolean that can fall out of sync. +// +// An EOU stops applying only once the user has STARTED a new utterance after +// it (a segment whose start is past the EOU): that is genuine resumed speech, +// so the earlier yield no longer holds. An in-progress segment whose speech +// began BEFORE the EOU is NOT resumed speech — it is just silero still padding +// before it closes the segment, which is the normal state at the instant the +// (predictive) EOU fires. Treating that as resumed speech was the bug that +// cleared the flag on the very tick the token arrived, dropping almost every +// EOU to the eagerness timeout. +func (l *liveTurnState) eouPending(segments []schema.VADSegment) bool { + if l.eouAtSec == 0 || len(segments) == 0 { + return false + } + last := segments[len(segments)-1] + return float64(last.Start) <= l.eouAtSec +} + +// thresholdSec is the dynamic commit threshold: zero once the model said +// the utterance is over (any VAD-confirmed silence commits), the eagerness +// fallback otherwise. +func (l *liveTurnState) thresholdSec(eouPending bool, sv *types.RealtimeSessionSemanticVad) float64 { + if eouPending { + return semanticEouSilenceSec + } + return eagernessMaxSilenceSec(sv.Eagerness) +} + +// commitTrigger describes how a commit decision was reached, for the per-turn +// timing log: "eou" with the token's lag behind the VAD's speech end, or +// "timeout" when the eagerness fallback elapsed without one. The lag is the +// number the user needs to tell a slow EOU emission apart from loop overhead. +func (l *liveTurnState) commitTrigger(eouPending bool, speechEndSec float64) (trigger string, eouLagSec float64) { + if !eouPending { + return "timeout", 0 + } + return "eou", l.eouAtSec - speechEndSec +} + +// finishTurn finalizes the live stream (flushing the decode tail — the last +// ~2 encoder frames of text only appear here), folds the terminal events in, +// and returns the turn's transcript. Returns nil when the stream never +// produced text (the VAD triggered on something the model heard nothing in). +func (l *liveTurnState) finishTurn(audioSec float64) *liveUtterance { + if l.live == nil { + return nil + } + if err := l.live.Close(); err != nil { + xlog.Warn("semantic_vad: live transcription finalize failed", "error", err) + } + l.live = nil + l.drainEvents(audioSec) + + text := strings.TrimSpace(l.finalText) + if text == "" { + text = l.previewText() + } + ut := &liveUtterance{Text: text} + l.resetTurn() + if ut.Text == "" { + return nil + } + return ut +} + +// discardTurn drops the current turn (no-speech buffer clear, feed failure, +// session teardown): the stream is closed and its transcript thrown away. +// Any caption deltas already shown for it are retracted via the failed +// event, so the client doesn't keep a stuck partial entry. +func (l *liveTurnState) discardTurn() { + if l.live != nil { + _ = l.live.Close() + l.live = nil + } + l.drainEvents(0) + if l.deltasSent && l.transport != nil && l.itemID != "" { + sendEvent(l.transport, types.ConversationItemInputAudioTranscriptionFailedEvent{ + ServerEventBase: types.ServerEventBase{EventID: "event_TODO"}, + ItemID: l.itemID, + ContentIndex: 0, + Error: types.Error{ + Type: "transcription_discarded", + Message: "turn discarded before commit", + }, + }) + } + l.resetTurn() +} + +func (l *liveTurnState) resetTurn() { + l.fed16k = 0 + l.eouAtSec = 0 + l.parts = nil + l.finalText = "" + l.itemID = "" + l.deltasSent = false +} + +// previewText is the turn's transcript so far (for the retranscribe +// comparison log and as the fallback when no Final event arrived). +func (l *liveTurnState) previewText() string { + return strings.TrimSpace(strings.Join(l.parts, "")) +} + +// int16sToFloat32 converts PCM to the [-1,1] float form the live stream +// feeds the model (the same scaling runVAD's go-audio conversion applies). +func int16sToFloat32(samples []int16) []float32 { + out := make([]float32, len(samples)) + for i, s := range samples { + out[i] = float32(s) / 32768.0 + } + return out +} + +// turnDetectionActive reports whether the session has any automatic turn +// detection (server or semantic VAD) that should run the handleVAD loop. +func turnDetectionActive(td *types.TurnDetectionUnion) bool { + return td != nil && (td.ServerVad != nil || td.SemanticVad != nil) +} + +// defaultTurnDetection seeds a new session's turn detection from the +// pipeline's server-side default: semantic_vad pipelines start sessions in +// semantic mode (clients can still override via session.update); everything +// else keeps the historical server_vad defaults. +func defaultTurnDetection(cfg *config.ModelConfig) *types.TurnDetectionUnion { + if cfg != nil && cfg.Pipeline.TurnDetectionSemantic() { + return &types.TurnDetectionUnion{ + SemanticVad: &types.RealtimeSessionSemanticVad{ + CreateResponse: true, + Eagerness: cfg.Pipeline.TurnDetection.Eagerness, + }, + } + } + return &types.TurnDetectionUnion{ + ServerVad: &types.ServerVad{ + Threshold: 0.5, + PrefixPaddingMs: 300, + SilenceDurationMs: 500, + CreateResponse: true, + }, + } +} diff --git a/core/http/endpoints/openai/realtime_semantic_vad_test.go b/core/http/endpoints/openai/realtime_semantic_vad_test.go new file mode 100644 index 000000000000..92abdb2d5fa7 --- /dev/null +++ b/core/http/endpoints/openai/realtime_semantic_vad_test.go @@ -0,0 +1,414 @@ +package openai + +import ( + "context" + "errors" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/mudler/LocalAI/core/backend" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/http/endpoints/openai/types" + "github.com/mudler/LocalAI/core/schema" +) + +var _ = Describe("eagernessMaxSilenceSec", func() { + DescribeTable("maps eagerness to the no-EOU fallback window", + func(eagerness string, want float64) { + Expect(eagernessMaxSilenceSec(eagerness)).To(Equal(want)) + }, + Entry("low", "low", 8.0), + Entry("medium", "medium", 4.0), + Entry("high", "high", 2.0), + Entry("auto equals medium", "auto", 4.0), + Entry("empty equals medium", "", 4.0), + Entry("case and space insensitive", " High ", 2.0), + Entry("unknown equals medium", "frantic", 4.0), + ) +}) + +var _ = Describe("turnDetectionActive", func() { + It("is active for server and semantic VAD, inactive otherwise", func() { + Expect(turnDetectionActive(nil)).To(BeFalse()) + Expect(turnDetectionActive(&types.TurnDetectionUnion{})).To(BeFalse()) + Expect(turnDetectionActive(&types.TurnDetectionUnion{ServerVad: &types.ServerVad{}})).To(BeTrue()) + Expect(turnDetectionActive(&types.TurnDetectionUnion{SemanticVad: &types.RealtimeSessionSemanticVad{}})).To(BeTrue()) + }) +}) + +var _ = Describe("defaultTurnDetection", func() { + It("keeps the historical server_vad defaults for non-semantic pipelines", func() { + td := defaultTurnDetection(&config.ModelConfig{}) + Expect(td.ServerVad).NotTo(BeNil()) + Expect(td.SemanticVad).To(BeNil()) + Expect(td.ServerVad.SilenceDurationMs).To(Equal(int64(500))) + Expect(td.ServerVad.CreateResponse).To(BeTrue()) + }) + + It("seeds semantic_vad with the pipeline's eagerness", func() { + cfg := &config.ModelConfig{} + cfg.Pipeline.TurnDetection.Type = "semantic_vad" + cfg.Pipeline.TurnDetection.Eagerness = "high" + td := defaultTurnDetection(cfg) + Expect(td.SemanticVad).NotTo(BeNil()) + Expect(td.ServerVad).To(BeNil()) + Expect(td.SemanticVad.Eagerness).To(Equal("high")) + Expect(td.SemanticVad.CreateResponse).To(BeTrue()) + }) + + It("treats a nil config as server_vad", func() { + Expect(defaultTurnDetection(nil).ServerVad).NotTo(BeNil()) + }) +}) + +var _ = Describe("int16sToFloat32", func() { + It("scales like the VAD conversion", func() { + out := int16sToFloat32([]int16{0, 16384, -32768}) + Expect(out).To(HaveLen(3)) + Expect(out[0]).To(BeNumerically("~", 0.0, 1e-6)) + Expect(out[1]).To(BeNumerically("~", 0.5, 1e-6)) + Expect(out[2]).To(BeNumerically("~", -1.0, 1e-6)) + }) +}) + +var _ = Describe("liveTurnState", func() { + var ( + m *fakeModel + lts *liveTurnState + ftr *fakeTransport + ) + + newSemanticSession := func(m *fakeModel) *Session { + return &Session{ + InputAudioTranscription: &types.AudioTranscription{}, + ModelInterface: m, + } + } + + BeforeEach(func() { + m = &fakeModel{} + ftr = &fakeTransport{} + lts = newLiveTurnState(newSemanticSession(m), ftr) + }) + + Describe("openTurn", func() { + It("opens once per turn and reports open()", func() { + Expect(lts.open()).To(BeFalse()) + Expect(lts.openTurn(context.Background())).To(BeTrue()) + Expect(lts.open()).To(BeTrue()) + Expect(lts.openTurn(context.Background())).To(BeTrue(), "idempotent while open") + Expect(m.liveOpened).To(Equal(1)) + }) + + It("degrades stickily when the backend cannot do live transcription", func() { + m.liveErr = errors.New("rpc error: code = Unimplemented desc = live transcription unsupported") + Expect(lts.openTurn(context.Background())).To(BeFalse()) + Expect(lts.unavailable).To(BeTrue()) + + // Later turns never retry: the failure is per-session sticky. + m.liveErr = nil + Expect(lts.openTurn(context.Background())).To(BeFalse()) + Expect(m.liveOpened).To(Equal(0)) + }) + }) + + Describe("feedNewAudio", func() { + It("feeds only the unfed tail and holds back the final resampled sample", func() { + Expect(lts.openTurn(context.Background())).To(BeTrue()) + + lts.feedNewAudio([]int16{1, 2, 3, 4}) + Expect(m.liveSession.fed).To(HaveLen(1)) + Expect(m.liveSession.fed[0]).To(HaveLen(3), "last sample held back") + + // Same buffer grown by two samples: only the delta is fed. + lts.feedNewAudio([]int16{1, 2, 3, 4, 5, 6}) + Expect(m.liveSession.fed).To(HaveLen(2)) + Expect(m.liveSession.fed[1]).To(HaveLen(2)) + + // No growth past the holdback: nothing fed. + lts.feedNewAudio([]int16{1, 2, 3, 4, 5, 6}) + Expect(m.liveSession.fed).To(HaveLen(2)) + }) + + It("degrades and closes the turn when a feed fails", func() { + Expect(lts.openTurn(context.Background())).To(BeTrue()) + m.liveSession.feedErr = errors.New("backend gone") + sess := m.liveSession + + lts.feedNewAudio([]int16{1, 2, 3, 4}) + + Expect(lts.open()).To(BeFalse()) + Expect(lts.unavailable).To(BeTrue()) + Expect(sess.closed).To(Equal(1)) + }) + }) + + Describe("event handling and the dynamic threshold", func() { + sv := &types.RealtimeSessionSemanticVad{Eagerness: "high"} + + It("uses the eagerness fallback until an EOU is recorded, then commits without an extra window", func() { + Expect(lts.thresholdSec(false, sv)).To(Equal(2.0)) + Expect(lts.thresholdSec(true, sv)).To(Equal(semanticEouSilenceSec)) + + Expect(lts.openTurn(context.Background())).To(BeTrue()) + lts.session.ModelInterface.(*fakeModel).liveSession.onEvent(backend.LiveTranscriptionEvent{Delta: "hello ", Eou: false}) + lts.session.ModelInterface.(*fakeModel).liveSession.onEvent(backend.LiveTranscriptionEvent{Eou: true}) + lts.drainEvents(3.3) + + Expect(lts.eouAtSec).To(BeNumerically("~", 3.3, 1e-9)) + Expect(lts.previewText()).To(Equal("hello")) + }) + + // The bug this replaces: the (predictive) EOU routinely arrives while + // silero is still padding the speech segment open. eouPending must NOT + // read that as resumed speech. + It("keeps the EOU pending while silero is still closing the same segment", func() { + lts.eouAtSec = 3.3 + Expect(lts.eouPending([]schema.VADSegment{{Start: 0, End: 0}})).To(BeTrue(), "segment began before the EOU and is merely unclosed") + Expect(lts.eouPending([]schema.VADSegment{{Start: 0, End: 3.0}})).To(BeTrue(), "and still pending once it closes") + }) + + It("drops the EOU only when a new utterance starts after it (resumed speech)", func() { + lts.eouAtSec = 3.3 + Expect(lts.eouPending([]schema.VADSegment{{Start: 0, End: 3.0}, {Start: 4.0, End: 0}})).To(BeFalse()) + Expect(lts.eouPending([]schema.VADSegment{{Start: 0, End: 3.0}, {Start: 4.0, End: 5.0}})).To(BeFalse()) + }) + + It("has no pending EOU before one is recorded", func() { + Expect(lts.eouPending([]schema.VADSegment{{Start: 0, End: 3.0}})).To(BeFalse()) + Expect(lts.eouPending(nil)).To(BeFalse()) + }) + + It("does not arm the commit threshold on an EOB backchannel", func() { + Expect(lts.openTurn(context.Background())).To(BeTrue()) + lts.session.ModelInterface.(*fakeModel).liveSession.onEvent(backend.LiveTranscriptionEvent{Delta: "uh-huh", Eob: true}) + lts.drainEvents(2.0) + + Expect(lts.eouAtSec).To(BeZero(), "a backchannel is not the user yielding the turn") + Expect(lts.eouPending([]schema.VADSegment{{Start: 0, End: 1.8}})).To(BeFalse(), "still on the eagerness fallback") + Expect(lts.previewText()).To(Equal("uh-huh"), "the backchannel text still lands in the transcript") + }) + + It("reports the commit trigger and the EOU token's lag behind speech end", func() { + trigger, lag := lts.commitTrigger(false, 3.2) + Expect(trigger).To(Equal("timeout")) + Expect(lag).To(BeZero()) + + lts.eouAtSec = 3.5 + trigger, lag = lts.commitTrigger(true, 3.2) + Expect(trigger).To(Equal("eou")) + Expect(lag).To(BeNumerically("~", 0.3, 1e-9)) + }) + }) + + Describe("finishTurn", func() { + It("finalizes the stream, prefers the Final text, and resets for the next turn", func() { + m.liveCloseEvents = []backend.LiveTranscriptionEvent{ + {Delta: " world"}, + {Final: &schema.TranscriptionResult{Text: "hello world", Eou: true}}, + } + Expect(lts.openTurn(context.Background())).To(BeTrue()) + sess := m.liveSession + sess.onEvent(backend.LiveTranscriptionEvent{Delta: "hello", Eou: true}) + lts.drainEvents(2.0) + + ut := lts.finishTurn(2.5) + + Expect(sess.closed).To(Equal(1)) + Expect(ut).NotTo(BeNil()) + Expect(ut.Text).To(Equal("hello world"), "Final event text wins over joined deltas") + Expect(lts.open()).To(BeFalse()) + Expect(lts.eouAtSec).To(BeZero()) + Expect(lts.parts).To(BeEmpty()) + Expect(lts.fed16k).To(BeZero()) + }) + + It("returns nil when the stream heard nothing", func() { + Expect(lts.openTurn(context.Background())).To(BeTrue()) + Expect(lts.finishTurn(1.0)).To(BeNil()) + Expect(m.liveSession.closed).To(Equal(1)) + }) + + It("is a no-op without an open stream", func() { + Expect(lts.finishTurn(1.0)).To(BeNil()) + }) + }) + + Describe("discardTurn", func() { + It("closes the stream, drops the transcript and retracts streamed captions", func() { + Expect(lts.openTurn(context.Background())).To(BeTrue()) + sess := m.liveSession + sess.onEvent(backend.LiveTranscriptionEvent{Delta: "noise"}) + lts.drainEvents(1.0) + + lts.discardTurn() + + Expect(sess.closed).To(Equal(1)) + Expect(lts.open()).To(BeFalse()) + Expect(lts.parts).To(BeEmpty()) + Expect(ftr.countEvents(types.ServerEventTypeConversationItemInputAudioTranscriptionFailed)).To(Equal(1), + "the client saw caption deltas for this turn — it must be told to drop them") + }) + + It("sends no failed event when no captions ever reached the client", func() { + Expect(lts.openTurn(context.Background())).To(BeTrue()) + lts.discardTurn() + Expect(ftr.countEvents(types.ServerEventTypeConversationItemInputAudioTranscriptionFailed)).To(Equal(0)) + }) + }) + + Describe("live captions", func() { + It("streams each delta to the client under the turn's item id as it drains", func() { + Expect(lts.openTurn(context.Background())).To(BeTrue()) + turnID := lts.itemID + Expect(turnID).NotTo(BeEmpty(), "the item id exists from turn open so captions can reference it") + + m.liveSession.onEvent(backend.LiveTranscriptionEvent{Delta: "hel"}) + m.liveSession.onEvent(backend.LiveTranscriptionEvent{Delta: "lo"}) + lts.drainEvents(1.0) + + var got []types.ConversationItemInputAudioTranscriptionDeltaEvent + for _, e := range ftr.events { + if d, ok := e.(types.ConversationItemInputAudioTranscriptionDeltaEvent); ok { + got = append(got, d) + } + } + Expect(got).To(HaveLen(2)) + Expect(got[0].Delta).To(Equal("hel")) + Expect(got[1].Delta).To(Equal("lo")) + Expect(got[0].ItemID).To(Equal(turnID)) + Expect(got[1].ItemID).To(Equal(turnID)) + Expect(lts.deltasSent).To(BeTrue()) + }) + + It("finishTurn does not retract captions — the commit's completed event supersedes them", func() { + Expect(lts.openTurn(context.Background())).To(BeTrue()) + m.liveSession.onEvent(backend.LiveTranscriptionEvent{Delta: "hello"}) + lts.drainEvents(1.0) + + Expect(lts.finishTurn(1.5)).NotTo(BeNil()) + Expect(ftr.countEvents(types.ServerEventTypeConversationItemInputAudioTranscriptionFailed)).To(Equal(0)) + }) + }) +}) + +// commitUtteranceWithTranscript routes the three transcript sources: the +// retranscribe gate's batch decode, the live stream's accumulated text, and +// the historical file path. +var _ = Describe("commitUtteranceWithTranscript", func() { + newTranscriptionOnlySession := func(m *fakeModel, streamTranscription bool) *Session { + cfg := &config.ModelConfig{} + if streamTranscription { + on := true + cfg.Pipeline.Streaming.Transcription = &on + } + return &Session{ + TranscriptionOnly: true, // stop after the transcript: no LLM/TTS in these specs + InputAudioTranscription: &types.AudioTranscription{}, + ModelConfig: cfg, + ModelInterface: m, + } + } + + It("uses the gate's batch transcript and never re-runs the backend", func() { + m := &fakeModel{transcribeErr: errors.New("must not be called")} + session := newTranscriptionOnlySession(m, true) + tr := &fakeTransport{} + + commitUtteranceWithTranscript(context.Background(), []byte{1, 2}, nil, + &schema.TranscriptionResult{Text: "batch text", Eou: true}, "item_turn", session, &Conversation{}, tr) + + Expect(tr.countEvents(types.ServerEventTypeConversationItemInputAudioTranscriptionDelta)).To(Equal(0)) + Expect(tr.countEvents(types.ServerEventTypeConversationItemInputAudioTranscriptionCompleted)).To(Equal(1)) + }) + + It("emits only the completed event for a live transcript — captions already streamed during the turn", func() { + m := &fakeModel{transcribeErr: errors.New("must not be called")} + session := newTranscriptionOnlySession(m, true) + tr := &fakeTransport{} + + commitUtteranceWithTranscript(context.Background(), []byte{1, 2}, + &liveUtterance{Text: "hello"}, nil, "item_turn", session, &Conversation{}, tr) + + Expect(tr.countEvents(types.ServerEventTypeConversationItemInputAudioTranscriptionDelta)).To(Equal(0)) + Expect(tr.countEvents(types.ServerEventTypeConversationItemInputAudioTranscriptionCompleted)).To(Equal(1)) + + var completed types.ConversationItemInputAudioTranscriptionCompletedEvent + for _, e := range tr.events { + if c, ok := e.(types.ConversationItemInputAudioTranscriptionCompletedEvent); ok { + completed = c + } + } + Expect(completed.ItemID).To(Equal("item_turn"), + "completed must reuse the caption deltas' item id so the client replaces, not duplicates") + Expect(completed.Transcript).To(Equal("hello")) + }) + + It("falls back to the file path when the live stream heard nothing", func() { + m := &fakeModel{transcribeFinal: &schema.TranscriptionResult{Text: "from file"}} + session := newTranscriptionOnlySession(m, false) + tr := &fakeTransport{} + + commitUtteranceWithTranscript(context.Background(), []byte{1, 2}, + &liveUtterance{}, nil, "", session, &Conversation{}, tr) + + Expect(tr.countEvents(types.ServerEventTypeConversationItemInputAudioTranscriptionCompleted)).To(Equal(1)) + }) +}) + +// transcribeUtterance is the retranscribe gate's offline decode of the +// buffered turn. +var _ = Describe("transcribeUtterance", func() { + It("returns the batch decode with its Eou flag", func() { + m := &fakeModel{transcribeFinal: &schema.TranscriptionResult{Text: "confirmed", Eou: true}} + session := &Session{ + InputAudioTranscription: &types.AudioTranscription{}, + ModelInterface: m, + } + + tr, err := transcribeUtterance(context.Background(), []byte{0, 0, 1, 1}, session) + Expect(err).ToNot(HaveOccurred()) + Expect(tr.Text).To(Equal("confirmed")) + Expect(tr.Eou).To(BeTrue()) + }) + + It("propagates backend errors", func() { + m := &fakeModel{transcribeErr: errors.New("engine fell over")} + session := &Session{ + InputAudioTranscription: &types.AudioTranscription{}, + ModelInterface: m, + } + + _, err := transcribeUtterance(context.Background(), []byte{0, 0}, session) + Expect(err).To(MatchError(ContainSubstring("engine fell over"))) + }) +}) + +// emitPrecomputedTranscription replays an already-produced transcript as the +// standard delta/completed event sequence. +var _ = Describe("emitPrecomputedTranscription", func() { + It("emits deltas then completed, sharing the item id", func() { + tr := &fakeTransport{} + Expect(emitPrecomputedTranscription(tr, "item42", []string{"a", "", "b"}, "ab")).To(Succeed()) + + Expect(tr.countEvents(types.ServerEventTypeConversationItemInputAudioTranscriptionDelta)).To(Equal(2), "empty deltas skipped") + Expect(tr.countEvents(types.ServerEventTypeConversationItemInputAudioTranscriptionCompleted)).To(Equal(1)) + for _, e := range tr.events { + switch ev := e.(type) { + case types.ConversationItemInputAudioTranscriptionDeltaEvent: + Expect(ev.ItemID).To(Equal("item42")) + case types.ConversationItemInputAudioTranscriptionCompletedEvent: + Expect(ev.ItemID).To(Equal("item42")) + Expect(ev.Transcript).To(Equal("ab")) + } + } + }) + + It("emits only the completed event with no deltas", func() { + tr := &fakeTransport{} + Expect(emitPrecomputedTranscription(tr, "item1", nil, "hi")).To(Succeed()) + Expect(tr.countEvents(types.ServerEventTypeConversationItemInputAudioTranscriptionDelta)).To(Equal(0)) + Expect(tr.countEvents(types.ServerEventTypeConversationItemInputAudioTranscriptionCompleted)).To(Equal(1)) + }) +}) diff --git a/core/http/endpoints/openai/realtime_stream.go b/core/http/endpoints/openai/realtime_stream.go index 909fc50dc514..60ed9c281d96 100644 --- a/core/http/endpoints/openai/realtime_stream.go +++ b/core/http/endpoints/openai/realtime_stream.go @@ -161,24 +161,30 @@ func streamLLMResponse(ctx context.Context, session *Session, conv *Conversation streamer.announce = announce // Clause chunking (opt-in): synthesize each clause as soon as it completes - // instead of buffering the whole reply. streamedAudio accumulates the PCM - // across clauses for the conversation item record; ttsErr captures the first - // synthesis failure so the token callback can stop the prediction. emitSpeech - // runs synchronously here — the LLM keeps generating into the gRPC stream - // while a clause is synthesized, so audio still starts mid-generation. + // instead of buffering the whole reply. Synthesis runs on a worker goroutine + // (ttsPipeline) rather than inline in the token callback: emitSpeech blocks + // until the whole clause is synthesized (and, for WebRTC, played back at + // real time), and the callback runs on the goroutine that drains the LLM + // gRPC stream — so speaking inline stalls generation and freezes the + // assistant transcript at every clause boundary. The worker lets generation + // and the transcript stream keep flowing while audio is produced behind them. var chunker *clauseChunker + var ttsPipe *ttsPipeline if session.ModelConfig != nil && session.ModelConfig.Pipeline.ChunkClauses() { chunker = newClauseChunker(defaultClauseMinRunes, defaultClauseMaxRunes) + ttsPipe = newTTSPipeline(func(clause string) ([]byte, error) { + return emitSpeech(ctx, t, session, responseID, itemID, clause) + }) } var streamedAudio []byte var ttsErr error - speakClause := func(clause string) error { - a, err := emitSpeech(ctx, t, session, responseID, itemID, clause) - if err != nil { - return err - } - streamedAudio = append(streamedAudio, a...) - return nil + + // Backstop: always join the TTS worker, even on an unexpected early return. + // wait() is idempotent, so the explicit drain below (which captures the + // streamed audio and first error) stays authoritative; this only guarantees + // the goroutine can never leak if a new return path is added. + if ttsPipe != nil { + defer func() { _, _ = ttsPipe.wait() }() } // fail reports a mid-stream failure. A cancelled context means the client @@ -207,8 +213,12 @@ func streamLLMResponse(ctx context.Context, session *Session, conv *Conversation delta := streamer.onToken(text) if chunker != nil && delta != "" { for _, clause := range chunker.push(delta) { - if ttsErr = speakClause(clause); ttsErr != nil { - return false // stop the prediction; reported after predFunc returns + // Hand the clause to the worker and keep going — never block the + // recv loop on synthesis. A false return means a prior clause + // already failed; stop the prediction (the error is collected + // from the pipeline after predFunc returns). + if !ttsPipe.enqueue(clause) { + return false } } } @@ -217,10 +227,27 @@ func streamLLMResponse(ctx context.Context, session *Session, conv *Conversation predFunc, err := session.ModelInterface.Predict(ctx, history, images, nil, nil, cb, tools, toolChoice, nil, nil, nil) if err != nil { + // The deferred wait() joins the (idle) worker. sendError(t, "inference_failed", fmt.Sprintf("backend error: %v", err), "", itemID) return true } pred, err := predFunc() + + // Drain the TTS worker. On a clean finish, enqueue the trailing clause(s) the + // chunker was still holding; on an error or barge-in, stop synthesizing. + // wait() runs on every path so the worker goroutine never leaks, and it + // returns the audio streamed so far plus the first synthesis failure. + if ttsPipe != nil { + if err == nil && ctx.Err() == nil { + for _, clause := range chunker.flush() { + if !ttsPipe.enqueue(clause) { + break + } + } + } + streamedAudio, ttsErr = ttsPipe.wait() + } + // A clause synthesis failed mid-stream (the callback stopped the prediction); // report it as a TTS error rather than a prediction error. if ttsErr != nil { @@ -244,24 +271,19 @@ func streamLLMResponse(ctx context.Context, session *Session, conv *Conversation announce() } - // Synthesize the audio. With clause chunking the completed clauses were - // already spoken inside the token callback; flush the trailing clause(s) - // the segmenter was still holding. Otherwise buffer the whole message and - // synthesize it once. emitSpeech streams the audio chunks when the TTS - // backend supports TTSStream, otherwise it sends a single unary delta. + // With clause chunking the clauses were synthesized on the worker as the + // reply streamed (including the trailing flush drained above), so the + // audio is already accumulated. Otherwise buffer the whole message and + // synthesize it once now — emitSpeech streams the audio chunks when the + // TTS backend supports TTSStream, otherwise it sends a single unary delta. var audio []byte if chunker != nil { - for _, clause := range chunker.flush() { - if ttsErr = speakClause(clause); ttsErr != nil { - break - } - } audio = streamedAudio } else { audio, ttsErr = emitSpeech(ctx, t, session, responseID, itemID, content) - } - if ttsErr != nil { - return fail("tts_error", "TTS generation failed", ttsErr) + if ttsErr != nil { + return fail("tts_error", "TTS generation failed", ttsErr) + } } _, isWebRTC := t.(*WebRTCTransport) diff --git a/core/http/endpoints/openai/realtime_transcription.go b/core/http/endpoints/openai/realtime_transcription.go index 44456101c44f..28a5147c17e0 100644 --- a/core/http/endpoints/openai/realtime_transcription.go +++ b/core/http/endpoints/openai/realtime_transcription.go @@ -7,6 +7,33 @@ import ( "github.com/mudler/LocalAI/core/http/endpoints/openai/types" ) +// emitPrecomputedTranscription emits the transcription events for a turn +// whose transcript already exists (semantic_vad's live stream, or the +// retranscribe gate's batch decode): optional delta replays followed by the +// completed event — the same contract emitTranscription produces, sharing +// one itemID — without running the backend again. +func emitPrecomputedTranscription(t Transport, itemID string, deltas []string, transcript string) error { + for _, d := range deltas { + if d == "" { + continue + } + if err := t.SendEvent(types.ConversationItemInputAudioTranscriptionDeltaEvent{ + ServerEventBase: types.ServerEventBase{EventID: "event_TODO"}, + ItemID: itemID, + ContentIndex: 0, + Delta: d, + }); err != nil { + return err + } + } + return t.SendEvent(types.ConversationItemInputAudioTranscriptionCompletedEvent{ + ServerEventBase: types.ServerEventBase{EventID: "event_TODO"}, + ItemID: itemID, + ContentIndex: 0, + Transcript: transcript, + }) +} + // emitTranscription transcribes a committed utterance and emits the transcription // events for it, returning the final transcript text. With // pipeline.streaming.transcription enabled it streams each transcript fragment as diff --git a/core/http/endpoints/openai/realtime_tts_pipeline.go b/core/http/endpoints/openai/realtime_tts_pipeline.go new file mode 100644 index 000000000000..2677e27d5a8d --- /dev/null +++ b/core/http/endpoints/openai/realtime_tts_pipeline.go @@ -0,0 +1,120 @@ +package openai + +import ( + "sync" + "sync/atomic" +) + +// ttsPipeline decouples speech synthesis from LLM token generation. +// +// The LLM token callback runs on the same goroutine that drains the model's +// gRPC stream, so anything it does serially — including a blocking TTS call — +// stops the stream from being read and stalls generation (and, since the same +// goroutine also sends the assistant transcript, freezes the transcript the +// client sees). ttsPipeline lets the callback hand each completed clause to a +// single worker goroutine that synthesizes them in order, concurrently with +// continued generation. One worker preserves clause — and therefore audio — +// ordering. +// +// The clause queue is intentionally unbounded: clauses are short strings and a +// reply has a bounded number of them, while the expensive product (audio) is +// paced by the TTS backend regardless. So enqueue never blocks the callback, +// and the transcript streams to the client at generation speed while audio is +// produced behind it. +type ttsPipeline struct { + speak func(clause string) ([]byte, error) + + mu sync.Mutex + queue []string + closed bool + wake chan struct{} // buffered(1) wakeup signal for the worker + + done chan struct{} + failed atomic.Bool + + // audio and firstErr are owned by the worker goroutine and only safe to + // read after wait() has returned (it joins on the worker via done). + audio []byte + firstErr error +} + +// newTTSPipeline starts the worker. speak performs the actual synthesis and +// returns the PCM accumulated for the conversation-item record (empty for +// transports that stream audio out-of-band, e.g. WebRTC). +func newTTSPipeline(speak func(clause string) ([]byte, error)) *ttsPipeline { + p := &ttsPipeline{ + speak: speak, + wake: make(chan struct{}, 1), + done: make(chan struct{}), + } + go p.run() + return p +} + +func (p *ttsPipeline) run() { + defer close(p.done) + for { + p.mu.Lock() + for len(p.queue) == 0 && !p.closed { + p.mu.Unlock() + <-p.wake + p.mu.Lock() + } + if len(p.queue) == 0 && p.closed { + p.mu.Unlock() + return + } + clause := p.queue[0] + p.queue = p.queue[1:] + p.mu.Unlock() + + // Once a clause has failed, keep draining the queue without speaking so + // the producer's wait() returns promptly and the first error is kept. + if p.failed.Load() { + continue + } + a, err := p.speak(clause) + if err != nil { + p.firstErr = err + p.failed.Store(true) + continue + } + p.audio = append(p.audio, a...) + } +} + +// enqueue offers a clause for synthesis. It never blocks; it returns false once +// synthesis has failed, signalling the caller to stop the prediction. +func (p *ttsPipeline) enqueue(clause string) bool { + if p.failed.Load() { + return false + } + p.mu.Lock() + p.queue = append(p.queue, clause) + p.mu.Unlock() + p.signal() + return true +} + +// signal wakes the worker without blocking; the buffered channel coalesces +// signals, which is safe because the worker drains the whole queue per wake. +func (p *ttsPipeline) signal() { + select { + case p.wake <- struct{}{}: + default: + } +} + +// wait closes the queue and blocks until the worker has spoken every enqueued +// clause, then returns the accumulated audio and the first synthesis error. It +// is idempotent: calling it again returns the same result without blocking, so +// callers can drain it explicitly to read the audio and still defer a wait() as +// a leak-proof backstop. No clause may be enqueued after the first wait(). +func (p *ttsPipeline) wait() ([]byte, error) { + p.mu.Lock() + p.closed = true + p.mu.Unlock() + p.signal() + <-p.done + return p.audio, p.firstErr +} diff --git a/core/http/endpoints/openai/realtime_tts_pipeline_test.go b/core/http/endpoints/openai/realtime_tts_pipeline_test.go new file mode 100644 index 000000000000..a5e070248e5d --- /dev/null +++ b/core/http/endpoints/openai/realtime_tts_pipeline_test.go @@ -0,0 +1,114 @@ +package openai + +import ( + "errors" + "sync" + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("ttsPipeline", func() { + It("synthesizes clauses in order and accumulates their audio", func() { + p := newTTSPipeline(func(clause string) ([]byte, error) { + return []byte(clause), nil + }) + Expect(p.enqueue("a")).To(BeTrue()) + Expect(p.enqueue("b")).To(BeTrue()) + Expect(p.enqueue("c")).To(BeTrue()) + + audio, err := p.wait() + Expect(err).NotTo(HaveOccurred()) + Expect(string(audio)).To(Equal("abc")) + }) + + It("never blocks the producer even when synthesis is slow", func() { + var started sync.WaitGroup + started.Add(1) + release := make(chan struct{}) + first := true + p := newTTSPipeline(func(clause string) ([]byte, error) { + if first { + first = false + started.Done() + <-release // hold the worker on the first clause + } + return []byte(clause), nil + }) + + Expect(p.enqueue("1")).To(BeTrue()) + started.Wait() // worker is now blocked synthesizing the first clause + + // Enqueuing many more clauses must return immediately, not block on the + // stalled worker — this is what keeps the LLM recv loop flowing. + done := make(chan struct{}) + go func() { + defer close(done) + for _, c := range []string{"2", "3", "4", "5"} { + p.enqueue(c) + } + }() + Eventually(done, time.Second).Should(BeClosed()) + + close(release) + audio, err := p.wait() + Expect(err).NotTo(HaveOccurred()) + Expect(string(audio)).To(Equal("12345")) + }) + + It("keeps the first error, stops speaking, and signals the producer to stop", func() { + boom := errors.New("backend gone") + var spoken []string + var mu sync.Mutex + p := newTTSPipeline(func(clause string) ([]byte, error) { + mu.Lock() + spoken = append(spoken, clause) + mu.Unlock() + if clause == "b" { + return nil, boom + } + return []byte(clause), nil + }) + + Expect(p.enqueue("a")).To(BeTrue()) + Expect(p.enqueue("b")).To(BeTrue()) + + // Once the failure is observed, enqueue reports it so the caller stops + // the prediction; any further clauses are dropped, not spoken. + Eventually(func() bool { return !p.enqueue("c") }, time.Second).Should(BeTrue()) + + _, err := p.wait() + Expect(err).To(MatchError(boom)) + + mu.Lock() + defer mu.Unlock() + Expect(spoken).NotTo(ContainElement("c"), "clauses after the failure are not synthesized") + }) + + It("is idempotent: a second wait returns the same result without blocking", func() { + p := newTTSPipeline(func(clause string) ([]byte, error) { + return []byte(clause), nil + }) + Expect(p.enqueue("x")).To(BeTrue()) + + audio1, err1 := p.wait() + // A deferred backstop wait() in the caller runs after the explicit one; + // it must not block or change the result. + audio2, err2 := p.wait() + + Expect(err1).NotTo(HaveOccurred()) + Expect(err2).NotTo(HaveOccurred()) + Expect(string(audio1)).To(Equal("x")) + Expect(string(audio2)).To(Equal("x")) + }) + + It("returns cleanly when no clause was ever enqueued", func() { + p := newTTSPipeline(func(clause string) ([]byte, error) { + return []byte(clause), nil + }) + audio, err := p.wait() + Expect(err).NotTo(HaveOccurred()) + Expect(audio).To(BeEmpty()) + }) +}) diff --git a/core/http/endpoints/openai/realtime_vad_buffer_test.go b/core/http/endpoints/openai/realtime_vad_buffer_test.go new file mode 100644 index 000000000000..0fbef3e6b0fc --- /dev/null +++ b/core/http/endpoints/openai/realtime_vad_buffer_test.go @@ -0,0 +1,54 @@ +package openai + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +// dropInspectedPrefix is what stands between the VAD loop's buffer clears and +// cutting the first word off an utterance: the no-speech clear must keep the +// holdback tail (silero hasn't crossed its onset threshold yet) and both +// clears must keep audio appended while the tick ran (the VAD never saw it). +var _ = Describe("dropInspectedPrefix", func() { + It("keeps the holdback tail of the inspected window and everything appended mid-tick", func() { + inspected := []byte{1, 2, 3, 4, 5, 6} + appended := []byte{7, 8} + buf := append(append([]byte(nil), inspected...), appended...) + + out := dropInspectedPrefix(buf, len(inspected), 2) + + Expect(out).To(Equal([]byte{5, 6, 7, 8}), "older confirmed-silent head dropped, possible onset + fresh audio kept") + }) + + It("returns the buffer unchanged when the inspected window fits in the holdback", func() { + buf := []byte{1, 2, 3} + + Expect(dropInspectedPrefix(buf, len(buf), 4)).To(Equal(buf)) + Expect(dropInspectedPrefix(buf, len(buf), len(buf))).To(Equal(buf)) + }) + + It("drops the whole inspected window with zero holdback, keeping only mid-tick appends", func() { + // The commit-time clear: the inspected audio was committed, audio + // appended while the tick ran belongs to the next turn. + buf := []byte{1, 2, 3, 4} + + Expect(dropInspectedPrefix(buf, 4, 0)).To(BeEmpty()) + Expect(dropInspectedPrefix(append(buf, 9), 4, 0)).To(Equal([]byte{9})) + }) + + It("clamps when told more was inspected than the buffer holds", func() { + buf := []byte{1, 2} + + Expect(dropInspectedPrefix(buf, 10, 0)).To(BeEmpty()) + }) + + It("returns a copy, not a sub-slice, when bytes are dropped", func() { + buf := []byte{1, 2, 3, 4} + + out := dropInspectedPrefix(buf, 4, 2) + + Expect(out).To(Equal([]byte{3, 4})) + buf[2] = 99 + Expect(out).To(Equal([]byte{3, 4}), "mutating the old backing array must not leak into the published buffer") + }) +}) diff --git a/core/http/react-ui/e2e/traces-audio.spec.js b/core/http/react-ui/e2e/traces-audio.spec.js new file mode 100644 index 000000000000..567fd56c26af --- /dev/null +++ b/core/http/react-ui/e2e/traces-audio.spec.js @@ -0,0 +1,87 @@ +import { test, expect } from './coverage-fixtures.js' + +// Audio snippets on the Traces page must play through a blob: object URL — +// the CSP's connect-src allows blob: but not data:, and the waveform peaks +// renderer fetch()es the player src — and must degrade to a readable note +// (not a broken player) when the stored payload is the "" +// marker an older server stamped into oversized fields. + +// Minimal valid 16 kHz mono 16-bit PCM WAV (0.1s 440 Hz sine), base64-encoded. +function wavBase64(samples = 1600, rate = 16000) { + const dataSize = samples * 2 + const buf = Buffer.alloc(44 + dataSize) + buf.write('RIFF', 0) + buf.writeUInt32LE(36 + dataSize, 4) + buf.write('WAVE', 8) + buf.write('fmt ', 12) + buf.writeUInt32LE(16, 16) + buf.writeUInt16LE(1, 20) // PCM + buf.writeUInt16LE(1, 22) // mono + buf.writeUInt32LE(rate, 24) + buf.writeUInt32LE(rate * 2, 28) + buf.writeUInt16LE(2, 32) + buf.writeUInt16LE(16, 34) + buf.write('data', 36) + buf.writeUInt32LE(dataSize, 40) + for (let i = 0; i < samples; i++) { + buf.writeInt16LE(Math.round(8000 * Math.sin((2 * Math.PI * 440 * i) / rate)), 44 + i * 2) + } + return buf.toString('base64') +} + +function transcriptionTrace(audioWavBase64) { + return { + type: 'transcription', + timestamp: Date.now() * 1_000_000, + model_name: 'parakeet-test', + summary: 'transcribed utterance', + duration: 500_000_000, + error: null, + data: { + audio_wav_base64: audioWavBase64, + audio_duration_s: 0.1, + audio_snippet_s: 0.1, + audio_sample_rate: 16000, + audio_samples: 1600, + audio_rms_dbfs: -12.0, + audio_peak_dbfs: -6.0, + audio_dc_offset: 0, + }, + } +} + +async function openBackendTraceRow(page, traces) { + await page.route('**/api/traces', (route) => { + route.fulfill({ contentType: 'application/json', body: JSON.stringify([]) }) + }) + await page.route('**/api/backend-traces', (route) => { + route.fulfill({ contentType: 'application/json', body: JSON.stringify(traces) }) + }) + await page.goto('/app/traces') + await expect(page.locator('text=Tracing is')).toBeVisible({ timeout: 10_000 }) + await page.locator('button', { hasText: 'Backend Traces' }).click() + await page.locator('td', { hasText: 'parakeet-test' }).first().click() +} + +test.describe('Traces - Audio Snippets', () => { + test('plays a clip through a blob: URL, not a CSP-blocked data: URL', async ({ page }) => { + await openBackendTraceRow(page, [transcriptionTrace(wavBase64())]) + + // The expanded row carries the snippet metrics and a player whose source + // is an object URL (connect-src allows blob:, so the peaks fetch works). + await expect(page.locator('text=Audio Snippet')).toBeVisible() + const audio = page.locator('audio') + await expect(audio).toHaveCount(1) + const src = await audio.getAttribute('src') + expect(src).toMatch(/^blob:/) + await expect(page.getByTestId('audio-snippet-unavailable')).toHaveCount(0) + }) + + test('shows a readable note instead of a broken player for truncated payloads', async ({ page }) => { + await openBackendTraceRow(page, [transcriptionTrace('')]) + + await expect(page.locator('text=Audio Snippet')).toBeVisible() + await expect(page.getByTestId('audio-snippet-unavailable')).toBeVisible() + await expect(page.locator('audio')).toHaveCount(0) + }) +}) diff --git a/core/http/react-ui/src/pages/Talk.jsx b/core/http/react-ui/src/pages/Talk.jsx index 5a6857a9e213..b25643aa7e0b 100644 --- a/core/http/react-ui/src/pages/Talk.jsx +++ b/core/http/react-ui/src/pages/Talk.jsx @@ -19,24 +19,31 @@ const STATUS_STYLES = { error: { icon: 'fa-solid fa-circle', color: 'var(--color-error)', bg: 'var(--color-error-light)' }, } -// upsertAssistant merges a streamed transcript fragment into the assistant entry -// identified by the server's item_id, or appends a new entry if none exists yet. -// Keying by item_id (not a mutable index tracked across handler/updater -// boundaries) makes streamed deltas idempotent and order-independent, so React's -// batching of non-React data-channel events cannot produce a duplicate bubble. -// mode 'append' adds to the running text; 'replace' sets the final transcript. -function upsertAssistant(prev, itemId, text, mode) { - // Only assistant entries carry an id, and the streaming entry is almost - // always the newest — search from the tail so per-delta cost stays constant. +// upsertEntry merges a streamed transcript fragment into the entry identified +// by the server's item_id, or appends a new entry (with the given role) if +// none exists yet. Keying by item_id (not a mutable index tracked across +// handler/updater boundaries) makes streamed deltas idempotent and +// order-independent, so React's batching of non-React data-channel events +// cannot produce a duplicate bubble. mode 'append' adds to the running text; +// 'replace' sets the final transcript — the server sends a completed event +// whose authoritative text supersedes any live captions (e.g. the +// semantic_vad retranscribe gate's batch decode). +function upsertEntry(prev, itemId, role, text, mode) { + // The streaming entry is almost always the newest — search from the tail + // so per-delta cost stays constant. const i = prev.findLastIndex(e => e.id === itemId) if (i === -1) { - return [...prev, { role: 'assistant', id: itemId, text }] + return [...prev, { role, id: itemId, text }] } const next = [...prev] next[i] = { ...next[i], text: mode === 'append' ? next[i].text + text : text } return next } +function upsertAssistant(prev, itemId, text, mode) { + return upsertEntry(prev, itemId, 'assistant', text, mode) +} + export default function Talk() { const { addToast } = useOutletContext() const navigate = useNavigate() @@ -252,12 +259,33 @@ export default function Talk() { case 'input_audio_buffer.speech_stopped': updateStatus('thinking', 'Processing...') break + case 'conversation.item.input_audio_transcription.delta': + // Live captions: semantic_vad streams the user's words while they + // are still speaking, keyed by the item id the commit will reuse. + if (event.delta && event.item_id) { + setTranscript(prev => upsertEntry(prev, event.item_id, 'user', event.delta, 'append')) + } + break case 'conversation.item.input_audio_transcription.completed': if (event.transcript) { - setTranscript(prev => [...prev, { role: 'user', text: event.transcript }]) + if (event.item_id) { + // Replaces any live captions with the authoritative transcript + // (which may differ, e.g. the retranscribe gate's batch decode); + // creates the entry when there were none (server_vad). + setTranscript(prev => upsertEntry(prev, event.item_id, 'user', event.transcript, 'replace')) + } else { + setTranscript(prev => [...prev, { role: 'user', text: event.transcript }]) + } } updateStatus('thinking', 'Generating response...') break + case 'conversation.item.input_audio_transcription.failed': + // The turn was discarded after captions were shown (e.g. the buffer + // was cleared as silence) — retract the partial entry. + if (event.item_id) { + setTranscript(prev => prev.filter(e => e.id !== event.item_id)) + } + break case 'response.output_audio_transcript.delta': if (event.delta) { inProgressIdRef.current = event.item_id @@ -712,7 +740,7 @@ export default function Talk() { )} {selectedModelInfo && !selectedModelInfo.self_contained && (
{[ @@ -724,9 +752,12 @@ export default function Talk() {
-
{item.label}
-
{item.value}
+
{item.label}
+ {/* full width for the value; wrap rather than overflow when the + model name is long (minWidth:0 lets the flex item shrink) */} +
{item.value || '—'}
))}
diff --git a/core/http/react-ui/src/pages/Traces.jsx b/core/http/react-ui/src/pages/Traces.jsx index 85387f815879..933acf344566 100644 --- a/core/http/react-ui/src/pages/Traces.jsx +++ b/core/http/react-ui/src/pages/Traces.jsx @@ -86,8 +86,40 @@ function typeBadgeStyle(type) { return { background: c.bg, color: c.color, padding: '2px 8px', borderRadius: 'var(--radius-sm)', fontSize: '0.75rem', fontWeight: 500 } } +// useWavObjectURL — decode a base64 WAV payload into a blob: object URL for +// the waveform player. A data: URL would render in