this is my code :
self.deq = get_deq(
ift = True,
f_solver=deq_solver,
b_solver=deq_solver,
f_max_iter=deq_f_max_iter,
b_max_iter=deq_b_max_iter,
f_tol=deq_f_tol,
b_tol=deq_b_tol,
stop_mode = 'rel',
f_anderson_m=deq_anderson_m if deq_solver == 'anderson' else None,
b_anderson_m=deq_anderson_m if deq_solver == 'anderson' else None,
)
solve , deq_info = self.deq(f,y)
then:
new_z_star = f(z_star.requires_grad_())
jac_loss = jac_reg(new_z_star,z_star)
so that i add the jac_loss to my loss,but it not works.