Skip to content

Commit 6d63531

Browse files
authored
DOC: LazyAdamOptimizer example (#66)
* DOC: optimizer example
1 parent c295523 commit 6d63531

File tree

2 files changed

+92
-30
lines changed

2 files changed

+92
-30
lines changed

tensorflow_addons/examples/demo.py

Lines changed: 0 additions & 30 deletions
This file was deleted.
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""MNIST example utilizing an optimizer from TensorFlow Addons."""
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import tensorflow as tf
21+
import tensorflow_addons as tfa
22+
23+
VALIDATION_SAMPLES = 10000
24+
25+
26+
def build_mnist_model():
27+
"""Build a simple dense network for processing MNIST data.
28+
29+
:return: Keras `Model`
30+
"""
31+
inputs = tf.keras.Input(shape=(784,), name='digits')
32+
net = tf.keras.layers.Dense(64, activation='relu', name='dense_1')(inputs)
33+
net = tf.keras.layers.Dense(64, activation='relu', name='dense_2')(net)
34+
net = tf.keras.layers.Dense(
35+
10, activation='softmax', name='predictions')(net)
36+
37+
return tf.keras.Model(inputs=inputs, outputs=net)
38+
39+
40+
def generate_data(num_validation):
41+
"""Download and preprocess the MNIST dataset.
42+
43+
:num_validaton: Number of samples to use in validation set
44+
:return: Dictionary of data split into train/test/val
45+
"""
46+
dataset = {}
47+
48+
# Load MNIST dataset as NumPy arrays
49+
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
50+
51+
# Preprocess the data
52+
x_train = x_train.reshape(-1, 784).astype('float32') / 255
53+
x_test = x_test.reshape(-1, 784).astype('float32') / 255
54+
55+
# Subset validation set
56+
dataset['x_train'] = x_train[:-num_validation]
57+
dataset['y_train'] = y_train[:-num_validation]
58+
dataset['x_val'] = x_train[-num_validation:]
59+
dataset['y_val'] = y_train[-num_validation:]
60+
61+
dataset['x_test'] = x_test
62+
dataset['y_test'] = y_test
63+
64+
return dataset
65+
66+
67+
def train_and_eval():
68+
"""Train and evalute simple MNIST model using LazyAdamOptimizer."""
69+
data = generate_data(num_validation=VALIDATION_SAMPLES)
70+
dense_net = build_mnist_model()
71+
dense_net.compile(
72+
optimizer=tfa.optimizers.LazyAdamOptimizer(0.001),
73+
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
74+
metrics=['accuracy'])
75+
76+
# Train the network
77+
history = dense_net.fit(
78+
data['x_train'],
79+
data['y_train'],
80+
batch_size=64,
81+
epochs=10,
82+
validation_data=(data['x_val'], data['y_val']))
83+
84+
# Evaluate the network
85+
print('Evaluate on test data:')
86+
results = dense_net.evaluate(
87+
data['x_test'], data['y_test'], batch_size=128)
88+
print('Test loss, Test acc:', results)
89+
90+
91+
if __name__ == "__main__":
92+
train_and_eval()

0 commit comments

Comments
 (0)