Skip to content

Commit d63556e

Browse files
authored
Bugfix in the unit test for GPT2Decoder (#288)
1 parent d9aab89 commit d63556e

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

texar/torch/modules/decoders/gpt2_decoder_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,9 +191,9 @@ def test_beam_search(self):
191191
self.assertEqual(
192192
outputs['log_prob'].shape,
193193
torch.Size([self.batch_size, self.beam_width]))
194-
self.assertEqual(
195-
outputs['sample_id'].shape,
196-
torch.Size([self.batch_size, self.max_length, self.beam_width]))
194+
self.assertEqual(outputs['sample_id'].shape[0], self.batch_size)
195+
self.assertLessEqual(outputs['sample_id'].shape[1], self.max_length)
196+
self.assertEqual(outputs['sample_id'].shape[2], self.beam_width)
197197

198198
def test_greedy_embedding_helper(self):
199199
r"""Tests with tf.contrib.seq2seq.GreedyEmbeddingHelper

0 commit comments

Comments
 (0)