1313# limitations under the License.
1414# ==============================================================================
1515"""Implements R^2 scores."""
16- from typing import Tuple
16+ import warnings
1717
1818import numpy as np
1919import tensorflow as tf
@@ -86,13 +86,18 @@ def __init__(
8686 self ,
8787 name : str = "r_square" ,
8888 dtype : AcceptableDTypes = None ,
89- y_shape : Tuple [int , ...] = (),
9089 multioutput : str = "uniform_average" ,
9190 num_regressors : tf .int32 = 0 ,
9291 ** kwargs ,
9392 ):
9493 super ().__init__ (name = name , dtype = dtype , ** kwargs )
95- self .y_shape = y_shape
94+
95+ if "y_shape" in kwargs :
96+ warnings .warn (
97+ "y_shape has been removed, because it's automatically derived,"
98+ "and will be deprecated in Addons 0.18." ,
99+ DeprecationWarning ,
100+ )
96101
97102 if multioutput not in _VALID_MULTIOUTPUT :
98103 raise ValueError (
@@ -102,21 +107,38 @@ def __init__(
102107 )
103108 self .multioutput = multioutput
104109 self .num_regressors = num_regressors
105- self .squared_sum = self .add_weight (
106- name = "squared_sum" , shape = y_shape , initializer = "zeros" , dtype = dtype
107- )
108- self .sum = self .add_weight (
109- name = "sum" , shape = y_shape , initializer = "zeros" , dtype = dtype
110- )
111- self .res = self .add_weight (
112- name = "residual" , shape = y_shape , initializer = "zeros" , dtype = dtype
113- )
114- self .count = self .add_weight (
115- name = "count" , shape = y_shape , initializer = "zeros" , dtype = dtype
116- )
117110 self .num_samples = self .add_weight (name = "num_samples" , dtype = tf .int32 )
118111
119112 def update_state (self , y_true , y_pred , sample_weight = None ) -> None :
113+ if not hasattr (self , "squared_sum" ):
114+ self .squared_sum = self .add_weight (
115+ name = "squared_sum" ,
116+ shape = y_true .shape [1 :],
117+ initializer = "zeros" ,
118+ dtype = self ._dtype ,
119+ )
120+ if not hasattr (self , "sum" ):
121+ self .sum = self .add_weight (
122+ name = "sum" ,
123+ shape = y_true .shape [1 :],
124+ initializer = "zeros" ,
125+ dtype = self ._dtype ,
126+ )
127+ if not hasattr (self , "res" ):
128+ self .res = self .add_weight (
129+ name = "residual" ,
130+ shape = y_true .shape [1 :],
131+ initializer = "zeros" ,
132+ dtype = self ._dtype ,
133+ )
134+ if not hasattr (self , "count" ):
135+ self .count = self .add_weight (
136+ name = "count" ,
137+ shape = y_true .shape [1 :],
138+ initializer = "zeros" ,
139+ dtype = self ._dtype ,
140+ )
141+
120142 y_true = tf .cast (y_true , dtype = self ._dtype )
121143 y_pred = tf .cast (y_pred , dtype = self ._dtype )
122144 if sample_weight is None :
@@ -191,7 +213,6 @@ def reset_states(self):
191213
192214 def get_config (self ):
193215 config = {
194- "y_shape" : self .y_shape ,
195216 "multioutput" : self .multioutput ,
196217 }
197218 base_config = super ().get_config ()
0 commit comments