Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions init2winit/mt_eval/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,12 @@ def init_offline_evaluator(self,
self.mt_eval_config = mt_eval_config
self.dataset = dataset
params_rng, dropout_rng = jax.random.split(rng, num=2)
self.encoder = self.load_tokenizer(hps.vocab_path)
if not hps.vocab_path:
self.encoder = None
self.use_test_vocab = True
else:
self.encoder = self.load_tokenizer(hps.vocab_path)
self.use_test_vocab = False
self.initialize_model(model_cls, dataset_meta_data, dropout_rng, params_rng)

def init_online_evaluator(self,
Expand All @@ -110,7 +115,12 @@ def init_online_evaluator(self,
self.eval_num_batches = mt_eval_config.get('eval_num_batches')
self.mt_eval_config = mt_eval_config
self.dataset = dataset
self.encoder = self.load_tokenizer(hps.vocab_path)
if not hps.vocab_path:
self.encoder = None
self.use_test_vocab = True
else:
self.encoder = self.load_tokenizer(hps.vocab_path)
self.use_test_vocab = False
params_rng, dropout_rng = jax.random.split(rng, num=2)
self.initialize_model(model_cls, dataset_metadata, params_rng, dropout_rng)

Expand Down Expand Up @@ -189,8 +199,15 @@ def initialize_cache(self, inputs, max_length, params_rng, dropout_rng):
return init_dict['cache']

def decode_tokens(self, toks):
valid_toks = toks[:np.argmax(toks == self.eos_id) + 1].astype(np.int32)
return self.encoder.detokenize(valid_toks).numpy().decode('utf-8')
print('DEBUG: toke shape : ', toks.shape)
print('DEBUG: toks: ', toks)
if not self.use_test_vocab:
valid_toks = toks[:np.argmax(toks == self.eos_id) + 1].astype(np.int32)
return self.encoder.detokenize(valid_toks).numpy().decode('utf-8')
else:
valid_tok = toks[0]
print('DEBUG: valid_toks: ', valid_tok)
return TEST_VOCAB[valid_tok]

def current_batch_size(self, batch):
# we assume first token is non-zero in each target side example.
Expand Down Expand Up @@ -275,6 +292,9 @@ def translate_and_calculate_bleu_single_model(self, params, eval_split):
len(decode_output.reference_list),
len(decode_output.source_list))
if self.mt_eval_config.get('decoding_type') == 'beam_search':
print('DEBUG: ', decode_output.translation_list)
print('DEBUG: ', decode_output.reference_list)

bleu_score = eval_utils.compute_bleu_from_predictions(
decode_output.translation_list,
decode_output.reference_list,
Expand Down