Skip to content

Commit b53d8a7

Browse files
AakashKumarNainseanpmorgan
authored andcommitted
focal loss implementation for tf keras (#32)
* ENH: Add focal loss
1 parent c338324 commit b53d8a7

File tree

5 files changed

+269
-2
lines changed

5 files changed

+269
-2
lines changed

tensorflow_addons/losses/BUILD

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ py_library(
66
name = "losses",
77
srcs = [
88
"__init__.py",
9+
"focal_loss.py",
910
"lifted.py",
1011
"metric_learning.py",
1112
"sparsemax_loss.py",
@@ -18,6 +19,19 @@ py_library(
1819
],
1920
)
2021

22+
py_test(
23+
name = "focal_loss_test",
24+
size = "small",
25+
srcs = [
26+
"focal_loss_test.py",
27+
],
28+
main = "focal_loss_test.py",
29+
srcs_version = "PY2AND3",
30+
deps = [
31+
":losses",
32+
],
33+
)
34+
2135
py_test(
2236
name = "sparsemax_loss_test",
2337
size = "small",

tensorflow_addons/losses/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
## Maintainers
44
| Submodule | Maintainers | Contact Info |
55
|:---------- |:----------- |:------------- |
6+
| focal_loss | SIG-Addons | addons@tensorflow.org |
67
| lifted | SIG-Addons | addons@tensorflow.org |
78
| sparsemax_loss | SIG-Addons | addons@tensorflow.org |
89
| triplet | SIG-Addons | addons@tensorflow.org |
910

1011
## Components
1112
| Submodule | Loss | Reference |
1213
|:----------------------- |:---------------------|:--------------------------|
14+
| focal_loss | SigmoidFocalCrossEntropy | https://arxiv.org/abs/1708.02002 |
1315
| lifted | LiftedStructLoss | https://arxiv.org/abs/1511.06452 |
1416
| sparsemax_loss | SparsemaxLoss | https://arxiv.org/abs/1602.02068 |
1517
| triplet | TripletSemiHardLoss | https://arxiv.org/abs/1503.03832 |
@@ -34,4 +36,5 @@ must:
3436
* Add a `py_test` to this sub-package's BUILD file.
3537

3638
#### Documentation Requirements
39+
* Update the table of contents in the project's central README.
3740
* Update the table of contents in this sub-package's README.

tensorflow_addons/losses/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21-
from tensorflow_addons.losses.lifted import lifted_struct_loss
21+
from tensorflow_addons.losses.focal_loss import sigmoid_focal_crossentropy, SigmoidFocalCrossEntropy
22+
from tensorflow_addons.losses.lifted import lifted_struct_loss, LiftedStructLoss
2223
from tensorflow_addons.losses.sparsemax_loss import sparsemax_loss, SparsemaxLoss
23-
from tensorflow_addons.losses.triplet import triplet_semihard_loss
24+
from tensorflow_addons.losses.triplet import triplet_semihard_loss, TripletSemiHardLoss
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Implements Focal loss."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import tensorflow as tf
22+
import tensorflow.keras.backend as K
23+
from tensorflow_addons.utils import keras_utils
24+
25+
26+
@keras_utils.register_keras_custom_object
27+
class SigmoidFocalCrossEntropy(keras_utils.LossFunctionWrapper):
28+
"""Implements the focal loss function.
29+
30+
Focal loss was first introduced in the RetinaNet paper
31+
(https://arxiv.org/pdf/1708.02002.pdf). Focal loss is extremely useful for
32+
classification when you have highly imbalanced classes. It down-weights
33+
well-classified examples and focuses on hard examples. The loss value is
34+
much high for a sample which is misclassified by the classifier as compared
35+
to the loss value corresponding to a well-classified example. One of the
36+
best use-cases of focal loss is its usage in object detection where the
37+
imbalance between the background class and other classes is extremely high.
38+
39+
Usage:
40+
41+
```python
42+
fl = tfa.losses.SigmoidFocalCrossEntropy()
43+
loss = fl(
44+
[[0.97], [0.91], [0.03]],
45+
[[1], [1], [0])
46+
print('Loss: ', loss.numpy()) # Loss: [[0.03045921]
47+
[0.09431068]
48+
[0.31471074]
49+
```
50+
Usage with tf.keras API:
51+
52+
```python
53+
model = tf.keras.Model(inputs, outputs)
54+
model.compile('sgd', loss=tf.keras.losses.SigmoidFocalCrossEntropy())
55+
```
56+
57+
Args
58+
alpha: balancing factor, default value is 0.25
59+
gamma: modulating factor, default value is 2.0
60+
61+
Returns:
62+
Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same
63+
shape as `y_true`; otherwise, it is scalar.
64+
65+
Raises:
66+
ValueError: If the shape of `sample_weight` is invalid or value of
67+
`gamma` is less than zero
68+
"""
69+
70+
def __init__(self,
71+
from_logits=False,
72+
alpha=0.25,
73+
gamma=2.0,
74+
reduction=tf.keras.losses.Reduction.NONE,
75+
name='sigmoid_focal_crossentropy'):
76+
super(SigmoidFocalCrossEntropy, self).__init__(
77+
sigmoid_focal_crossentropy,
78+
name=name,
79+
reduction=reduction,
80+
from_logits=from_logits,
81+
alpha=alpha,
82+
gamma=gamma)
83+
84+
self.from_logits = from_logits
85+
self.alpha = alpha
86+
self.gamma = gamma
87+
88+
89+
@keras_utils.register_keras_custom_object
90+
@tf.function
91+
def sigmoid_focal_crossentropy(y_true,
92+
y_pred,
93+
alpha=0.25,
94+
gamma=2.0,
95+
from_logits=False):
96+
"""
97+
Args
98+
y_true: true targets tensor.
99+
y_pred: predictions tensor.
100+
alpha: balancing factor.
101+
gamma: modulating factor.
102+
103+
Returns:
104+
Weighted loss float `Tensor`. If `reduction` is `NONE`,this has the
105+
same shape as `y_true`; otherwise, it is scalar.
106+
"""
107+
if gamma and gamma < 0:
108+
raise ValueError(
109+
"Value of gamma should be greater than or equal to zero")
110+
111+
y_pred = tf.convert_to_tensor(y_pred)
112+
y_true = tf.cast(y_true, y_pred.dtype)
113+
114+
# Get the binary cross_entropy
115+
bce = K.binary_crossentropy(y_true, y_pred, from_logits=from_logits)
116+
117+
# If logits are provided then convert the predictions into probabilities
118+
if from_logits:
119+
y_pred = K.sigmoid(y_pred)
120+
else:
121+
y_pred = K.clip(y_pred, K.epsilon(), 1. - K.epsilon())
122+
123+
p_t = (y_true * y_pred) + ((1 - y_true) * (1 - y_pred))
124+
alpha_factor = 1
125+
modulating_factor = 1
126+
127+
if alpha:
128+
alpha = tf.convert_to_tensor(alpha, dtype=K.floatx())
129+
alpha_factor = y_true * alpha + ((1 - alpha) * (1 - y_true))
130+
131+
if gamma:
132+
gamma = tf.convert_to_tensor(gamma, dtype=K.floatx())
133+
modulating_factor = K.pow((1 - p_t), gamma)
134+
135+
# compute the final loss and return
136+
return K.mean(
137+
alpha_factor * modulating_factor * bce, axis=-1, keepdims=True)
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
## Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Tests for focal loss."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import numpy as np
22+
import tensorflow as tf
23+
import tensorflow.keras.backend as K
24+
from tensorflow_addons.utils import test_utils
25+
from tensorflow_addons.losses import sigmoid_focal_crossentropy, SigmoidFocalCrossEntropy
26+
27+
28+
@test_utils.run_all_in_graph_and_eager_modes
29+
class SigmoidFocalCrossEntropyTest(tf.test.TestCase):
30+
def test_config(self):
31+
bce_obj = SigmoidFocalCrossEntropy(
32+
reduction=tf.keras.losses.Reduction.NONE,
33+
name='sigmoid_focal_crossentropy')
34+
self.assertEqual(bce_obj.name, 'sigmoid_focal_crossentropy')
35+
self.assertEqual(bce_obj.reduction, tf.keras.losses.Reduction.NONE)
36+
37+
def to_logit(self, prob):
38+
logit = np.log(prob / (1. - prob))
39+
return logit
40+
41+
def log10(self, x):
42+
numerator = tf.math.log(x)
43+
denominator = tf.math.log(tf.constant(10, dtype=numerator.dtype))
44+
return numerator / denominator
45+
46+
# Test with logits
47+
def test_with_logits(self):
48+
# predictiions represented as logits
49+
prediction_tensor = tf.constant(
50+
[[self.to_logit(0.97)], [self.to_logit(0.91)],
51+
[self.to_logit(0.73)], [self.to_logit(0.27)],
52+
[self.to_logit(0.09)], [self.to_logit(0.03)]], tf.float32)
53+
# Ground truth
54+
target_tensor = tf.constant([[1], [1], [1], [0], [0], [0]], tf.float32)
55+
56+
fl = sigmoid_focal_crossentropy(
57+
y_true=target_tensor,
58+
y_pred=prediction_tensor,
59+
from_logits=True,
60+
alpha=None,
61+
gamma=None)
62+
bce = K.binary_crossentropy(
63+
target_tensor, prediction_tensor, from_logits=True)
64+
65+
# When alpha and gamma are None, it should be equal to BCE
66+
self.assertAllClose(fl, bce)
67+
68+
# When gamma==2.0
69+
fl = sigmoid_focal_crossentropy(
70+
y_true=target_tensor,
71+
y_pred=prediction_tensor,
72+
from_logits=True,
73+
alpha=None,
74+
gamma=2.0)
75+
76+
# order_of_ratio = np.power(10, np.floor(np.log10(bce/FL)))
77+
order_of_ratio = tf.pow(10.0, tf.math.floor(self.log10(bce / fl)))
78+
pow_values = tf.constant([[1000], [100], [10], [10], [100], [1000]])
79+
self.assertAllClose(order_of_ratio, pow_values)
80+
81+
# Test without logits
82+
def test_without_logits(self):
83+
# predictiions represented as logits
84+
prediction_tensor = tf.constant(
85+
[[0.97], [0.91], [0.73], [0.27], [0.09], [0.03]], tf.float32)
86+
# Ground truth
87+
target_tensor = tf.constant([[1], [1], [1], [0], [0], [0]], tf.float32)
88+
89+
fl = sigmoid_focal_crossentropy(
90+
y_true=target_tensor,
91+
y_pred=prediction_tensor,
92+
alpha=None,
93+
gamma=None)
94+
bce = K.binary_crossentropy(target_tensor, prediction_tensor)
95+
96+
# When alpha and gamma are None, it should be equal to BCE
97+
self.assertAllClose(fl, bce)
98+
99+
# When gamma==2.0
100+
fl = sigmoid_focal_crossentropy(
101+
y_true=target_tensor,
102+
y_pred=prediction_tensor,
103+
alpha=None,
104+
gamma=2.0)
105+
106+
order_of_ratio = tf.pow(10.0, tf.math.floor(self.log10(bce / fl)))
107+
pow_values = tf.constant([[1000], [100], [10], [10], [100], [1000]])
108+
self.assertAllClose(order_of_ratio, pow_values)
109+
110+
111+
if __name__ == '__main__':
112+
tf.test.main()

0 commit comments

Comments
 (0)