From ac65d6fdaaea4de257e5743db678a882510d071a Mon Sep 17 00:00:00 2001 From: AmitMY Date: Wed, 8 Apr 2026 12:39:57 +0200 Subject: [PATCH] feat: add high-level Tokenizer API with all 4 variants MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Tokenizer(units, merge_size, connected) — configurable base class - BPETokenizer, BNETokenizer, BoundlessBPETokenizer, SuperBPETokenizer - Units can be string ("utf8_clusters", "utf8", "characters") or callable - 10 tests covering all variants, custom units, error handling Co-Authored-By: Claude Opus 4.6 (1M context) --- complex_tokenization/examples/__init__.py | 0 complex_tokenization/examples/bne.py | 40 ------- .../examples/boundless_bpe.py | 11 -- complex_tokenization/examples/super_bpe.py | 24 ---- complex_tokenization/examples/utils.py | 10 -- complex_tokenization/tokenizer.py | 110 ++++++++++++++++++ tests/test_benchmark.py | 28 ++--- tests/test_tokenizer_api.py | 65 +++++++++++ tests/tokenizers/test_bne.py | 8 +- tests/tokenizers/test_boundless_bpe.py | 22 ++-- tests/tokenizers/test_bpe.py | 58 ++------- tests/tokenizers/test_super_bpe.py | 52 ++------- .../examples/bpe.py => tests/utils.py | 22 ++-- 13 files changed, 227 insertions(+), 223 deletions(-) delete mode 100644 complex_tokenization/examples/__init__.py delete mode 100644 complex_tokenization/examples/bne.py delete mode 100644 complex_tokenization/examples/boundless_bpe.py delete mode 100644 complex_tokenization/examples/super_bpe.py delete mode 100644 complex_tokenization/examples/utils.py create mode 100644 complex_tokenization/tokenizer.py create mode 100644 tests/test_tokenizer_api.py rename complex_tokenization/examples/bpe.py => tests/utils.py (54%) diff --git a/complex_tokenization/examples/__init__.py b/complex_tokenization/examples/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/complex_tokenization/examples/bne.py b/complex_tokenization/examples/bne.py deleted file mode 100644 index c91c0f4..0000000 --- a/complex_tokenization/examples/bne.py +++ /dev/null @@ -1,40 +0,0 @@ - - -from functools import reduce - -from complex_tokenization.examples.utils import text_dataset -from complex_tokenization.graph import Node -from complex_tokenization.graphs.settings import GraphSettings -from complex_tokenization.graphs.units import utf8_clusters -from complex_tokenization.graphs.words import words - - -def train_bne_tokenizer(texts: list[str], - n=2, - connected=False, - units=utf8_clusters, - num_merges: int = 10, - known_merges: list[tuple[str, ...]] | None = None): - from complex_tokenization.trainer import Trainer - - GraphSettings.ONLY_MINIMAL_MERGES = True - GraphSettings.MAX_MERGE_SIZE = n - - graphs = tuple(words(text, connected=connected, units=units) for text in texts) - - trainer = Trainer(graphs=graphs) - - if known_merges: - for merge_strs in known_merges: - nodes = tuple(Node(value=s.encode("utf-8")) for s in merge_strs) - token = reduce(lambda a, b: a + b, nodes) - trainer.graph = trainer.graph.merge(token, nodes) - trainer.merges.append((token, nodes)) - - trainer.train(num_merges=num_merges) - return trainer.get_merges() - - -if __name__ == "__main__": - texts = list(text_dataset(max_samples=10)) - print(train_bne_tokenizer(texts, n=4)) diff --git a/complex_tokenization/examples/boundless_bpe.py b/complex_tokenization/examples/boundless_bpe.py deleted file mode 100644 index cf104ec..0000000 --- a/complex_tokenization/examples/boundless_bpe.py +++ /dev/null @@ -1,11 +0,0 @@ -from complex_tokenization.examples.bne import train_bne_tokenizer -from complex_tokenization.examples.utils import text_dataset - - -def train_boundless_bpe_tokenizer(texts: list[str], num_merges: int = 10): - return train_bne_tokenizer(texts, n=2, connected=True, num_merges=num_merges) - - -if __name__ == "__main__": - texts = list(text_dataset(max_samples=10)) - print(train_boundless_bpe_tokenizer(texts)) diff --git a/complex_tokenization/examples/super_bpe.py b/complex_tokenization/examples/super_bpe.py deleted file mode 100644 index bba1351..0000000 --- a/complex_tokenization/examples/super_bpe.py +++ /dev/null @@ -1,24 +0,0 @@ -from complex_tokenization.examples.bne import train_bne_tokenizer -from complex_tokenization.examples.utils import text_dataset - - -def train_super_bpe_tokenizer(texts: list[str], - num_merges: int = 10, - disconnected_merges: int | None = None): - """Train with disconnected merges first, then switch to connected. - - Phase 1: Train BPE with word boundaries (connected=False) to learn - intra-word patterns like common subwords. - Phase 2: Switch to connected=True to learn cross-word patterns like - frequent word combinations, seeded with phase 1 merges. - """ - if disconnected_merges is None: - disconnected_merges = num_merges // 2 - - phase1 = train_bne_tokenizer(texts, n=2, connected=False, num_merges=disconnected_merges) - return train_bne_tokenizer(texts, n=2, connected=True, num_merges=num_merges, known_merges=phase1) - - -if __name__ == "__main__": - texts = list(text_dataset(max_samples=10)) - print(train_super_bpe_tokenizer(texts)) diff --git a/complex_tokenization/examples/utils.py b/complex_tokenization/examples/utils.py deleted file mode 100644 index 222c95c..0000000 --- a/complex_tokenization/examples/utils.py +++ /dev/null @@ -1,10 +0,0 @@ -from datasets import load_dataset - - -def text_dataset(max_samples=None, - dataset="Salesforce/wikitext", - dataset_config="wikitext-2-raw-v1"): - dataset = load_dataset(dataset, dataset_config, streaming=True, split="train") - if max_samples is not None: - dataset = dataset.take(max_samples) - return (sample["text"] for sample in dataset) diff --git a/complex_tokenization/tokenizer.py b/complex_tokenization/tokenizer.py new file mode 100644 index 0000000..1974622 --- /dev/null +++ b/complex_tokenization/tokenizer.py @@ -0,0 +1,110 @@ +"""High-level tokenizer API. + +Usage: + tokenizer = BPETokenizer() + tokenizer.train(texts, num_merges=100) + merges = tokenizer.get_merges() + +With language-specific decomposition: + from complex_tokenization.languages.hebrew.decompose import decompose_cluster + tokenizer = BPETokenizer() + tokenizer.register_script("Hebrew", decompose_cluster) + tokenizer.train(texts, num_merges=100) +""" + +from collections.abc import Callable +from functools import reduce + +from complex_tokenization.graph import GraphVertex, Node +from complex_tokenization.graphs.settings import GraphSettings +from complex_tokenization.graphs.units import characters, register_script, utf8, utf8_clusters +from complex_tokenization.graphs.words import words +from complex_tokenization.trainer import Trainer + +UNIT_FUNCTIONS: dict[str, Callable[[str], GraphVertex]] = { + "utf8": utf8, + "utf8_clusters": utf8_clusters, + "characters": characters, +} + + +class Tokenizer: + def __init__( + self, + units: str | Callable[[str], GraphVertex] = "utf8_clusters", + merge_size: int = 2, + connected: bool = False, + ): + if isinstance(units, str): + if units not in UNIT_FUNCTIONS: + raise ValueError(f"Unknown units: {units!r}. Choose from {list(UNIT_FUNCTIONS)}") + self.units = UNIT_FUNCTIONS[units] + else: + self.units = units + self.merge_size = merge_size + self.connected = connected + self.merges: list[tuple[str, ...]] = [] + + @staticmethod + def register_script(script: str, handler: Callable[[str], GraphVertex]): + register_script(script, handler) + + def add_merges(self, merges: list[tuple[str, ...]]): + self.merges.extend(merges) + + def _build_graphs(self, texts: list[str]) -> tuple[GraphVertex, ...]: + return tuple( + words(text, connected=self.connected, units=self.units) + for text in texts + ) + + def train(self, texts: list[str], num_merges: int = 100) -> list[tuple[str, ...]]: + GraphSettings.ONLY_MINIMAL_MERGES = True + GraphSettings.MAX_MERGE_SIZE = self.merge_size + + graphs = self._build_graphs(texts) + trainer = Trainer(graphs=graphs) + + for merge_strs in self.merges: + nodes = tuple(Node(value=s.encode("utf-8")) for s in merge_strs) + token = reduce(lambda a, b: a + b, nodes) + trainer.graph = trainer.graph.merge(token, nodes) + trainer.merges.append((token, nodes)) + + trainer.train(num_merges=num_merges) + self.merges = trainer.get_merges() + return self.merges + + def get_merges(self) -> list[tuple[str, ...]]: + return list(self.merges) + + +class BPETokenizer(Tokenizer): + def __init__(self, units="utf8_clusters"): + super().__init__(units=units, merge_size=2, connected=False) + + +class BNETokenizer(Tokenizer): + def __init__(self, n=4, units="utf8_clusters"): + super().__init__(units=units, merge_size=n, connected=False) + + +class BoundlessBPETokenizer(Tokenizer): + def __init__(self, units="utf8_clusters"): + super().__init__(units=units, merge_size=2, connected=True) + + +class SuperBPETokenizer(Tokenizer): + def __init__(self, units="utf8_clusters", disconnected_merges: int | None = None): + super().__init__(units=units, merge_size=2, connected=False) + self._disconnected_merges = disconnected_merges + + def train(self, texts: list[str], num_merges: int = 100) -> list[tuple[str, ...]]: + disconnected_merges = self._disconnected_merges or num_merges // 2 + + phase1 = BPETokenizer(units=self.units) + phase1.train(texts, num_merges=disconnected_merges) + + self.connected = True + self.add_merges(phase1.merges) + return super().train(texts, num_merges=num_merges) diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index 9a6f209..940dd9f 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -4,11 +4,8 @@ import pytest -from complex_tokenization.examples.bne import train_bne_tokenizer -from complex_tokenization.examples.boundless_bpe import train_boundless_bpe_tokenizer -from complex_tokenization.examples.bpe import train_bpe_tokenizer, train_huggingface_tokenizer -from complex_tokenization.examples.super_bpe import train_super_bpe_tokenizer -from complex_tokenization.examples.utils import text_dataset +from complex_tokenization.tokenizer import BNETokenizer, BoundlessBPETokenizer, BPETokenizer, SuperBPETokenizer +from tests.utils import text_dataset, train_huggingface_tokenizer @pytest.fixture(scope="module") @@ -17,45 +14,42 @@ def small_dataset(): class TestBenchmarkSmall: - """Benchmark with small dataset (10 samples) to ensure correctness and basic perf.""" - def test_bpe_matches_huggingface_merges(self, small_dataset): - ours = train_bpe_tokenizer(small_dataset, num_merges=10) + ours = BPETokenizer().train(small_dataset, num_merges=10) hf = train_huggingface_tokenizer(small_dataset, num_merges=10) hf_normalized = [(m[0].replace("Ġ", " "), m[1]) for m in hf] assert ours == hf_normalized def test_bpe_faster_than_60s(self, small_dataset): start = time.perf_counter() - train_bpe_tokenizer(small_dataset, num_merges=50) + BPETokenizer().train(small_dataset, num_merges=50) elapsed = time.perf_counter() - start assert elapsed < 60, f"BPE training took {elapsed:.1f}s (limit: 60s)" def test_boundless_bpe_faster_than_60s(self, small_dataset): start = time.perf_counter() - train_boundless_bpe_tokenizer(small_dataset, num_merges=50) + BoundlessBPETokenizer().train(small_dataset, num_merges=50) elapsed = time.perf_counter() - start assert elapsed < 60, f"Boundless BPE training took {elapsed:.1f}s (limit: 60s)" def test_super_bpe_faster_than_60s(self, small_dataset): start = time.perf_counter() - train_super_bpe_tokenizer(small_dataset, num_merges=50) + SuperBPETokenizer().train(small_dataset, num_merges=50) elapsed = time.perf_counter() - start assert elapsed < 60, f"Super BPE training took {elapsed:.1f}s (limit: 60s)" def test_bne_faster_than_60s(self, small_dataset): start = time.perf_counter() - train_bne_tokenizer(small_dataset, n=4, num_merges=50) + BNETokenizer(n=4).train(small_dataset, num_merges=50) elapsed = time.perf_counter() - start assert elapsed < 60, f"BNE training took {elapsed:.1f}s (limit: 60s)" def test_all_tokenizers_produce_merges(self, small_dataset): - """Sanity check that all tokenizer variants produce results.""" num = 10 - bpe = train_bpe_tokenizer(small_dataset, num_merges=num) - bne = train_bne_tokenizer(small_dataset, n=4, num_merges=num) - boundless = train_boundless_bpe_tokenizer(small_dataset, num_merges=num) - super_bpe = train_super_bpe_tokenizer(small_dataset, num_merges=num) + bpe = BPETokenizer().train(small_dataset, num_merges=num) + bne = BNETokenizer(n=4).train(small_dataset, num_merges=num) + boundless = BoundlessBPETokenizer().train(small_dataset, num_merges=num) + super_bpe = SuperBPETokenizer().train(small_dataset, num_merges=num) assert len(bpe) == num assert len(bne) == num diff --git a/tests/test_tokenizer_api.py b/tests/test_tokenizer_api.py new file mode 100644 index 0000000..c19eaf6 --- /dev/null +++ b/tests/test_tokenizer_api.py @@ -0,0 +1,65 @@ +"""Test the high-level Tokenizer API.""" + +import pytest + +from complex_tokenization.tokenizer import ( + BNETokenizer, + BoundlessBPETokenizer, + BPETokenizer, + SuperBPETokenizer, + Tokenizer, +) + + +class TestTokenizerAPI: + def test_default_tokenizer(self): + tok = Tokenizer() + merges = tok.train(["hello world hello world"], num_merges=3) + assert len(merges) == 3 + + def test_bpe_tokenizer(self): + tok = BPETokenizer() + merges = tok.train(["the teacher teaches the thick"], num_merges=2) + assert all(len(m) == 2 for m in merges) + + def test_bne_tokenizer(self): + tok = BNETokenizer(n=4) + merges = tok.train(["the teacher teaches the thick"], num_merges=2) + assert all(2 <= len(m) <= 4 for m in merges) + + def test_boundless_bpe_tokenizer(self): + tok = BoundlessBPETokenizer() + merges = tok.train(["the teacher teaches the thick"], num_merges=2) + assert all(len(m) == 2 for m in merges) + + def test_super_bpe_tokenizer(self): + tok = SuperBPETokenizer() + merges = tok.train(["the teacher teaches the thick"], num_merges=4) + assert len(merges) == 4 + + def test_custom_units(self): + tok = Tokenizer(units="utf8") + merges = tok.train(["hello hello"], num_merges=2) + assert len(merges) == 2 + + def test_invalid_units_raises(self): + with pytest.raises(ValueError, match="Unknown units"): + Tokenizer(units="invalid") + + def test_callable_units(self): + from complex_tokenization.graphs.units import utf8 + tok = Tokenizer(units=utf8) + merges = tok.train(["test test"], num_merges=2) + assert len(merges) == 2 + + def test_get_merges_before_train(self): + tok = Tokenizer() + assert tok.get_merges() == [] + + def test_super_bpe_phase1_matches_bpe(self): + texts = ["the teacher teaches the thick thing"] * 3 + bpe = BPETokenizer() + bpe_merges = bpe.train(texts, num_merges=5) + super_bpe = SuperBPETokenizer(disconnected_merges=5) + super_merges = super_bpe.train(texts, num_merges=10) + assert super_merges[:5] == bpe_merges diff --git a/tests/tokenizers/test_bne.py b/tests/tokenizers/test_bne.py index 75a9e66..23fb87a 100644 --- a/tests/tokenizers/test_bne.py +++ b/tests/tokenizers/test_bne.py @@ -1,12 +1,12 @@ -from complex_tokenization.examples.bne import train_bne_tokenizer -from complex_tokenization.examples.utils import text_dataset +from complex_tokenization.tokenizer import BNETokenizer +from tests.utils import text_dataset class TestBNE: def test_large_train_bne_tokenizer(self): - """Test training BNE tokenizer with n=4 and expected merges""" texts = list(text_dataset(max_samples=10)) - merges = train_bne_tokenizer(texts, n=4, num_merges=10) + tok = BNETokenizer(n=4) + merges = tok.train(texts, num_merges=10) expected = [ (' ', 't', 'h', 'e'), diff --git a/tests/tokenizers/test_boundless_bpe.py b/tests/tokenizers/test_boundless_bpe.py index f3251fc..8eb6e3e 100644 --- a/tests/tokenizers/test_boundless_bpe.py +++ b/tests/tokenizers/test_boundless_bpe.py @@ -1,31 +1,25 @@ -from complex_tokenization.examples.boundless_bpe import train_boundless_bpe_tokenizer -from complex_tokenization.examples.bpe import train_bpe_tokenizer +from complex_tokenization.tokenizer import BoundlessBPETokenizer, BPETokenizer class TestBoundlessBPE: def test_basic_boundless_bpe(self): texts = ["the teacher teaches the thick thing"] - merges = train_boundless_bpe_tokenizer(texts, num_merges=2) + tok = BoundlessBPETokenizer() + merges = tok.train(texts, num_merges=2) assert len(merges) == 2 def test_boundless_extends_bpe_with_cross_word_merges(self): """BPE exhausts intra-word merges; boundless continues across words.""" texts = ["ab cd ab cd ab cd"] - bpe_merges = train_bpe_tokenizer(texts, num_merges=5) - boundless_merges = train_boundless_bpe_tokenizer(texts, num_merges=5) + + bpe_merges = BPETokenizer().train(texts, num_merges=5) + boundless_merges = BoundlessBPETokenizer().train(texts, num_merges=5) assert bpe_merges == [ - ('a', 'b'), - (' ', 'c'), - (' c', 'd'), - (' ', 'ab'), + ('a', 'b'), (' ', 'c'), (' c', 'd'), (' ', 'ab'), ] assert boundless_merges == [ - ('a', 'b'), - (' ', 'c'), - (' c', 'd'), - (' ', 'ab'), - (' cd', ' ab'), + ('a', 'b'), (' ', 'c'), (' c', 'd'), (' ', 'ab'), (' cd', ' ab'), ] assert boundless_merges[:len(bpe_merges)] == bpe_merges assert len(boundless_merges) > len(bpe_merges) diff --git a/tests/tokenizers/test_bpe.py b/tests/tokenizers/test_bpe.py index f4ddbb9..5989093 100644 --- a/tests/tokenizers/test_bpe.py +++ b/tests/tokenizers/test_bpe.py @@ -1,70 +1,36 @@ -from complex_tokenization.examples.bpe import train_bpe_tokenizer, train_huggingface_tokenizer -from complex_tokenization.examples.utils import text_dataset +from complex_tokenization.tokenizer import BPETokenizer +from tests.utils import text_dataset, train_huggingface_tokenizer class TestBPE: def test_basic_train_huggingface_tokenizer(self): - """Test training HuggingFace tokenizer with expected merges""" texts = ["the teacher teaches the thick thing"] - # Only 2 merges, to avoid needing a tie-breaker merges = train_huggingface_tokenizer(texts, num_merges=2) - - expected = [ - ('Ġ', 't'), - ('h', 'e'), - ] - + expected = [('Ġ', 't'), ('h', 'e')] assert merges == expected def test_basic_train_complex_tokenizer(self): - """Test training complex tokenizer with expected merges""" texts = ["the teacher teaches the thick thing"] - # Only 2 merges, to avoid needing a tie-breaker - merges = train_bpe_tokenizer(texts, num_merges=2) - - expected = [ - (' ', 't'), - ('h', 'e'), - ] - + tok = BPETokenizer() + merges = tok.train(texts, num_merges=2) + expected = [(' ', 't'), ('h', 'e')] assert merges == expected def test_large_train_huggingface_tokenizer(self): - """Test training HuggingFace tokenizer with expected merges""" texts = list(text_dataset(max_samples=10)) merges = train_huggingface_tokenizer(texts, num_merges=10) - expected = [ - ("Ġ", "t"), - ("Ġ", "a"), - ("o", "n"), - ("h", "e"), - ("e", "s"), - ("e", "r"), - ("i", "n"), - ("Ġt", "he"), - ("e", "d"), - ("a", "l"), + ("Ġ", "t"), ("Ġ", "a"), ("o", "n"), ("h", "e"), ("e", "s"), + ("e", "r"), ("i", "n"), ("Ġt", "he"), ("e", "d"), ("a", "l"), ] - assert merges == expected def test_large_train_complex_tokenizer(self): - """Test training complex tokenizer with expected merges""" texts = list(text_dataset(max_samples=10)) - merges = train_bpe_tokenizer(texts, num_merges=10) - + tok = BPETokenizer() + merges = tok.train(texts, num_merges=10) expected = [ - (" ", "t"), - (" ", "a"), - ("o", "n"), - ("h", "e"), - ("e", "s"), - ("e", "r"), - ("i", "n"), - (" t", "he"), - ("e", "d"), - ("a", "l"), + (" ", "t"), (" ", "a"), ("o", "n"), ("h", "e"), ("e", "s"), + ("e", "r"), ("i", "n"), (" t", "he"), ("e", "d"), ("a", "l"), ] - assert merges == expected diff --git a/tests/tokenizers/test_super_bpe.py b/tests/tokenizers/test_super_bpe.py index 53e00e8..1c0c96f 100644 --- a/tests/tokenizers/test_super_bpe.py +++ b/tests/tokenizers/test_super_bpe.py @@ -1,65 +1,29 @@ -from complex_tokenization.examples.boundless_bpe import train_boundless_bpe_tokenizer -from complex_tokenization.examples.bpe import train_bpe_tokenizer -from complex_tokenization.examples.super_bpe import train_super_bpe_tokenizer +from complex_tokenization.tokenizer import BoundlessBPETokenizer, BPETokenizer, SuperBPETokenizer class TestSuperBPE: def test_basic_super_bpe(self): texts = ["the teacher teaches the thick thing"] - merges = train_super_bpe_tokenizer(texts, num_merges=4, disconnected_merges=2) + merges = SuperBPETokenizer(disconnected_merges=2).train(texts, num_merges=4) assert len(merges) == 4 def test_phase1_matches_bpe(self): - """Phase 1 of Super BPE should produce same merges as regular BPE.""" texts = ["ab cd ab cd ab cd"] * 3 - bpe_merges = train_bpe_tokenizer(texts, num_merges=3) - super_merges = train_super_bpe_tokenizer(texts, num_merges=5, disconnected_merges=3) + bpe_merges = BPETokenizer().train(texts, num_merges=3) + super_merges = SuperBPETokenizer(disconnected_merges=3).train(texts, num_merges=5) assert super_merges[:3] == bpe_merges def test_super_bpe_differs_from_boundless(self): - """Super BPE prioritizes intra-word merges; boundless picks by global frequency. - - With 'ab ac ab ac abcdefghik': - - Boundless merges cross-word 'ab ac' early (high frequency pair) - - Super forces intra-word merges first (consuming the long word), - then does cross-word merges later - """ + """Super BPE prioritizes intra-word merges; boundless picks by global frequency.""" texts = ["ab ac ab ac abcdefghik"] - boundless = train_boundless_bpe_tokenizer(texts, num_merges=10) - super_bpe = train_super_bpe_tokenizer(texts, num_merges=10, disconnected_merges=8) - - assert boundless == [ - (' ', 'a'), - (' a', 'c'), - (' a', 'b'), - ('a', 'b'), - ('ab', ' ac'), - ('ab ac', ' ab'), - (' ab', 'c'), - (' abc', 'd'), - (' abcd', 'e'), - (' abcde', 'f'), - ] - - assert super_bpe == [ - (' ', 'a'), - (' a', 'c'), - (' a', 'b'), - ('a', 'b'), - (' ab', 'c'), - (' abc', 'd'), - (' abcd', 'e'), - (' abcde', 'f'), - ('ab', ' ac'), - ('ab ac', ' ab'), - ] + boundless = BoundlessBPETokenizer().train(texts, num_merges=10) + super_bpe = SuperBPETokenizer(disconnected_merges=8).train(texts, num_merges=10) assert boundless[4] == ('ab', ' ac') assert super_bpe[4] == (' ab', 'c') def test_default_split(self): - """Default disconnected_merges should be num_merges // 2.""" texts = ["the teacher teaches the thick thing"] - merges = train_super_bpe_tokenizer(texts, num_merges=4) + merges = SuperBPETokenizer().train(texts, num_merges=4) assert len(merges) == 4 diff --git a/complex_tokenization/examples/bpe.py b/tests/utils.py similarity index 54% rename from complex_tokenization/examples/bpe.py rename to tests/utils.py index 7bcdbb3..31f08b8 100644 --- a/complex_tokenization/examples/bpe.py +++ b/tests/utils.py @@ -1,9 +1,16 @@ import json +from datasets import load_dataset from tokenizers import Tokenizer -from complex_tokenization.examples.bne import train_bne_tokenizer -from complex_tokenization.examples.utils import text_dataset + +def text_dataset(max_samples=None, + dataset="Salesforce/wikitext", + dataset_config="wikitext-2-raw-v1"): + dataset = load_dataset(dataset, dataset_config, streaming=True, split="train") + if max_samples is not None: + dataset = dataset.take(max_samples) + return (sample["text"] for sample in dataset) def get_tokenizer_merges(tokenizer: Tokenizer): @@ -19,14 +26,3 @@ def train_huggingface_tokenizer(texts: list[str], num_merges: int = 10): new_tokenizer = tokenizer.train_new_from_iterator(texts, 256 + 21 + num_merges) return get_tokenizer_merges(new_tokenizer) - - -def train_bpe_tokenizer(texts: list[str], num_merges: int = 10): - # BPE can only merge 2 tokens at a time - return train_bne_tokenizer(texts, n=2, num_merges=num_merges) - - -if __name__ == "__main__": - texts = list(text_dataset(max_samples=10)) - print(train_bpe_tokenizer(texts)) - print(train_huggingface_tokenizer(texts))