Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions complex_tokenization/graphs/units.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,31 @@
from collections.abc import Callable

import regex

from complex_tokenization.graph import GraphVertex, Node, NodesSequence

_cluster_handlers: dict[str, Callable[[str], GraphVertex]] = {}


def register_script(script: str, handler: Callable[[str], GraphVertex]):
"""Register a handler for grapheme clusters matching a Unicode script.

The script name must be a valid Unicode script property (e.g. "Han",
"Hebrew"). When utf8_clusters processes a cluster whose first character
matches the script, the handler is called instead of the default utf8.
"""
_cluster_handlers[script] = handler


def _get_handler(cluster: str) -> Callable[[str], GraphVertex] | None:
if not _cluster_handlers:
return None
first_char = cluster[0]
for script, handler in _cluster_handlers.items():
if regex.match(rf'\p{{{script}}}', first_char):
return handler
return None


def characters(s: str) -> GraphVertex:
nodes = [Node(c.encode("utf-8")) for c in s]
Expand All @@ -20,10 +44,14 @@ def utf8(s: str) -> GraphVertex:


def utf8_clusters(s: str) -> GraphVertex:
# Split string into grapheme clusters using regex
# \X matches extended grapheme clusters
clusters = regex.findall(r'\X', s)
nodes = [utf8(cluster) for cluster in clusters]
nodes = []
for cluster in clusters:
handler = _get_handler(cluster)
if handler is not None:
nodes.append(handler(cluster))
else:
nodes.append(utf8(cluster))

if len(nodes) == 1:
return nodes[0]
Expand Down
31 changes: 31 additions & 0 deletions complex_tokenization/languages/chinese/graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Convert Chinese characters into graph structures using IDS decomposition."""

from complex_tokenization.graph import GraphVertex, Tree
from complex_tokenization.graphs.units import utf8
from complex_tokenization.languages.chinese.ideographic_description_sequences import (
IDSNode,
get_ids_for_character,
parse_ideographic_description_sequences,
)


def ids_node_to_graph(node: IDSNode) -> GraphVertex:
if node.is_leaf():
return utf8(node.value)

root = utf8(node.value)
children = tuple(ids_node_to_graph(child) for child in node.children)
return Tree(root=root, children=children)


def chinese_character_to_graph(cluster: str) -> GraphVertex:
"""Convert a Chinese character cluster to a graph, decomposing via IDS if possible."""
if len(cluster) == 1:
ids = get_ids_for_character(cluster)
if ids is not None:
try:
tree = parse_ideographic_description_sequences(ids)
return ids_node_to_graph(tree)
except ValueError:
pass
return utf8(cluster)
14 changes: 1 addition & 13 deletions complex_tokenization/languages/hebrew/decompose.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Decompose Hebrew text into graph structure.
"""Decompose Hebrew grapheme clusters into graph structure.

Each grapheme cluster becomes:
- A NodesSequence of [base_letter, diacritics_graph]
Expand All @@ -8,8 +8,6 @@

import unicodedata

import regex

from complex_tokenization.graph import FullyConnectedGraph, GraphVertex, NodesSequence
from complex_tokenization.graphs.units import utf8

Expand Down Expand Up @@ -48,13 +46,3 @@ def decompose_cluster(cluster: str) -> GraphVertex:
return diacritics

return NodesSequence(nodes=(base_node, diacritics))


def hebrew_grapheme_clusters(text: str) -> GraphVertex:
"""Convert Hebrew text to a graph using grapheme cluster decomposition."""
clusters = regex.findall(r'\X', text)
nodes = [decompose_cluster(c) for c in clusters]

if len(nodes) == 1:
return nodes[0]
return NodesSequence(nodes=tuple(nodes))
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

from complex_tokenization.graphs.settings import GraphSettings
from complex_tokenization.graphs.units import _cluster_handlers


@pytest.fixture(autouse=True)
Expand All @@ -12,3 +13,10 @@ def reset_graph_settings():
yield
GraphSettings.MAX_MERGE_SIZE = original["MAX_MERGE_SIZE"]
GraphSettings.ONLY_MINIMAL_MERGES = original["ONLY_MINIMAL_MERGES"]


