Skip to content

Commit bbec769

Browse files
AakashKumarNainWindQAQ
authored andcommitted
GeLU activation as a layer (#424)
* add gelu activation * add tests for gelu activation * add gelu to imports * include gelu in build file * update tests and refactor * refactor * make compatible with every fp dtype and fulfill layer requirements * add dummy model test * code format * code format and sanity check pass * code format * auto code format * use fused gelu activation * remove redundant test cases
1 parent 7e0f343 commit bbec769

File tree

4 files changed

+112
-1
lines changed

4 files changed

+112
-1
lines changed

tensorflow_addons/layers/BUILD

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ py_library(
66
name = "layers",
77
srcs = [
88
"__init__.py",
9+
"gelu.py",
910
"maxout.py",
1011
"normalizations.py",
1112
"optical_flow.py",
@@ -23,6 +24,19 @@ py_library(
2324
],
2425
)
2526

27+
py_test(
28+
name = "gelu_test",
29+
size = "small",
30+
srcs = [
31+
"gelu_test.py",
32+
],
33+
main = "gelu_test.py",
34+
srcs_version = "PY2AND3",
35+
deps = [
36+
":layers",
37+
],
38+
)
39+
2640
py_test(
2741
name = "layers_wrappers_test",
2842
size = "small",

tensorflow_addons/layers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,11 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21+
from tensorflow_addons.layers.gelu import GeLU
2122
from tensorflow_addons.layers.maxout import Maxout
2223
from tensorflow_addons.layers.normalizations import GroupNormalization
2324
from tensorflow_addons.layers.normalizations import InstanceNormalization
2425
from tensorflow_addons.layers.optical_flow import CorrelationCost
2526
from tensorflow_addons.layers.poincare import PoincareNormalize
2627
from tensorflow_addons.layers.sparsemax import Sparsemax
27-
from tensorflow_addons.layers.wrappers import WeightNormalization
28+
from tensorflow_addons.layers.wrappers import WeightNormalization

tensorflow_addons/layers/gelu.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Implements GeLU activation."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import tensorflow as tf
22+
from tensorflow_addons.utils import keras_utils
23+
from tensorflow_addons.activations import gelu
24+
25+
26+
@keras_utils.register_keras_custom_object
27+
class GeLU(tf.keras.layers.Layer):
28+
"""Gaussian Error Linear Unit.
29+
30+
A smoother version of ReLU generally used
31+
in the BERT or BERT architecture based models.
32+
Original paper: https://arxiv.org/abs/1606.08415
33+
34+
Input shape:
35+
Arbitrary. Use the keyword argument `input_shape`
36+
(tuple of integers, does not include the samples axis)
37+
when using this layer as the first layer in a model.
38+
39+
Output shape:
40+
Same shape as the input.
41+
"""
42+
43+
def __init__(self, approximate=True, **kwargs):
44+
super(GeLU, self).__init__(**kwargs)
45+
self.approximate = approximate
46+
self.supports_masking = True
47+
48+
def call(self, inputs):
49+
return gelu(inputs, approximate=self.approximate)
50+
51+
def get_config(self):
52+
config = {'approximate': self.approximate}
53+
base_config = super(GeLU, self).get_config()
54+
return dict(list(base_config.items()) + list(config.items()))
55+
56+
def compute_output_shape(self, input_shape):
57+
return input_shape
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Tests for GeLU activation."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import numpy as np
22+
import tensorflow as tf
23+
from absl.testing import parameterized
24+
from tensorflow_addons.layers.gelu import GeLU
25+
from tensorflow_addons.utils import test_utils
26+
27+
28+
@parameterized.parameters([np.float16, np.float32, np.float64])
29+
@test_utils.run_all_in_graph_and_eager_modes
30+
class TestGeLU(tf.test.TestCase):
31+
def test_random(self, dtype):
32+
x = np.array([[0.5, 1.2, -0.3]]).astype(dtype)
33+
val = np.array([[0.345714, 1.0617027, -0.11462909]]).astype(dtype)
34+
test_utils.layer_test(
35+
GeLU, kwargs={'dtype': dtype}, input_data=x, expected_output=val)
36+
37+
38+
if __name__ == '__main__':
39+
tf.test.main()

0 commit comments

Comments
 (0)