Skip to content

Commit b756950

Browse files
authored
Correlation metrics optimization (#2747)
* remove unnecessary loop * tensor type fix
1 parent 38e6ec2 commit b756950

File tree

1 file changed

+14
-35
lines changed

1 file changed

+14
-35
lines changed

tensorflow_addons/metrics/streaming_correlations.py

Lines changed: 14 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def update_state(self, y_true, y_pred, sample_weight=None):
102102
self.actual_cuts,
103103
tf.cast(tf.reshape(y_true, [-1]), self.actual_cuts.dtype),
104104
side="right",
105+
out_type=tf.int64,
105106
)
106107
- 1
107108
)
@@ -110,46 +111,24 @@ def update_state(self, y_true, y_pred, sample_weight=None):
110111
self.preds_cuts,
111112
tf.cast(tf.reshape(y_pred, [-1]), self.preds_cuts.dtype),
112113
side="right",
114+
out_type=tf.int64,
113115
)
114116
- 1
115117
)
116118

117-
m = tf.sparse.from_dense(self.m)
118-
nrow = tf.sparse.from_dense(self.nrow)
119-
ncol = tf.sparse.from_dense(self.ncol)
120-
121-
k = 0
122-
while k < tf.shape(i)[0]:
123-
m = tf.sparse.add(
124-
m,
125-
tf.SparseTensor(
126-
[[i[k], j[k]]],
127-
tf.cast([1], dtype=m.dtype),
128-
self.m.shape,
129-
),
130-
)
131-
nrow = tf.sparse.add(
132-
nrow,
133-
tf.SparseTensor(
134-
[[i[k]]],
135-
tf.cast([1], dtype=nrow.dtype),
136-
self.nrow.shape,
137-
),
138-
)
139-
ncol = tf.sparse.add(
140-
ncol,
141-
tf.SparseTensor(
142-
[[j[k]]],
143-
tf.cast([1], dtype=ncol.dtype),
144-
self.ncol.shape,
145-
),
146-
)
147-
k += 1
119+
nrow = tf.tensor_scatter_nd_add(
120+
self.nrow, tf.expand_dims(i, axis=-1), tf.ones_like(i)
121+
)
122+
ncol = tf.tensor_scatter_nd_add(
123+
self.ncol, tf.expand_dims(j, axis=-1), tf.ones_like(j)
124+
)
125+
ij = tf.stack([i, j], axis=1)
126+
m = tf.tensor_scatter_nd_add(self.m, ij, tf.ones_like(i))
148127

149-
self.n.assign_add(tf.cast(k, tf.int64))
150-
self.m.assign(tf.sparse.to_dense(m))
151-
self.nrow.assign(tf.sparse.to_dense(nrow))
152-
self.ncol.assign(tf.sparse.to_dense(ncol))
128+
self.n.assign_add(tf.shape(i, out_type=tf.int64)[0])
129+
self.m.assign(m)
130+
self.nrow.assign(nrow)
131+
self.ncol.assign(ncol)
153132

154133
@abstractmethod
155134
def result(self):

0 commit comments

Comments
 (0)