diff --git a/sdk/batch/speechmatics/batch/_models.py b/sdk/batch/speechmatics/batch/_models.py index 7375c0f..21d0911 100644 --- a/sdk/batch/speechmatics/batch/_models.py +++ b/sdk/batch/speechmatics/batch/_models.py @@ -847,7 +847,7 @@ def transcript_text(self) -> str: # Group results by speaker and process transcript_parts = [] current_speaker = None - current_group: list[str] = [] + current_group: list[RecognitionResult] = [] for result in self.results: if not result.alternatives: @@ -861,9 +861,9 @@ def transcript_text(self) -> str: if speaker != current_speaker: # Process accumulated group for previous speaker if current_group: - text = self._join_content_items(current_group, word_delimiter) + text = self._join_results(current_group, word_delimiter) if current_speaker: - transcript_parts.append(f"SPEAKER {current_speaker}: {text}") # type: ignore[unreachable] + transcript_parts.append(f"SPEAKER {current_speaker}: {text}") else: transcript_parts.append(text) current_group = [] @@ -872,11 +872,11 @@ def transcript_text(self) -> str: # Add content to current group if content: - current_group.append(content) + current_group.append(result) # Process final group if current_group: - text = self._join_content_items(current_group, word_delimiter) + text = self._join_results(current_group, word_delimiter) if current_speaker: transcript_parts.append(f"SPEAKER {current_speaker}: {text}") else: @@ -884,39 +884,57 @@ def transcript_text(self) -> str: return "\n".join(transcript_parts) - def _join_content_items(self, content_items: list[str], word_delimiter: str) -> str: + def _join_results(self, results: list[RecognitionResult], word_delimiter: str) -> str: """ - Join content items with appropriate spacing and punctuation handling. + Join results with attachment-aware punctuation spacing. Args: - content_items: List of content strings to join. + results: List of recognition results to join. word_delimiter: Delimiter to use between words. Returns: Properly formatted text string. """ - if not content_items: + if not results: return "" - result: list[str] = [] + output: list[str] = [] + previous_result: Optional[RecognitionResult] = None - for i, content in enumerate(content_items): + for result in results: + if not result.alternatives: + continue + + content = result.alternatives[0].content if not content: continue - # Check if this content is punctuation - is_punctuation = content.strip() in ".,!?;:()[]{}\"'-" + if previous_result and self._needs_word_delimiter(previous_result, result): + output.append(word_delimiter) + + output.append(content) + previous_result = result + + return "".join(output).strip() + + def _needs_word_delimiter(self, previous_result: RecognitionResult, result: RecognitionResult) -> bool: + previous_attaches_to = self._punctuation_attachment(previous_result) + if previous_attaches_to in {"next", "both"}: + return False + + attaches_to = self._punctuation_attachment(result) + return attaches_to not in {"previous", "both"} - # Add delimiter before content unless: - # - It's the first item - # - It's punctuation - # - Previous item ended with whitespace - if i > 0 and not is_punctuation and result and not result[-1].endswith(" "): - result.append(word_delimiter) + @staticmethod + def _punctuation_attachment(result: RecognitionResult) -> Optional[str]: + if result.type != "punctuation": + return None - result.append(content) + attaches_to = result.attaches_to or "previous" + if attaches_to not in {"previous", "next", "both", "none"}: + return "previous" - return "".join(result).strip() + return attaches_to @property def confidence(self) -> Optional[float]: diff --git a/tests/batch/test_models.py b/tests/batch/test_models.py index d262685..04058f1 100644 --- a/tests/batch/test_models.py +++ b/tests/batch/test_models.py @@ -1,18 +1,57 @@ -from speechmatics.batch._models import JobConfig, TranscriptFilteringConfig, TranscriptionConfig +from dataclasses import asdict +from typing import Optional + +from speechmatics.batch._models import JobConfig +from speechmatics.batch._models import RecognitionResult +from speechmatics.batch._models import Transcript +from speechmatics.batch._models import TranscriptFilteringConfig +from speechmatics.batch._models import TranscriptionConfig + + +def _transcript_payload(results: list[dict], word_delimiter: str = " ") -> dict: + return { + "format": "2.9", + "job": { + "id": "job-id", + "created_at": "2026-06-12T00:00:00Z", + "data_name": "audio.wav", + }, + "metadata": { + "created_at": "2026-06-12T00:00:00Z", + "type": "transcription", + "language_pack_info": {"word_delimiter": word_delimiter}, + }, + "results": results, + } + + +def _recognition_result( + result_type: str, + content: str, + attaches_to: Optional[str] = None, + is_eos: Optional[bool] = None, +) -> dict: + result = { + "type": result_type, + "start_time": 1.0, + "end_time": 1.0, + "alternatives": [{"content": content, "confidence": 1.0, "language": "en"}], + } + if attaches_to is not None: + result["attaches_to"] = attaches_to + if is_eos is not None: + result["is_eos"] = is_eos + return result class TestTranscriptFilteringConfigToDict: def test_remove_disfluencies_true_serializes_correctly(self): - config = TranscriptionConfig( - transcript_filtering_config=TranscriptFilteringConfig(remove_disfluencies=True) - ) + config = TranscriptionConfig(transcript_filtering_config=TranscriptFilteringConfig(remove_disfluencies=True)) result = config.to_dict() assert result["transcript_filtering_config"] == {"remove_disfluencies": True} def test_remove_disfluencies_false_included_in_output(self): - config = TranscriptionConfig( - transcript_filtering_config=TranscriptFilteringConfig(remove_disfluencies=False) - ) + config = TranscriptionConfig(transcript_filtering_config=TranscriptFilteringConfig(remove_disfluencies=False)) result = config.to_dict() assert result["transcript_filtering_config"] == {"remove_disfluencies": False} @@ -23,9 +62,7 @@ def test_none_excluded_from_output(self): def test_replacements_serialized(self): replacements = [{"from": "um", "to": ""}, {"from": "uh", "to": ""}] - config = TranscriptionConfig( - transcript_filtering_config=TranscriptFilteringConfig(replacements=replacements) - ) + config = TranscriptionConfig(transcript_filtering_config=TranscriptFilteringConfig(replacements=replacements)) result = config.to_dict() assert result["transcript_filtering_config"] == { "remove_disfluencies": False, @@ -33,18 +70,14 @@ def test_replacements_serialized(self): } def test_replacements_absent_when_none(self): - config = TranscriptionConfig( - transcript_filtering_config=TranscriptFilteringConfig(remove_disfluencies=True) - ) + config = TranscriptionConfig(transcript_filtering_config=TranscriptFilteringConfig(remove_disfluencies=True)) result = config.to_dict() assert "replacements" not in result["transcript_filtering_config"] def test_replacements_and_remove_disfluencies_together(self): replacements = [{"from": "gonna", "to": "going to"}] config = TranscriptionConfig( - transcript_filtering_config=TranscriptFilteringConfig( - remove_disfluencies=True, replacements=replacements - ) + transcript_filtering_config=TranscriptFilteringConfig(remove_disfluencies=True, replacements=replacements) ) result = config.to_dict() assert result["transcript_filtering_config"] == { @@ -127,3 +160,94 @@ def test_absent_output_config_is_none(self): data = {"type": "transcription"} job_config = JobConfig.from_dict(data) assert job_config.output_config is None + + +class TestRecognitionResultFromDict: + def test_preserves_punctuation_metadata(self): + result = RecognitionResult.from_dict( + _recognition_result("punctuation", ".", attaches_to="previous", is_eos=True) + ) + + assert result.attaches_to == "previous" + assert result.is_eos is True + + def test_transcript_payload_preserves_punctuation_metadata_in_asdict(self): + transcript = Transcript.from_dict( + _transcript_payload( + [ + _recognition_result("word", "Hello"), + _recognition_result("punctuation", ".", attaches_to="previous", is_eos=True), + ] + ) + ) + + assert transcript.results[1].attaches_to == "previous" + assert asdict(transcript)["results"][1]["attaches_to"] == "previous" + assert asdict(transcript)["results"][1]["is_eos"] is True + + +class TestTranscriptText: + def test_word_only_transcript_uses_word_delimiter(self): + transcript = Transcript.from_dict( + _transcript_payload( + [ + _recognition_result("word", "Hello"), + _recognition_result("word", "world"), + ] + ) + ) + + assert transcript.transcript_text == "Hello world" + + def test_punctuation_attached_to_previous(self): + transcript = Transcript.from_dict( + _transcript_payload( + [ + _recognition_result("word", "Hello"), + _recognition_result("punctuation", ",", attaches_to="previous"), + _recognition_result("word", "world"), + _recognition_result("punctuation", ".", attaches_to="previous"), + ] + ) + ) + + assert transcript.transcript_text == "Hello, world." + + def test_punctuation_attached_to_next(self): + transcript = Transcript.from_dict( + _transcript_payload( + [ + _recognition_result("punctuation", "¿", attaches_to="next"), + _recognition_result("word", "Hola"), + _recognition_result("punctuation", "?", attaches_to="previous"), + ] + ) + ) + + assert transcript.transcript_text == "¿Hola?" + + def test_punctuation_attached_to_neither_side(self): + transcript = Transcript.from_dict( + _transcript_payload( + [ + _recognition_result("word", "hello"), + _recognition_result("punctuation", "-", attaches_to="none"), + _recognition_result("word", "world"), + ] + ) + ) + + assert transcript.transcript_text == "hello - world" + + def test_punctuation_attached_to_both_sides(self): + transcript = Transcript.from_dict( + _transcript_payload( + [ + _recognition_result("word", "and"), + _recognition_result("punctuation", "/", attaches_to="both"), + _recognition_result("word", "or"), + ] + ) + ) + + assert transcript.transcript_text == "and/or"