@pytest.fixture(autouse=True)
def clear_script_registry():
_cluster_handlers.clear()
yield
_cluster_handlers.clear()
73 changes: 73 additions & 0 deletions tests/languages/test_chinese_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""Test training a tokenizer on Chinese text with IDS decomposition."""

from complex_tokenization.graph import Node, NodesSequence, Tree
from complex_tokenization.graphs.settings import GraphSettings
from complex_tokenization.graphs.units import register_script, utf8_clusters
from complex_tokenization.languages.chinese.graph import chinese_character_to_graph
from complex_tokenization.trainer import Trainer


class TestChineseGraph:
def test_decomposable_character(self):
graph = chinese_character_to_graph("林")
assert isinstance(graph, Tree)
assert bytes(graph.root) == "⿰".encode()

def test_non_decomposable_character(self):
graph = chinese_character_to_graph("a")
assert isinstance(graph, Node)

def test_chinese_text_via_registry(self):
register_script("Han", chinese_character_to_graph)
graph = utf8_clusters("林木")
assert isinstance(graph, NodesSequence)
assert bytes(graph) == "⿰木木木".encode()

def test_mixed_text_via_registry(self):
register_script("Han", chinese_character_to_graph)
graph = utf8_clusters("hello")
assert bytes(graph) == b"hello"


class TestChineseTraining:
def test_train_on_repeated_characters(self):
register_script("Han", chinese_character_to_graph)
GraphSettings.ONLY_MINIMAL_MERGES = True
GraphSettings.MAX_MERGE_SIZE = 2

texts = ["林森木本末朱机杏"] * 3
graphs = tuple(utf8_clusters(t) for t in texts)
trainer = Trainer(graphs=graphs)
trainer.train(num_merges=5)

assert len(trainer.get_merges()) > 0

def test_train_on_mixed_chinese_text(self):
register_script("Han", chinese_character_to_graph)
GraphSettings.ONLY_MINIMAL_MERGES = True
GraphSettings.MAX_MERGE_SIZE = 2

texts = ["你好世界 hello 你好"]
graphs = tuple(utf8_clusters(t) for t in texts)
trainer = Trainer(graphs=graphs)
trainer.train(num_merges=3)
assert len(trainer.merges) <= 3

def test_common_radicals_merge_early(self):
register_script("Han", chinese_character_to_graph)
GraphSettings.ONLY_MINIMAL_MERGES = True
GraphSettings.MAX_MERGE_SIZE = 2

texts = ["林森林森林森"] * 5
graphs = tuple(utf8_clusters(t) for t in texts)
trainer = Trainer(graphs=graphs)
trainer.train(num_merges=20)

merge_bytes = [
b"".join(bytes(n) for n in nodes)
for _, nodes in trainer.merges
]
wood = "木".encode()
assert any(wood in mb for mb in merge_bytes), (
"Expected '木' in merge bytes within 20 merges"
)
51 changes: 23 additions & 28 deletions tests/languages/test_hebrew.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
import unicodedata

from complex_tokenization.graph import FullyConnectedGraph, Node, NodesSequence
from complex_tokenization.languages.hebrew.decompose import (
decompose_cluster,
hebrew_grapheme_clusters,
)
from complex_tokenization.graphs.units import register_script, utf8_clusters
from complex_tokenization.languages.hebrew.decompose import decompose_cluster


class TestDecomposeCluster:
def test_bare_letter(self):
result = decompose_cluster("א")
assert isinstance(result, NodesSequence) or isinstance(result, Node)
assert isinstance(result, (NodesSequence, Node))
assert bytes(result) == "א".encode()

def test_letter_with_one_mark(self):
Expand Down Expand Up @@ -43,42 +41,39 @@ def test_single_mark_collapses(self):
base, mark = result.nodes
assert not isinstance(mark, FullyConnectedGraph)

def test_fully_connected_merges(self):
cluster = "בְּ" # bet + sheva + dagesh
result = decompose_cluster(cluster)
merges = list(result.get_merges())
merge_bytes = [b"".join(bytes(n) for n in m) for m in merges]
sheva = "ְ".encode()
dagesh = "ּ".encode()
assert sheva + dagesh in merge_bytes
assert dagesh + sheva in merge_bytes


class TestHebrewViaRegistry:
def test_bytes_roundtrip(self):
register_script("Hebrew", decompose_cluster)
word = "בְּרֵאשִׁ֖ית"
graph = hebrew_grapheme_clusters(word)
graph = utf8_clusters(word)
assert bytes(graph) == word.encode()


class TestHebrewGraphemeClusters:
def test_simple_word(self):
result = hebrew_grapheme_clusters("שלום")
register_script("Hebrew", decompose_cluster)
result = utf8_clusters("שלום")
assert isinstance(result, NodesSequence)
assert bytes(result) == "שלום".encode()

def test_word_with_nikkud(self):
word = "שָׁלוֹם" # shalom with nikkud
result = hebrew_grapheme_clusters(word)
register_script("Hebrew", decompose_cluster)
word = "שָׁלוֹם"
result = utf8_clusters(word)
assert isinstance(result, NodesSequence)
assert bytes(result) == word.encode()

def test_bereshit_structure(self):
word = "בְּרֵאשִׁ֖ית"
result = hebrew_grapheme_clusters(word)
assert isinstance(result, NodesSequence)

def test_mark_categories(self):
"""Verify that we correctly identify all Hebrew mark types."""
marks = "ְִֵּׁ֖"
marks = "ְִֵּׁ֖"
for ch in marks:
cat = unicodedata.category(ch)
assert cat == "Mn", f"{ch!r} (U+{ord(ch):04X}) is {cat}, not Mn"

def test_fully_connected_merges(self):
cluster = "בְּ" # bet + sheva + dagesh
result = decompose_cluster(cluster)
merges = list(result.get_merges())
merge_bytes = [b"".join(bytes(n) for n in m) for m in merges]
sheva = "ְ".encode()
dagesh = "ּ".encode()
assert sheva + dagesh in merge_bytes
assert dagesh + sheva in merge_bytes
Loading