@@ -512,6 +512,10 @@ def call(self, potentials, transition_params, sequence_length):
512512 reason = "CRF Decode doesn't work in TF2.4, the issue was fixed in TF core, but didn't make the release" ,
513513)
514514def test_crf_decode_save_load (tmpdir ):
515+ class DummyLoss (tf .keras .losses .Loss ):
516+ def call (self , y_true , y_pred ):
517+ return tf .zeros (shape = ())
518+
515519 tf .keras .backend .clear_session ()
516520 input_tensor = tf .keras .Input (shape = (10 , 3 ), dtype = tf .float32 , name = "input_tensor" )
517521 seq_len = tf .keras .Input (shape = (), dtype = tf .int32 , name = "seq_len" )
@@ -523,7 +527,7 @@ def test_crf_decode_save_load(tmpdir):
523527 model = tf .keras .Model (
524528 inputs = [input_tensor , seq_len ], outputs = [output , decoded ], name = "example_model"
525529 )
526- model .compile (optimizer = "Adam" )
530+ model .compile (optimizer = "Adam" , loss = DummyLoss () )
527531
528532 x_data = {
529533 "input_tensor" : np .random .random_sample ((5 , 10 , 3 )).astype (dtype = np .float32 ),
@@ -551,7 +555,10 @@ def test_crf_decode_save_load(tmpdir):
551555 tf .keras .backend .clear_session ()
552556 model = tf .keras .models .load_model (
553557 temp_dir ,
554- custom_objects = {"CrfDecodeForwardRnnCell" : text .crf .CrfDecodeForwardRnnCell },
558+ custom_objects = {
559+ "CrfDecodeForwardRnnCell" : text .crf .CrfDecodeForwardRnnCell ,
560+ "DummyLoss" : DummyLoss ,
561+ },
555562 )
556563 model .fit (x_data , y_data )
557564 model .predict (
0 commit comments