Skip to content

Commit 9c67102

Browse files
Squadrickseanpmorgan
authored andcommitted
Make crf TF 2.0 compatible (#433)
* Replace tf.get_variable with tf.Variable * Add test case for above
1 parent 28acaf4 commit 9c67102

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

tensorflow_addons/text/crf.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,8 +188,9 @@ def crf_log_likelihood(inputs,
188188

189189
# Get the transition matrix if not provided.
190190
if transition_params is None:
191-
transition_params = tf.get_variable("transitions",
192-
[num_tags, num_tags])
191+
initializer = tf.initializers.GlorotUniform()
192+
transition_params = tf.Variable(
193+
initializer([num_tags, num_tags]), "transitions")
193194

194195
sequence_scores = crf_sequence_score(inputs, tag_indices, sequence_lengths,
195196
transition_params)

tensorflow_addons/text/crf_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,12 @@ def testCrfLogLikelihood(self):
229229
tf_total_log_likelihood = self.evaluate(total_log_likelihood)
230230
self.assertAllClose(tf_total_log_likelihood, 0.0)
231231

232+
# check if `transition_params = None` raises an error
233+
text.crf_log_likelihood(
234+
inputs=tf.expand_dims(inputs, 0),
235+
tag_indices=tf.expand_dims(tag_indices, 0),
236+
sequence_lengths=tf.expand_dims(sequence_lengths, 0))
237+
232238
def testViterbiDecode(self):
233239
inputs = np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]],
234240
dtype=np.float32)

0 commit comments

Comments
 (0)