Skip to content

Commit 7bdc9f7

Browse files
committed
Add common interface to get proba distribution for perturbed & regularized
1 parent ffb5ad7 commit 7bdc9f7

File tree

3 files changed

+31
-2
lines changed

3 files changed

+31
-2
lines changed

src/InferOpt.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ include("ssvm/isbaseloss.jl")
3838
include("ssvm/zeroone_baseloss.jl")
3939
include("ssvm/ssvm_loss.jl")
4040

41+
export get_probability_distribution
42+
4143
export Interpolation
4244

4345
export DifferentiableFrankWolfe

src/perturbed/abstract_perturbed.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,33 @@ function (perturbed::AbstractPerturbed)(θ::AbstractArray{<:Real}; kwargs...)
4545
return mean(y_samples)
4646
end
4747

48+
function get_probability_distribution(
49+
perturbed::AbstractPerturbed, θ::AbstractArray{<:Real}; atol=0, kwargs...
50+
)
51+
(; nb_samples) = perturbed
52+
Z_samples = sample_perturbations(perturbed, θ)
53+
y_samples = [perturbed(θ, Z; kwargs...) for Z in Z_samples]
54+
multiplicity = ones(Int, nb_samples)
55+
to_delete = Int[]
56+
for i in nb_samples:-1:1
57+
yi = y_samples[i]
58+
for j in 1:(i - 1)
59+
yj = y_samples[j]
60+
if isapprox(yi, yj; atol=atol)
61+
multiplicity[j] += 1
62+
push!(to_delete, i)
63+
break
64+
end
65+
end
66+
end
67+
sort!(to_delete)
68+
deleteat!(y_samples, to_delete)
69+
deleteat!(multiplicity, to_delete)
70+
weights = multiplicity ./ sum(multiplicity)
71+
y_mean = sum(w * a for (w, a) in zip(weights, y_samples))
72+
return ActiveSet(weights, y_samples, y_mean)
73+
end
74+
4875
"""
4976
compute_y_and_F(perturbed, θ, Z; kwargs...)
5077
"""

src/regularized/regularized_generic.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ end
4343

4444
## Forward pass
4545

46-
function optimal_active_set(
46+
function get_probability_distribution(
4747
regularized::RegularizedGeneric,
4848
θ::AbstractArray{<:Real};
4949
maximizer_kwargs=(;),
@@ -59,7 +59,7 @@ end
5959
function (regularized::RegularizedGeneric)(
6060
θ::AbstractArray{<:Real}; maximizer_kwargs=(;), fw_kwargs=(;)
6161
)
62-
active_set = optimal_active_set(
62+
active_set = get_probability_distribution(
6363
regularized, θ; maximizer_kwargs=maximizer_kwargs, fw_kwargs=fw_kwargs
6464
)
6565
return active_set.x

0 commit comments

Comments
 (0)