Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion tokenizers/src/pre_tokenizers/byte_level.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ lazy_static! {
r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"
)
.unwrap();
static ref RE_VEC: Vec<SysRegex> = {
let pattern = r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+";
let mut vec = Vec::with_capacity(MAX_NUM_THREADS);
for _ in 0..MAX_NUM_THREADS {
vec.push(SysRegex::new(pattern).unwrap());
}
vec
};
static ref BYTES_CHAR: HashMap<u8, char> = bytes_char();
static ref CHAR_BYTES: HashMap<char, u8> =
bytes_char().into_iter().map(|(c, b)| (b, c)).collect();
Expand Down Expand Up @@ -111,12 +119,31 @@ impl ByteLevel {
}
}

use std::num::NonZeroU64;
use std::thread;

pub struct FakeThreadId(NonZeroU64);

fn hash_current_thread() -> usize {
// It's easier to use unsafe than to use nightly. Rust has this nice u64 thread id counter
// that works great for our use case of avoiding collisions in our array. Unfortunately,
// it's private. However, there are only so many ways you can layout a u64, so just transmute
// https://github.com/rust-lang/rust/issues/67939
const _: [u8; 8] = [0; std::mem::size_of::<thread::ThreadId>()];
const _: [u8; 8] = [0; std::mem::size_of::<FakeThreadId>()];
let x =
unsafe { std::mem::transmute::<thread::ThreadId, FakeThreadId>(thread::current().id()).0 };
u64::from(x) as usize - 1
}

const MAX_NUM_THREADS: usize = 128;

