Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions batchalign/cli/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


@click.command()
@click.argument("command", type=click.Choice(["align", "transcribe", "transcribe_s", "morphotag", "translate", "utseg", "benchmark", "opensmile", "coref"]))
@click.argument("command", type=click.Choice(["align", "transcribe", "transcribe_s", "morphotag", "translate", "utseg", "benchmark", "opensmile", "coref", "compare"]))
@click.argument("in_dir", type=click.Path(exists=True, file_okay=False))
@click.argument("out_dir", type=click.Path(exists=True, file_okay=False))
@click.option("--runs", type=int, default=1, show_default=True, help="Number of benchmark runs.")
Expand All @@ -33,7 +33,7 @@ def bench(ctx, command, in_dir, out_dir, runs, no_pool, no_lazy_audio, no_adapti
if workers is not None:
run_ctx.obj["workers"] = workers
start = time.time()
if command in ["align", "morphotag", "translate", "utseg", "coref"]:
if command in ["align", "morphotag", "translate", "utseg", "coref", "compare"]:
extensions = ["cha"]
elif command in ["transcribe", "transcribe_s", "benchmark", "opensmile"]:
extensions = ["wav", "mp3", "mp4"]
Expand Down
24 changes: 24 additions & 0 deletions batchalign/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,30 @@ def writer(doc, output):
**kwargs)


#################### COMPARE ################################

@batchalign.command()
@common_options
@click.option("--lang",
help="sample language in three-letter ISO 3166-1 alpha-3 code",
show_default=True,
default="eng",
type=str)
@click.option("--merge-abbrev/--no-merge-abbrev",
default=False, help="Merge abbreviations in output. Default: no.")
@click.pass_context
def compare(ctx, in_dir, out_dir, lang, **kwargs):
"""Compare transcripts against gold-standard references.

For each FILE.cha in IN_DIR, expects a companion FILE.gold.cha in the
same directory. Runs morphosyntax analysis on the main transcript, then
produces a word-level diff stored as %%xsrep / %%xsmor tiers and writes
error metrics to a .compare.csv file in OUT_DIR.
"""

_dispatch("compare", lang, 1, ["cha"], ctx,
in_dir, out_dir, None, None, C, **kwargs)

#################### AVQI ################################

@batchalign.command()
Expand Down
41 changes: 40 additions & 1 deletion batchalign/cli/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
"seamless_translate",
"opensmile_egemaps",
"opensmile_gemaps",
"compare_engine",
"opensmile_compare",
"opensmile_eGeMAPSv01b",
}
Expand All @@ -68,6 +69,7 @@
"gtrans",
"replacement",
"ngram",
"compare_analysis_engine",
}

warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
Expand Down Expand Up @@ -196,6 +198,41 @@ def progress_callback(completed, total, tasks):
CHATFile(doc=doc["doc"]).write(str(P(output).with_suffix(".asr.cha")),
write_wor=local_kwargs.get("wor", False))

elif command == "compare":
from pathlib import Path as P
# Skip gold files that dispatch picked up
if file.endswith(".gold.cha"):
return

# Find companion gold file
p = P(file)
gold_path = p.parent / (p.stem + ".gold.cha")
if not gold_path.exists():
raise FileNotFoundError(
f"No gold .cha file found for comparison. "
f"main: {p.name}, expected: {gold_path.name}, looked in: {str(gold_path)}"
)

main_doc = CHATFile(path=str(p)).doc
gold_doc = CHATFile(path=str(gold_path), special_mor_=True).doc

# Pipeline: morphosyntax(main) -> compare -> compare_analysis
result = pipeline(main_doc, callback=progress_callback, gold=gold_doc)

# Write annotated CHAT
CHATFile(doc=result["doc"]).write(output,
merge_abbrev=local_kwargs.get("merge_abbrev", False))

# Write metrics CSV
import csv
metrics = result["metrics"]
csv_path = P(output).with_suffix(".compare.csv")
with open(csv_path, 'w', newline='') as f:
writer = csv.writer(f)
writer.writerow(["metric", "value"])
for k, v in metrics.items():
writer.writerow([k, v])

elif command == "opensmile":
from batchalign.document import Document
doc = Document.new(media_path=file, lang=local_kwargs.get("lang", kwargs.get("lang", "eng")))
Expand Down Expand Up @@ -347,6 +384,7 @@ def _safe_peak_rss():
"coref": "coref",
"translate": "translate",
"opensmile": "opensmile",
"compare": "morphosyntax,compare,compare_analysis",
}

