diff --git a/complex_tokenization/graphs/units.py b/complex_tokenization/graphs/units.py index 9407b1c..1a7c8de 100644 --- a/complex_tokenization/graphs/units.py +++ b/complex_tokenization/graphs/units.py @@ -1,7 +1,31 @@ +from collections.abc import Callable + import regex from complex_tokenization.graph import GraphVertex, Node, NodesSequence +_cluster_handlers: dict[str, Callable[[str], GraphVertex]] = {} + + +def register_script(script: str, handler: Callable[[str], GraphVertex]): + """Register a handler for grapheme clusters matching a Unicode script. + + The script name must be a valid Unicode script property (e.g. "Han", + "Hebrew"). When utf8_clusters processes a cluster whose first character + matches the script, the handler is called instead of the default utf8. + """ + _cluster_handlers[script] = handler + + +def _get_handler(cluster: str) -> Callable[[str], GraphVertex] | None: + if not _cluster_handlers: + return None + first_char = cluster[0] + for script, handler in _cluster_handlers.items(): + if regex.match(rf'\p{{{script}}}', first_char): + return handler + return None + def characters(s: str) -> GraphVertex: nodes = [Node(c.encode("utf-8")) for c in s] @@ -20,10 +44,14 @@ def utf8(s: str) -> GraphVertex: def utf8_clusters(s: str) -> GraphVertex: - # Split string into grapheme clusters using regex - # \X matches extended grapheme clusters clusters = regex.findall(r'\X', s) - nodes = [utf8(cluster) for cluster in clusters] + nodes = [] + for cluster in clusters: + handler = _get_handler(cluster) + if handler is not None: + nodes.append(handler(cluster)) + else: + nodes.append(utf8(cluster)) if len(nodes) == 1: return nodes[0] diff --git a/complex_tokenization/languages/chinese/graph.py b/complex_tokenization/languages/chinese/graph.py new file mode 100644 index 0000000..746ecd1 --- /dev/null +++ b/complex_tokenization/languages/chinese/graph.py @@ -0,0 +1,31 @@ +"""Convert Chinese characters into graph structures using IDS decomposition.""" + +from complex_tokenization.graph import GraphVertex, Tree +from complex_tokenization.graphs.units import utf8 +from complex_tokenization.languages.chinese.ideographic_description_sequences import ( + IDSNode, + get_ids_for_character, + parse_ideographic_description_sequences, +) + + +def ids_node_to_graph(node: IDSNode) -> GraphVertex: + if node.is_leaf(): + return utf8(node.value) + + root = utf8(node.value) + children = tuple(ids_node_to_graph(child) for child in node.children) + return Tree(root=root, children=children) + + +def chinese_character_to_graph(cluster: str) -> GraphVertex: + """Convert a Chinese character cluster to a graph, decomposing via IDS if possible.""" + if len(cluster) == 1: + ids = get_ids_for_character(cluster) + if ids is not None: + try: + tree = parse_ideographic_description_sequences(ids) + return ids_node_to_graph(tree) + except ValueError: + pass + return utf8(cluster) diff --git a/complex_tokenization/languages/hebrew/decompose.py b/complex_tokenization/languages/hebrew/decompose.py index 441444b..e238bca 100644 --- a/complex_tokenization/languages/hebrew/decompose.py +++ b/complex_tokenization/languages/hebrew/decompose.py @@ -1,4 +1,4 @@ -"""Decompose Hebrew text into graph structure. +"""Decompose Hebrew grapheme clusters into graph structure. Each grapheme cluster becomes: - A NodesSequence of [base_letter, diacritics_graph] @@ -8,8 +8,6 @@ import unicodedata -import regex - from complex_tokenization.graph import FullyConnectedGraph, GraphVertex, NodesSequence from complex_tokenization.graphs.units import utf8 @@ -48,13 +46,3 @@ def decompose_cluster(cluster: str) -> GraphVertex: return diacritics return NodesSequence(nodes=(base_node, diacritics)) - - -def hebrew_grapheme_clusters(text: str) -> GraphVertex: - """Convert Hebrew text to a graph using grapheme cluster decomposition.""" - clusters = regex.findall(r'\X', text) - nodes = [decompose_cluster(c) for c in clusters] - - if len(nodes) == 1: - return nodes[0] - return NodesSequence(nodes=tuple(nodes)) diff --git a/tests/conftest.py b/tests/conftest.py index 833776e..84c4e23 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,7 @@ import pytest from complex_tokenization.graphs.settings import GraphSettings +from complex_tokenization.graphs.units import _cluster_handlers @pytest.fixture(autouse=True) @@ -12,3 +13,10 @@ def reset_graph_settings(): yield GraphSettings.MAX_MERGE_SIZE = original["MAX_MERGE_SIZE"] GraphSettings.ONLY_MINIMAL_MERGES = original["ONLY_MINIMAL_MERGES"] + + +@pytest.fixture(autouse=True) +def clear_script_registry(): + _cluster_handlers.clear() + yield + _cluster_handlers.clear() diff --git a/tests/languages/test_chinese_training.py b/tests/languages/test_chinese_training.py new file mode 100644 index 0000000..0a21561 --- /dev/null +++ b/tests/languages/test_chinese_training.py @@ -0,0 +1,73 @@ +"""Test training a tokenizer on Chinese text with IDS decomposition.""" + +from complex_tokenization.graph import Node, NodesSequence, Tree +from complex_tokenization.graphs.settings import GraphSettings +from complex_tokenization.graphs.units import register_script, utf8_clusters +from complex_tokenization.languages.chinese.graph import chinese_character_to_graph +from complex_tokenization.trainer import Trainer + + +class TestChineseGraph: + def test_decomposable_character(self): + graph = chinese_character_to_graph("林") + assert isinstance(graph, Tree) + assert bytes(graph.root) == "⿰".encode() + + def test_non_decomposable_character(self): + graph = chinese_character_to_graph("a") + assert isinstance(graph, Node) + + def test_chinese_text_via_registry(self): + register_script("Han", chinese_character_to_graph) + graph = utf8_clusters("林木") + assert isinstance(graph, NodesSequence) + assert bytes(graph) == "⿰木木木".encode() + + def test_mixed_text_via_registry(self): + register_script("Han", chinese_character_to_graph) + graph = utf8_clusters("hello") + assert bytes(graph) == b"hello" + + +class TestChineseTraining: + def test_train_on_repeated_characters(self): + register_script("Han", chinese_character_to_graph) + GraphSettings.ONLY_MINIMAL_MERGES = True + GraphSettings.MAX_MERGE_SIZE = 2 + + texts = ["林森木本末朱机杏"] * 3 + graphs = tuple(utf8_clusters(t) for t in texts) + trainer = Trainer(graphs=graphs) + trainer.train(num_merges=5) + + assert len(trainer.get_merges()) > 0 + + def test_train_on_mixed_chinese_text(self): + register_script("Han", chinese_character_to_graph) + GraphSettings.ONLY_MINIMAL_MERGES = True + GraphSettings.MAX_MERGE_SIZE = 2 + + texts = ["你好世界 hello 你好"] + graphs = tuple(utf8_clusters(t) for t in texts) + trainer = Trainer(graphs=graphs) + trainer.train(num_merges=3) + assert len(trainer.merges) <= 3 + + def test_common_radicals_merge_early(self): + register_script("Han", chinese_character_to_graph) + GraphSettings.ONLY_MINIMAL_MERGES = True + GraphSettings.MAX_MERGE_SIZE = 2 + + texts = ["林森林森林森"] * 5 + graphs = tuple(utf8_clusters(t) for t in texts) + trainer = Trainer(graphs=graphs) + trainer.train(num_merges=20) + + merge_bytes = [ + b"".join(bytes(n) for n in nodes) + for _, nodes in trainer.merges + ] + wood = "木".encode() + assert any(wood in mb for mb in merge_bytes), ( + "Expected '木' in merge bytes within 20 merges" + ) diff --git a/tests/languages/test_hebrew.py b/tests/languages/test_hebrew.py index d885557..de08197 100644 --- a/tests/languages/test_hebrew.py +++ b/tests/languages/test_hebrew.py @@ -1,16 +1,14 @@ import unicodedata from complex_tokenization.graph import FullyConnectedGraph, Node, NodesSequence -from complex_tokenization.languages.hebrew.decompose import ( - decompose_cluster, - hebrew_grapheme_clusters, -) +from complex_tokenization.graphs.units import register_script, utf8_clusters +from complex_tokenization.languages.hebrew.decompose import decompose_cluster class TestDecomposeCluster: def test_bare_letter(self): result = decompose_cluster("א") - assert isinstance(result, NodesSequence) or isinstance(result, Node) + assert isinstance(result, (NodesSequence, Node)) assert bytes(result) == "א".encode() def test_letter_with_one_mark(self): @@ -43,42 +41,39 @@ def test_single_mark_collapses(self): base, mark = result.nodes assert not isinstance(mark, FullyConnectedGraph) + def test_fully_connected_merges(self): + cluster = "בְּ" # bet + sheva + dagesh + result = decompose_cluster(cluster) + merges = list(result.get_merges()) + merge_bytes = [b"".join(bytes(n) for n in m) for m in merges] + sheva = "ְ".encode() + dagesh = "ּ".encode() + assert sheva + dagesh in merge_bytes + assert dagesh + sheva in merge_bytes + + +class TestHebrewViaRegistry: def test_bytes_roundtrip(self): + register_script("Hebrew", decompose_cluster) word = "בְּרֵאשִׁ֖ית" - graph = hebrew_grapheme_clusters(word) + graph = utf8_clusters(word) assert bytes(graph) == word.encode() - -class TestHebrewGraphemeClusters: def test_simple_word(self): - result = hebrew_grapheme_clusters("שלום") + register_script("Hebrew", decompose_cluster) + result = utf8_clusters("שלום") assert isinstance(result, NodesSequence) assert bytes(result) == "שלום".encode() def test_word_with_nikkud(self): - word = "שָׁלוֹם" # shalom with nikkud - result = hebrew_grapheme_clusters(word) + register_script("Hebrew", decompose_cluster) + word = "שָׁלוֹם" + result = utf8_clusters(word) assert isinstance(result, NodesSequence) assert bytes(result) == word.encode() - def test_bereshit_structure(self): - word = "בְּרֵאשִׁ֖ית" - result = hebrew_grapheme_clusters(word) - assert isinstance(result, NodesSequence) - def test_mark_categories(self): - """Verify that we correctly identify all Hebrew mark types.""" - marks = "ְִֵּׁ֖" + marks = "ְִֵּׁ֖" for ch in marks: cat = unicodedata.category(ch) assert cat == "Mn", f"{ch!r} (U+{ord(ch):04X}) is {cat}, not Mn" - - def test_fully_connected_merges(self): - cluster = "בְּ" # bet + sheva + dagesh - result = decompose_cluster(cluster) - merges = list(result.get_merges()) - merge_bytes = [b"".join(bytes(n) for n in m) for m in merges] - sheva = "ְ".encode() - dagesh = "ּ".encode() - assert sheva + dagesh in merge_bytes - assert dagesh + sheva in merge_bytes