Skip to content

Commit 33baec0

Browse files
fsx950223WindQAQ
authored andcommitted
Add tanhshrink (#493)
* Add tanhshrink * Update readme * Add name * Add gradient * Fix build bug * Migrate api style * Fix docs * Remove name argument * Remove name scope
1 parent 8064035 commit 33baec0

File tree

10 files changed

+401
-0
lines changed

10 files changed

+401
-0
lines changed

tensorflow_addons/activations/BUILD

100644100755
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ py_library(
99
"gelu.py",
1010
"hardshrink.py",
1111
"sparsemax.py",
12+
"tanhshrink.py",
1213
],
1314
data = [
1415
"//tensorflow_addons/custom_ops/activations:_activation_ops.so",
@@ -55,3 +56,16 @@ py_test(
5556
":activations",
5657
],
5758
)
59+
60+
py_test(
61+
name = "tanhshrink_test",
62+
size = "medium",
63+
srcs = [
64+
"tanhshrink_test.py",
65+
],
66+
main = "tanhshrink_test.py",
67+
srcs_version = "PY2AND3",
68+
deps = [
69+
":activations",
70+
],
71+
)

tensorflow_addons/activations/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66
| gelu | @AakashKumarNain @WindQAQ | aakashnain@outlook.com windqaq@gmail.com |
77
| hardshrink| @WindQAQ | windqaq@gmail.com
88
| sparsemax | @AndreasMadsen | amwwebdk+github@gmail.com |
9+
| tanhshrink | @fsx950223 | fsx950223@gmail.com |
910

1011
## Contents
1112
| Submodule | Activation | Reference |
1213
|:----------|:-----------|:---------------------------------|
1314
| gelu | gelu | https://arxiv.org/abs/1606.08415 |
1415
| hardshrink| hardshrink | |
1516
| sparsemax | Sparsemax | https://arxiv.org/abs/1602.02068 |
17+
| tanhshrink | Tanhshrink | |
1618

1719

1820
## Contribution Guidelines

tensorflow_addons/activations/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,4 @@
2121
from tensorflow_addons.activations.gelu import gelu
2222
from tensorflow_addons.activations.hardshrink import hardshrink
2323
from tensorflow_addons.activations.sparsemax import sparsemax
24+
from tensorflow_addons.activations.tanhshrink import tanhshrink
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import tensorflow as tf
21+
from tensorflow_addons.utils import keras_utils
22+
from tensorflow_addons.utils.resource_loader import get_path_to_datafile
23+
24+
_activation_ops_so = tf.load_op_library(
25+
get_path_to_datafile("custom_ops/activations/_activation_ops.so"))
26+
27+
28+
@keras_utils.register_keras_custom_object
29+
@tf.function
30+
def tanhshrink(x):
31+
"""Applies the element-wise function: x - tanh(x)
32+
33+
Args:
34+
features: A `Tensor`. Must be one of the following types:
35+
`float16`, `float32`, `float64`.
36+
Returns:
37+
A `Tensor`. Has the same type as `features`.
38+
"""
39+
x = tf.convert_to_tensor(x)
40+
return _activation_ops_so.addons_tanhshrink(x)
41+
42+
43+
@tf.RegisterGradient("Addons>Tanhshrink")
44+
def _tanhshrink_grad(op, grad):
45+
return _activation_ops_so.addons_tanhshrink_grad(grad, op.inputs[0])
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
from absl.testing import parameterized
21+
22+
import numpy as np
23+
import tensorflow as tf
24+
from tensorflow_addons.activations import tanhshrink
25+
from tensorflow_addons.utils import test_utils
26+
27+
28+
def _ref_tanhshrink(x):
29+
return x - tf.tanh(x)
30+
31+
32+
@test_utils.run_all_in_graph_and_eager_modes
33+
class TanhshrinkTest(tf.test.TestCase, parameterized.TestCase):
34+
@parameterized.named_parameters(("float16", np.float16),
35+
("float32", np.float32),
36+
("float64", np.float64))
37+
def test_tanhshrink(self, dtype):
38+
x = tf.constant([1.0, 2.0, 3.0], dtype=dtype)
39+
self.assertAllCloseAccordingToType(tanhshrink(x), _ref_tanhshrink(x))
40+
41+
@parameterized.named_parameters(("float16", np.float16),
42+
("float32", np.float32),
43+
("float64", np.float64))
44+
def test_gradients(self, dtype):
45+
x = tf.constant([1.0, 2.0, 3.0], dtype=dtype)
46+
with tf.GradientTape(persistent=True) as tape:
47+
tape.watch(x)
48+
y_ref = _ref_tanhshrink(x)
49+
y = tanhshrink(x)
50+
grad_ref = tape.gradient(y_ref, x)
51+
grad = tape.gradient(y, x)
52+
self.assertAllCloseAccordingToType(grad, grad_ref)
53+
54+
def test_serialization(self):
55+
ref_fn = tanhshrink
56+
config = tf.keras.activations.serialize(ref_fn)
57+
fn = tf.keras.activations.deserialize(config)
58+
self.assertEqual(fn, ref_fn)
59+
60+
61+
if __name__ == "__main__":
62+
tf.test.main()

tensorflow_addons/custom_ops/activations/BUILD

100644100755
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,40 @@ cc_library(
4949
alwayslink = 1,
5050
)
5151

52+
cc_library(
53+
name = "tanhshrink_op_gpu",
54+
srcs = [
55+
"cc/kernels/tanhshrink_op.h",
56+
"cc/kernels/tanhshrink_op_gpu.cu.cc",
57+
],
58+
copts = if_cuda_is_configured([
59+
"-DGOOGLE_CUDA=1",
60+
"-x cuda",
61+
"-nvcc_options=relaxed-constexpr",
62+
"-nvcc_options=ftz=true",
63+
]),
64+
deps = [
65+
"@local_config_tf//:libtensorflow_framework",
66+
"@local_config_tf//:tf_header_lib",
67+
] + if_cuda_is_configured([
68+
"@local_config_cuda//cuda:cuda_libs",
69+
"@local_config_cuda//cuda:cuda_headers",
70+
]),
71+
alwayslink = 1,
72+
)
73+
5274
cc_binary(
5375
name = "_activation_ops.so",
5476
srcs = [
5577
"cc/kernels/gelu_op.cc",
5678
"cc/kernels/gelu_op.h",
5779
"cc/kernels/hardshrink_op.cc",
5880
"cc/kernels/hardshrink_op.h",
81+
"cc/kernels/tanhshrink_op.cc",
82+
"cc/kernels/tanhshrink_op.h",
5983
"cc/ops/gelu_op.cc",
6084
"cc/ops/hardshrink_op.cc",
85+
"cc/ops/tanhshrink_op.cc",
6186
],
6287
copts = [
6388
"-pthread",
@@ -71,5 +96,6 @@ cc_binary(
7196
] + if_cuda_is_configured([
7297
":gelu_op_gpu",
7398
":hardshrink_op_gpu",
99+
":tanhshrink_op_gpu",
74100
]),
75101
)
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#define EIGEN_USE_THREADS
17+
18+
#include "tensorflow_addons/custom_ops/activations/cc/kernels/tanhshrink_op.h"
19+
#include "tensorflow/core/framework/op_kernel.h"
20+
#include "tensorflow/core/framework/register_types.h"
21+
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22+
23+
namespace tensorflow {
24+
namespace addons {
25+
26+
using CPUDevice = Eigen::ThreadPoolDevice;
27+
28+
#define REGISTER_TANHSHRINK_KERNELS(type) \
29+
REGISTER_KERNEL_BUILDER( \
30+
Name("Addons>Tanhshrink").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
31+
TanhshrinkOp<CPUDevice, type>); \
32+
REGISTER_KERNEL_BUILDER(Name("Addons>TanhshrinkGrad") \
33+
.Device(DEVICE_CPU) \
34+
.TypeConstraint<type>("T"), \
35+
TanhshrinkGradOp<CPUDevice, type>);
36+
37+
TF_CALL_GPU_NUMBER_TYPES(REGISTER_TANHSHRINK_KERNELS);
38+
#undef REGISTER_TANHSHRINK_KERNELS
39+
40+
#if GOOGLE_CUDA
41+
42+
using GPUDevice = Eigen::GpuDevice;
43+
44+
// Forward declarations of the functor specializations for GPU.
45+
namespace functor {
46+
#define DECLARE_GPU_SPEC(T) \
47+
template <> \
48+
void Tanhshrink<GPUDevice, T>::operator()( \
49+
const GPUDevice& d, typename TTypes<T>::ConstTensor features, \
50+
typename TTypes<T>::Tensor activations); \
51+
extern template struct Tanhshrink<GPUDevice, T>; \
52+
\
53+
template <> \
54+
void TanhshrinkGrad<GPUDevice, T>::operator()( \
55+
const GPUDevice& d, typename TTypes<T>::ConstTensor gradients, \
56+
typename TTypes<T>::ConstTensor features, \
57+
typename TTypes<T>::Tensor backprops); \
58+
extern template struct TanhshrinkGrad<GPUDevice, T>;
59+
60+
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
61+
#undef DECLARE_GPU_SPEC
62+
} // namespace functor
63+
64+
// Registration of the GPU implementations.
65+
#define REGISTER_TANHSHRINK_GPU_KERNELS(type) \
66+
REGISTER_KERNEL_BUILDER( \
67+
Name("Addons>Tanhshrink").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
68+
TanhshrinkOp<GPUDevice, type>); \
69+
REGISTER_KERNEL_BUILDER(Name("Addons>TanhshrinkGrad") \
70+
.Device(DEVICE_GPU) \
71+
.TypeConstraint<type>("T"), \
72+
TanhshrinkGradOp<GPUDevice, type>);
73+
74+
TF_CALL_GPU_NUMBER_TYPES(REGISTER_TANHSHRINK_GPU_KERNELS);
75+
#undef REGISTER_TANHSHRINK_GPU_KERNELS
76+
77+
#endif // GOOGLE_CUDA
78+
79+
} // namespace addons
80+
} // namespace tensorflow
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef TENSORFLOW_ADDONS_ACTIVATIONS_KERNELS_TANHSHRINK_OP_H_
17+
#define TENSORFLOW_ADDONS_ACTIVATIONS_KERNELS_TANHSHRINK_OP_H_
18+
19+
#define EIGEN_USE_THREADS
20+
21+
#include "tensorflow/core/framework/numeric_op.h"
22+
#include "tensorflow/core/framework/op_kernel.h"
23+
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24+
25+
namespace tensorflow {
26+
namespace addons {
27+
namespace functor {
28+
29+
template <typename Device, typename T>
30+
struct Tanhshrink {
31+
void operator()(const Device& d, typename TTypes<T>::ConstTensor features,
32+
typename TTypes<T>::Tensor activations) {
33+
activations.device(d) = features - features.tanh();
34+
}
35+
};
36+
37+
template <typename Device, typename T>
38+
struct TanhshrinkGrad {
39+
void operator()(const Device& d, typename TTypes<T>::ConstTensor gradients,
40+
typename TTypes<T>::ConstTensor features,
41+
typename TTypes<T>::Tensor backprops) {
42+
backprops.device(d) = gradients * features.tanh().square();
43+
}
44+
};
45+
46+
} // namespace functor
47+
48+
template <typename Device, typename T>
49+
class TanhshrinkOp : public UnaryElementWiseOp<T, TanhshrinkOp<Device, T>> {
50+
public:
51+
using UnaryElementWiseOp<T, TanhshrinkOp<Device, T>>::UnaryElementWiseOp;
52+
53+
void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
54+
functor::Tanhshrink<Device, T> functor;
55+
functor(context->eigen_device<Device>(), input.flat<T>(),
56+
output->flat<T>());
57+
}
58+
};
59+
60+
template <typename Device, typename T>
61+
class TanhshrinkGradOp
62+
: public BinaryElementWiseOp<T, TanhshrinkGradOp<Device, T>> {
63+
public:
64+
using BinaryElementWiseOp<T,
65+
TanhshrinkGradOp<Device, T>>::BinaryElementWiseOp;
66+
67+
void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
68+
const Tensor& a, Tensor* output);
69+
70+
// INPUTS:
71+
// g (gradients): backpropagated gradients
72+
// a (inputs): the inputs that were passed to the Tanhshrink op.
73+
// OUTPUT:
74+
// gradients to backprop
75+
template <int NDIMS>
76+
void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
77+
Tensor* output) {
78+
OperateNoTemplate(context, g, a, output);
79+
}
80+
};
81+
82+
template <typename Device, typename T>
83+
void TanhshrinkGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
84+
const Tensor& g,
85+
const Tensor& a,
86+
Tensor* output) {
87+
functor::TanhshrinkGrad<Device, T> functor;
88+
functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
89+
output->flat<T>());
90+
}
91+
} // namespace addons
92+
} // namespace tensorflow
93+
94+
#undef EIGEN_USE_THREADS
95+
96+
#endif // TENSORFLOW_ADDONS_ACTIVATIONS_KERNELS_TANHSHRINK_OP_H_

0 commit comments

Comments
 (0)