diff --git a/complex_tokenization/graph.py b/complex_tokenization/graph.py index 919a9b0..4c2b6a8 100644 --- a/complex_tokenization/graph.py +++ b/complex_tokenization/graph.py @@ -81,39 +81,43 @@ def oid(self) -> str: # object pointer id for Graphviz node id return self.nodes[0].oid def get_merges(self): - num_nodes = len(self.nodes) - for i, node in enumerate(self.nodes): + nodes = self.nodes + num_nodes = len(nodes) + only_minimal = GraphSettings.ONLY_MINIMAL_MERGES + max_size = GraphSettings.MAX_MERGE_SIZE + + for i in range(num_nodes): + node = nodes[i] yield from node.get_merges() - if GraphSettings.ONLY_MINIMAL_MERGES and not isinstance(node, Node): + if only_minimal and not isinstance(node, Node): continue - for j in range(i + 2, min(i + GraphSettings.MAX_MERGE_SIZE + 1, num_nodes + 1)): - if GraphSettings.ONLY_MINIMAL_MERGES and j < num_nodes and not isinstance(self.nodes[j], Node): + for j in range(i + 2, min(i + max_size + 1, num_nodes + 1)): + if only_minimal and j < num_nodes and not isinstance(nodes[j], Node): break - yield tuple(self.nodes[i:j]) + yield (nodes[i], nodes[j - 1]) if j - i == 2 else tuple(nodes[i:j]) def merge(self, token: Node, merge: tuple["GraphVertex", ...]): m = len(merge) - i = 0 + nodes = self.nodes + n = len(nodes) out: list[GraphVertex] = [] - nodes = self.nodes # local alias + i = 0 - while i <= len(nodes) - m: - if tuple(nodes[i:i + m]) == merge: - out.append(Node(value=token.value)) - i += m # skip the matched span + while i <= n - m: + if nodes[i:i + m] == merge: + out.append(token) + i += m else: out.append(nodes[i]) i += 1 - - # append any remaining tail out.extend(nodes[i:]) if len(out) == 1: return out[0] - merged_nodes = tuple([n.merge(token, merge) for n in out]) + merged_nodes = tuple(n.merge(token, merge) for n in out) return NodesSequence(merged_nodes) def dot(self, level=0) -> Iterable[str]: @@ -192,6 +196,65 @@ def __bytes__(self): return self_bytes +@dataclass(frozen=True, slots=True) +class FullyConnectedGraph(GraphVertex): + """A set of nodes where every pair is a valid merge candidate. + + Used for Hebrew diacritics: dagesh, nikkud, and cantillation marks + on the same letter are interchangeable in merge order. + """ + nodes: tuple[GraphVertex, ...] + + def __bytes__(self): + return b"".join(bytes(n) for n in self.nodes) + + @property + def oid(self) -> str: + return self.nodes[0].oid + + def get_merges(self) -> Iterator[tuple]: + for node in self.nodes: + yield from node.get_merges() + for i in range(len(self.nodes)): + for j in range(len(self.nodes)): + if i != j: + yield (self.nodes[i], self.nodes[j]) + + def merge(self, token: Node, merge: tuple): + remaining = list(self.nodes) + if len(merge) == 2: + m0, m1 = merge + for i in range(len(remaining)): + if remaining[i] == m0: + for j in range(len(remaining)): + if i != j and remaining[j] == m1: + merged = [n for k, n in enumerate(remaining) if k not in (i, j)] + merged.append(token) + if len(merged) == 1: + return merged[0] + return FullyConnectedGraph(nodes=tuple(merged)) + + merged_nodes = tuple(n.merge(token, merge) for n in self.nodes) + if merged_nodes == self.nodes: + return self + return FullyConnectedGraph(nodes=merged_nodes) + + def dot(self, level=0) -> Iterable[str]: + color = "#ffe0cc" if level % 2 == 1 else "#ffd0b0" + yield f"subgraph cluster_{id(self)} {{" + yield f'\tlabel="{dot_escape(str(self))}";' + yield f'\tstyle=filled; color="{color}";' + yield '\tnode [style=filled, color=white];' + yield '\tedge [arrowhead=none, style=dashed];' + yield '' + for node in self.nodes: + yield from node.dot(level + 1) + for i in range(len(self.nodes)): + for j in range(i + 1, len(self.nodes)): + yield f'\t{self.nodes[i].oid} -> {self.nodes[j].oid} [dir=both];' + yield "}" + + @dataclass(frozen=True, slots=True) class UnconnectedGraphs(GraphVertex): subgraphs: tuple[GraphVertex, ...] diff --git a/complex_tokenization/languages/hebrew/__init__.py b/complex_tokenization/languages/hebrew/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/complex_tokenization/languages/hebrew/decompose.py b/complex_tokenization/languages/hebrew/decompose.py new file mode 100644 index 0000000..441444b --- /dev/null +++ b/complex_tokenization/languages/hebrew/decompose.py @@ -0,0 +1,60 @@ +"""Decompose Hebrew text into graph structure. + +Each grapheme cluster becomes: +- A NodesSequence of [base_letter, diacritics_graph] +- Where diacritics_graph is a FullyConnectedGraph of all marks +- Single diacritics or bare letters collapse to plain Nodes +""" + +import unicodedata + +import regex + +from complex_tokenization.graph import FullyConnectedGraph, GraphVertex, NodesSequence +from complex_tokenization.graphs.units import utf8 + + +def is_hebrew_mark(char: str) -> bool: + return unicodedata.category(char) == "Mn" + + +def decompose_cluster(cluster: str) -> GraphVertex: + """Decompose a single grapheme cluster into a graph vertex.""" + base_chars = [] + marks = [] + + for char in cluster: + if is_hebrew_mark(char): + marks.append(char) + else: + base_chars.append(char) + + base_text = "".join(base_chars) + base_node = utf8(base_text) if base_text else None + + if not marks: + if base_node is None: + return utf8(cluster) + return base_node + + mark_nodes = [utf8(m) for m in marks] + + if len(mark_nodes) == 1: + diacritics = mark_nodes[0] + else: + diacritics = FullyConnectedGraph(nodes=tuple(mark_nodes)) + + if base_node is None: + return diacritics + + return NodesSequence(nodes=(base_node, diacritics)) + + +def hebrew_grapheme_clusters(text: str) -> GraphVertex: + """Convert Hebrew text to a graph using grapheme cluster decomposition.""" + clusters = regex.findall(r'\X', text) + nodes = [decompose_cluster(c) for c in clusters] + + if len(nodes) == 1: + return nodes[0] + return NodesSequence(nodes=tuple(nodes)) diff --git a/tests/languages/test_hebrew.py b/tests/languages/test_hebrew.py new file mode 100644 index 0000000..d885557 --- /dev/null +++ b/tests/languages/test_hebrew.py @@ -0,0 +1,84 @@ +import unicodedata + +from complex_tokenization.graph import FullyConnectedGraph, Node, NodesSequence +from complex_tokenization.languages.hebrew.decompose import ( + decompose_cluster, + hebrew_grapheme_clusters, +) + + +class TestDecomposeCluster: + def test_bare_letter(self): + result = decompose_cluster("א") + assert isinstance(result, NodesSequence) or isinstance(result, Node) + assert bytes(result) == "א".encode() + + def test_letter_with_one_mark(self): + cluster = "בָ" # bet + qamats + result = decompose_cluster(cluster) + assert isinstance(result, NodesSequence) + assert bytes(result) == cluster.encode() + + def test_letter_with_dagesh_and_vowel(self): + cluster = "בְּ" # bet + sheva + dagesh + result = decompose_cluster(cluster) + assert isinstance(result, NodesSequence) + assert bytes(result) == cluster.encode() + base, diacritics = result.nodes + assert isinstance(diacritics, FullyConnectedGraph) + assert len(diacritics.nodes) == 2 + + def test_letter_with_three_marks(self): + cluster = "שִׁ֖" # shin + hiriq + shin dot + tipeha + result = decompose_cluster(cluster) + assert isinstance(result, NodesSequence) + assert bytes(result) == cluster.encode() + base, diacritics = result.nodes + assert isinstance(diacritics, FullyConnectedGraph) + assert len(diacritics.nodes) == 3 + + def test_single_mark_collapses(self): + cluster = "בָ" # bet + qamats (one mark) + result = decompose_cluster(cluster) + base, mark = result.nodes + assert not isinstance(mark, FullyConnectedGraph) + + def test_bytes_roundtrip(self): + word = "בְּרֵאשִׁ֖ית" + graph = hebrew_grapheme_clusters(word) + assert bytes(graph) == word.encode() + + +class TestHebrewGraphemeClusters: + def test_simple_word(self): + result = hebrew_grapheme_clusters("שלום") + assert isinstance(result, NodesSequence) + assert bytes(result) == "שלום".encode() + + def test_word_with_nikkud(self): + word = "שָׁלוֹם" # shalom with nikkud + result = hebrew_grapheme_clusters(word) + assert isinstance(result, NodesSequence) + assert bytes(result) == word.encode() + + def test_bereshit_structure(self): + word = "בְּרֵאשִׁ֖ית" + result = hebrew_grapheme_clusters(word) + assert isinstance(result, NodesSequence) + + def test_mark_categories(self): + """Verify that we correctly identify all Hebrew mark types.""" + marks = "ְִֵּׁ֖" + for ch in marks: + cat = unicodedata.category(ch) + assert cat == "Mn", f"{ch!r} (U+{ord(ch):04X}) is {cat}, not Mn" + + def test_fully_connected_merges(self): + cluster = "בְּ" # bet + sheva + dagesh + result = decompose_cluster(cluster) + merges = list(result.get_merges()) + merge_bytes = [b"".join(bytes(n) for n in m) for m in merges] + sheva = "ְ".encode() + dagesh = "ּ".encode() + assert sheva + dagesh in merge_bytes + assert dagesh + sheva in merge_bytes