We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent d9aab89 commit d63556eCopy full SHA for d63556e
texar/torch/modules/decoders/gpt2_decoder_test.py
@@ -191,9 +191,9 @@ def test_beam_search(self):
191
self.assertEqual(
192
outputs['log_prob'].shape,
193
torch.Size([self.batch_size, self.beam_width]))
194
- self.assertEqual(
195
- outputs['sample_id'].shape,
196
- torch.Size([self.batch_size, self.max_length, self.beam_width]))
+ self.assertEqual(outputs['sample_id'].shape[0], self.batch_size)
+ self.assertLessEqual(outputs['sample_id'].shape[1], self.max_length)
+ self.assertEqual(outputs['sample_id'].shape[2], self.beam_width)
197
198
def test_greedy_embedding_helper(self):
199
r"""Tests with tf.contrib.seq2seq.GreedyEmbeddingHelper
0 commit comments