Skip to content

Commit aaf3064

Browse files
james-martensKfacJaxDev
authored andcommitted
- Fixing bug with step rejection where reject_damping_increase_factor was applied when step was *not* rejected.
- Minor internal changes to optimizer code. PiperOrigin-RevId: 702353481
1 parent cf3acc0 commit aaf3064

File tree

1 file changed

+29
-21
lines changed

1 file changed

+29
-21
lines changed

kfac_jax/_src/optimizer.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -872,7 +872,7 @@ def _maybe_apply_norm_constraint(
872872

873873
return precon_grad, sq_norm_scaled_grads
874874

875-
def _compute_quad_change_for_damping(
875+
def _compute_quad_change_for_damping_adapt(
876876
self,
877877
state: State,
878878
delta: Params,
@@ -910,32 +910,38 @@ def _coefficients_and_quad_change(
910910
# we multiply the gradients, while the momentum is the coefficient by
911911
# which we multiply the velocities.
912912
neg_learning_rate = -learning_rate if learning_rate is not None else None
913-
coefficients = (neg_learning_rate, momentum)
913+
fixed_coefficients = (neg_learning_rate, momentum)
914914

915915
if self._use_adaptive_learning_rate or self._use_adaptive_momentum:
916916

917+
assert fixed_coefficients[0] is None or fixed_coefficients[1] is None
918+
917919
quad_model = self.compute_exact_quad_model_filtered(
918-
vectors, grads, func_args, state=state, coefficients=coefficients)
919-
return self._solve_quad_model(quad_model, damping, vectors, coefficients)
920+
vectors, grads, func_args, state=state,
921+
fixed_coefficients=fixed_coefficients)
922+
923+
return self._solve_quad_model(quad_model, damping, vectors,
924+
fixed_coefficients)
920925

921926
else:
922-
assert all(c is not None for c in coefficients)
923-
coefficients: tuple[Numeric, Numeric]
927+
assert all(c is not None for c in fixed_coefficients)
928+
fixed_coefficients: tuple[Numeric, Numeric]
924929

925930
if self._use_adaptive_damping:
926-
delta = self.weighted_sum_of_objects(vectors, coefficients)
931+
932+
delta = self.weighted_sum_of_objects(vectors, fixed_coefficients)
927933

928934
quad_change = lax.cond(
929935
self.should_update_damping(state),
930-
lambda args: self._compute_quad_change_for_damping(*args),
936+
lambda args: self._compute_quad_change_for_damping_adapt(*args),
931937
lambda args: self._invalid_metric_value,
932938
(state, delta, grads, damping, func_args),
933939
)
934940

935941
else:
936942
quad_change = self._invalid_metric_value
937943

938-
return coefficients, quad_change
944+
return fixed_coefficients, quad_change
939945

940946
@utils.staged
941947
def compute_loss_from_registrations(
@@ -1217,9 +1223,9 @@ def _step(
12171223

12181224
params, state.velocities, state.damping = lax.cond(
12191225
reject_step,
1220-
lambda: (params, state.velocities, state.damping),
1221-
lambda: (new_params, delta,
1222-
self._reject_damping_increase_factor * state.damping))
1226+
lambda: (params, state.velocities,
1227+
self._reject_damping_increase_factor * state.damping),
1228+
lambda: (new_params, delta, state.damping))
12231229

12241230
else:
12251231
# stop the linter from complaining about uninitialized variable
@@ -1404,20 +1410,21 @@ def compute_exact_quad_model_filtered(
14041410
grads: Params,
14051411
func_args: FuncArgsVariants,
14061412
state: State | None = None,
1407-
coefficients: Sequence[Numeric | None] | None = None,
1413+
fixed_coefficients: Sequence[Numeric | None] | None = None,
14081414
) -> tuple[Array, Array, Array]:
14091415
"""Computes the components of the exact quadratic model."""
14101416

1411-
# We check the coefficients for zeros to save computing the expensive matrix
1412-
# vector products for vectors that will eventually be multiplied by zero.
1417+
# We check the fixed_coefficients for zeros to save computing the expensive
1418+
# matrix vector products for vectors that will eventually be multiplied by
1419+
# zero. If fixed_coefficients is None, we assume that all coefficients are
1420+
# free and compute the full model.
14131421

1414-
if coefficients is None:
1422+
if fixed_coefficients is None:
14151423
return self.compute_exact_quad_model(
14161424
vectors, grads, func_args, state=state)
14171425

1418-
assert len(vectors) == len(coefficients)
1419-
# only deal with the two vector case
1420-
assert len(vectors) == 2
1426+
assert len(vectors) == len(fixed_coefficients)
1427+
assert len(vectors) == 2 # only deal with the two vector case
14211428

14221429
def if_momentum_coeff_zero():
14231430
# only pass in the vectors that won't be multiplied by zero
@@ -1431,11 +1438,12 @@ def if_momentum_coeff_zero():
14311438
for arr in quad_model
14321439
)
14331440
# add a check here to save compiling both branches in the static case
1434-
if isinstance(coefficients[1], float) and coefficients[1] == 0.0:
1441+
if (isinstance(fixed_coefficients[1], float)
1442+
and fixed_coefficients[1] == 0.0):
14351443
return if_momentum_coeff_zero()
14361444

14371445
return jax.lax.cond(
1438-
coefficients[1] == 0.0,
1446+
fixed_coefficients[1] == 0.0,
14391447
if_momentum_coeff_zero,
14401448
lambda: self.compute_exact_quad_model(
14411449
vectors, grads, func_args, state=state),

0 commit comments

Comments
 (0)