Skip to content

Commit 594e183

Browse files
WindQAQseanpmorgan
authored andcommitted
[activations] fused gelu kernel (#427)
* add CPU and GPU kernel for gelu * add some documentations * format codes * support original (non-approximate) gelu * GPUDevice is super fast * fix typo * format codes * python API for gelu * unittests for gelu * update BUILD file * lint * update init and README * alphabetical order * update docs * update docs * test gradients on non-approximate gelu * change test name
1 parent 75e847b commit 594e183

File tree

10 files changed

+528
-8
lines changed

10 files changed

+528
-8
lines changed

tensorflow_addons/activations/BUILD

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@ py_library(
66
name = "activations",
77
srcs = [
88
"__init__.py",
9+
"gelu.py",
910
"sparsemax.py",
1011
],
11-
srcs_version = "PY2AND3",
12-
deps = [
12+
data = [
13+
"//tensorflow_addons/custom_ops/activations:_activation_ops.so",
1314
"//tensorflow_addons/utils",
1415
],
16+
srcs_version = "PY2AND3",
1517
)
1618

1719
py_test(
@@ -26,3 +28,16 @@ py_test(
2628
":activations",
2729
],
2830
)
31+
32+
py_test(
33+
name = "gelu_test",
34+
size = "large",
35+
srcs = [
36+
"gelu_test.py",
37+
],
38+
main = "gelu_test.py",
39+
srcs_version = "PY2AND3",
40+
deps = [
41+
":activations",
42+
],
43+
)

tensorflow_addons/activations/README.md

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
# Addons - Activations
22

33
## Maintainers
4-
| Submodule | Maintainers | Contact Info |
5-
|:---------- |:------------- |:--------------|
6-
| sparsemax | @AndreasMadsen | amwwebdk+github@gmail.com |
4+
| Submodule | Maintainers | Contact Info |
5+
|:----------|:--------------------------|:-----------------------------------------|
6+
| gelu | @AakashKumarNain @WindQAQ | aakashnain@outlook.com windqaq@gmail.com |
7+
| sparsemax | @AndreasMadsen | amwwebdk+github@gmail.com |
78

89
## Contents
9-
| Submodule | Activation | Reference |
10-
|:----------------------- |:-------------------|:---------------|
11-
| sparsemax | Sparsemax | https://arxiv.org/abs/1602.02068 |
10+
| Submodule | Activation | Reference |
11+
|:----------|:-----------|:---------------------------------|
12+
| gelu | gelu | https://arxiv.org/abs/1606.08415 |
13+
| sparsemax | Sparsemax | https://arxiv.org/abs/1602.02068 |
1214

1315

1416
## Contribution Guidelines

tensorflow_addons/activations/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,5 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21+
from tensorflow_addons.activations.gelu import gelu
2122
from tensorflow_addons.activations.sparsemax import sparsemax
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import tensorflow as tf
21+
from tensorflow_addons.utils import keras_utils
22+
from tensorflow_addons.utils.resource_loader import get_path_to_datafile
23+
24+
_activation_ops_so = tf.load_op_library(
25+
get_path_to_datafile("custom_ops/activations/_activation_ops.so"))
26+
27+
28+
@keras_utils.register_keras_custom_object
29+
@tf.function
30+
def gelu(x, approximate=True):
31+
"""Gaussian Error Linear Unit.
32+
33+
Computes gaussian error linear:
34+
`0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x^3)))` or
35+
`x * P(X <= x) = 0.5 * x * (1 + erf(x / sqrt(2)))`, where P(X) ~ N(0, 1),
36+
depending on whether approximation is enabled.
37+
38+
See [Gaussian Error Linear Units (GELUs)](https://arxiv.org/abs/1606.08415)
39+
and [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805).
40+
41+
Args:
42+
x: A `Tensor`. Must be one of the following types:
43+
`float16`, `float32`, `float64`.
44+
approximate: bool, whether to enable approximation.
45+
Returns:
46+
A `Tensor`. Has the same type as `x`.
47+
"""
48+
x = tf.convert_to_tensor(x)
49+
return _activation_ops_so.gelu(x, approximate)
50+
51+
52+
@tf.RegisterGradient("Gelu")
53+
def _gelu_grad(op, grad):
54+
return _activation_ops_so.gelu_grad(grad, op.inputs[0],
55+
op.get_attr("approximate"))
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
from absl.testing import parameterized
21+
22+
import math
23+
24+
import numpy as np
25+
import tensorflow as tf
26+
from tensorflow_addons.activations import gelu
27+
from tensorflow_addons.utils import test_utils
28+
29+
30+
def _ref_gelu(x, approximate=True):
31+
x = tf.convert_to_tensor(x)
32+
if approximate:
33+
pi = tf.cast(math.pi, x.dtype)
34+
coeff = tf.cast(0.044715, x.dtype)
35+
return 0.5 * x * (
36+
1.0 + tf.tanh(tf.sqrt(2.0 / pi) * (x + coeff * tf.pow(x, 3))))
37+
else:
38+
return 0.5 * x * (
39+
1.0 + tf.math.erf(x / tf.cast(tf.sqrt(2.0), x.dtype)))
40+
41+
42+
@test_utils.run_all_in_graph_and_eager_modes
43+
class GeluTest(tf.test.TestCase, parameterized.TestCase):
44+
@parameterized.named_parameters(("float16", np.float16),
45+
("float32", np.float32),
46+
("float64", np.float64))
47+
def test_gelu(self, dtype):
48+
x = np.random.rand(2, 3, 4).astype(dtype)
49+
self.assertAllCloseAccordingToType(gelu(x), _ref_gelu(x))
50+
self.assertAllCloseAccordingToType(gelu(x, False), _ref_gelu(x, False))
51+
52+
@parameterized.named_parameters(("float16", np.float16),
53+
("float32", np.float32),
54+
("float64", np.float64))
55+
def test_gradients(self, dtype):
56+
x = tf.constant([1.0, 2.0, 3.0], dtype=dtype)
57+
58+
for approximate in [True, False]:
59+
with self.subTest(approximate=approximate):
60+
with tf.GradientTape(persistent=True) as tape:
61+
tape.watch(x)
62+
y_ref = _ref_gelu(x, approximate)
63+
y = gelu(x, approximate)
64+
grad_ref = tape.gradient(y_ref, x)
65+
grad = tape.gradient(y, x)
66+
self.assertAllCloseAccordingToType(grad, grad_ref)
67+
68+
@parameterized.named_parameters(("float32", np.float32),
69+
("float64", np.float64))
70+
def test_theoretical_gradients(self, dtype):
71+
# Only test theoretical gradients for float32 and float64
72+
# because of the instability of float16 while computing jacobian
73+
x = tf.constant([1.0, 2.0, 3.0], dtype=dtype)
74+
75+
for approximate in [True, False]:
76+
with self.subTest(approximate=approximate):
77+
theoretical, numerical = tf.test.compute_gradient(
78+
lambda x: gelu(x, approximate=approximate), [x])
79+
self.assertAllCloseAccordingToType(
80+
theoretical, numerical, atol=1e-4)
81+
82+
def test_unknown_shape(self):
83+
fn = gelu.get_concrete_function(
84+
tf.TensorSpec(shape=None, dtype=tf.float32))
85+
86+
for shape in [(1,), (1, 2), (1, 2, 3), (1, 2, 3, 4)]:
87+
x = tf.ones(shape=shape, dtype=tf.float32)
88+
self.assertAllClose(fn(x), gelu(x))
89+
90+
def test_serialization(self):
91+
ref_fn = gelu
92+
config = tf.keras.activations.serialize(ref_fn)
93+
fn = tf.keras.activations.deserialize(config)
94+
self.assertEqual(fn, ref_fn)
95+
96+
def test_serialization_with_layers(self):
97+
layer = tf.keras.layers.Dense(3, activation=gelu)
98+
config = tf.keras.layers.serialize(layer)
99+
deserialized_layer = tf.keras.layers.deserialize(config)
100+
self.assertEqual(deserialized_layer.__class__.__name__,
101+
layer.__class__.__name__)
102+
self.assertEqual(deserialized_layer.activation.__name__, "gelu")
103+
104+
105+
if __name__ == "__main__":
106+
tf.test.main()
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
licenses(["notice"]) # Apache 2.0
2+
3+
package(default_visibility = ["//visibility:public"])
4+
5+
load("@local_config_tf//:build_defs.bzl", "D_GLIBCXX_USE_CXX11_ABI")
6+
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured", "if_cuda")
7+
8+
cc_library(
9+
name = "gelu_op_gpu",
10+
srcs = [
11+
"cc/kernels/gelu_op.h",
12+
"cc/kernels/gelu_op_gpu.cu.cc",
13+
],
14+
copts = if_cuda_is_configured([
15+
"-DGOOGLE_CUDA=1",
16+
"-x cuda",
17+
"-nvcc_options=relaxed-constexpr",
18+
"-nvcc_options=ftz=true",
19+
]),
20+
deps = [
21+
"@local_config_tf//:libtensorflow_framework",
22+
"@local_config_tf//:tf_header_lib",
23+
] + if_cuda_is_configured([
24+
"@local_config_cuda//cuda:cuda_libs",
25+
"@local_config_cuda//cuda:cuda_headers",
26+
]),
27+
alwayslink = 1,
28+
)
29+
30+
cc_binary(
31+
name = "_activation_ops.so",
32+
srcs = [
33+
"cc/kernels/gelu_op.cc",
34+
"cc/kernels/gelu_op.h",
35+
"cc/ops/gelu_op.cc",
36+
],
37+
copts = [
38+
"-pthread",
39+
"-std=c++11",
40+
D_GLIBCXX_USE_CXX11_ABI,
41+
] + if_cuda(["-DGOOGLE_CUDA=1"]),
42+
linkshared = 1,
43+
deps = [
44+
"@local_config_tf//:libtensorflow_framework",
45+
"@local_config_tf//:tf_header_lib",
46+
] + if_cuda_is_configured([":gelu_op_gpu"]),
47+
)
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#define EIGEN_USE_THREADS
17+
18+
#include "tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h"
19+
#include "tensorflow/core/framework/op_kernel.h"
20+
#include "tensorflow/core/framework/register_types.h"
21+
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22+
23+
namespace tensorflow {
24+
25+
using CPUDevice = Eigen::ThreadPoolDevice;
26+
27+
#define REGISTER_GELU_KERNELS(type) \
28+
REGISTER_KERNEL_BUILDER( \
29+
Name("Gelu").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
30+
GeluOp<CPUDevice, type>); \
31+
REGISTER_KERNEL_BUILDER( \
32+
Name("GeluGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
33+
GeluGradOp<CPUDevice, type>);
34+
35+
// Gelu only makes sense with floating points.
36+
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GELU_KERNELS);
37+
#undef REGISTER_GELU_KERNELS
38+
39+
#ifdef GOOGLE_CUDA
40+
41+
using GPUDevice = Eigen::GpuDevice;
42+
43+
// Forward declarations of the functor specializations for GPU.
44+
namespace functor {
45+
#define DECLARE_GPU_SPEC(T) \
46+
template <> \
47+
void Gelu<GPUDevice, T>::operator()( \
48+
const GPUDevice& d, typename TTypes<T>::ConstTensor features, \
49+
bool approximate, typename TTypes<T>::Tensor activations); \
50+
extern template struct Gelu<GPUDevice, T>; \
51+
\
52+
template <> \
53+
void GeluGrad<GPUDevice, T>::operator()( \
54+
const GPUDevice& d, typename TTypes<T>::ConstTensor gradients, \
55+
typename TTypes<T>::ConstTensor features, bool approximate, \
56+
typename TTypes<T>::Tensor backprops); \
57+
extern template struct GeluGrad<GPUDevice, T>;
58+
59+
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
60+
#undef DECLARE_GPU_SPEC
61+
} // namespace functor
62+
63+
// Registration of the GPU implementations.
64+
#define REGISTER_GELU_GPU_KERNELS(type) \
65+
REGISTER_KERNEL_BUILDER( \
66+
Name("Gelu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
67+
GeluOp<GPUDevice, type>); \
68+
REGISTER_KERNEL_BUILDER( \
69+
Name("GeluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
70+
GeluGradOp<GPUDevice, type>);
71+
72+
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GELU_GPU_KERNELS);
73+
#undef REGISTER_GELU_GPU_KERNELS
74+
75+
#endif // GOOGLE_CUDA
76+
77+
} // namespace tensorflow

0 commit comments

Comments
 (0)