Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 39 additions & 21 deletions sdk/batch/speechmatics/batch/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,7 +847,7 @@ def transcript_text(self) -> str:
# Group results by speaker and process
transcript_parts = []
current_speaker = None
current_group: list[str] = []
current_group: list[RecognitionResult] = []

for result in self.results:
if not result.alternatives:
Expand All @@ -861,9 +861,9 @@ def transcript_text(self) -> str:
if speaker != current_speaker:
# Process accumulated group for previous speaker
if current_group:
text = self._join_content_items(current_group, word_delimiter)
text = self._join_results(current_group, word_delimiter)
if current_speaker:
transcript_parts.append(f"SPEAKER {current_speaker}: {text}") # type: ignore[unreachable]
transcript_parts.append(f"SPEAKER {current_speaker}: {text}")
else:
transcript_parts.append(text)
current_group = []
Expand All @@ -872,51 +872,69 @@ def transcript_text(self) -> str:

# Add content to current group
if content:
current_group.append(content)
current_group.append(result)

# Process final group
if current_group:
text = self._join_content_items(current_group, word_delimiter)
text = self._join_results(current_group, word_delimiter)
if current_speaker:
transcript_parts.append(f"SPEAKER {current_speaker}: {text}")
else:
transcript_parts.append(text)

return "\n".join(transcript_parts)

def _join_content_items(self, content_items: list[str], word_delimiter: str) -> str:
def _join_results(self, results: list[RecognitionResult], word_delimiter: str) -> str:
"""
Join content items with appropriate spacing and punctuation handling.
Join results with attachment-aware punctuation spacing.

Args:
content_items: List of content strings to join.
results: List of recognition results to join.
word_delimiter: Delimiter to use between words.

Returns:
Properly formatted text string.
"""
if not content_items:
if not results:
return ""

result: list[str] = []
output: list[str] = []
previous_result: Optional[RecognitionResult] = None

for i, content in enumerate(content_items):
for result in results:
if not result.alternatives:
continue

content = result.alternatives[0].content
if not content:
continue

# Check if this content is punctuation
is_punctuation = content.strip() in ".,!?;:()[]{}\"'-"
if previous_result and self._needs_word_delimiter(previous_result, result):
output.append(word_delimiter)

output.append(content)
previous_result = result

return "".join(output).strip()

def _needs_word_delimiter(self, previous_result: RecognitionResult, result: RecognitionResult) -> bool:
previous_attaches_to = self._punctuation_attachment(previous_result)
if previous_attaches_to in {"next", "both"}:
return False

attaches_to = self._punctuation_attachment(result)
return attaches_to not in {"previous", "both"}

# Add delimiter before content unless:
# - It's the first item
# - It's punctuation
# - Previous item ended with whitespace
if i > 0 and not is_punctuation and result and not result[-1].endswith(" "):
result.append(word_delimiter)
@staticmethod
def _punctuation_attachment(result: RecognitionResult) -> Optional[str]:
if result.type != "punctuation":
return None

result.append(content)
attaches_to = result.attaches_to or "previous"
if attaches_to not in {"previous", "next", "both", "none"}:
return "previous"

return "".join(result).strip()
return attaches_to

@property
def confidence(self) -> Optional[float]:
Expand Down
156 changes: 140 additions & 16 deletions tests/batch/test_models.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,57 @@
from speechmatics.batch._models import JobConfig, TranscriptFilteringConfig, TranscriptionConfig
from dataclasses import asdict
from typing import Optional

from speechmatics.batch._models import JobConfig
from speechmatics.batch._models import RecognitionResult
from speechmatics.batch._models import Transcript
from speechmatics.batch._models import TranscriptFilteringConfig
from speechmatics.batch._models import TranscriptionConfig


def _transcript_payload(results: list[dict], word_delimiter: str = " ") -> dict:
return {
"format": "2.9",
"job": {
"id": "job-id",
"created_at": "2026-06-12T00:00:00Z",
"data_name": "audio.wav",
},
"metadata": {
"created_at": "2026-06-12T00:00:00Z",
"type": "transcription",
"language_pack_info": {"word_delimiter": word_delimiter},
},
"results": results,
}


def _recognition_result(
result_type: str,
content: str,
attaches_to: Optional[str] = None,
is_eos: Optional[bool] = None,
) -> dict:
result = {
"type": result_type,
"start_time": 1.0,
"end_time": 1.0,
"alternatives": [{"content": content, "confidence": 1.0, "language": "en"}],
}
if attaches_to is not None:
result["attaches_to"] = attaches_to
if is_eos is not None:
result["is_eos"] = is_eos
return result


