From 60850ebfda7ce048b3a2fd35057cc452a3108d6c Mon Sep 17 00:00:00 2001 From: Sourabh Medapati Date: Mon, 28 Apr 2025 12:38:04 -0700 Subject: [PATCH] Internal PiperOrigin-RevId: 752408679 --- init2winit/mt_eval/inference.py | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/init2winit/mt_eval/inference.py b/init2winit/mt_eval/inference.py index 6c4003cd..9b38f837 100644 --- a/init2winit/mt_eval/inference.py +++ b/init2winit/mt_eval/inference.py @@ -93,7 +93,12 @@ def init_offline_evaluator(self, self.mt_eval_config = mt_eval_config self.dataset = dataset params_rng, dropout_rng = jax.random.split(rng, num=2) - self.encoder = self.load_tokenizer(hps.vocab_path) + if not hps.vocab_path: + self.encoder = None + self.use_test_vocab = True + else: + self.encoder = self.load_tokenizer(hps.vocab_path) + self.use_test_vocab = False self.initialize_model(model_cls, dataset_meta_data, dropout_rng, params_rng) def init_online_evaluator(self, @@ -110,7 +115,12 @@ def init_online_evaluator(self, self.eval_num_batches = mt_eval_config.get('eval_num_batches') self.mt_eval_config = mt_eval_config self.dataset = dataset - self.encoder = self.load_tokenizer(hps.vocab_path) + if not hps.vocab_path: + self.encoder = None + self.use_test_vocab = True + else: + self.encoder = self.load_tokenizer(hps.vocab_path) + self.use_test_vocab = False params_rng, dropout_rng = jax.random.split(rng, num=2) self.initialize_model(model_cls, dataset_metadata, params_rng, dropout_rng) @@ -189,8 +199,15 @@ def initialize_cache(self, inputs, max_length, params_rng, dropout_rng): return init_dict['cache'] def decode_tokens(self, toks): - valid_toks = toks[:np.argmax(toks == self.eos_id) + 1].astype(np.int32) - return self.encoder.detokenize(valid_toks).numpy().decode('utf-8') + print('DEBUG: toke shape : ', toks.shape) + print('DEBUG: toks: ', toks) + if not self.use_test_vocab: + valid_toks = toks[:np.argmax(toks == self.eos_id) + 1].astype(np.int32) + return self.encoder.detokenize(valid_toks).numpy().decode('utf-8') + else: + valid_tok = toks[0] + print('DEBUG: valid_toks: ', valid_tok) + return TEST_VOCAB[valid_tok] def current_batch_size(self, batch): # we assume first token is non-zero in each target side example. @@ -275,6 +292,9 @@ def translate_and_calculate_bleu_single_model(self, params, eval_split): len(decode_output.reference_list), len(decode_output.source_list)) if self.mt_eval_config.get('decoding_type') == 'beam_search': + print('DEBUG: ', decode_output.translation_list) + print('DEBUG: ', decode_output.reference_list) + bleu_score = eval_utils.compute_bleu_from_predictions( decode_output.translation_list, decode_output.reference_list,