Skip to content

Commit 8ae8116

Browse files
fsx950223seanpmorgan
authored andcommitted
Rrelu fix (#820)
* refactor rrelu
1 parent 7e06c2f commit 8ae8116

File tree

8 files changed

+54
-350
lines changed

8 files changed

+54
-350
lines changed

tensorflow_addons/activations/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ py_test(
123123
srcs = [
124124
"rrelu_test.py",
125125
],
126-
flaky = True,
126+
args = ["--benchmarks=all"],
127127
main = "rrelu_test.py",
128128
deps = [
129129
":activations",

tensorflow_addons/activations/rrelu.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,9 @@
1818
from __future__ import print_function
1919

2020
import tensorflow as tf
21-
from tensorflow_addons.utils.resource_loader import get_path_to_datafile
22-
23-
_activation_ops_so = tf.load_op_library(
24-
get_path_to_datafile("custom_ops/activations/_activation_ops.so"))
2521

2622

2723
@tf.keras.utils.register_keras_serializable(package='Addons')
28-
@tf.function
2924
def rrelu(x, lower=0.125, upper=0.3333333333333333, training=None, seed=None):
3025
"""rrelu function.
3126
@@ -51,14 +46,11 @@ def rrelu(x, lower=0.125, upper=0.3333333333333333, training=None, seed=None):
5146
if training is None:
5247
training = tf.keras.backend.learning_phase()
5348
training = bool(tf.keras.backend.get_value(training))
54-
# TODO: get rid of v1 API
55-
seed1, seed2 = tf.compat.v1.random.get_seed(seed)
56-
result, _ = _activation_ops_so.addons_rrelu(x, lower, upper, training,
57-
seed1, seed2)
58-
return result
5949

50+
if training:
51+
alpha = tf.random.uniform(
52+
tf.shape(x), minval=lower, maxval=upper, dtype=x.dtype, seed=seed)
53+
else:
54+
alpha = tf.cast((lower + upper) / 2, x.dtype)
6055

61-
@tf.RegisterGradient("Addons>Rrelu")
62-
def _rrelu_grad(op, *grad):
63-
return _activation_ops_so.addons_rrelu_grad(grad[0], op.inputs[0],
64-
op.outputs[1])
56+
return tf.where(x >= 0, x, alpha * x)

tensorflow_addons/activations/rrelu_test.py

Lines changed: 47 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -24,54 +24,68 @@
2424
from tensorflow_addons.activations import rrelu
2525
from tensorflow_addons.utils import test_utils
2626

27-
28-
def _ref_rrelu(x, lower, upper):
29-
return tf.where(x >= 0, x, (lower + upper) * x / 2)
27+
SEED = 111111
3028

3129

3230
@test_utils.run_all_in_graph_and_eager_modes
3331
class RreluTest(tf.test.TestCase, parameterized.TestCase):
3432
@parameterized.named_parameters(("float16", np.float16),
3533
("float32", np.float32),
3634
("float64", np.float64))
37-
@tf.function
3835
def test_rrelu(self, dtype):
3936
x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype)
4037
lower = 0.1
4138
upper = 0.2
42-
result = rrelu(x, lower, upper, training=False)
43-
expect_result = _ref_rrelu(x, lower, upper)
44-
self.assertAllCloseAccordingToType(result, expect_result)
39+
40+
training_results = {
41+
np.float16: [-0.288330078, -0.124206543, 0, 1, 2],
42+
np.float32: [-0.26851666, -0.116421416, 0, 1, 2],
43+
np.float64: [-0.3481333923206531, -0.17150176242558851, 0, 1, 2],
44+
}
45+
for training in [True, False]:
46+
with self.subTest(training=training):
47+
tf.random.set_seed(SEED)
48+
result = rrelu(x, lower, upper, training=training, seed=SEED)
49+
if training:
50+
expect_result = training_results.get(dtype)
51+
else:
52+
expect_result = [
53+
-0.30000001192092896, -0.15000000596046448, 0, 1, 2
54+
]
55+
self.assertAllCloseAccordingToType(result, expect_result)
4556

4657
@parameterized.named_parameters(("float32", np.float32),
4758
("float64", np.float64))
4859
def test_theoretical_gradients(self, dtype):
49-
x = tf.constant([-2.0, -1.0, -0.1, 0.1, 1.0, 2.0], dtype=dtype)
50-
lower = 0.1
51-
upper = 0.2
52-
for training in [True, False]:
53-
with self.subTest(training=training):
54-
theoretical, numerical = tf.test.compute_gradient(
55-
lambda x: rrelu(
56-
x, lower, upper, training=training, seed=111111), [x])
57-
# TODO: investigate the difference between CPU and GPU
58-
if training is True and tf.test.is_gpu_available() is False:
59-
numerical = [[[0.134971, 0., 0., 0., 0., 0.],
60-
[0., 0.15648358, 0., 0., 0., 0.],
61-
[0., 0., 0.18776372, 0., 0., 0.],
62-
[0., 0., 0., 1., 0., 0.],
63-
[0., 0., 0., 0., 1., 0.],
64-
[0., 0., 0., 0., 0., 1.]]]
65-
self.assertAllCloseAccordingToType(
66-
theoretical, numerical, rtol=5e-4, atol=5e-4)
67-
68-
def test_unknown_shape(self):
69-
fn = rrelu.get_concrete_function(
70-
tf.TensorSpec(shape=None, dtype=tf.float32))
71-
72-
for shape in [(1,), (1, 2), (1, 2, 3), (1, 2, 3, 4)]:
73-
x = tf.ones(shape=shape, dtype=tf.float32)
74-
self.assertAllClose(fn(x), rrelu(x))
60+
if tf.executing_eagerly():
61+
62+
def rrelu_wrapper(lower, upper, training):
63+
def inner(x):
64+
tf.random.set_seed(SEED)
65+
return rrelu(x, lower, upper, training=training, seed=SEED)
66+
67+
return inner
68+
69+
x = tf.constant([-2.0, -1.0, -0.1, 0.1, 1.0, 2.0], dtype=dtype)
70+
lower = 0.1
71+
upper = 0.2
72+
73+
for training in [True, False]:
74+
with self.subTest(training=training):
75+
theoretical, numerical = tf.test.compute_gradient(
76+
rrelu_wrapper(lower, upper, training), [x])
77+
self.assertAllCloseAccordingToType(
78+
theoretical, numerical, rtol=5e-4, atol=5e-4)
79+
80+
81+
class RreluBenchmarks(tf.test.Benchmark):
82+
def benchmarkRreluOp(self):
83+
with tf.compat.v1.Session(config=tf.test.benchmark_config()) as sess:
84+
x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=np.float32)
85+
lower = 0.1
86+
upper = 0.2
87+
result = rrelu(x, lower, upper, training=True)
88+
self.run_op_benchmark(sess, result.op, min_iters=25)
7589

7690

7791
if __name__ == "__main__":

tensorflow_addons/custom_ops/activations/BUILD

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@ custom_op_library(
1515
"cc/kernels/lisht_op.h",
1616
"cc/kernels/mish_op.cc",
1717
"cc/kernels/mish_op.h",
18-
"cc/kernels/rrelu_op.cc",
19-
"cc/kernels/rrelu_op.h",
2018
"cc/kernels/softshrink_op.cc",
2119
"cc/kernels/softshrink_op.h",
2220
"cc/kernels/tanhshrink_op.cc",
@@ -25,7 +23,6 @@ custom_op_library(
2523
"cc/ops/hardshrink_op.cc",
2624
"cc/ops/lisht_op.cc",
2725
"cc/ops/mish_op.cc",
28-
"cc/ops/rrelu_op.cc",
2926
"cc/ops/softshrink_op.cc",
3027
"cc/ops/tanhshrink_op.cc",
3128
],
@@ -38,8 +35,6 @@ custom_op_library(
3835
"cc/kernels/lisht_op_gpu.cu.cc",
3936
"cc/kernels/mish_op.h",
4037
"cc/kernels/mish_op_gpu.cu.cc",
41-
"cc/kernels/rrelu_op.h",
42-
"cc/kernels/rrelu_op_gpu.cu.cc",
4338
"cc/kernels/softshrink_op.h",
4439
"cc/kernels/softshrink_op_gpu.cu.cc",
4540
"cc/kernels/tanhshrink_op.h",

tensorflow_addons/custom_ops/activations/cc/kernels/rrelu_op.cc

Lines changed: 0 additions & 78 deletions
This file was deleted.

tensorflow_addons/custom_ops/activations/cc/kernels/rrelu_op.h

Lines changed: 0 additions & 138 deletions
This file was deleted.

0 commit comments

Comments
 (0)