class TestTranscriptFilteringConfigToDict:
def test_remove_disfluencies_true_serializes_correctly(self):
config = TranscriptionConfig(
transcript_filtering_config=TranscriptFilteringConfig(remove_disfluencies=True)
)
config = TranscriptionConfig(transcript_filtering_config=TranscriptFilteringConfig(remove_disfluencies=True))
result = config.to_dict()
assert result["transcript_filtering_config"] == {"remove_disfluencies": True}

def test_remove_disfluencies_false_included_in_output(self):
config = TranscriptionConfig(
transcript_filtering_config=TranscriptFilteringConfig(remove_disfluencies=False)
)
config = TranscriptionConfig(transcript_filtering_config=TranscriptFilteringConfig(remove_disfluencies=False))
result = config.to_dict()
assert result["transcript_filtering_config"] == {"remove_disfluencies": False}

Expand All @@ -23,28 +62,22 @@ def test_none_excluded_from_output(self):

def test_replacements_serialized(self):
replacements = [{"from": "um", "to": ""}, {"from": "uh", "to": ""}]
config = TranscriptionConfig(
transcript_filtering_config=TranscriptFilteringConfig(replacements=replacements)
)
config = TranscriptionConfig(transcript_filtering_config=TranscriptFilteringConfig(replacements=replacements))
result = config.to_dict()
assert result["transcript_filtering_config"] == {
"remove_disfluencies": False,
"replacements": replacements,
}

def test_replacements_absent_when_none(self):
config = TranscriptionConfig(
transcript_filtering_config=TranscriptFilteringConfig(remove_disfluencies=True)
)
config = TranscriptionConfig(transcript_filtering_config=TranscriptFilteringConfig(remove_disfluencies=True))
result = config.to_dict()
assert "replacements" not in result["transcript_filtering_config"]

def test_replacements_and_remove_disfluencies_together(self):
replacements = [{"from": "gonna", "to": "going to"}]
config = TranscriptionConfig(
transcript_filtering_config=TranscriptFilteringConfig(
remove_disfluencies=True, replacements=replacements
)
transcript_filtering_config=TranscriptFilteringConfig(remove_disfluencies=True, replacements=replacements)
)
result = config.to_dict()
assert result["transcript_filtering_config"] == {
Expand Down Expand Up @@ -127,3 +160,94 @@ def test_absent_output_config_is_none(self):
data = {"type": "transcription"}
job_config = JobConfig.from_dict(data)
assert job_config.output_config is None


class TestRecognitionResultFromDict:
def test_preserves_punctuation_metadata(self):
result = RecognitionResult.from_dict(
_recognition_result("punctuation", ".", attaches_to="previous", is_eos=True)
)

assert result.attaches_to == "previous"
assert result.is_eos is True

def test_transcript_payload_preserves_punctuation_metadata_in_asdict(self):
transcript = Transcript.from_dict(
_transcript_payload(
[
_recognition_result("word", "Hello"),
_recognition_result("punctuation", ".", attaches_to="previous", is_eos=True),
]
)
)

assert transcript.results[1].attaches_to == "previous"
assert asdict(transcript)["results"][1]["attaches_to"] == "previous"
assert asdict(transcript)["results"][1]["is_eos"] is True


class TestTranscriptText:
def test_word_only_transcript_uses_word_delimiter(self):
transcript = Transcript.from_dict(
_transcript_payload(
[
_recognition_result("word", "Hello"),
_recognition_result("word", "world"),
]
)
)

assert transcript.transcript_text == "Hello world"

def test_punctuation_attached_to_previous(self):
transcript = Transcript.from_dict(
_transcript_payload(
[
_recognition_result("word", "Hello"),
_recognition_result("punctuation", ",", attaches_to="previous"),
_recognition_result("word", "world"),
_recognition_result("punctuation", ".", attaches_to="previous"),
]
)
)

assert transcript.transcript_text == "Hello, world."

def test_punctuation_attached_to_next(self):
transcript = Transcript.from_dict(
_transcript_payload(
[
_recognition_result("punctuation", "¿", attaches_to="next"),
_recognition_result("word", "Hola"),
_recognition_result("punctuation", "?", attaches_to="previous"),
]
)
)

assert transcript.transcript_text == "¿Hola?"

def test_punctuation_attached_to_neither_side(self):
transcript = Transcript.from_dict(
_transcript_payload(
[
_recognition_result("word", "hello"),
_recognition_result("punctuation", "-", attaches_to="none"),
_recognition_result("word", "world"),
]
)
)

assert transcript.transcript_text == "hello - world"

def test_punctuation_attached_to_both_sides(self):
transcript = Transcript.from_dict(
_transcript_payload(
[
_recognition_result("word", "and"),
_recognition_result("punctuation", "/", attaches_to="both"),
_recognition_result("word", "or"),
]
)
)

assert transcript.transcript_text == "and/or"