Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions benchmarks/bench_scaling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""Benchmark how training scales with text size and merge count."""

import time

from complex_tokenization.fast_bpe_trainer import FastBPETrainer
from complex_tokenization.graphs.settings import GraphSettings
from complex_tokenization.graphs.units import utf8_clusters
from complex_tokenization.graphs.words import words
from complex_tokenization.trainer import Trainer


def train_graph_bpe(texts, num_merges):
GraphSettings.ONLY_MINIMAL_MERGES = True
GraphSettings.MAX_MERGE_SIZE = 2
GraphSettings.USE_SINGLETONS = False
graphs = tuple(words(t, connected=False, units=utf8_clusters) for t in texts)
trainer = Trainer(graphs=graphs)
trainer.train(num_merges=num_merges)
return trainer.get_merges()


def train_fast_bpe(texts, num_merges):
fast = FastBPETrainer(texts)
fast.train(num_merges=num_merges)
return fast.get_merges()


BASE_TEXT = "the teacher teaches the thick thing about the theorem "


def run():
print(f"\n{'='*80}")
print("Scaling Benchmark: Graph BPE vs Fast BPE")
print(f"{'='*80}")
print(f"{'Config':30s} {'Graph BPE':>10s} {'Fast BPE':>10s} {'Speedup':>8s}")
print("-" * 80)

for num_texts in [10, 50, 100]:
for repeat in [10, 50]:
for num_merges in [50, 100, 200]:
texts = [BASE_TEXT * repeat] * num_texts
total_chars = sum(len(t) for t in texts)

start = time.perf_counter()
graph_merges = train_graph_bpe(texts, num_merges)
graph_time = time.perf_counter() - start

start = time.perf_counter()
fast_merges = train_fast_bpe(texts, num_merges)
fast_time = time.perf_counter() - start

speedup = graph_time / fast_time if fast_time > 0 else float('inf')
match = "ok" if graph_merges == fast_merges else "MISMATCH"
label = f"{num_texts}x{repeat}rep m={num_merges} ({total_chars:,}ch)"
print(f"{label:30s} {graph_time:>9.3f}s {fast_time:>9.3f}s {speedup:>7.1f}x {match}")


if __name__ == "__main__":
run()
27 changes: 27 additions & 0 deletions tests/test_scaling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Test that FastBPE scales well with larger inputs."""

import time

from complex_tokenization.fast_bpe_trainer import FastBPETrainer

BASE = "the teacher teaches the thick thing about the theorem "


class TestScaling:
def test_100k_chars_under_5s(self):
texts = [BASE * 50] * 100 # ~270k chars
start = time.perf_counter()
fast = FastBPETrainer(texts)
fast.train(num_merges=25)
elapsed = time.perf_counter() - start
assert elapsed < 5, f"FastBPE on 270k chars took {elapsed:.1f}s (limit: 5s)"
assert len(fast.merges) == 25

def test_merges_scale_with_data(self):
small = [BASE * 10] * 10
large = [BASE * 50] * 50
f_small = FastBPETrainer(small)
f_small.train(num_merges=50)
f_large = FastBPETrainer(large)
f_large.train(num_merges=50)
assert f_small.get_merges() == f_large.get_merges()
Loading