Skip to content

Commit cb3b716

Browse files
committed
use full loss to check convergence if batching is disabled
1 parent 5cfbe46 commit cb3b716

File tree

2 files changed

+137
-126
lines changed

2 files changed

+137
-126
lines changed

batchglm/train/tf/base.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def loss(self):
165165
return self._get_unsafe("loss")
166166

167167
def _train_to_convergence(self,
168+
loss,
168169
train_op,
169170
feed_dict,
170171
loss_window_size,
@@ -199,7 +200,7 @@ def should_stop(step):
199200

200201
while True:
201202
train_step, global_loss, _ = self.session.run(
202-
(self.model.global_step, self.model.loss, train_op),
203+
(self.model.global_step, loss, train_op),
203204
feed_dict=feed_dict
204205
)
205206

@@ -223,6 +224,7 @@ def train(self, *args,
223224
convergence_criteria="t_test",
224225
loss_window_size=None,
225226
stop_at_loss_change=None,
227+
loss=None,
226228
train_op=None,
227229
**kwargs):
228230
"""
@@ -248,6 +250,7 @@ def train(self, *args,
248250
249251
See parameter `convergence_criteria` for exact meaning
250252
:param loss_window_size: specifies `N` in `convergence_criteria`.
253+
:param loss: uses this loss tensor if specified
251254
:param train_op: uses this training operation if specified
252255
"""
253256
# feed_dict = dict() if feed_dict is None else feed_dict.copy()
@@ -260,11 +263,15 @@ def train(self, *args,
260263
stop_at_loss_change = 1e-5
261264
else:
262265
stop_at_loss_change = 0.05
263-
266+
267+
if loss is None:
268+
loss = self.model.loss
269+
264270
if train_op is None:
265271
train_op = self.model.train_op
266272

267273
self._train_to_convergence(
274+
loss=loss,
268275
train_op=train_op,
269276
convergence_criteria=convergence_criteria,
270277
loss_window_size=loss_window_size,

0 commit comments

Comments
 (0)