@@ -286,3 +286,38 @@ def test_serialization():
286286
287287 new_optimizer = tf .keras .optimizers .deserialize (config )
288288 assert new_optimizer .get_config () == optimizer .get_config ()
289+
290+
291+ def test_serialization_after_training (tmpdir ):
292+ x = np .array (np .ones ([100 ]))
293+ y = np .array (np .ones ([100 ]))
294+ model = tf .keras .Sequential (
295+ [tf .keras .Input (shape = [1 ]), tf .keras .layers .Dense (1 ), tf .keras .layers .Dense (1 )]
296+ )
297+
298+ opt1 = tf .keras .optimizers .Adam (learning_rate = 1e-3 )
299+ opt2 = tf .keras .optimizers .SGD (learning_rate = 0 )
300+
301+ opt_layer_pairs = [(opt1 , model .layers [0 ]), (opt2 , model .layers [1 ])]
302+
303+ optimizer = MultiOptimizer (opt_layer_pairs )
304+
305+ # Train the model for a few epochs.
306+ model .compile (loss = "categorical_crossentropy" , optimizer = optimizer )
307+ model .fit (x , y )
308+
309+ # Verify the optimizer can still be serialized (saved).
310+ model .save (str (tmpdir ))
311+ loaded_model = tf .keras .models .load_model (str (tmpdir ))
312+ old_config = model .optimizer .get_config ()
313+ new_config = loaded_model .optimizer .get_config ()
314+ # Verify the loaded model has the same optimizer as before.
315+ assert len (old_config ["optimizer_specs" ]) == len (new_config ["optimizer_specs" ])
316+ for old_optimizer_spec , new_optimizer_spec in zip (
317+ old_config ["optimizer_specs" ], new_config ["optimizer_specs" ]
318+ ):
319+ assert old_optimizer_spec ["weights" ] == new_optimizer_spec ["weights" ]
320+ assert (
321+ old_optimizer_spec ["optimizer" ].get_config ()
322+ == new_optimizer_spec ["optimizer" ].get_config ()
323+ )
0 commit comments