diff --git a/batchalign/cli/bench.py b/batchalign/cli/bench.py index cbf3f4f..b88cf56 100644 --- a/batchalign/cli/bench.py +++ b/batchalign/cli/bench.py @@ -9,7 +9,7 @@ @click.command() -@click.argument("command", type=click.Choice(["align", "transcribe", "transcribe_s", "morphotag", "translate", "utseg", "benchmark", "opensmile", "coref"])) +@click.argument("command", type=click.Choice(["align", "transcribe", "transcribe_s", "morphotag", "translate", "utseg", "benchmark", "opensmile", "coref", "compare"])) @click.argument("in_dir", type=click.Path(exists=True, file_okay=False)) @click.argument("out_dir", type=click.Path(exists=True, file_okay=False)) @click.option("--runs", type=int, default=1, show_default=True, help="Number of benchmark runs.") @@ -33,7 +33,7 @@ def bench(ctx, command, in_dir, out_dir, runs, no_pool, no_lazy_audio, no_adapti if workers is not None: run_ctx.obj["workers"] = workers start = time.time() - if command in ["align", "morphotag", "translate", "utseg", "coref"]: + if command in ["align", "morphotag", "translate", "utseg", "coref", "compare"]: extensions = ["cha"] elif command in ["transcribe", "transcribe_s", "benchmark", "opensmile"]: extensions = ["wav", "mp3", "mp4"] diff --git a/batchalign/cli/cli.py b/batchalign/cli/cli.py index 931db5c..440ed56 100644 --- a/batchalign/cli/cli.py +++ b/batchalign/cli/cli.py @@ -400,6 +400,30 @@ def writer(doc, output): **kwargs) +#################### COMPARE ################################ + +@batchalign.command() +@common_options +@click.option("--lang", + help="sample language in three-letter ISO 3166-1 alpha-3 code", + show_default=True, + default="eng", + type=str) +@click.option("--merge-abbrev/--no-merge-abbrev", + default=False, help="Merge abbreviations in output. Default: no.") +@click.pass_context +def compare(ctx, in_dir, out_dir, lang, **kwargs): + """Compare transcripts against gold-standard references. + + For each FILE.cha in IN_DIR, expects a companion FILE.gold.cha in the + same directory. Runs morphosyntax analysis on the main transcript, then + produces a word-level diff stored as %%xsrep / %%xsmor tiers and writes + error metrics to a .compare.csv file in OUT_DIR. + """ + + _dispatch("compare", lang, 1, ["cha"], ctx, + in_dir, out_dir, None, None, C, **kwargs) + #################### AVQI ################################ @batchalign.command() diff --git a/batchalign/cli/dispatch.py b/batchalign/cli/dispatch.py index bd0e2a0..77d0088 100644 --- a/batchalign/cli/dispatch.py +++ b/batchalign/cli/dispatch.py @@ -57,6 +57,7 @@ "seamless_translate", "opensmile_egemaps", "opensmile_gemaps", + "compare_engine", "opensmile_compare", "opensmile_eGeMAPSv01b", } @@ -68,6 +69,7 @@ "gtrans", "replacement", "ngram", + "compare_analysis_engine", } warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') @@ -196,6 +198,41 @@ def progress_callback(completed, total, tasks): CHATFile(doc=doc["doc"]).write(str(P(output).with_suffix(".asr.cha")), write_wor=local_kwargs.get("wor", False)) + elif command == "compare": + from pathlib import Path as P + # Skip gold files that dispatch picked up + if file.endswith(".gold.cha"): + return + + # Find companion gold file + p = P(file) + gold_path = p.parent / (p.stem + ".gold.cha") + if not gold_path.exists(): + raise FileNotFoundError( + f"No gold .cha file found for comparison. " + f"main: {p.name}, expected: {gold_path.name}, looked in: {str(gold_path)}" + ) + + main_doc = CHATFile(path=str(p)).doc + gold_doc = CHATFile(path=str(gold_path), special_mor_=True).doc + + # Pipeline: morphosyntax(main) -> compare -> compare_analysis + result = pipeline(main_doc, callback=progress_callback, gold=gold_doc) + + # Write annotated CHAT + CHATFile(doc=result["doc"]).write(output, + merge_abbrev=local_kwargs.get("merge_abbrev", False)) + + # Write metrics CSV + import csv + metrics = result["metrics"] + csv_path = P(output).with_suffix(".compare.csv") + with open(csv_path, 'w', newline='') as f: + writer = csv.writer(f) + writer.writerow(["metric", "value"]) + for k, v in metrics.items(): + writer.writerow([k, v]) + elif command == "opensmile": from batchalign.document import Document doc = Document.new(media_path=file, lang=local_kwargs.get("lang", kwargs.get("lang", "eng"))) @@ -347,6 +384,7 @@ def _safe_peak_rss(): "coref": "coref", "translate": "translate", "opensmile": "opensmile", + "compare": "morphosyntax,compare,compare_analysis", } # this is the main runner used by all functions @@ -366,6 +404,7 @@ def _dispatch(command, lang, num_speakers, "coref", "benchmark", "opensmile", + "compare", } if command in worker_handled: # Avoid pickling CLI-local loader/writer functions when the worker @@ -478,7 +517,7 @@ def _dispatch(command, lang, num_speakers, memory_history_path.parent.mkdir(parents=True, exist_ok=True) # Pre-download stanza resources if needed to avoid interleaved downloads in workers - if command in ["morphotag", "utseg", "coref"]: + if command in ["morphotag", "utseg", "coref", "compare"]: try: import stanza stanza.download_resources_json() diff --git a/batchalign/document.py b/batchalign/document.py index baefec9..53485b2 100644 --- a/batchalign/document.py +++ b/batchalign/document.py @@ -33,6 +33,8 @@ class Task(IntEnum): COREF = 12 WER = 13 TRANSLATE = 14 + COMPARE = 15 + COMPARE_ANALYSIS = 16 DEBUG__G = 0 @@ -57,6 +59,8 @@ class TaskType(IntEnum): Task.COREF: TaskType.PROCESSING, Task.WER: TaskType.ANALYSIS, Task.TRANSLATE: TaskType.PROCESSING, + Task.COMPARE: TaskType.PROCESSING, + Task.COMPARE_ANALYSIS: TaskType.ANALYSIS, Task.DEBUG__G: TaskType.GENERATION, Task.DEBUG__P: TaskType.PROCESSING, @@ -77,6 +81,8 @@ class TaskType(IntEnum): Task.COREF: "Coreference Resolution", Task.WER: "Word Error Rate", Task.TRANSLATE: "Translation", + Task.COMPARE: "Transcript Comparison", + Task.COMPARE_ANALYSIS: "Comparison Analysis", Task.DEBUG__G: "TEST_GENERATION", Task.DEBUG__P: "TEST_PROCESSING", Task.DEBUG__A: "TEST_ANALYSIS", @@ -100,6 +106,11 @@ class CustomLine(BaseModel): type: CustomLineType # % or @ content: Optional[str] = Field(default=None) # the contents of the line +class CompareToken(BaseModel): + text: str # the word (conformed/expanded form) + pos: Optional[str] = Field(default=None) # POS tag (uppercased) + status: str = Field(default="match") # "match" | "extra_main" | "extra_gold" + class Dependency(BaseModel): id: int # first number, 1 indexed dep_id: int # second number (where the arrow points to) @@ -158,6 +169,7 @@ class Utterance(BaseModel): translation: Optional[str] = Field(default=None) time: Optional[Tuple[int,int]] = Field(default=None) custom_dependencies: List[CustomLine] = Field(default=[]) + comparison: Optional[List[CompareToken]] = Field(default=None) @property def delim(self) -> str: diff --git a/batchalign/formats/chat/generator.py b/batchalign/formats/chat/generator.py index 960f76d..fe6e6bc 100644 --- a/batchalign/formats/chat/generator.py +++ b/batchalign/formats/chat/generator.py @@ -123,6 +123,17 @@ def generate_chat_utterance(utterance: Utterance, special_mor=False, write_wor=T if special.content: result.append(f"%{special.id}:\t"+special.content) + #### COMPARISON LINE GENERATION #### + if utterance.comparison is not None: + xsrep_parts = [] + xsmor_parts = [] + for tok in utterance.comparison: + prefix = "+" if tok.status == "extra_main" else ("-" if tok.status == "extra_gold" else "") + xsrep_parts.append(f"{prefix}{tok.text}") + xsmor_parts.append(f"{prefix}{tok.pos or '?'}") + result.append(f"%xsrep:\t" + " ".join(xsrep_parts)) + result.append(f"%xsmor:\t" + " ".join(xsmor_parts)) + return "\n".join(result) def check_utterances_ordered(doc): diff --git a/batchalign/formats/chat/parser.py b/batchalign/formats/chat/parser.py index ef7d8e6..8cd776f 100644 --- a/batchalign/formats/chat/parser.py +++ b/batchalign/formats/chat/parser.py @@ -2,7 +2,7 @@ from batchalign.document import ( Document, Utterance, Form, Tier, Media, MediaType, - CustomLine, CustomLineType, Morphology, Dependency, ENDING_PUNCT + CustomLine, CustomLineType, Morphology, Dependency, CompareToken, ENDING_PUNCT ) from batchalign.utils import * from batchalign.errors import CHATValidationException @@ -14,6 +14,27 @@ import re +def _parse_comparison(xsrep_str, xsmor_str=None): + """Parse %xsrep and %xsmor lines into a list of CompareToken.""" + tokens = [] + xsrep_parts = xsrep_str.split() + xsmor_parts = xsmor_str.split() if xsmor_str else [None] * len(xsrep_parts) + for word, pos in zip(xsrep_parts, xsmor_parts): + if word.startswith("+"): + status = "extra_main" + word = word[1:] + if pos and pos.startswith("+"): + pos = pos[1:] + elif word.startswith("-"): + status = "extra_gold" + word = word[1:] + if pos and pos.startswith("-"): + pos = pos[1:] + else: + status = "match" + tokens.append(CompareToken(text=word, pos=pos, status=status)) + return tokens + def chat_parse_utterance(text, mor, gra, wor, additional): """Encode a CHAT utterance into a Batchalign utterance. @@ -299,6 +320,8 @@ def chat_parse_doc(lines, special_mor=False): gra = None wor = None translation = None + xsrep_line = None + xsmor_line = None additional = [] while raw[0][0] == "%": @@ -312,6 +335,10 @@ def chat_parse_doc(lines, special_mor=False): wor = line elif beg.strip() == "xtra": translation = line + elif beg.strip() == "xsrep": + xsrep_line = line + elif beg.strip() == "xsmor": + xsmor_line = line else: additional.append(CustomLine(id=beg.strip(), type=CustomLineType.DEPENDENT, @@ -336,6 +363,9 @@ def chat_parse_doc(lines, special_mor=False): "override_lang": None if len(multilingual) == 0 else multilingual[0] }) + if xsrep_line: + ut.comparison = _parse_comparison(xsrep_line, xsmor_line) + timing = re.findall(rf"\x15(\d+)_(\d+)\x15", text) if len(timing) != 0: x,y = timing[0] diff --git a/batchalign/pipelines/analysis/__init__.py b/batchalign/pipelines/analysis/__init__.py index fba4377..1d34879 100644 --- a/batchalign/pipelines/analysis/__init__.py +++ b/batchalign/pipelines/analysis/__init__.py @@ -10,6 +10,12 @@ def __getattr__(name): if name == 'EvaluationEngine': from .eval import EvaluationEngine return EvaluationEngine + if name == 'CompareEngine': + from .compare import CompareEngine + return CompareEngine + if name == 'CompareAnalysisEngine': + from .compare import CompareAnalysisEngine + return CompareAnalysisEngine raise AttributeError(f"module '{__name__}' has no attribute '{name}'") -__all__ = ['EvaluationEngine'] +__all__ = ['EvaluationEngine', 'CompareEngine', 'CompareAnalysisEngine'] diff --git a/batchalign/pipelines/analysis/compare.py b/batchalign/pipelines/analysis/compare.py new file mode 100644 index 0000000..1160bc8 --- /dev/null +++ b/batchalign/pipelines/analysis/compare.py @@ -0,0 +1,353 @@ +""" +compare.py +Engines for transcript comparison against gold-standard references. + +CompareEngine (PROCESSING): Aligns main vs gold transcripts word-by-word +using the same conform/match_fn logic as WER evaluation, then annotates +each main utterance with comparison tokens (%xsrep / %xsmor). + +CompareAnalysisEngine (ANALYSIS): Reads the comparison annotations and +computes error-rate metrics for CSV output. +""" + +import re +import logging +from batchalign.document import * +from batchalign.pipelines.base import * +from batchalign.utils.dp import align, ExtraType, Extra, Match +from batchalign.utils.names import names +from batchalign.utils.compounds import compounds +from batchalign.utils.abbrev import abbrev + +L = logging.getLogger("batchalign") + +# --- Duplicated from eval.py to avoid heavy import chain (asr.utils -> num2words) --- + +joined_compounds = ["".join(k) for k in compounds] +lowered_abbrev = [k for k in abbrev] + +fillers = ["um", "uhm", "em", "mhm", "uhhm", "eh", "uh", "hm"] + +def conform(x): + result = [] + for i in x: + if i.strip().lower() in joined_compounds: + for k in compounds[joined_compounds.index(i.strip().lower())]: + result.append(k) + elif i.strip() in lowered_abbrev: + for j in i.strip(): + result.append(j.strip()) + elif "'s" in i.strip().lower(): + result.append(i.split("\u2019")[0] if "\u2019" in i else i.split("'")[0]) + result.append("is") + elif "\u2019ve" in i.strip().lower() or "'ve" in i.strip().lower(): + result.append(i.split("\u2019")[0] if "\u2019" in i else i.split("'")[0]) + result.append("have") + elif "\u2019d" in i.strip().lower() or "'d" in i.strip().lower(): + result.append(i.split("\u2019")[0] if "\u2019" in i else i.split("'")[0]) + result.append("had") + elif "\u2019m" in i.strip().lower() or "'m" in i.strip().lower(): + result.append(i.split("\u2019")[0] if "\u2019" in i else i.split("'")[0]) + result.append("am") + elif i.strip().lower() in fillers: + result.append("um") + elif "-" in i.strip().lower(): + result += [k.strip() for k in i.lower().split("-")] + elif "ok" == i.strip().lower(): + result.append("okay") + elif "gimme" == i.strip().lower(): + result.append("give") + result.append("me") + elif "hafta" == i.strip().lower() or "havta" == i.strip().lower(): + result.append("have") + result.append("to") + elif i.strip().lower() in names: + result.append("name") + elif "dunno" == i.strip().lower(): + result.append("don't") + result.append("know") + elif "wanna" == i.strip().lower(): + result.append("want") + result.append("to") + elif "gonna" == i.strip().lower(): + result.append("going") + result.append("to") + elif "gotta" == i.strip().lower(): + result.append("got") + result.append("to") + elif "kinda" == i.strip().lower(): + result.append("kind") + result.append("of") + elif "sorta" == i.strip().lower(): + result.append("sort") + result.append("of") + elif "alright" == i.strip().lower() or "alrightie" == i.strip().lower(): + result.append("all") + result.append("right") + elif "shoulda" == i.strip().lower(): + result.append("should") + result.append("have") + elif "sposta" == i.strip().lower(): + result.append("supposed") + result.append("to") + elif "hadta" == i.strip().lower(): + result.append("had") + result.append("to") + elif "til" == i.strip().lower(): + result.append("until") + elif "ed" == i.strip().lower(): + result.append("education") + elif "mm" == i.strip().lower() or "hmm" == i.strip().lower(): + result.append("hm") + elif "eh" == i.strip().lower(): + result.append("uh") + elif "em" == i.strip().lower(): + result.append("them") + elif "farmhouse" == i.strip().lower(): + result.append("farm") + result.append("house") + elif "this'll" == i.strip().lower(): + result.append("this") + result.append("will") + elif "i'd" == i.strip().lower(): + result.append("i") + result.append("had") + elif "mba" == i.strip().lower(): + result.append("m") + result.append("b") + result.append("a") + elif "tli" == i.strip().lower(): + result.append("t") + result.append("l") + result.append("i") + elif "bbc" == i.strip().lower(): + result.append("b") + result.append("b") + result.append("c") + elif "ai" == i.strip().lower(): + result.append("a") + result.append("i") + elif "ii" == i.strip().lower(): + result.append("i") + result.append("i") + elif "aa" == i.strip().lower(): + result.append("a") + result.append("a") + elif "_" in i.strip().lower(): + for j in i.strip().split("_"): + result.append(j) + else: + result.append(i.lower()) + + return result + +def match_fn(x, y): + x = x.lower() + y = y.lower() + return (y == x or + y.replace("(", "").replace(")", "") == x.replace("(", "").replace(")", "") or + re.sub(r"\((.*)\)", r"", y) == x or re.sub(r"\((.*)\)", r"", x) == y) + +# --- End of eval.py duplicates --- + + +def _get_pos(form): + """Extract uppercased POS from a Form's morphology, or '?' if absent.""" + if form is not None and form.morphology: + return form.morphology[0].pos.upper() + return "?" + + +def conform_with_mapping(words, conform_fn): + """Apply conform() per word, returning expanded tokens and an index mapping. + + Parameters + ---------- + words : list[str] + Original word list. + conform_fn : callable + The conform function. + + Returns + ------- + conformed : list[str] + The conformed (expanded) token list. + mapping : list[int] + mapping[j] = index into the original `words` list that conformed[j] + originated from. + """ + conformed = [] + mapping = [] + for idx, word in enumerate(words): + expanded = conform_fn([word]) + for token in expanded: + conformed.append(token) + mapping.append(idx) + return conformed, mapping + + +class CompareEngine(BatchalignEngine): + tasks = [Task.COMPARE] + + def process(self, doc, **kwargs): + gold = kwargs.get("gold") + if not gold or not isinstance(gold, Document): + raise ValueError( + f"CompareEngine requires a 'gold' Document kwarg, got '{type(gold)}'" + ) + + # --- 1. Extract words from main utterances --- + main_utterances = [ + u for u in doc.content if isinstance(u, Utterance) + ] + main_info = [] # (utt_idx, form_idx, Form) + main_words = [] + main_punct = {} # utt_idx -> list of (form_idx, Form) + + for utt_idx, utt in enumerate(main_utterances): + main_punct[utt_idx] = [] + for form_idx, form in enumerate(utt.content): + if form.text.strip() in MOR_PUNCT + ENDING_PUNCT: + main_punct[utt_idx].append((form_idx, form)) + continue + if form.text.strip().lower() in fillers: + continue + main_info.append((utt_idx, form_idx, form)) + main_words.append(form.text) + + # --- 2. Extract words from gold utterances --- + gold_utterances = [ + u for u in gold.content if isinstance(u, Utterance) + ] + gold_info = [] # (utt_idx, form_idx, Form) + gold_words = [] + + for utt_idx, utt in enumerate(gold_utterances): + for form_idx, form in enumerate(utt.content): + if form.text.strip() in MOR_PUNCT + ENDING_PUNCT: + continue + if form.text.strip().lower() in fillers: + continue + gold_info.append((utt_idx, form_idx, form)) + gold_words.append(form.text) + + # --- 3. Apply conform() with mapping --- + conformed_main, main_map = conform_with_mapping(main_words, conform) + conformed_gold, gold_map = conform_with_mapping(gold_words, conform) + + # --- 4. Align --- + alignment = align(conformed_main, conformed_gold, False, match_fn) + + # --- 5. Redistribute alignment results per main utterance --- + # Store (position, CompareToken) pairs so we can interleave punct + utt_positioned = {i: [] for i in range(len(main_utterances))} + current_main_utt = 0 + last_main_form_idx = -1 + main_cursor = 0 + gold_cursor = 0 + + for item in alignment: + if isinstance(item, Match): + orig_main_idx = main_map[main_cursor] + main_utt_idx = main_info[orig_main_idx][0] + main_form_idx = main_info[orig_main_idx][1] + main_form = main_info[orig_main_idx][2] + current_main_utt = main_utt_idx + last_main_form_idx = main_form_idx + + utt_positioned[main_utt_idx].append((main_form_idx, CompareToken( + text=item.key, + pos=_get_pos(main_form), + status="match" + ))) + main_cursor += 1 + gold_cursor += 1 + + elif isinstance(item, Extra): + if item.extra_type == ExtraType.PAYLOAD: + # Word in main but not in gold -> extra_main (+) + orig_main_idx = main_map[main_cursor] + main_utt_idx = main_info[orig_main_idx][0] + main_form_idx = main_info[orig_main_idx][1] + main_form = main_info[orig_main_idx][2] + current_main_utt = main_utt_idx + last_main_form_idx = main_form_idx + + utt_positioned[main_utt_idx].append((main_form_idx, CompareToken( + text=item.key, + pos=_get_pos(main_form), + status="extra_main" + ))) + main_cursor += 1 + + else: + # Word in gold but not in main -> extra_gold (-) + orig_gold_idx = gold_map[gold_cursor] + gold_form = gold_info[orig_gold_idx][2] + + # Position just after last main form for correct ordering + pos = last_main_form_idx + 0.5 + utt_positioned[current_main_utt].append((pos, CompareToken( + text=item.key, + pos=_get_pos(gold_form), + status="extra_gold" + ))) + gold_cursor += 1 + + # --- 6. Merge punctuation at original positions --- + for utt_idx in range(len(main_utterances)): + for form_idx, form in main_punct[utt_idx]: + utt_positioned[utt_idx].append((form_idx, CompareToken( + text=form.text, + pos="PUNCT", + status="match" + ))) + # Stable sort by position preserves order within same form_idx + utt_positioned[utt_idx].sort(key=lambda x: x[0]) + + # --- 7. Set comparison on each utterance --- + for utt_idx, utt in enumerate(main_utterances): + tokens = [tok for _, tok in utt_positioned[utt_idx]] + utt.comparison = tokens if tokens else None + + return doc + + +class CompareAnalysisEngine(BatchalignEngine): + tasks = [Task.COMPARE_ANALYSIS] + + def analyze(self, doc, **kwargs): + matches = 0 + extra_main = 0 + extra_gold = 0 + + for utt in doc.content: + if not isinstance(utt, Utterance) or utt.comparison is None: + continue + for tok in utt.comparison: + if tok.status == "match": + matches += 1 + elif tok.status == "extra_main": + extra_main += 1 + elif tok.status == "extra_gold": + extra_gold += 1 + + total_gold = matches + extra_gold + total_main = matches + extra_main + wer = (extra_main + extra_gold) / total_gold if total_gold > 0 else 0.0 + accuracy = 1.0 - wer + + metrics = { + "wer": round(wer, 4), + "accuracy": round(accuracy, 4), + "matches": matches, + "insertions": extra_main, + "deletions": extra_gold, + "total_gold_words": total_gold, + "total_main_words": total_main, + } + + return { + "doc": doc, + "metrics": metrics, + } diff --git a/batchalign/pipelines/dispatch.py b/batchalign/pipelines/dispatch.py index c8e9092..2c35de2 100644 --- a/batchalign/pipelines/dispatch.py +++ b/batchalign/pipelines/dispatch.py @@ -25,6 +25,8 @@ "coref": "stanza_coref", "translate": "gtrans", "opensmile": "opensmile_egemaps", + "compare": "compare_engine", + "compare_analysis": "compare_analysis_engine", } LANGUAGE_OVERRIDE_PACKAGES: dict = { @@ -169,6 +171,12 @@ def dispatch_pipeline(pkg_str, lang, num_speakers=None, **arg_overrides): elif engine == "opensmile_eGeMAPSv01b": from batchalign.pipelines.opensmile import OpenSMILEEngine engines.append(OpenSMILEEngine(feature_set='eGeMAPSv01b')) + elif engine == "compare_engine": + from batchalign.pipelines.analysis import CompareEngine + engines.append(CompareEngine()) + elif engine == "compare_analysis_engine": + from batchalign.pipelines.analysis import CompareAnalysisEngine + engines.append(CompareAnalysisEngine()) L.debug(f"Done initalizing packages.") diff --git a/batchalign/tests/cli/test_dispatch_memory.py b/batchalign/tests/cli/test_dispatch_memory.py index bbacf36..4129dbf 100644 --- a/batchalign/tests/cli/test_dispatch_memory.py +++ b/batchalign/tests/cli/test_dispatch_memory.py @@ -4,6 +4,9 @@ import pytest +# cli → cli.py → models.training.run → torch +pytest.importorskip("torch") + from batchalign.cli import dispatch as dispatch_module from batchalign import constants as constants_module diff --git a/batchalign/tests/formats/textgrid/test_textgrid.py b/batchalign/tests/formats/textgrid/test_textgrid.py index 715b743..29542b2 100644 --- a/batchalign/tests/formats/textgrid/test_textgrid.py +++ b/batchalign/tests/formats/textgrid/test_textgrid.py @@ -3,6 +3,8 @@ import pytest import pathlib +pytest.importorskip("praatio") + from batchalign.document import * from batchalign.formats.textgrid import TextGridFile diff --git a/batchalign/tests/models/test_audio_io.py b/batchalign/tests/models/test_audio_io.py index 3fada1b..75a3e1b 100644 --- a/batchalign/tests/models/test_audio_io.py +++ b/batchalign/tests/models/test_audio_io.py @@ -8,9 +8,10 @@ import tempfile from pathlib import Path -import numpy as np import pytest -import torch + +np = pytest.importorskip("numpy") +torch = pytest.importorskip("torch") from batchalign.models import audio_io diff --git a/batchalign/tests/models/test_audio_lazy.py b/batchalign/tests/models/test_audio_lazy.py index 85e73f6..0dd0354 100644 --- a/batchalign/tests/models/test_audio_lazy.py +++ b/batchalign/tests/models/test_audio_lazy.py @@ -1,4 +1,6 @@ -import torch +import pytest + +torch = pytest.importorskip("torch") from batchalign.models.utils import ASRAudioFile diff --git a/batchalign/tests/pipelines/analysis/test_eval.py b/batchalign/tests/pipelines/analysis/test_eval.py index 7e1603d..6f7506f 100644 --- a/batchalign/tests/pipelines/analysis/test_eval.py +++ b/batchalign/tests/pipelines/analysis/test_eval.py @@ -1,3 +1,8 @@ +import pytest + +# eval.py → asr.utils → num2words +pytest.importorskip("num2words") + from batchalign.pipelines.analysis.eval import EvaluationEngine from batchalign.document import * from copy import deepcopy diff --git a/batchalign/tests/pipelines/asr/test_asr_pipeline.py b/batchalign/tests/pipelines/asr/test_asr_pipeline.py index ba0520b..3b0b606 100644 --- a/batchalign/tests/pipelines/asr/test_asr_pipeline.py +++ b/batchalign/tests/pipelines/asr/test_asr_pipeline.py @@ -1,5 +1,8 @@ import pytest +# WhisperASRModel → numpy +pytest.importorskip("numpy") + from batchalign.pipelines import BatchalignPipeline from batchalign.models import WhisperASRModel diff --git a/batchalign/tests/pipelines/asr/test_asr_utils.py b/batchalign/tests/pipelines/asr/test_asr_utils.py index 97cecad..8d56d37 100644 --- a/batchalign/tests/pipelines/asr/test_asr_utils.py +++ b/batchalign/tests/pipelines/asr/test_asr_utils.py @@ -1,3 +1,7 @@ +import pytest + +pytest.importorskip("num2words") + from batchalign.pipelines.asr.utils import * from batchalign.document import * diff --git a/batchalign/tests/pipelines/cache/test_cache.py b/batchalign/tests/pipelines/cache/test_cache.py index b4a3674..4f4262d 100644 --- a/batchalign/tests/pipelines/cache/test_cache.py +++ b/batchalign/tests/pipelines/cache/test_cache.py @@ -15,6 +15,8 @@ import pytest +pytest.importorskip("filelock") + from batchalign.document import ( Utterance, Form, Morphology, Dependency, TokenType, Tier ) diff --git a/batchalign/tests/pipelines/fa/test_fa_pipeline.py b/batchalign/tests/pipelines/fa/test_fa_pipeline.py index 73c2dc8..4ab10c6 100644 --- a/batchalign/tests/pipelines/fa/test_fa_pipeline.py +++ b/batchalign/tests/pipelines/fa/test_fa_pipeline.py @@ -1,3 +1,7 @@ +import pytest + +pytest.importorskip("torch") + from batchalign.pipelines import BatchalignPipeline def test_whisper_fa_pipeline(en_doc): diff --git a/batchalign/tests/pipelines/fa/test_fa_short_segments.py b/batchalign/tests/pipelines/fa/test_fa_short_segments.py index 26de155..81534dd 100644 --- a/batchalign/tests/pipelines/fa/test_fa_short_segments.py +++ b/batchalign/tests/pipelines/fa/test_fa_short_segments.py @@ -19,6 +19,9 @@ import pytest import logging from unittest.mock import MagicMock, patch + +torch = pytest.importorskip("torch") + from batchalign.document import Document, Utterance, Form, Media, MediaType, TokenType