# this is the main runner used by all functions
Expand All @@ -366,6 +404,7 @@ def _dispatch(command, lang, num_speakers,
"coref",
"benchmark",
"opensmile",
"compare",
}
if command in worker_handled:
# Avoid pickling CLI-local loader/writer functions when the worker
Expand Down Expand Up @@ -478,7 +517,7 @@ def _dispatch(command, lang, num_speakers,
memory_history_path.parent.mkdir(parents=True, exist_ok=True)

# Pre-download stanza resources if needed to avoid interleaved downloads in workers
if command in ["morphotag", "utseg", "coref"]:
if command in ["morphotag", "utseg", "coref", "compare"]:
try:
import stanza
stanza.download_resources_json()
Expand Down
12 changes: 12 additions & 0 deletions batchalign/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class Task(IntEnum):
COREF = 12
WER = 13
TRANSLATE = 14
COMPARE = 15
COMPARE_ANALYSIS = 16


DEBUG__G = 0
Expand All @@ -57,6 +59,8 @@ class TaskType(IntEnum):
Task.COREF: TaskType.PROCESSING,
Task.WER: TaskType.ANALYSIS,
Task.TRANSLATE: TaskType.PROCESSING,
Task.COMPARE: TaskType.PROCESSING,
Task.COMPARE_ANALYSIS: TaskType.ANALYSIS,

Task.DEBUG__G: TaskType.GENERATION,
Task.DEBUG__P: TaskType.PROCESSING,
Expand All @@ -77,6 +81,8 @@ class TaskType(IntEnum):
Task.COREF: "Coreference Resolution",
Task.WER: "Word Error Rate",
Task.TRANSLATE: "Translation",
Task.COMPARE: "Transcript Comparison",
Task.COMPARE_ANALYSIS: "Comparison Analysis",
Task.DEBUG__G: "TEST_GENERATION",
Task.DEBUG__P: "TEST_PROCESSING",
Task.DEBUG__A: "TEST_ANALYSIS",
Expand All @@ -100,6 +106,11 @@ class CustomLine(BaseModel):
type: CustomLineType # % or @
content: Optional[str] = Field(default=None) # the contents of the line

class CompareToken(BaseModel):
text: str # the word (conformed/expanded form)
pos: Optional[str] = Field(default=None) # POS tag (uppercased)
status: str = Field(default="match") # "match" | "extra_main" | "extra_gold"

class Dependency(BaseModel):
id: int # first number, 1 indexed
dep_id: int # second number (where the arrow points to)
Expand Down Expand Up @@ -158,6 +169,7 @@ class Utterance(BaseModel):
translation: Optional[str] = Field(default=None)
time: Optional[Tuple[int,int]] = Field(default=None)
custom_dependencies: List[CustomLine] = Field(default=[])
comparison: Optional[List[CompareToken]] = Field(default=None)

@property
def delim(self) -> str:
Expand Down
11 changes: 11 additions & 0 deletions batchalign/formats/chat/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,17 @@ def generate_chat_utterance(utterance: Utterance, special_mor=False, write_wor=T
if special.content:
result.append(f"%{special.id}:\t"+special.content)

#### COMPARISON LINE GENERATION ####
if utterance.comparison is not None:
xsrep_parts = []
xsmor_parts = []
for tok in utterance.comparison:
prefix = "+" if tok.status == "extra_main" else ("-" if tok.status == "extra_gold" else "")
xsrep_parts.append(f"{prefix}{tok.text}")
xsmor_parts.append(f"{prefix}{tok.pos or '?'}")
result.append(f"%xsrep:\t" + " ".join(xsrep_parts))
result.append(f"%xsmor:\t" + " ".join(xsmor_parts))

return "\n".join(result)

def check_utterances_ordered(doc):
Expand Down
32 changes: 31 additions & 1 deletion batchalign/formats/chat/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from batchalign.document import (
Document, Utterance, Form, Tier, Media, MediaType,
CustomLine, CustomLineType, Morphology, Dependency, ENDING_PUNCT
CustomLine, CustomLineType, Morphology, Dependency, CompareToken, ENDING_PUNCT
)
from batchalign.utils import *
from batchalign.errors import CHATValidationException
Expand All @@ -14,6 +14,27 @@

import re

def _parse_comparison(xsrep_str, xsmor_str=None):
"""Parse %xsrep and %xsmor lines into a list of CompareToken."""
tokens = []
xsrep_parts = xsrep_str.split()
xsmor_parts = xsmor_str.split() if xsmor_str else [None] * len(xsrep_parts)
for word, pos in zip(xsrep_parts, xsmor_parts):
if word.startswith("+"):
status = "extra_main"
word = word[1:]
if pos and pos.startswith("+"):
pos = pos[1:]
elif word.startswith("-"):
status = "extra_gold"
word = word[1:]
if pos and pos.startswith("-"):
pos = pos[1:]
else:
status = "match"
tokens.append(CompareToken(text=word, pos=pos, status=status))
return tokens

def chat_parse_utterance(text, mor, gra, wor, additional):
"""Encode a CHAT utterance into a Batchalign utterance.

Expand Down Expand Up @@ -299,6 +320,8 @@ def chat_parse_doc(lines, special_mor=False):
gra = None
wor = None
translation = None
xsrep_line = None
xsmor_line = None
additional = []

while raw[0][0] == "%":
Expand All @@ -312,6 +335,10 @@ def chat_parse_doc(lines, special_mor=False):
wor = line
elif beg.strip() == "xtra":
translation = line
elif beg.strip() == "xsrep":
xsrep_line = line
elif beg.strip() == "xsmor":
xsmor_line = line
else:
additional.append(CustomLine(id=beg.strip(),
type=CustomLineType.DEPENDENT,
Expand All @@ -336,6 +363,9 @@ def chat_parse_doc(lines, special_mor=False):
"override_lang": None if len(multilingual) == 0 else multilingual[0]
})

if xsrep_line:
ut.comparison = _parse_comparison(xsrep_line, xsmor_line)

timing = re.findall(rf"\x15(\d+)_(\d+)\x15", text)
if len(timing) != 0:
x,y = timing[0]
Expand Down
8 changes: 7 additions & 1 deletion batchalign/pipelines/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ def __getattr__(name):
if name == 'EvaluationEngine':
from .eval import EvaluationEngine
return EvaluationEngine
if name == 'CompareEngine':
from .compare import CompareEngine
return CompareEngine
if name == 'CompareAnalysisEngine':
from .compare import CompareAnalysisEngine
return CompareAnalysisEngine
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")

__all__ = ['EvaluationEngine']
__all__ = ['EvaluationEngine', 'CompareEngine', 'CompareAnalysisEngine']
Loading
Loading