@@ -102,6 +102,7 @@ def update_state(self, y_true, y_pred, sample_weight=None):
102102 self .actual_cuts ,
103103 tf .cast (tf .reshape (y_true , [- 1 ]), self .actual_cuts .dtype ),
104104 side = "right" ,
105+ out_type = tf .int64 ,
105106 )
106107 - 1
107108 )
@@ -110,46 +111,24 @@ def update_state(self, y_true, y_pred, sample_weight=None):
110111 self .preds_cuts ,
111112 tf .cast (tf .reshape (y_pred , [- 1 ]), self .preds_cuts .dtype ),
112113 side = "right" ,
114+ out_type = tf .int64 ,
113115 )
114116 - 1
115117 )
116118
117- m = tf .sparse .from_dense (self .m )
118- nrow = tf .sparse .from_dense (self .nrow )
119- ncol = tf .sparse .from_dense (self .ncol )
120-
121- k = 0
122- while k < tf .shape (i )[0 ]:
123- m = tf .sparse .add (
124- m ,
125- tf .SparseTensor (
126- [[i [k ], j [k ]]],
127- tf .cast ([1 ], dtype = m .dtype ),
128- self .m .shape ,
129- ),
130- )
131- nrow = tf .sparse .add (
132- nrow ,
133- tf .SparseTensor (
134- [[i [k ]]],
135- tf .cast ([1 ], dtype = nrow .dtype ),
136- self .nrow .shape ,
137- ),
138- )
139- ncol = tf .sparse .add (
140- ncol ,
141- tf .SparseTensor (
142- [[j [k ]]],
143- tf .cast ([1 ], dtype = ncol .dtype ),
144- self .ncol .shape ,
145- ),
146- )
147- k += 1
119+ nrow = tf .tensor_scatter_nd_add (
120+ self .nrow , tf .expand_dims (i , axis = - 1 ), tf .ones_like (i )
121+ )
122+ ncol = tf .tensor_scatter_nd_add (
123+ self .ncol , tf .expand_dims (j , axis = - 1 ), tf .ones_like (j )
124+ )
125+ ij = tf .stack ([i , j ], axis = 1 )
126+ m = tf .tensor_scatter_nd_add (self .m , ij , tf .ones_like (i ))
148127
149- self .n .assign_add (tf .cast ( k , tf .int64 ))
150- self .m .assign (tf . sparse . to_dense ( m ) )
151- self .nrow .assign (tf . sparse . to_dense ( nrow ) )
152- self .ncol .assign (tf . sparse . to_dense ( ncol ) )
128+ self .n .assign_add (tf .shape ( i , out_type = tf .int64 )[ 0 ] )
129+ self .m .assign (m )
130+ self .nrow .assign (nrow )
131+ self .ncol .assign (ncol )
153132
154133 @abstractmethod
155134 def result (self ):
0 commit comments