/// As a `PreTokenizer`, `ByteLevel` is in charge of transforming all the unicode characters into
/// their byte-level counterpart. It also splits the input according to the configured regex.
// TODO: Give the ability to modify this regex
impl PreTokenizer for ByteLevel {
fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {
let re_ref: &SysRegex = &RE;
let re_ref: &SysRegex = &RE_VEC[hash_current_thread() % MAX_NUM_THREADS]; // TODO use the thread thing here as well!
pretokenized.split(|_, mut normalized| {
if self.add_prefix_space && !normalized.get().starts_with(' ') {
normalized.prepend(" ");
Expand Down
66 changes: 61 additions & 5 deletions tokenizers/src/tokenizer/added_vocabulary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,25 @@ impl std::hash::Hash for AddedToken {
}
}

use std::num::NonZeroU64;
use std::thread;

pub struct FakeThreadId(NonZeroU64);

fn hash_current_thread() -> usize {
// It's easier to use unsafe than to use nightly. Rust has this nice u64 thread id counter
// that works great for our use case of avoiding collisions in our array. Unfortunately,
// it's private. However, there are only so many ways you can layout a u64, so just transmute
// https://github.com/rust-lang/rust/issues/67939
const _: [u8; 8] = [0; std::mem::size_of::<thread::ThreadId>()];
const _: [u8; 8] = [0; std::mem::size_of::<FakeThreadId>()];
let x =
unsafe { std::mem::transmute::<thread::ThreadId, FakeThreadId>(thread::current().id()).0 };
u64::from(x) as usize
}

const MAX_NUM_THREADS: usize = 128;

type MatchingSet = (AhoCorasick, Vec<u32>);

lazy_static! {
Expand Down Expand Up @@ -156,11 +175,16 @@ pub struct AddedVocabulary {
/// us remove them easily with an O(1) complexity.
special_tokens_set: HashSet<String>,

/// A RegexSet containing all the non-normalized patterns used to split on AddedTokens
//// A RegexSet containing all the non-normalized patterns used to split on AddedTokens
split_trie: MatchingSet,
/// A RegexSet containing all the normalized patterns used to split on AddedTokens
split_normalized_trie: MatchingSet,

// A RegexSet containing all the non-normalized patterns used to split on AddedTokens
split_trie_vec: Vec<MatchingSet>,
/// A RegexSet containing all the normalized patterns used to split on AddedTokens
split_normalized_trie_vec: Vec<MatchingSet>,

/// Whether or not special tokens should be splitted when encoding. This is equivalent to ignoring them
encode_special_tokens: bool,
}
Expand All @@ -181,8 +205,10 @@ impl AddedVocabulary {
added_tokens: vec![],
special_tokens: vec![],
special_tokens_set: HashSet::new(),
split_trie: (trie, vec![]),
split_normalized_trie: (normalized_trie, vec![]),
split_trie: (trie.clone(), vec![]),
split_normalized_trie: (normalized_trie.clone(), vec![]),
split_trie_vec: vec![(trie, vec![]); MAX_NUM_THREADS],
split_normalized_trie_vec: vec![(normalized_trie, vec![]); MAX_NUM_THREADS],
encode_special_tokens: false,
}
}
Expand Down Expand Up @@ -345,6 +371,7 @@ impl AddedVocabulary {
.build(tokens.iter().map(|token| &token.content))
.expect("Failed to build tried when refreshing tokens");
self.split_trie = (trie, ids);
self.split_trie_vec = vec![self.split_trie.clone(); MAX_NUM_THREADS];

let (ntokens, nids): (Vec<&AddedToken>, Vec<u32>) = normalized.into_iter().unzip();
let patterns: Vec<_> = ntokens
Expand All @@ -362,6 +389,7 @@ impl AddedVocabulary {
.build(patterns.iter().map(|content| content.get()))
.expect("Failed to build tried when refreshing tokens (normalized)");
self.split_normalized_trie = (normalized_trie, nids);
self.split_normalized_trie_vec = vec![self.split_normalized_trie.clone(); MAX_NUM_THREADS];
}

/// Find any AddedToken in the given sentence, using the provided MatchingSet.
Expand Down Expand Up @@ -425,6 +453,26 @@ impl AddedVocabulary {
splits
}


fn fast_split_with_indices(
&self,
sentence: NormalizedString,
split_re: &MatchingSet,
) -> Vec<(NormalizedString, Option<Vec<Token>>)> {
self.find_matches(sentence.get(), split_re)
.into_iter()
.map(|(id, byte_offsets)| {
let slice = sentence
.slice(Range::Normalized(byte_offsets.0..byte_offsets.1))
.expect("AddedVocabulary bad split");
if let Some(id) = id {
(slice, Some(vec![Token::new(id, String::new(), (0, 0))]))
} else {
(slice, None)
}
})
.collect()
}
/// Split the input sentence to extract anything we found from the `MatchingSet`, as well as
/// the list of corresponding IDs
/// The list of IDs have the exact same number of elements than the Iterator.
Expand Down Expand Up @@ -465,7 +513,12 @@ impl AddedVocabulary {

// 1. We extract all the non-normalized tokens from the non-normalized string
pretokenized
.split(|_, sequence| Ok(self.split_with_indices(sequence, &self.split_trie)))
.split(|_, sequence| {
Ok(self.fast_split_with_indices(
sequence,
&self.split_trie_vec[hash_current_thread() % MAX_NUM_THREADS],
))
})
.expect("AddedVocabulary bad split");

// <s> normalized = False
Expand All @@ -484,7 +537,10 @@ impl AddedVocabulary {
pretokenized
.split(|_, mut sequence| {
normalizer.map(|n| n.normalize(&mut sequence));
Ok(self.split_with_indices(sequence, &self.split_normalized_trie))
Ok(self.fast_split_with_indices(
sequence,
&self.split_normalized_trie_vec[hash_current_thread() % MAX_NUM_THREADS],
))
})
.expect("AddedVocabulary bad split");

Expand Down
2 changes: 1 addition & 1 deletion tokenizers/src/tokenizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -881,7 +881,7 @@ where
) -> Result<Encoding> {
let mut pretokenized: PreTokenizedString = pretokenized.into();
pretokenized.tokenize(|normalized| self.model.tokenize(normalized.get()))?;
pretokenized.into_encoding(word_idx, type_id, offsets_type)
pretokenized.fast_into_encoding()
}
}

Expand Down
19 changes: 19 additions & 0 deletions tokenizers/src/tokenizer/pre_tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,25 @@ impl PreTokenizedString {
}
}

pub fn fast_into_encoding(self) -> Result<Encoding> {
if self.splits.is_empty() {
Ok(Encoding::default())
} else if !self.splits.iter().all(|split| split.tokens.is_some()) {
Err("Split has not been tokenized.".into())
} else {
let tokens = self
.splits
.into_iter()
.flat_map(|split| {
split.tokens.unwrap().into_iter().map(|token| {
// Replace this with the actual fields you need for the Encoding type
(token.id, String::new(), (0, 0), None, 0)
})
})
.collect();
Ok(tokens)
}
}
/// Returns a list of splits, each of them being a slice of the normalized
/// string, the associated offsets either in original or normalized
/// referential, as well as the potention tokens
Expand Down