diff --git a/lingvo/core/ops/BUILD b/lingvo/core/ops/BUILD index afa48f2ce..73c2571bc 100644 --- a/lingvo/core/ops/BUILD +++ b/lingvo/core/ops/BUILD @@ -107,6 +107,14 @@ lingvo_cc_library( hdrs = ["simple_vocab.h"], ) +lingvo_cc_test( + name = "simple_vocab_cc_test", + srcs = ["simple_vocab_cc_test.cc"], + deps = [ + ":simple_vocab", + ], +) + custom_kernel_library( name = "ml_perf_subword_op", srcs = ["ml_perf_subword_op.cc"], diff --git a/lingvo/core/ops/simple_vocab.cc b/lingvo/core/ops/simple_vocab.cc index 48b0e2cb3..e54d2f1c5 100644 --- a/lingvo/core/ops/simple_vocab.cc +++ b/lingvo/core/ops/simple_vocab.cc @@ -15,6 +15,8 @@ limitations under the License. #include "lingvo/core/ops/simple_vocab.h" +#include + #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -80,6 +82,7 @@ Status Vocab::Load(const string& vocab_glob, bool load_token_ids) { Status Vocab::Load(const std::vector& lines, bool load_token_ids) { id_to_token_.clear(); token_to_id_.clear(); + max_token_id_ = -1; int32 next_id = 0; for (StringPiece line : lines) { if (line.empty()) continue; @@ -89,12 +92,20 @@ Status Vocab::Load(const std::vector& lines, bool load_token_ids) { if (!load_token_ids) { token_to_id_[tok] = next_id; id_to_token_[next_id] = tok; + max_token_id_ = std::max(max_token_id_, next_id); next_id++; } else { CHECK_GE(parts.size(), 2); const int32 id = std::stoi(parts[1]); + if (id < 0 && str_util::StartsWith(tok, kBowStr)) { + return errors::InvalidArgument("BOW token ", tok, + " has negative id ", id, "."); + } token_to_id_[tok] = id; id_to_token_[id] = tok; + if (id >= 0) { + max_token_id_ = std::max(max_token_id_, id); + } } VLOG(2) << "Vocab " << token_to_id_[tok] << " " << tok; } diff --git a/lingvo/core/ops/simple_vocab.h b/lingvo/core/ops/simple_vocab.h index 41bc8f6f3..ae3fa15bc 100644 --- a/lingvo/core/ops/simple_vocab.h +++ b/lingvo/core/ops/simple_vocab.h @@ -17,6 +17,8 @@ limitations under the License. #define LINGVO_CORE_OPS_SIMPLE_VOCAB_H_ // TODO(zhifengc): Add comments for this class. +#include +#include #include #include #include @@ -110,10 +112,12 @@ class Vocab { } std::vector GetBowTokenIds() const { - std::vector is_bow_token_id(id_to_token_.size(), false); + const size_t num_token_ids = + max_token_id_ < 0 ? 0 : static_cast(max_token_id_) + 1; + std::vector is_bow_token_id(num_token_ids, false); static const int bowStrLen = strlen(kBowStr); for (auto const& kv : id_to_token_) { - if (kv.second.substr(0, bowStrLen) == kBowStr) { + if (kv.first >= 0 && kv.second.substr(0, bowStrLen) == kBowStr) { is_bow_token_id[kv.first] = true; } } @@ -126,6 +130,7 @@ class Vocab { int32 unk_id_ = -1; int32 sow_id_ = -1; int32 eow_id_ = -1; + int32 max_token_id_ = -1; bool use_upper_token_symbols_ = false; std::unordered_map id_to_token_; std::unordered_map token_to_id_; diff --git a/lingvo/core/ops/simple_vocab_cc_test.cc b/lingvo/core/ops/simple_vocab_cc_test.cc new file mode 100644 index 000000000..3319dec00 --- /dev/null +++ b/lingvo/core/ops/simple_vocab_cc_test.cc @@ -0,0 +1,54 @@ +/* Copyright 2026 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "lingvo/core/ops/simple_vocab.h" + +#include +#include + +#include + +namespace tensorflow { +namespace lingvo { +namespace { + +TEST(SimpleVocabTest, GetBowTokenIdsHandlesSparseIds) { + Vocab vocab; + ASSERT_TRUE(vocab + .Load({"\t3", "\t5", "\t7", "plain\t-1", + std::string(kBowStr) + "hello\t100"}, + true) + .ok()); + + const std::vector bow_token_ids = vocab.GetBowTokenIds(); + + ASSERT_EQ(101, bow_token_ids.size()); + EXPECT_TRUE(bow_token_ids[100]); + EXPECT_FALSE(bow_token_ids[3]); + EXPECT_FALSE(bow_token_ids[7]); +} + +TEST(SimpleVocabTest, LoadRejectsNegativeBowTokenIds) { + Vocab vocab; + EXPECT_FALSE(vocab + .Load({"\t3", "\t5", "\t7", + std::string(kBowStr) + "hello\t-1"}, + true) + .ok()); +} + +} // namespace +} // namespace lingvo +} // namespace tensorflow