@@ -43,6 +43,21 @@ function fenchel_young_loss_and_grad(
4343 return l, g
4444end
4545
46+ function fenchel_young_loss_and_grad (
47+ fyl:: FenchelYoungLoss{O} , θ:: AbstractArray , y_true:: AbstractArray ; kwargs...
48+ ) where {O<: AbstractRegularized{<:GeneralizedMaximizer} }
49+ (; optimization_layer) = fyl
50+ ŷ = optimization_layer (θ; kwargs... )
51+ Ωy_true = compute_regularization (optimization_layer, y_true)
52+ Ωŷ = compute_regularization (optimization_layer, ŷ)
53+ maximizer = get_maximizer (optimization_layer)
54+ l =
55+ (Ωy_true - objective_value (maximizer, θ, y_true; kwargs... )) -
56+ (Ωŷ - objective_value (maximizer, θ, ŷ; kwargs... ))
57+ g = maximizer. g (ŷ; kwargs... ) - maximizer. g (y_true; kwargs... )
58+ return l, g
59+ end
60+
4661function fenchel_young_loss_and_grad (
4762 fyl:: FenchelYoungLoss{O} , θ:: AbstractArray , y_true:: AbstractArray ; kwargs...
4863) where {O<: AbstractPerturbed }
@@ -61,7 +76,7 @@ function fenchel_young_loss_and_grad(
6176 optimization_layer, θ; kwargs...
6277 )
6378 l = F - objective_value (optimization_layer. oracle, θ, y_true; kwargs... )
64- g = almost_g_of_ŷ - optimization_layer. oracle. g (y_true)
79+ g = almost_g_of_ŷ - optimization_layer. oracle. g (y_true; kwargs ... )
6580 return l, g
6681end
6782
0 commit comments