Skip to content

Commit 4f88ac5

Browse files
seanpmorganWindQAQ
authored andcommitted
Remove tf function where unecessary (#823)
* Remove decorators wrapping simple custom op calls * Remove decorator from functions that are likely to be re-traced * Checkpoint * Revert * Modify test cases * Add unknown test * Lint * Lint
1 parent d541fbb commit 4f88ac5

22 files changed

+61
-95
lines changed

tensorflow_addons/activations/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
#### Standard API
2929
In order to conform with the current API standard, all activations
3030
must:
31-
* Be a `tf.function`.
31+
* Be a `tf.function` unless it is a straightforward call to a custom op or likely to be retraced.
3232
* Register as a keras global object so it can be serialized properly: `@tf.keras.utils.register_keras_serializable(package='Addons')`
3333
* Add the addon to the `py_library` in this sub-package's BUILD file.
3434

tensorflow_addons/activations/gelu.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626

2727
@tf.keras.utils.register_keras_serializable(package='Addons')
28-
@tf.function
2928
def gelu(x, approximate=True):
3029
"""Gaussian Error Linear Unit.
3130

tensorflow_addons/activations/gelu_test.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,6 @@ def test_theoretical_gradients(self, dtype):
5454
self.assertAllCloseAccordingToType(
5555
theoretical, numerical, atol=1e-4)
5656

57-
def test_unknown_shape(self):
58-
fn = gelu.get_concrete_function(
59-
tf.TensorSpec(shape=None, dtype=tf.float32))
60-
61-
for shape in [(1,), (1, 2), (1, 2, 3), (1, 2, 3, 4)]:
62-
x = tf.ones(shape=shape, dtype=tf.float32)
63-
self.assertAllClose(fn(x), gelu(x))
64-
6557

6658
if __name__ == "__main__":
6759
tf.test.main()

tensorflow_addons/activations/hardshrink.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626

2727
@tf.keras.utils.register_keras_serializable(package='Addons')
28-
@tf.function
2928
def hardshrink(x, lower=-0.5, upper=0.5):
3029
"""Hard shrink function.
3130

tensorflow_addons/activations/hardshrink_test.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,6 @@ def test_theoretical_gradients(self, dtype):
5858
theoretical, numerical = tf.test.compute_gradient(hardshrink, [x])
5959
self.assertAllCloseAccordingToType(theoretical, numerical, atol=1e-4)
6060

61-
def test_unknown_shape(self):
62-
fn = hardshrink.get_concrete_function(
63-
tf.TensorSpec(shape=None, dtype=tf.float32))
64-
65-
for shape in [(1,), (1, 2), (1, 2, 3), (1, 2, 3, 4)]:
66-
x = tf.ones(shape=shape, dtype=tf.float32)
67-
self.assertAllClose(fn(x), hardshrink(x))
68-
6961

7062
if __name__ == "__main__":
7163
tf.test.main()

tensorflow_addons/activations/lisht.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626

2727
@tf.keras.utils.register_keras_serializable(package='Addons')
28-
@tf.function
2928
def lisht(x):
3029
"""LiSHT: Non-Parameteric Linearly Scaled Hyperbolic Tangent Activation Function.
3130

tensorflow_addons/activations/lisht_test.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,6 @@ def test_theoretical_gradients(self, dtype):
4747
self.assertAllCloseAccordingToType(
4848
theoretical, numerical, rtol=5e-4, atol=5e-4)
4949

50-
def test_unknown_shape(self):
51-
fn = lisht.get_concrete_function(
52-
tf.TensorSpec(shape=None, dtype=tf.float32))
53-
54-
for shape in [(1,), (1, 2), (1, 2, 3), (1, 2, 3, 4)]:
55-
x = tf.ones(shape=shape, dtype=tf.float32)
56-
self.assertAllClose(fn(x), lisht(x))
57-
5850

5951
if __name__ == "__main__":
6052
tf.test.main()

tensorflow_addons/activations/mish.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626

2727
@tf.keras.utils.register_keras_serializable(package='Addons')
28-
@tf.function
2928
def mish(x):
3029
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function.
3130

tensorflow_addons/activations/mish_test.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,6 @@ def test_theoretical_gradients(self, dtype):
4646
theoretical, numerical = tf.test.compute_gradient(mish, [x])
4747
self.assertAllCloseAccordingToType(theoretical, numerical, atol=1e-4)
4848

49-
def test_unknown_shape(self):
50-
fn = mish.get_concrete_function(
51-
tf.TensorSpec(shape=None, dtype=tf.float32))
52-
53-
for shape in [(1,), (1, 2), (1, 2, 3), (1, 2, 3, 4)]:
54-
x = tf.ones(shape=shape, dtype=tf.float32)
55-
self.assertAllClose(fn(x), mish(x))
56-
5749

5850
if __name__ == "__main__":
5951
tf.test.main()

tensorflow_addons/activations/softshrink.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626

2727
@tf.keras.utils.register_keras_serializable(package='Addons')
28-
@tf.function
2928
def softshrink(x, lower=-0.5, upper=0.5):
3029
"""Soft shrink function.
3130

0 commit comments

Comments
 (0)