Skip to content

Commit a4379ea

Browse files
WindQAQfacaiy
authored andcommitted
add lisht kernel (#529)
* add lisht kernel * update README * format code * fix tolerance * reorder the computation * unify namespace * clean up testcase * format code * fix typo * fix namespace comment * remove extra the * change test size to small
1 parent d227cbd commit a4379ea

File tree

10 files changed

+433
-9
lines changed

10 files changed

+433
-9
lines changed

tensorflow_addons/activations/BUILD

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ py_library(
88
"__init__.py",
99
"gelu.py",
1010
"hardshrink.py",
11+
"lisht.py",
1112
"sparsemax.py",
1213
"tanhshrink.py",
1314
],
@@ -20,7 +21,7 @@ py_library(
2021

2122
py_test(
2223
name = "sparsemax_test",
23-
size = "medium",
24+
size = "small",
2425
srcs = [
2526
"sparsemax_test.py",
2627
],
@@ -33,7 +34,7 @@ py_test(
3334

3435
py_test(
3536
name = "gelu_test",
36-
size = "medium",
37+
size = "small",
3738
srcs = [
3839
"gelu_test.py",
3940
],
@@ -46,7 +47,7 @@ py_test(
4647

4748
py_test(
4849
name = "hardshrink_test",
49-
size = "medium",
50+
size = "small",
5051
srcs = [
5152
"hardshrink_test.py",
5253
],
@@ -57,9 +58,22 @@ py_test(
5758
],
5859
)
5960

61+
py_test(
62+
name = "lisht_test",
63+
size = "small",
64+
srcs = [
65+
"lisht_test.py",
66+
],
67+
main = "lisht_test.py",
68+
srcs_version = "PY2AND3",
69+
deps = [
70+
":activations",
71+
],
72+
)
73+
6074
py_test(
6175
name = "tanhshrink_test",
62-
size = "medium",
76+
size = "small",
6377
srcs = [
6478
"tanhshrink_test.py",
6579
],

tensorflow_addons/activations/README.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,26 @@
44
| Submodule | Maintainers | Contact Info |
55
|:----------|:--------------------------|:-----------------------------------------|
66
| gelu | @AakashKumarNain @WindQAQ | aakashnain@outlook.com windqaq@gmail.com |
7-
| hardshrink| @WindQAQ | windqaq@gmail.com
7+
| hardshrink| @WindQAQ | windqaq@gmail.com |
8+
| lisht | @WindQAQ | windqaq@gmail.com |
89
| sparsemax | @AndreasMadsen | amwwebdk+github@gmail.com |
9-
| tanhshrink | @fsx950223 | fsx950223@gmail.com |
10+
| tanhshrink| @fsx950223 | fsx950223@gmail.com |
1011

1112
## Contents
1213
| Submodule | Activation | Reference |
1314
|:----------|:-----------|:---------------------------------|
1415
| gelu | gelu | https://arxiv.org/abs/1606.08415 |
1516
| hardshrink| hardshrink | |
16-
| sparsemax | Sparsemax | https://arxiv.org/abs/1602.02068 |
17-
| tanhshrink | Tanhshrink | |
17+
| lisht | lisht | https://arxiv.org/abs/1901.05894 |
18+
| sparsemax | sparsemax | https://arxiv.org/abs/1602.02068 |
19+
| tanhshrink| tanhshrink | |
1820

1921

2022
## Contribution Guidelines
2123
#### Standard API
2224
In order to conform with the current API standard, all activations
2325
must:
2426
* Be a `tf.function`.
25-
* Have the signature `fn(input, axis=-1, name=None)`.
2627
* [Register as a keras global object](https://github.com/tensorflow/addons/blob/master/tensorflow_addons/utils/python/keras_utils.py)
2728
so it can be serialized properly.
2829
* Add the addon to the `py_library` in this sub-package's BUILD file.

tensorflow_addons/activations/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,6 @@
2020

2121
from tensorflow_addons.activations.gelu import gelu
2222
from tensorflow_addons.activations.hardshrink import hardshrink
23+
from tensorflow_addons.activations.lisht import lisht
2324
from tensorflow_addons.activations.sparsemax import sparsemax
2425
from tensorflow_addons.activations.tanhshrink import tanhshrink
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import tensorflow as tf
21+
from tensorflow_addons.utils import keras_utils
22+
from tensorflow_addons.utils.resource_loader import get_path_to_datafile
23+
24+
_activation_ops_so = tf.load_op_library(
25+
get_path_to_datafile("custom_ops/activations/_activation_ops.so"))
26+
27+
28+
@keras_utils.register_keras_custom_object
29+
@tf.function
30+
def lisht(x):
31+
"""LiSHT: Non-Parameteric Linearly Scaled Hyperbolic Tangent Activation Function.
32+
33+
Computes linearly scaled hyperbolic tangent (LiSHT): `x * tanh(x)`
34+
35+
See [LiSHT: Non-Parameteric Linearly Scaled Hyperbolic Tangent Activation Function for Neural Networks](https://arxiv.org/abs/1901.05894).
36+
37+
Args:
38+
x: A `Tensor`. Must be one of the following types:
39+
`float16`, `float32`, `float64`.
40+
Returns:
41+
A `Tensor`. Has the same type as `x`.
42+
"""
43+
x = tf.convert_to_tensor(x)
44+
return _activation_ops_so.addons_lisht(x)
45+
46+
47+
@tf.RegisterGradient("Addons>Lisht")
48+
def _lisht_grad(op, grad):
49+
return _activation_ops_so.addons_lisht_grad(grad, op.inputs[0])
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
from absl.testing import parameterized
21+
22+
import numpy as np
23+
import tensorflow as tf
24+
from tensorflow_addons.activations import lisht
25+
from tensorflow_addons.utils import test_utils
26+
27+
28+
@test_utils.run_all_in_graph_and_eager_modes
29+
class LishtTest(tf.test.TestCase, parameterized.TestCase):
30+
@parameterized.named_parameters(("float16", np.float16),
31+
("float32", np.float32),
32+
("float64", np.float64))
33+
def test_lisht(self, dtype):
34+
x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype)
35+
expected_result = tf.constant(
36+
[1.9280552, 0.7615942, 0.0, 0.7615942, 1.9280552], dtype=dtype)
37+
self.assertAllCloseAccordingToType(lisht(x), expected_result)
38+
39+
@parameterized.named_parameters(("float32", np.float32),
40+
("float64", np.float64))
41+
def test_theoretical_gradients(self, dtype):
42+
# Only test theoretical gradients for float32 and float64
43+
# because of the instability of float16 while computing jacobian
44+
x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype)
45+
46+
theoretical, numerical = tf.test.compute_gradient(lisht, [x])
47+
self.assertAllCloseAccordingToType(
48+
theoretical, numerical, rtol=5e-4, atol=5e-4)
49+
50+
def test_unknown_shape(self):
51+
fn = lisht.get_concrete_function(
52+
tf.TensorSpec(shape=None, dtype=tf.float32))
53+
54+
for shape in [(1,), (1, 2), (1, 2, 3), (1, 2, 3, 4)]:
55+
x = tf.ones(shape=shape, dtype=tf.float32)
56+
self.assertAllClose(fn(x), lisht(x))
57+
58+
def test_serialization(self):
59+
config = tf.keras.activations.serialize(lisht)
60+
fn = tf.keras.activations.deserialize(config)
61+
self.assertEqual(fn, lisht)
62+
63+
def test_serialization_with_layers(self):
64+
layer = tf.keras.layers.Dense(3, activation=lisht)
65+
config = tf.keras.layers.serialize(layer)
66+
deserialized_layer = tf.keras.layers.deserialize(config)
67+
self.assertEqual(deserialized_layer.__class__.__name__,
68+
layer.__class__.__name__)
69+
self.assertEqual(deserialized_layer.activation.__name__, "lisht")
70+
71+
72+
if __name__ == "__main__":
73+
tf.test.main()

tensorflow_addons/custom_ops/activations/BUILD

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,28 @@ cc_library(
4949
alwayslink = 1,
5050
)
5151

52+
cc_library(
53+
name = "lisht_op_gpu",
54+
srcs = [
55+
"cc/kernels/lisht_op.h",
56+
"cc/kernels/lisht_op_gpu.cu.cc",
57+
],
58+
copts = if_cuda_is_configured([
59+
"-DGOOGLE_CUDA=1",
60+
"-x cuda",
61+
"-nvcc_options=relaxed-constexpr",
62+
"-nvcc_options=ftz=true",
63+
]),
64+
deps = [
65+
"@local_config_tf//:libtensorflow_framework",
66+
"@local_config_tf//:tf_header_lib",
67+
] + if_cuda_is_configured([
68+
"@local_config_cuda//cuda:cuda_libs",
69+
"@local_config_cuda//cuda:cuda_headers",
70+
]),
71+
alwayslink = 1,
72+
)
73+
5274
cc_library(
5375
name = "tanhshrink_op_gpu",
5476
srcs = [
@@ -78,10 +100,13 @@ cc_binary(
78100
"cc/kernels/gelu_op.h",
79101
"cc/kernels/hardshrink_op.cc",
80102
"cc/kernels/hardshrink_op.h",
103+
"cc/kernels/lisht_op.cc",
104+
"cc/kernels/lisht_op.h",
81105
"cc/kernels/tanhshrink_op.cc",
82106
"cc/kernels/tanhshrink_op.h",
83107
"cc/ops/gelu_op.cc",
84108
"cc/ops/hardshrink_op.cc",
109+
"cc/ops/lisht_op.cc",
85110
"cc/ops/tanhshrink_op.cc",
86111
],
87112
copts = [
@@ -96,6 +121,7 @@ cc_binary(
96121
] + if_cuda_is_configured([
97122
":gelu_op_gpu",
98123
":hardshrink_op_gpu",
124+
":lisht_op_gpu",
99125
":tanhshrink_op_gpu",
100126
]),
101127
)
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#define EIGEN_USE_THREADS
17+
18+
#include "tensorflow_addons/custom_ops/activations/cc/kernels/lisht_op.h"
19+
#include "tensorflow/core/framework/op_kernel.h"
20+
#include "tensorflow/core/framework/register_types.h"
21+
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22+
23+
namespace tensorflow {
24+
namespace addons {
25+
26+
using CPUDevice = Eigen::ThreadPoolDevice;
27+
28+
#define REGISTER_LISHT_KERNELS(type) \
29+
REGISTER_KERNEL_BUILDER( \
30+
Name("Addons>Lisht").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
31+
LishtOp<CPUDevice, type>); \
32+
REGISTER_KERNEL_BUILDER( \
33+
Name("Addons>LishtGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
34+
LishtGradOp<CPUDevice, type>);
35+
36+
// Lisht only makes sense with floating points.
37+
TF_CALL_GPU_NUMBER_TYPES(REGISTER_LISHT_KERNELS);
38+
#undef REGISTER_LISHT_KERNELS
39+
40+
#if GOOGLE_CUDA
41+
42+
using GPUDevice = Eigen::GpuDevice;
43+
44+
// Forward declarations of the functor specializations for GPU.
45+
namespace functor {
46+
#define DECLARE_GPU_SPEC(T) \
47+
template <> \
48+
void Lisht<GPUDevice, T>::operator()( \
49+
const GPUDevice& d, typename TTypes<T>::ConstTensor features, \
50+
typename TTypes<T>::Tensor activations); \
51+
extern template struct Lisht<GPUDevice, T>; \
52+
\
53+
template <> \
54+
void LishtGrad<GPUDevice, T>::operator()( \
55+
const GPUDevice& d, typename TTypes<T>::ConstTensor gradients, \
56+
typename TTypes<T>::ConstTensor features, \
57+
typename TTypes<T>::Tensor backprops); \
58+
extern template struct LishtGrad<GPUDevice, T>;
59+
60+
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
61+
#undef DECLARE_GPU_SPEC
62+
} // namespace functor
63+
64+
// Registration of the GPU implementations.
65+
#define REGISTER_LISHT_GPU_KERNELS(type) \
66+
REGISTER_KERNEL_BUILDER( \
67+
Name("Addons>Lisht").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
68+
LishtOp<GPUDevice, type>); \
69+
REGISTER_KERNEL_BUILDER( \
70+
Name("Addons>LishtGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
71+
LishtGradOp<GPUDevice, type>);
72+
73+
TF_CALL_GPU_NUMBER_TYPES(REGISTER_LISHT_GPU_KERNELS);
74+
#undef REGISTER_LISHT_GPU_KERNELS
75+
76+
#endif // GOOGLE_CUDA
77+
78+
} // namespace addons
79+
} // namespace tensorflow

0 commit comments

Comments
 (0)