Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 44 additions & 10 deletions lingvo/core/ops/simple_vocab.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,32 +99,66 @@ Status Vocab::Load(const std::vector<string>& lines, bool load_token_ids) {
VLOG(2) << "Vocab " << token_to_id_[tok] << " " << tok;
}
use_upper_token_symbols_ = false;
std::vector<string> expected_tokens = {kSosToken, kEosToken, kUnkToken};
std::vector<string> unexpected_tokens = {kSosTokenUpper, kEosTokenUpper,
kUnkTokenUpper};
if (token_to_id_.find(sos_token()) == token_to_id_.end()) {
const std::vector<string> lower_tokens = {kSosToken, kEosToken, kUnkToken};
const std::vector<string> upper_tokens = {kSosTokenUpper, kEosTokenUpper,
kUnkTokenUpper};
const auto token_id = [this](const string& token) {
return token_to_id_.find(token)->second;
};
const auto token_exists = [this](const string& token) {
return token_to_id_.find(token) != token_to_id_.end();
};
const auto any_token_exists =
[&token_exists](const std::vector<string>& tokens) {
for (const auto& token : tokens) {
if (token_exists(token)) {
return true;
}
}
return false;
};
const bool has_lower_tokens = any_token_exists(lower_tokens);
const bool has_upper_tokens = any_token_exists(upper_tokens);
if (has_lower_tokens && has_upper_tokens) {
return errors::InvalidArgument(
"Mixed lower-case and upper-case special tokens are found in the "
"vocab.");
}
std::vector<string> expected_tokens = lower_tokens;
std::vector<string> unexpected_tokens = upper_tokens;
if (!has_lower_tokens && has_upper_tokens) {
use_upper_token_symbols_ = true;
expected_tokens.swap(unexpected_tokens);
}
sos_id_ = token_to_id_[sos_token()];

for (const auto& token : expected_tokens) {
if (token_to_id_.find(token) == token_to_id_.end()) {
if (!token_exists(token)) {
return errors::InvalidArgument(token, " is not found in the vocab.");
}
}
for (const auto& token : unexpected_tokens) {
if (token_to_id_.find(token) != token_to_id_.end()) {
if (token_exists(token)) {
return errors::InvalidArgument("Invalid token ", token,
" is found in the vocab.");
}
}
for (size_t i = 0; i < expected_tokens.size(); ++i) {
const int32 id = token_id(expected_tokens[i]);
for (size_t j = 0; j < i; ++j) {
const int32 previous_id = token_id(expected_tokens[j]);
if (id == previous_id) {
return errors::InvalidArgument(
"Special tokens ", expected_tokens[j], " and ", expected_tokens[i],
" must have different ids, but both are ", id, ".");
}
}
}
unk_id_ = -1;
sos_id_ = TokenToId(sos_token());
eos_id_ = TokenToId(eos_token());
sos_id_ = token_id(sos_token());
eos_id_ = token_id(eos_token());
sow_id_ = TokenToId(sow_token());
eow_id_ = TokenToId(eow_token());
unk_id_ = TokenToId(unk_token());
unk_id_ = token_id(unk_token());
return Status();
}

Expand Down
56 changes: 56 additions & 0 deletions lingvo/core/ops/simple_vocab_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,25 @@
# ==============================================================================
"""Tests for simple_vocab."""

from lingvo import compat as tf
from lingvo.core import ops
from lingvo.core import test_utils


class VocabOpsTest(test_utils.TestCase):

def _AssertInvalidVocab(self,
vocab,
error_message,
load_token_ids_from_vocab=False):
with self.session(use_gpu=False):
with self.assertRaisesRegex(tf.errors.InvalidArgumentError,
error_message):
ops.vocab_token_to_id(
'<S>',
vocab=vocab,
load_token_ids_from_vocab=load_token_ids_from_vocab).eval()

def testVocabTokenToId(self):
with self.session(use_gpu=False):
vocab = [
Expand Down Expand Up @@ -156,6 +169,49 @@ def testTokenInVocab(self):
ops.token_in_vocab(['b c d e', '♣'], vocab=vocab).eval().all())
self.assertFalse(ops.token_in_vocab('unknown', vocab=vocab).eval())

def testVocabTokenToIdRejectsMissingLowercaseSos(self):
vocab = [
'</s>',
'<unk>',
'<epsilon>',
'a',
]
self._AssertInvalidVocab(vocab, '<s> is not found in the vocab.')

def testVocabTokenToIdRejectsMissingUppercaseSos(self):
vocab = [
'</S>',
'<UNK>',
'<epsilon>',
'a',
]
self._AssertInvalidVocab(vocab, '<S> is not found in the vocab.')

def testVocabTokenToIdRejectsMixedSpecialTokenCasing(self):
vocab = [
'<S>',
'</S>',
'<UNK>',
'<s>',
'<epsilon>',
'a',
]
self._AssertInvalidVocab(
vocab, 'Mixed lower-case and upper-case special tokens')

def testVocabTokenToIdRejectsDuplicateSpecialTokenIds(self):
vocab = [
'<S> 3',
'</S> 3',
'<UNK> 7',
'<epsilon> 9',
'a 2',
]
self._AssertInvalidVocab(
vocab,
'Special tokens <S> and </S> must have different ids',
load_token_ids_from_vocab=True)


if __name__ == '__main__':
test_utils.main()