@@ -76,7 +76,22 @@ def __init__(
7676 self .preds_max = preds_max
7777 self .actual_cutpoints = actual_cutpoints
7878 self .preds_cutpoints = preds_cutpoints
79- self .reset_state ()
79+ self .actual_cuts = tf .linspace (
80+ tf .cast (self .actual_min , tf .float32 ),
81+ tf .cast (self .actual_max , tf .float32 ),
82+ self .actual_cutpoints - 1 ,
83+ )
84+ self .preds_cuts = tf .linspace (
85+ tf .cast (self .preds_min , tf .float32 ),
86+ tf .cast (self .preds_max , tf .float32 ),
87+ self .preds_cutpoints - 1 ,
88+ )
89+ self .m = self .add_weight (
90+ "m" , (self .actual_cutpoints , self .preds_cutpoints ), dtype = tf .int64
91+ )
92+ self .nrow = self .add_weight ("nrow" , (self .actual_cutpoints ), dtype = tf .int64 )
93+ self .ncol = self .add_weight ("ncol" , (self .preds_cutpoints ), dtype = tf .int64 )
94+ self .n = self .add_weight ("n" , (), dtype = tf .int64 )
8095
8196 def update_state (self , y_true , y_pred , sample_weight = None ):
8297 """Accumulates ranks.
@@ -89,75 +104,69 @@ def update_state(self, y_true, y_pred, sample_weight=None):
89104 Returns:
90105 Update op.
91106 """
92- if y_true .shape and y_true .shape [0 ]:
93- i = tf .searchsorted (
94- self .actual_cuts ,
95- tf .cast (tf .reshape (y_true , - 1 ), self .actual_cuts .dtype ),
107+ i = tf .searchsorted (
108+ self .actual_cuts ,
109+ tf .cast (tf .reshape (y_true , [- 1 ]), self .actual_cuts .dtype ),
110+ )
111+ j = tf .searchsorted (
112+ self .preds_cuts , tf .cast (tf .reshape (y_pred , [- 1 ]), self .preds_cuts .dtype )
113+ )
114+
115+ m = tf .sparse .from_dense (self .m )
116+ nrow = tf .sparse .from_dense (self .nrow )
117+ ncol = tf .sparse .from_dense (self .ncol )
118+
119+ k = 0
120+ while k < tf .shape (i )[0 ]:
121+ m = tf .sparse .add (
122+ m ,
123+ tf .SparseTensor (
124+ [[i [k ], j [k ]]],
125+ tf .cast ([1 ], dtype = m .dtype ),
126+ self .m .shape ,
127+ ),
96128 )
97- j = tf .searchsorted (
98- self .preds_cuts , tf .cast (tf .reshape (y_pred , - 1 ), self .preds_cuts .dtype )
129+ nrow = tf .sparse .add (
130+ nrow ,
131+ tf .SparseTensor (
132+ [[i [k ]]],
133+ tf .cast ([1 ], dtype = nrow .dtype ),
134+ self .nrow .shape ,
135+ ),
99136 )
100-
101- def body (k , n , m , nrow , ncol ):
102- return (
103- k + 1 ,
104- n + 1 ,
105- tf .sparse .add (
106- m ,
107- tf .SparseTensor (
108- [[i [k ], j [k ]]],
109- tf .cast ([1 ], dtype = self .m .dtype ),
110- self .m .shape ,
111- ),
112- ),
113- tf .sparse .add (
114- nrow ,
115- tf .SparseTensor (
116- [[i [k ]]],
117- tf .cast ([1 ], dtype = self .nrow .dtype ),
118- self .nrow .shape ,
119- ),
120- ),
121- tf .sparse .add (
122- ncol ,
123- tf .SparseTensor (
124- [[j [k ]]],
125- tf .cast ([1 ], dtype = self .ncol .dtype ),
126- self .ncol .shape ,
127- ),
128- ),
129- )
130-
131- _ , self .n , self .m , self .nrow , self .ncol = tf .while_loop (
132- lambda k , n , m , nrow , ncol : k < i .shape [0 ],
133- body = body ,
134- loop_vars = (0 , self .n , self .m , self .nrow , self .ncol ),
137+ ncol = tf .sparse .add (
138+ ncol ,
139+ tf .SparseTensor (
140+ [[j [k ]]],
141+ tf .cast ([1 ], dtype = ncol .dtype ),
142+ self .ncol .shape ,
143+ ),
135144 )
145+ k += 1
146+
147+ self .n .assign_add (tf .cast (k , tf .int64 ))
148+ self .m .assign (tf .sparse .to_dense (m ))
149+ self .nrow .assign (tf .sparse .to_dense (nrow ))
150+ self .ncol .assign (tf .sparse .to_dense (ncol ))
136151
137152 def result (self ):
138- m_dense = tf .sparse .to_dense (tf .cast (self .m , tf .float32 ))
139- n_cap = tf .cumsum (
140- tf .cumsum (
141- tf .slice (tf .pad (m_dense , [[1 , 0 ], [1 , 0 ]]), [0 , 0 ], self .m .shape ),
142- axis = 0 ,
143- ),
144- axis = 1 ,
145- )
153+ m = tf .cast (self .m , tf .float32 )
154+ n_cap = tf .cumsum (tf .cumsum (m , axis = 0 ), axis = 1 )
146155 # Number of concordant pairs.
147- p = tf .math .reduce_sum (tf .multiply (n_cap , m_dense ))
148- sum_m_squard = tf .math .reduce_sum (tf .math .square (m_dense ))
156+ p = tf .math .reduce_sum (tf .multiply (n_cap [: - 1 , : - 1 ], m [ 1 :, 1 :] ))
157+ sum_m_squard = tf .math .reduce_sum (tf .math .square (m ))
149158 # Ties in x.
150159 t = (
151- tf .math .reduce_sum (tf .math .square (tf . sparse . to_dense ( self .nrow )))
160+ tf .cast ( tf . math .reduce_sum (tf .math .square (self .nrow )), tf . float32 )
152161 - sum_m_squard
153162 ) / 2.0
154163 # Ties in y.
155164 u = (
156- tf .math .reduce_sum (tf .math .square (tf . sparse . to_dense ( self .ncol )))
165+ tf .cast ( tf . math .reduce_sum (tf .math .square (self .ncol )), tf . float32 )
157166 - sum_m_squard
158167 ) / 2.0
159168 # Ties in both.
160- b = tf .math .reduce_sum (tf .multiply (m_dense , (m_dense - 1.0 ))) / 2.0
169+ b = tf .math .reduce_sum (tf .multiply (m , (m - 1.0 ))) / 2.0
161170 # Number of discordant pairs.
162171 n = tf .cast (self .n , tf .float32 )
163172 q = (n - 1.0 ) * n / 2.0 - p - t - u - b
@@ -179,28 +188,11 @@ def get_config(self):
179188
180189 def reset_state (self ):
181190 """Resets all of the metric state variables."""
182- self .actual_cuts = tf .linspace (
183- tf .cast (self .actual_min , tf .float32 ),
184- tf .cast (self .actual_max , tf .float32 ),
185- self .actual_cutpoints - 1 ,
186- )
187- self .preds_cuts = tf .linspace (
188- tf .cast (self .preds_min , tf .float32 ),
189- tf .cast (self .preds_max , tf .float32 ),
190- self .preds_cutpoints - 1 ,
191- )
192- self .m = tf .SparseTensor (
193- tf .zeros ((0 , 2 ), tf .int64 ),
194- [],
195- [self .actual_cutpoints , self .preds_cutpoints ],
196- )
197- self .nrow = tf .SparseTensor (
198- tf .zeros ((0 , 1 ), dtype = tf .int64 ), [], [self .actual_cutpoints ]
199- )
200- self .ncol = tf .SparseTensor (
201- tf .zeros ((0 , 1 ), dtype = tf .int64 ), [], [self .preds_cutpoints ]
202- )
203- self .n = 0
191+
192+ self .m .assign (tf .zeros ((self .actual_cutpoints , self .preds_cutpoints ), tf .int64 ))
193+ self .nrow .assign (tf .zeros ((self .actual_cutpoints ), tf .int64 ))
194+ self .ncol .assign (tf .zeros ((self .preds_cutpoints ), tf .int64 ))
195+ self .n .assign (0 )
204196
205197 def reset_states (self ):
206198 # Backwards compatibility alias of `reset_state`. New classes should
0 commit comments