@@ -165,6 +165,7 @@ def loss(self):
165165 return self ._get_unsafe ("loss" )
166166
167167 def _train_to_convergence (self ,
168+ loss ,
168169 train_op ,
169170 feed_dict ,
170171 loss_window_size ,
@@ -199,7 +200,7 @@ def should_stop(step):
199200
200201 while True :
201202 train_step , global_loss , _ = self .session .run (
202- (self .model .global_step , self . model . loss , train_op ),
203+ (self .model .global_step , loss , train_op ),
203204 feed_dict = feed_dict
204205 )
205206
@@ -223,6 +224,7 @@ def train(self, *args,
223224 convergence_criteria = "t_test" ,
224225 loss_window_size = None ,
225226 stop_at_loss_change = None ,
227+ loss = None ,
226228 train_op = None ,
227229 ** kwargs ):
228230 """
@@ -248,6 +250,7 @@ def train(self, *args,
248250
249251 See parameter `convergence_criteria` for exact meaning
250252 :param loss_window_size: specifies `N` in `convergence_criteria`.
253+ :param loss: uses this loss tensor if specified
251254 :param train_op: uses this training operation if specified
252255 """
253256 # feed_dict = dict() if feed_dict is None else feed_dict.copy()
@@ -260,11 +263,15 @@ def train(self, *args,
260263 stop_at_loss_change = 1e-5
261264 else :
262265 stop_at_loss_change = 0.05
263-
266+
267+ if loss is None :
268+ loss = self .model .loss
269+
264270 if train_op is None :
265271 train_op = self .model .train_op
266272
267273 self ._train_to_convergence (
274+ loss = loss ,
268275 train_op = train_op ,
269276 convergence_criteria = convergence_criteria ,
270277 loss_window_size = loss_window_size ,
0 commit comments