|
23 | 23 |
|
24 | 24 | from tensorflow_addons import text |
25 | 25 | from tensorflow_addons.utils import test_utils |
| 26 | +from numpy.testing import assert_array_equal |
26 | 27 |
|
27 | 28 |
|
28 | 29 | def calculate_sequence_score(inputs, transition_params, tag_indices, sequence_lengths): |
@@ -559,3 +560,60 @@ def test_crf_decode_save_load(tmpdir): |
559 | 560 | "seq_len": np.array([10]), |
560 | 561 | } |
561 | 562 | ) |
| 563 | + |
| 564 | + |
| 565 | +@pytest.mark.parametrize( |
| 566 | + "potentials,sequence_length", |
| 567 | + [ |
| 568 | + # performs masking |
| 569 | + pytest.param( |
| 570 | + tf.random.normal([2, 12, 3]), |
| 571 | + tf.constant([8, 10]), |
| 572 | + ), |
| 573 | + # does not perform masking |
| 574 | + pytest.param( |
| 575 | + tf.random.normal([4, 8, 10]), |
| 576 | + tf.constant([8, 8, 8, 8]), |
| 577 | + ), |
| 578 | + ], |
| 579 | +) |
| 580 | +def test_crf_decode_forward_mask(potentials, sequence_length): |
| 581 | + # mimics setup of the `_multi_seq_fn` closure in `crf_decode` |
| 582 | + initial_state = tf.slice(potentials, [0, 0, 0], [-1, 1, -1]) |
| 583 | + initial_state = tf.squeeze(initial_state, axis=[1]) |
| 584 | + inputs = tf.slice(potentials, [0, 1, 0], [-1, -1, -1]) |
| 585 | + |
| 586 | + sequence_length_less_one = tf.maximum( |
| 587 | + tf.constant(0, dtype=tf.int32), sequence_length - 1 |
| 588 | + ) |
| 589 | + |
| 590 | + n_tags = potentials.shape[-1] |
| 591 | + transition_params = tf.random.normal([n_tags, n_tags]) |
| 592 | + |
| 593 | + backpointers, _ = text.crf_decode_forward( |
| 594 | + inputs, initial_state, transition_params, sequence_length_less_one |
| 595 | + ) |
| 596 | + |
| 597 | + # everything masked by `sequence_length_less_one` should be equal to 0. |
| 598 | + mask = tf.sequence_mask(sequence_length_less_one, tf.shape(inputs)[1]) |
| 599 | + |
| 600 | + # the indices that _should_ have been masked in the RNN operation |
| 601 | + masked_indices = tf.cast(tf.logical_not(mask), tf.int32) |
| 602 | + |
| 603 | + # sum of each row in the mask should equal timedim - seq lens |
| 604 | + exp_mask_sums = ( |
| 605 | + tf.repeat(inputs.shape[1], inputs.shape[0]) - sequence_length_less_one |
| 606 | + ) |
| 607 | + mask_sums = tf.reduce_sum(masked_indices, axis=1) |
| 608 | + assert_array_equal( |
| 609 | + exp_mask_sums.numpy(), |
| 610 | + mask_sums.numpy(), |
| 611 | + ) |
| 612 | + |
| 613 | + # now apply the inverse mask to the backpointers and show that ALL are zeros. this is proof that |
| 614 | + # we appropriately masked timesteps |
| 615 | + masked_indices = tf.expand_dims(masked_indices, [2]) |
| 616 | + zeros = masked_indices * backpointers |
| 617 | + assert tf.reduce_all(zeros == 0).numpy(), "Mask not applied correctly: {0}".format( |
| 618 | + zeros |
| 619 | + ) |
0 commit comments