@@ -872,7 +872,7 @@ def _maybe_apply_norm_constraint(
872872
873873 return precon_grad , sq_norm_scaled_grads
874874
875- def _compute_quad_change_for_damping (
875+ def _compute_quad_change_for_damping_adapt (
876876 self ,
877877 state : State ,
878878 delta : Params ,
@@ -910,32 +910,38 @@ def _coefficients_and_quad_change(
910910 # we multiply the gradients, while the momentum is the coefficient by
911911 # which we multiply the velocities.
912912 neg_learning_rate = - learning_rate if learning_rate is not None else None
913- coefficients = (neg_learning_rate , momentum )
913+ fixed_coefficients = (neg_learning_rate , momentum )
914914
915915 if self ._use_adaptive_learning_rate or self ._use_adaptive_momentum :
916916
917+ assert fixed_coefficients [0 ] is None or fixed_coefficients [1 ] is None
918+
917919 quad_model = self .compute_exact_quad_model_filtered (
918- vectors , grads , func_args , state = state , coefficients = coefficients )
919- return self ._solve_quad_model (quad_model , damping , vectors , coefficients )
920+ vectors , grads , func_args , state = state ,
921+ fixed_coefficients = fixed_coefficients )
922+
923+ return self ._solve_quad_model (quad_model , damping , vectors ,
924+ fixed_coefficients )
920925
921926 else :
922- assert all (c is not None for c in coefficients )
923- coefficients : tuple [Numeric , Numeric ]
927+ assert all (c is not None for c in fixed_coefficients )
928+ fixed_coefficients : tuple [Numeric , Numeric ]
924929
925930 if self ._use_adaptive_damping :
926- delta = self .weighted_sum_of_objects (vectors , coefficients )
931+
932+ delta = self .weighted_sum_of_objects (vectors , fixed_coefficients )
927933
928934 quad_change = lax .cond (
929935 self .should_update_damping (state ),
930- lambda args : self ._compute_quad_change_for_damping (* args ),
936+ lambda args : self ._compute_quad_change_for_damping_adapt (* args ),
931937 lambda args : self ._invalid_metric_value ,
932938 (state , delta , grads , damping , func_args ),
933939 )
934940
935941 else :
936942 quad_change = self ._invalid_metric_value
937943
938- return coefficients , quad_change
944+ return fixed_coefficients , quad_change
939945
940946 @utils .staged
941947 def compute_loss_from_registrations (
@@ -1217,9 +1223,9 @@ def _step(
12171223
12181224 params , state .velocities , state .damping = lax .cond (
12191225 reject_step ,
1220- lambda : (params , state .velocities , state . damping ),
1221- lambda : ( new_params , delta ,
1222- self . _reject_damping_increase_factor * state .damping ))
1226+ lambda : (params , state .velocities ,
1227+ self . _reject_damping_increase_factor * state . damping ) ,
1228+ lambda : ( new_params , delta , state .damping ))
12231229
12241230 else :
12251231 # stop the linter from complaining about uninitialized variable
@@ -1404,20 +1410,21 @@ def compute_exact_quad_model_filtered(
14041410 grads : Params ,
14051411 func_args : FuncArgsVariants ,
14061412 state : State | None = None ,
1407- coefficients : Sequence [Numeric | None ] | None = None ,
1413+ fixed_coefficients : Sequence [Numeric | None ] | None = None ,
14081414 ) -> tuple [Array , Array , Array ]:
14091415 """Computes the components of the exact quadratic model."""
14101416
1411- # We check the coefficients for zeros to save computing the expensive matrix
1412- # vector products for vectors that will eventually be multiplied by zero.
1417+ # We check the fixed_coefficients for zeros to save computing the expensive
1418+ # matrix vector products for vectors that will eventually be multiplied by
1419+ # zero. If fixed_coefficients is None, we assume that all coefficients are
1420+ # free and compute the full model.
14131421
1414- if coefficients is None :
1422+ if fixed_coefficients is None :
14151423 return self .compute_exact_quad_model (
14161424 vectors , grads , func_args , state = state )
14171425
1418- assert len (vectors ) == len (coefficients )
1419- # only deal with the two vector case
1420- assert len (vectors ) == 2
1426+ assert len (vectors ) == len (fixed_coefficients )
1427+ assert len (vectors ) == 2 # only deal with the two vector case
14211428
14221429 def if_momentum_coeff_zero ():
14231430 # only pass in the vectors that won't be multiplied by zero
@@ -1431,11 +1438,12 @@ def if_momentum_coeff_zero():
14311438 for arr in quad_model
14321439 )
14331440 # add a check here to save compiling both branches in the static case
1434- if isinstance (coefficients [1 ], float ) and coefficients [1 ] == 0.0 :
1441+ if (isinstance (fixed_coefficients [1 ], float )
1442+ and fixed_coefficients [1 ] == 0.0 ):
14351443 return if_momentum_coeff_zero ()
14361444
14371445 return jax .lax .cond (
1438- coefficients [1 ] == 0.0 ,
1446+ fixed_coefficients [1 ] == 0.0 ,
14391447 if_momentum_coeff_zero ,
14401448 lambda : self .compute_exact_quad_model (
14411449 vectors , grads , func_args , state = state ),
0 commit comments