Skip to content

Commit 8f7f103

Browse files
committed
add tests for new interface
1 parent cb8da5a commit 8f7f103

File tree

3 files changed

+10
-2
lines changed

3 files changed

+10
-2
lines changed

src/regularized/soft_argmax.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ Corresponds to regularized prediction on the probability simplex with entropic p
88
struct SoftArgmax <: AbstractRegularized end
99

1010
(::SoftArgmax)(z::AbstractVector; kwargs...) = soft_argmax(z)
11-
compute_regularization(::SoftArgmax, y::AbstractVector) = soft_argmax_regularization(y)
11+
compute_regularization(::SoftArgmax, y::AbstractArray) = soft_argmax_regularization(y)
1212

1313
function soft_argmax(z::AbstractVector)
1414
s = exp.(z)

src/regularized/sparse_argmax.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ Corresponds to regularized prediction on the probability simplex with square nor
88
struct SparseArgmax <: AbstractRegularized end
99

1010
(::SparseArgmax)(z::AbstractVector; kwargs...) = sparse_argmax(z)
11-
compute_regularization(::SparseArgmax, y::AbstractVector) = sparse_argmax_regularization(y)
11+
compute_regularization(::SparseArgmax, y::AbstractArray) = sparse_argmax_regularization(y)
1212

1313
function sparse_argmax(z::AbstractVector; kwargs...)
1414
p, _ = simplex_projection_and_support(z)

test/interface.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
@testitem "Test interfaces are correctly implemented" default_imports = false begin
2+
using InferOpt, RequiredInterfaces, Test
3+
const RI = RequiredInterfaces
4+
5+
@test RI.check_interface_implemented(AbstractRegularized, RegularizedFrankWolfe)
6+
@test RI.check_interface_implemented(AbstractRegularized, SoftArgmax)
7+
@test RI.check_interface_implemented(AbstractRegularized, SparseArgmax)
8+
end

0 commit comments

Comments
 (0)