1+ /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
3+ Licensed under the Apache License, Version 2.0 (the "License");
4+ you may not use this file except in compliance with the License.
5+ You may obtain a copy of the License at
6+
7+ http://www.apache.org/licenses/LICENSE-2.0
8+
9+ Unless required by applicable law or agreed to in writing, software
10+ distributed under the License is distributed on an "AS IS" BASIS,
11+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ See the License for the specific language governing permissions and
13+ limitations under the License.
14+ ==============================================================================*/
15+
16+ #ifndef TENSORFLOW_ADDONS_ACTIVATIONS_KERNELS_TANHSHRINK_OP_H_
17+ #define TENSORFLOW_ADDONS_ACTIVATIONS_KERNELS_TANHSHRINK_OP_H_
18+
19+ #define EIGEN_USE_THREADS
20+
21+ #include " tensorflow/core/framework/numeric_op.h"
22+ #include " tensorflow/core/framework/op_kernel.h"
23+ #include " third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24+
25+ namespace tensorflow {
26+ namespace addons {
27+ namespace functor {
28+
29+ template <typename Device, typename T>
30+ struct Tanhshrink {
31+ void operator ()(const Device& d, typename TTypes<T>::ConstTensor features,
32+ typename TTypes<T>::Tensor activations) {
33+ activations.device (d) = features - features.tanh ();
34+ }
35+ };
36+
37+ template <typename Device, typename T>
38+ struct TanhshrinkGrad {
39+ void operator ()(const Device& d, typename TTypes<T>::ConstTensor gradients,
40+ typename TTypes<T>::ConstTensor features,
41+ typename TTypes<T>::Tensor backprops) {
42+ backprops.device (d) = gradients * features.tanh ().square ();
43+ }
44+ };
45+
46+ } // namespace functor
47+
48+ template <typename Device, typename T>
49+ class TanhshrinkOp : public UnaryElementWiseOp <T, TanhshrinkOp<Device, T>> {
50+ public:
51+ using UnaryElementWiseOp<T, TanhshrinkOp<Device, T>>::UnaryElementWiseOp;
52+
53+ void Operate (OpKernelContext* context, const Tensor& input, Tensor* output) {
54+ functor::Tanhshrink<Device, T> functor;
55+ functor (context->eigen_device <Device>(), input.flat <T>(),
56+ output->flat <T>());
57+ }
58+ };
59+
60+ template <typename Device, typename T>
61+ class TanhshrinkGradOp
62+ : public BinaryElementWiseOp<T, TanhshrinkGradOp<Device, T>> {
63+ public:
64+ using BinaryElementWiseOp<T,
65+ TanhshrinkGradOp<Device, T>>::BinaryElementWiseOp;
66+
67+ void OperateNoTemplate (OpKernelContext* context, const Tensor& g,
68+ const Tensor& a, Tensor* output);
69+
70+ // INPUTS:
71+ // g (gradients): backpropagated gradients
72+ // a (inputs): the inputs that were passed to the Tanhshrink op.
73+ // OUTPUT:
74+ // gradients to backprop
75+ template <int NDIMS>
76+ void Operate (OpKernelContext* context, const Tensor& g, const Tensor& a,
77+ Tensor* output) {
78+ OperateNoTemplate (context, g, a, output);
79+ }
80+ };
81+
82+ template <typename Device, typename T>
83+ void TanhshrinkGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
84+ const Tensor& g,
85+ const Tensor& a,
86+ Tensor* output) {
87+ functor::TanhshrinkGrad<Device, T> functor;
88+ functor (context->eigen_device <Device>(), g.flat <T>(), a.flat <T>(),
89+ output->flat <T>());
90+ }
91+ } // namespace addons
92+ } // namespace tensorflow
93+
94+ #undef EIGEN_USE_THREADS
95+
96+ #endif // TENSORFLOW_ADDONS_ACTIVATIONS_KERNELS_TANHSHRINK_OP_H_
0 commit comments