Skip to content

Commit 1d6b245

Browse files
supercontractsseanpmorgan
authored andcommitted
Implement lifted_struct_loss (#30)
* lifted_struct_loss function added
1 parent d47387a commit 1d6b245

File tree

4 files changed

+258
-1
lines changed

4 files changed

+258
-1
lines changed

tensorflow_addons/losses/BUILD

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@ py_library(
77
srcs = [
88
"__init__.py",
99
"python/__init__.py",
10-
"python/triplet.py",
10+
"python/lifted.py",
1111
"python/metric_learning.py",
12+
"python/triplet.py",
1213
],
1314
srcs_version = "PY2AND3",
1415
deps = [
@@ -28,3 +29,16 @@ py_test(
2829
],
2930
srcs_version = "PY2AND3",
3031
)
32+
33+
py_test(
34+
name = "lifted_py_test",
35+
size = "small",
36+
srcs = [
37+
"python/lifted_test.py",
38+
],
39+
main = "python/lifted_test.py",
40+
deps = [
41+
":losses_py",
42+
],
43+
srcs_version = "PY2AND3",
44+
)
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Implements lifted_struct_loss."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import tensorflow as tf
22+
from tensorflow.python.framework import dtypes
23+
from tensorflow.python.keras import losses
24+
from tensorflow.python.keras.utils import losses_utils
25+
from tensorflow.python.ops import array_ops
26+
from tensorflow.python.ops import math_ops
27+
from tensorflow_addons.losses.python import metric_learning
28+
from tensorflow_addons.utils.python import keras_utils
29+
30+
31+
@keras_utils.register_keras_custom_object
32+
@tf.function
33+
def lifted_struct_loss(labels, embeddings, margin=1.0):
34+
"""Computes the lifted structured loss.
35+
36+
Args:
37+
labels: 1-D tf.int32 `Tensor` with shape [batch_size] of
38+
multiclass integer labels.
39+
embeddings: 2-D float `Tensor` of embedding vectors. Embeddings should
40+
not be l2 normalized.
41+
margin: Float, margin term in the loss definition.
42+
43+
Returns:
44+
lifted_loss: tf.float32 scalar.
45+
"""
46+
# Reshape [batch_size] label tensor to a [batch_size, 1] label tensor.
47+
lshape = array_ops.shape(labels)
48+
assert lshape.shape == 1
49+
labels = array_ops.reshape(labels, [lshape[0], 1])
50+
51+
# Build pairwise squared distance matrix.
52+
pairwise_distances = metric_learning.pairwise_distance(embeddings)
53+
54+
# Build pairwise binary adjacency matrix.
55+
adjacency = math_ops.equal(labels, array_ops.transpose(labels))
56+
# Invert so we can select negatives only.
57+
adjacency_not = math_ops.logical_not(adjacency)
58+
59+
batch_size = array_ops.size(labels)
60+
61+
diff = margin - pairwise_distances
62+
mask = math_ops.cast(adjacency_not, dtype=dtypes.float32)
63+
# Safe maximum: Temporarily shift negative distances
64+
# above zero before taking max.
65+
# this is to take the max only among negatives.
66+
row_minimums = math_ops.reduce_min(diff, 1, keepdims=True)
67+
row_negative_maximums = math_ops.reduce_max(
68+
math_ops.multiply(diff - row_minimums, mask), 1,
69+
keepdims=True) + row_minimums
70+
71+
# Compute the loss.
72+
# Keep track of matrix of maximums where M_ij = max(m_i, m_j)
73+
# where m_i is the max of alpha - negative D_i's.
74+
# This matches the Caffe loss layer implementation at:
75+
# https://github.com/rksltnl/Caffe-Deep-Metric-Learning-CVPR16/blob/0efd7544a9846f58df923c8b992198ba5c355454/src/caffe/layers/lifted_struct_similarity_softmax_layer.cpp # pylint: disable=line-too-long
76+
77+
max_elements = math_ops.maximum(row_negative_maximums,
78+
array_ops.transpose(row_negative_maximums))
79+
diff_tiled = array_ops.tile(diff, [batch_size, 1])
80+
mask_tiled = array_ops.tile(mask, [batch_size, 1])
81+
max_elements_vect = array_ops.reshape(
82+
array_ops.transpose(max_elements), [-1, 1])
83+
84+
loss_exp_left = array_ops.reshape(
85+
math_ops.reduce_sum(
86+
math_ops.multiply(
87+
math_ops.exp(diff_tiled - max_elements_vect), mask_tiled),
88+
1,
89+
keepdims=True), [batch_size, batch_size])
90+
91+
loss_mat = max_elements + math_ops.log(loss_exp_left +
92+
array_ops.transpose(loss_exp_left))
93+
# Add the positive distance.
94+
loss_mat += pairwise_distances
95+
96+
mask_positives = math_ops.cast(
97+
adjacency, dtype=dtypes.float32) - array_ops.diag(
98+
array_ops.ones([batch_size]))
99+
100+
# *0.5 for upper triangular, and another *0.5 for 1/2 factor for loss^2.
101+
num_positives = math_ops.reduce_sum(mask_positives) / 2.0
102+
103+
lifted_loss = math_ops.truediv(
104+
0.25 * math_ops.reduce_sum(
105+
math_ops.square(
106+
math_ops.maximum(
107+
math_ops.multiply(loss_mat, mask_positives), 0.0))),
108+
num_positives)
109+
return lifted_loss
110+
111+
112+
@keras_utils.register_keras_custom_object
113+
class LiftedStructLoss(losses.LossFunctionWrapper):
114+
"""Computes the lifted structured loss.
115+
116+
The loss encourages the positive distances (between a pair of embeddings
117+
with the same labels) to be smaller than any negative distances (between
118+
a pair of embeddings with different labels) in the mini-batch in a way
119+
that is differentiable with respect to the embedding vectors.
120+
See: https://arxiv.org/abs/1511.06452.
121+
122+
Args:
123+
margin: Float, margin term in the loss definition.
124+
name: Optional name for the op.
125+
"""
126+
127+
def __init__(self, margin=1.0, name=None):
128+
super(LiftedStructLoss, self).__init__(
129+
lifted_struct_loss,
130+
name=name,
131+
reduction=losses_utils.ReductionV2.NONE,
132+
margin=margin)
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Tests for lifted loss."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import numpy as np
22+
23+
from tensorflow.python.framework import constant_op
24+
from tensorflow.python.framework import test_util
25+
from tensorflow.python.platform import test
26+
from tensorflow_addons.losses.python import lifted
27+
28+
29+
def pairwise_distance_np(feature, squared=False):
30+
"""Computes the pairwise distance matrix in numpy.
31+
32+
Args:
33+
feature: 2-D numpy array of size [number of data, feature dimension]
34+
squared: Boolean. If true, output is the pairwise squared euclidean
35+
distance matrix; else, output is the pairwise euclidean distance
36+
matrix.
37+
38+
Returns:
39+
pairwise_distances: 2-D numpy array of size
40+
[number of data, number of data].
41+
"""
42+
triu = np.triu_indices(feature.shape[0], 1)
43+
upper_tri_pdists = np.linalg.norm(
44+
feature[triu[1]] - feature[triu[0]], axis=1)
45+
if squared:
46+
upper_tri_pdists **= 2.
47+
num_data = feature.shape[0]
48+
pairwise_distances = np.zeros((num_data, num_data))
49+
pairwise_distances[np.triu_indices(num_data, 1)] = upper_tri_pdists
50+
# Make symmetrical.
51+
pairwise_distances = pairwise_distances + pairwise_distances.T - np.diag(
52+
pairwise_distances.diagonal())
53+
return pairwise_distances
54+
55+
56+
class LiftedStructLossTest(test.TestCase):
57+
58+
@test_util.run_all_in_graph_and_eager_modes
59+
def testLiftedStruct(self):
60+
num_data = 10
61+
feat_dim = 6
62+
margin = 1.0
63+
num_classes = 4
64+
65+
embedding = np.random.rand(num_data, feat_dim).astype(np.float32)
66+
labels = np.random.randint(
67+
0, num_classes, size=(num_data)).astype(np.float32)
68+
# Reshape labels to compute adjacency matrix.
69+
labels_reshaped = np.reshape(labels, (labels.shape[0], 1))
70+
71+
# Compute the loss in NP
72+
adjacency = np.equal(labels_reshaped, labels_reshaped.T)
73+
pdist_matrix = pairwise_distance_np(embedding)
74+
loss_np = 0.0
75+
num_constraints = 0.0
76+
for i in range(num_data):
77+
for j in range(num_data):
78+
if adjacency[i][j] > 0.0 and i != j:
79+
d_pos = pdist_matrix[i][j]
80+
negs = []
81+
for k in range(num_data):
82+
if not adjacency[i][k]:
83+
negs.append(margin - pdist_matrix[i][k])
84+
for l in range(num_data):
85+
if not adjacency[j][l]:
86+
negs.append(margin - pdist_matrix[j][l])
87+
88+
negs = np.array(negs)
89+
max_elem = np.max(negs)
90+
negs -= max_elem
91+
negs = np.exp(negs)
92+
soft_maximum = np.log(np.sum(negs)) + max_elem
93+
94+
num_constraints += 1.0
95+
this_loss = max(soft_maximum + d_pos, 0)
96+
loss_np += this_loss * this_loss
97+
98+
loss_np = loss_np / num_constraints / 2.0
99+
100+
# Compute the loss in TF.
101+
y_true = constant_op.constant(labels)
102+
y_pred = constant_op.constant(embedding)
103+
cce_obj = lifted.LiftedStructLoss()
104+
loss = cce_obj(y_true, y_pred)
105+
self.assertAlmostEqual(self.evaluate(loss), loss_np, 3)
106+
107+
108+
if __name__ == '__main__':
109+
test.main()

tensorflow_addons/losses/python/triplet.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def _masked_maximum(data, mask, dim=1):
3434
data: 2-D float `Tensor` of size [n, m].
3535
mask: 2-D Boolean `Tensor` of size [n, m].
3636
dim: The dimension over which to compute the maximum.
37+
3738
Returns:
3839
masked_maximums: N-D `Tensor`.
3940
The maximized dimension is of size 1 after the operation.
@@ -53,6 +54,7 @@ def _masked_minimum(data, mask, dim=1):
5354
data: 2-D float `Tensor` of size [n, m].
5455
mask: 2-D Boolean `Tensor` of size [n, m].
5556
dim: The dimension over which to compute the minimum.
57+
5658
Returns:
5759
masked_minimums: N-D `Tensor`.
5860
The minimized dimension is of size 1 after the operation.

0 commit comments

Comments
 (0)