Skip to content

Commit b8cab7f

Browse files
authored
Fix kwargs in crf_decode_forward (#2642)
* Address issue #2639
1 parent 3f2f7a1 commit b8cab7f

File tree

2 files changed

+63
-1
lines changed

2 files changed

+63
-1
lines changed

tensorflow_addons/text/crf.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,11 @@ def crf_decode_forward(
490490
mask = tf.sequence_mask(sequence_lengths, tf.shape(inputs)[1])
491491
crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params, dtype=inputs.dtype)
492492
crf_fwd_layer = tf.keras.layers.RNN(
493-
crf_fwd_cell, return_sequences=True, return_state=True, dtype=inputs.dtype
493+
crf_fwd_cell,
494+
return_sequences=True,
495+
return_state=True,
496+
dtype=inputs.dtype,
497+
zero_output_for_mask=True, # See: https://github.com/tensorflow/addons/issues/2639
494498
)
495499
return crf_fwd_layer(inputs, state, mask=mask)
496500

tensorflow_addons/text/tests/crf_test.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from tensorflow_addons import text
2525
from tensorflow_addons.utils import test_utils
26+
from numpy.testing import assert_array_equal
2627

2728

2829
def calculate_sequence_score(inputs, transition_params, tag_indices, sequence_lengths):
@@ -559,3 +560,60 @@ def test_crf_decode_save_load(tmpdir):
559560
"seq_len": np.array([10]),
560561
}
561562
)
563+
564+
565+
@pytest.mark.parametrize(
566+
"potentials,sequence_length",
567+
[
568+
# performs masking
569+
pytest.param(
570+
tf.random.normal([2, 12, 3]),
571+
tf.constant([8, 10]),
572+
),
573+
# does not perform masking
574+
pytest.param(
575+
tf.random.normal([4, 8, 10]),
576+
tf.constant([8, 8, 8, 8]),
577+
),
578+
],
579+
)
580+
def test_crf_decode_forward_mask(potentials, sequence_length):
581+
# mimics setup of the `_multi_seq_fn` closure in `crf_decode`
582+
initial_state = tf.slice(potentials, [0, 0, 0], [-1, 1, -1])
583+
initial_state = tf.squeeze(initial_state, axis=[1])
584+
inputs = tf.slice(potentials, [0, 1, 0], [-1, -1, -1])
585+
586+
sequence_length_less_one = tf.maximum(
587+
tf.constant(0, dtype=tf.int32), sequence_length - 1
588+
)
589+
590+
n_tags = potentials.shape[-1]
591+
transition_params = tf.random.normal([n_tags, n_tags])
592+
593+
backpointers, _ = text.crf_decode_forward(
594+
inputs, initial_state, transition_params, sequence_length_less_one
595+
)
596+
597+
# everything masked by `sequence_length_less_one` should be equal to 0.
598+
mask = tf.sequence_mask(sequence_length_less_one, tf.shape(inputs)[1])
599+
600+
# the indices that _should_ have been masked in the RNN operation
601+
masked_indices = tf.cast(tf.logical_not(mask), tf.int32)
602+
603+
# sum of each row in the mask should equal timedim - seq lens
604+
exp_mask_sums = (
605+
tf.repeat(inputs.shape[1], inputs.shape[0]) - sequence_length_less_one
606+
)
607+
mask_sums = tf.reduce_sum(masked_indices, axis=1)
608+
assert_array_equal(
609+
exp_mask_sums.numpy(),
610+
mask_sums.numpy(),
611+
)
612+
613+
# now apply the inverse mask to the backpointers and show that ALL are zeros. this is proof that
614+
# we appropriately masked timesteps
615+
masked_indices = tf.expand_dims(masked_indices, [2])
616+
zeros = masked_indices * backpointers
617+
assert tf.reduce_all(zeros == 0).numpy(), "Mask not applied correctly: {0}".format(
618+
zeros
619+
)

0 commit comments

Comments
 (0)