Skip to content

Commit c10c8bb

Browse files
authored
Add Keras model test with TF2 model (#315)
* Add MNIST test with tf2 model * Clean up model generating script for TF2 * Update README with information on Keras and TF2 * Improve TF 2 support description
1 parent 2bd2eed commit c10c8bb

File tree

4 files changed

+104
-0
lines changed

4 files changed

+104
-0
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ RedisAI currently supports PyTorch (libtorch), Tensorflow (libtensorflow), Tenso
130130
| 0.9.0 | 1.3.1 | 1.14.0 | 2.0.0 | 1.0.0 |
131131
| master | 1.4.0 | 1.15.0 | 2.0.0 | 1.0.0 |
132132

133+
Note: Keras and TensorFlow 2.x are supported through graph freezing. See [this script](https://github.com/RedisAI/RedisAI/blob/master/test/test_data/tf2-minimal.py) to see how to export a frozen graph from Keras and TensorFlow 2.x. Note that a frozen graph will be executed using the TensorFlow 1.15 backend. Should any 2.0 ops be not supported on the 1.15 after freezing, please open an Issue.
134+
133135
## Documentation
134136

135137
Read the docs at [redisai.io](http://redisai.io). Checkout our [showcase repo](https://github.com/RedisAI/redisai-examples) for a lot of examples written using different client libraries.

test/test_data/graph_v2.pb

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:9970a47f2818a1137c1e39e96a8c43938abced0eb44f78f5beabe0add9449539
3+
size 408865

test/test_data/tf2-minimal.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import tensorflow as tf
2+
from tensorflow import keras
3+
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
4+
import numpy as np
5+
6+
# From https://github.com/leimao/Frozen_Graph_TensorFlow
7+
8+
tf.random.set_seed(seed=0)
9+
10+
model = keras.Sequential(layers=[
11+
keras.layers.InputLayer(input_shape=(28, 28), name="input"),
12+
keras.layers.Flatten(input_shape=(28, 28), name="flatten"),
13+
keras.layers.Dense(128, activation="relu", name="dense"),
14+
keras.layers.Dense(10, activation="softmax", name="output")
15+
], name="FCN")
16+
17+
model.compile(optimizer="adam",
18+
loss="sparse_categorical_crossentropy",
19+
metrics=["accuracy"])
20+
21+
full_model = tf.function(lambda x: model(x))
22+
full_model = full_model.get_concrete_function(
23+
tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))
24+
25+
frozen_func = convert_variables_to_constants_v2(full_model)
26+
frozen_func.graph.as_graph_def()
27+
28+
tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
29+
logdir=".",
30+
name="graph_v2.pb",
31+
as_text=False)
32+

test/tests_tensorflow.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,73 @@ def test_run_tf_model(env):
242242
env.assertFalse(con2.execute_command('EXISTS', 'm'))
243243

244244

245+
@skip_if_no_TF
246+
def test_run_tf2_model(env):
247+
con = env.getConnection()
248+
249+
test_data_path = os.path.join(os.path.dirname(__file__), 'test_data')
250+
model_filename = os.path.join(test_data_path, 'graph_v2.pb')
251+
252+
with open(model_filename, 'rb') as f:
253+
model_pb = f.read()
254+
255+
ret = con.execute_command('AI.MODELSET', 'm', 'TF', DEVICE,
256+
'INPUTS', 'x', 'OUTPUTS', 'Identity', model_pb)
257+
env.assertEqual(ret, b'OK')
258+
259+
ensureSlaveSynced(con, env)
260+
261+
ret = con.execute_command('AI.MODELGET', 'm')
262+
env.assertEqual(len(ret), 6)
263+
env.assertEqual(ret[-1], b'')
264+
265+
ret = con.execute_command('AI.MODELSET', 'm', 'TF', DEVICE, 'TAG', 'asdf',
266+
'INPUTS', 'x', 'OUTPUTS', 'Identity', model_pb)
267+
env.assertEqual(ret, b'OK')
268+
269+
ensureSlaveSynced(con, env)
270+
271+
ret = con.execute_command('AI.MODELGET', 'm')
272+
env.assertEqual(len(ret), 6)
273+
env.assertEqual(ret[-1], b'asdf')
274+
275+
zero_values = [0] * (28 * 28)
276+
277+
con.execute_command('AI.TENSORSET', 'x', 'FLOAT',
278+
1, 1, 28, 28, 'VALUES', *zero_values)
279+
280+
ensureSlaveSynced(con, env)
281+
282+
con.execute_command('AI.MODELRUN', 'm', 'INPUTS', 'x', 'OUTPUTS', 'y')
283+
284+
ensureSlaveSynced(con, env)
285+
286+
tensor = con.execute_command('AI.TENSORGET', 'y', 'VALUES')
287+
values = tensor[-1]
288+
for value in values:
289+
env.assertAlmostEqual(float(value), 0.1, 1E-4)
290+
291+
if env.useSlaves:
292+
con2 = env.getSlaveConnection()
293+
tensor2 = con2.execute_command('AI.TENSORGET', 'y', 'VALUES')
294+
env.assertEqual(tensor2, tensor)
295+
296+
for _ in env.reloadingIterator():
297+
env.assertExists('m')
298+
env.assertExists('x')
299+
env.assertExists('y')
300+
301+
con.execute_command('AI.MODELDEL', 'm')
302+
ensureSlaveSynced(con, env)
303+
304+
env.assertFalse(env.execute_command('EXISTS', 'm'))
305+
306+
ensureSlaveSynced(con, env)
307+
if env.useSlaves:
308+
con2 = env.getSlaveConnection()
309+
env.assertFalse(con2.execute_command('EXISTS', 'm'))
310+
311+
245312
@skip_if_no_TF
246313
def test_run_tf_model_errors(env):
247314
con = env.getConnection()

0 commit comments

Comments
 (0)