From e4a011df5b24e96f2014c5520c2c41e84b926210 Mon Sep 17 00:00:00 2001 From: Minh Vu Date: Sat, 20 Jun 2026 23:32:17 +0200 Subject: [PATCH] Validate SimpleVocab special tokens --- lingvo/core/ops/simple_vocab.cc | 54 ++++++++++++++++++++++----- lingvo/core/ops/simple_vocab_test.py | 56 ++++++++++++++++++++++++++++ 2 files changed, 100 insertions(+), 10 deletions(-) diff --git a/lingvo/core/ops/simple_vocab.cc b/lingvo/core/ops/simple_vocab.cc index 48b0e2cb3..5fa3c5d9b 100644 --- a/lingvo/core/ops/simple_vocab.cc +++ b/lingvo/core/ops/simple_vocab.cc @@ -99,32 +99,66 @@ Status Vocab::Load(const std::vector& lines, bool load_token_ids) { VLOG(2) << "Vocab " << token_to_id_[tok] << " " << tok; } use_upper_token_symbols_ = false; - std::vector expected_tokens = {kSosToken, kEosToken, kUnkToken}; - std::vector unexpected_tokens = {kSosTokenUpper, kEosTokenUpper, - kUnkTokenUpper}; - if (token_to_id_.find(sos_token()) == token_to_id_.end()) { + const std::vector lower_tokens = {kSosToken, kEosToken, kUnkToken}; + const std::vector upper_tokens = {kSosTokenUpper, kEosTokenUpper, + kUnkTokenUpper}; + const auto token_id = [this](const string& token) { + return token_to_id_.find(token)->second; + }; + const auto token_exists = [this](const string& token) { + return token_to_id_.find(token) != token_to_id_.end(); + }; + const auto any_token_exists = + [&token_exists](const std::vector& tokens) { + for (const auto& token : tokens) { + if (token_exists(token)) { + return true; + } + } + return false; + }; + const bool has_lower_tokens = any_token_exists(lower_tokens); + const bool has_upper_tokens = any_token_exists(upper_tokens); + if (has_lower_tokens && has_upper_tokens) { + return errors::InvalidArgument( + "Mixed lower-case and upper-case special tokens are found in the " + "vocab."); + } + std::vector expected_tokens = lower_tokens; + std::vector unexpected_tokens = upper_tokens; + if (!has_lower_tokens && has_upper_tokens) { use_upper_token_symbols_ = true; expected_tokens.swap(unexpected_tokens); } - sos_id_ = token_to_id_[sos_token()]; for (const auto& token : expected_tokens) { - if (token_to_id_.find(token) == token_to_id_.end()) { + if (!token_exists(token)) { return errors::InvalidArgument(token, " is not found in the vocab."); } } for (const auto& token : unexpected_tokens) { - if (token_to_id_.find(token) != token_to_id_.end()) { + if (token_exists(token)) { return errors::InvalidArgument("Invalid token ", token, " is found in the vocab."); } } + for (size_t i = 0; i < expected_tokens.size(); ++i) { + const int32 id = token_id(expected_tokens[i]); + for (size_t j = 0; j < i; ++j) { + const int32 previous_id = token_id(expected_tokens[j]); + if (id == previous_id) { + return errors::InvalidArgument( + "Special tokens ", expected_tokens[j], " and ", expected_tokens[i], + " must have different ids, but both are ", id, "."); + } + } + } unk_id_ = -1; - sos_id_ = TokenToId(sos_token()); - eos_id_ = TokenToId(eos_token()); + sos_id_ = token_id(sos_token()); + eos_id_ = token_id(eos_token()); sow_id_ = TokenToId(sow_token()); eow_id_ = TokenToId(eow_token()); - unk_id_ = TokenToId(unk_token()); + unk_id_ = token_id(unk_token()); return Status(); } diff --git a/lingvo/core/ops/simple_vocab_test.py b/lingvo/core/ops/simple_vocab_test.py index b9857dc91..13908db20 100644 --- a/lingvo/core/ops/simple_vocab_test.py +++ b/lingvo/core/ops/simple_vocab_test.py @@ -14,12 +14,25 @@ # ============================================================================== """Tests for simple_vocab.""" +from lingvo import compat as tf from lingvo.core import ops from lingvo.core import test_utils class VocabOpsTest(test_utils.TestCase): + def _AssertInvalidVocab(self, + vocab, + error_message, + load_token_ids_from_vocab=False): + with self.session(use_gpu=False): + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + error_message): + ops.vocab_token_to_id( + '', + vocab=vocab, + load_token_ids_from_vocab=load_token_ids_from_vocab).eval() + def testVocabTokenToId(self): with self.session(use_gpu=False): vocab = [ @@ -156,6 +169,49 @@ def testTokenInVocab(self): ops.token_in_vocab(['b c d e', '♣'], vocab=vocab).eval().all()) self.assertFalse(ops.token_in_vocab('unknown', vocab=vocab).eval()) + def testVocabTokenToIdRejectsMissingLowercaseSos(self): + vocab = [ + '', + '', + '', + 'a', + ] + self._AssertInvalidVocab(vocab, ' is not found in the vocab.') + + def testVocabTokenToIdRejectsMissingUppercaseSos(self): + vocab = [ + '', + '', + '', + 'a', + ] + self._AssertInvalidVocab(vocab, ' is not found in the vocab.') + + def testVocabTokenToIdRejectsMixedSpecialTokenCasing(self): + vocab = [ + '', + '', + '', + '', + '', + 'a', + ] + self._AssertInvalidVocab( + vocab, 'Mixed lower-case and upper-case special tokens') + + def testVocabTokenToIdRejectsDuplicateSpecialTokenIds(self): + vocab = [ + ' 3', + ' 3', + ' 7', + ' 9', + 'a 2', + ] + self._AssertInvalidVocab( + vocab, + 'Special tokens and must have different ids', + load_token_ids_from_vocab=True) + if __name__ == '__main__': test_utils.main()