diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index 9e26dfe..0000000 --- a/.vscode/settings.json +++ /dev/null @@ -1 +0,0 @@ -{} \ No newline at end of file diff --git a/Project.toml b/Project.toml index 8dee87a..7b0ff8e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,15 +1,20 @@ name = "GsvdInitialization" uuid = "2ac24108-be9c-42b8-8d78-6a4f62a87e7d" -authors = ["youdongguo <1010705897@qq.com> and contributors"] version = "1.0.0" +authors = ["youdongguo <1010705897@qq.com> and contributors"] [deps] +Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NMF = "6ef6ca0d-6ad7-5ff6-b225-e928bfa0a386" NonNegLeastSquares = "b7351bd1-99d9-5c5d-8786-f205a815c4d7" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" TSVD = "9449cd9e-2762-5aa3-a617-5413e99d722e" [compat] +FileIO = "1.18" +JLD2 = "0.6" +Kronecker = "0.5" LinearAlgebra = "1" NMF = "1" NonNegLeastSquares = "0.4" @@ -17,8 +22,10 @@ TSVD = "0.4" julia = "1.10" [extras] +FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" +JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" NMF = "6ef6ca0d-6ad7-5ff6-b225-e928bfa0a386" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["NMF", "Test"] +test = ["NMF", "Test", "FileIO", "JLD2"] diff --git a/demo/generate_ground_truth.jl b/demo/generate_ground_truth.jl index 409eb0f..901a797 100644 --- a/demo/generate_ground_truth.jl +++ b/demo/generate_ground_truth.jl @@ -42,3 +42,7 @@ function gaussiantemplate(T::Type, r::Real) end gaussiantemplate(r::Real) = gaussiantemplate(Float64, r) +load_svd_of_gt() = load(joinpath(dirname(@__DIR__), "demo/svd_of_GT.jld2"))["svdX"] + + + diff --git a/demo/svd_of_GT.jld2 b/demo/svd_of_GT.jld2 new file mode 100644 index 0000000..9f44f9a Binary files /dev/null and b/demo/svd_of_GT.jld2 differ diff --git a/src/GsvdInitialization.jl b/src/GsvdInitialization.jl index 7c6e6b3..0171df0 100644 --- a/src/GsvdInitialization.jl +++ b/src/GsvdInitialization.jl @@ -2,6 +2,7 @@ module GsvdInitialization using LinearAlgebra, NMF, TSVD using NonNegLeastSquares +using Kronecker, SparseArrays export gsvdnmf, gsvdrecover @@ -33,6 +34,9 @@ Other keyword arguments are passed to `NMF.nnmf`. function gsvdnmf(X::AbstractMatrix, W::AbstractMatrix, H::AbstractMatrix, f; n2 = size(first(f), 2), tol_nmf=1e-4, + alg = :cd, + initW = :standard, + truncmult = 1e-5, kwargs...) n1 = size(W, 2) kadd = n2 - n1 @@ -42,9 +46,12 @@ function gsvdnmf(X::AbstractMatrix, W::AbstractMatrix, H::AbstractMatrix, f; if kadd == 0 return W, H else - W_recover, H_recover = gsvdrecover(X, copy(W), copy(H), kadd, f) - result_recover = nnmf(X, n2; kwargs..., init=:custom, tol=tol_nmf, W0=W_recover, H0=H_recover) - return result_recover.W, result_recover.H + W_recover, H_recover, _ = gsvdrecover(X, copy(W), copy(H), kadd, f; initW=initW) + if alg == :multmse + W_recover, H_recover = max.(W_recover, truncmult), max.(H_recover, truncmult) + end + result_recover = nnmf(X, n2; kwargs..., init=:custom, tol=tol_nmf, W0=copy(W_recover), H0=copy(H_recover)) + return result_recover, result_recover.W, result_recover.H end end gsvdnmf(X::AbstractMatrix, W::AbstractMatrix, H::AbstractMatrix, n2::Int; kwargs...) = gsvdnmf(X, W, H, tsvd(X, n2); kwargs...) @@ -73,13 +80,13 @@ Keyword arguments: Other keyword arguments are passed to `NMF.nnmf`. """ -function gsvdnmf(X::AbstractMatrix, ncomponents::Pair{Int,Int}; tol_final=1e-4, tol_intermediate=tol_final, kwargs...) +function gsvdnmf(X::AbstractMatrix, ncomponents::Pair{Int,Int}; tol_final=1e-4, tol_intermediate=tol_final, initW = :standard, kwargs...) n1, n2 = ncomponents f = tsvd(X, n2) W0, H0 = NMF.nndsvd(X, n1; initdata = (U = f[1], S = f[2], V = f[3])) result_initial_nmf = nnmf(X, n1; kwargs..., init=:custom, tol=tol_intermediate, W0=copy(W0), H0=copy(H0)) W_initial_nmf, H_initial_nmf = result_initial_nmf.W, result_initial_nmf.H - return gsvdnmf(X, W_initial_nmf, H_initial_nmf, f; kwargs..., n2=n2, tol_nmf=tol_final) + return gsvdnmf(X, W_initial_nmf, H_initial_nmf, f; kwargs..., n2=n2, tol_nmf=tol_final, initW=initW) end gsvdnmf(X::AbstractMatrix, ncomponents_final::Integer; kwargs...) = gsvdnmf(X, ncomponents_final-1 => ncomponents_final; kwargs...) @@ -108,7 +115,7 @@ Arguments: `f`: SVD (or Truncated SVD) of `X` """ -function gsvdrecover(X::AbstractArray, W0::AbstractArray, H0::AbstractArray, kadd::Int, f::Tuple) +function gsvdrecover(X::AbstractArray, W0::AbstractArray, H0::AbstractArray, kadd::Int, f::Tuple; initW::Symbol = :standard, kwargs...) m, n = size(W0) kadd <= n || throw(ArgumentError("# of extra columns must less than 1st NMF components")) if kadd == 0 @@ -117,15 +124,29 @@ function gsvdrecover(X::AbstractArray, W0::AbstractArray, H0::AbstractArray, kad U0, S0, V0 = f U0, S0, V0 = U0[:,1:n], S0[1:n], V0[:,1:n] Hadd, Λ = init_H(U0, S0, V0, W0, H0, kadd) - Wadd, a = init_W(X, W0, H0, Hadd) - Wadd_nn, Hadd_nn = NMF.nndsvd(X, kadd, initdata = (U = Wadd, S = ones(kadd), V = Hadd')) - W0_1, H0_1 = [repeat(a', m, 1).*W0 Wadd_nn], [H0; Hadd_nn] - cs = Wcols_modification(X, W0_1, H0_1) - W0_2, H0_2 = repeat(cs', m, 1).*W0_1, H0_1 + if initW == :standard + Wadd, a = init_W(X, W0, H0, Hadd) + Wadd_nn, Hadd_nn = NMF.nndsvd(X, kadd, initdata = (U = Wadd, S = ones(kadd), V = Hadd')) + W0_1, H0_1 = [repeat(a', m, 1).*W0 Wadd_nn], [H0; Hadd_nn] + cs = Wcols_modification(X, W0_1, H0_1) + W0_2, H0_2 = repeat(cs', m, 1).*W0_1, H0_1 + elseif initW == :joint + W0_2, H0_2 = gsvdrecover_Wa(X, W0, H0, Hadd; kwargs...) + else + throw(ArgumentError("Unknown initW method: $initW")) + end return abs.(W0_2), abs.(H0_2), Λ end end +function gsvdrecover_Wa(X::AbstractArray, W0::AbstractArray, H0::AbstractArray, Hadd::AbstractArray) + m = size(W0, 1) + Hadd_nn = truncatepos(Hadd', X, W0, H0)' + Wadd, a = init_Wa(X, W0, H0, Hadd_nn) + W0_1, H0_1 = [repeat(a', m, 1).*W0 Wadd], [H0; Hadd_nn] + return abs.(W0_1), abs.(H0_1) +end + function init_H(U0::AbstractArray, S0::AbstractArray, V0::AbstractArray, W0::AbstractArray, H0::AbstractArray, kadd::Int) _, _, Q, D1, D2, R = svd(Matrix(Diagonal(S0)), (U0'*W0)*(H0*V0)); inv_RQt = inv(R*Q') @@ -138,7 +159,41 @@ function init_H(U0::AbstractArray, S0::AbstractArray, V0::AbstractArray, W0::Abs H_index = sortperm(Λ, rev = true)[1:kadd] Hadd = inv_RQt[:, H_index] Hadd_1 = V0*Hadd - return Hadd_1', Λ[H_index] + return Hadd_1', Λ +end + +function init_Wa(X::AbstractArray{T}, W0::AbstractArray{T}, H0::AbstractArray{T}, Hadd::AbstractArray{T}) where T + m = size(X, 1) + kadd = size(Hadd, 1) + G = gram_sp_C(W0, H0, Hadd)[1] + b = gram_b(X, W0, H0, Hadd) + θ = nonneg_lsq(G, b; alg=:fnnls, gram=true) + Wadd = reshape(θ[1:m*kadd], m, kadd) + α = θ[m*kadd+1:end] + return Wadd, α +end + +function gram_sp_C(W0, H0, Hadd) + m, r0 = size(W0) + k = size(Hadd, 1) + mk = m*k + W0W0, H0H0 = W0'*W0, H0*H0' + P = Hadd*H0' + HH = Hadd*Hadd' + G22 = sparse(W0W0.*H0H0) + G12 = zeros(Float64, mk, r0) + for j in 1:r0 + G12[:,j] .= vec(W0[:,j] * P[:,j]') + end + G12 = sparse(G12) + G11 = kronecker(HH, sparse(I, m, m)) + G = [G11 G12; G12' G22] + return G, G11, G12, G22 +end + +function gram_b(X, W0, H0, Hadd) + b = vcat(vec(X * Hadd'), diag(W0' * X * H0')) + return b end function init_W(X::AbstractArray{T}, W0::AbstractArray{T}, H0::AbstractArray{T}, Hadd::AbstractArray{T}; α = nothing) where T @@ -176,4 +231,21 @@ function Wcols_modification(X::AbstractArray{T}, W::AbstractArray{T}, H::Abstrac return β[:] end +function truncatepos(Y, X, W, H) + ΔX = max.(zero(eltype(X)), X - W*H) + Yout = similar(Y) + for j in axes(Y, 2) + y = view(Y, :, j) + yp = max.(y, zero(eltype(y))) + ym = max.(-y, zero(eltype(y))) + if sum(ΔX * yp) >= sum(ΔX * ym) + Yout[:, j] = yp + else + Yout[:, j] = ym + end + end + return Yout +end + + end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index e6d0f7e..ac145d2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,31 +1,31 @@ using GsvdInitialization using Test -using LinearAlgebra, NMF +using LinearAlgebra, NMF, FileIO include(joinpath(dirname(@__DIR__), "demo/generate_ground_truth.jl")) W_GT, H_GT = generate_ground_truth() +svdX = load_svd_of_gt() @testset "test top wrapper" begin W = W_GT H = H_GT X = W*H - standard_nmf = nnmf(X, 10; alg = :cd, init=:nndsvd, tol=1e-4, initdata = svd(float(X))) - W_gsvd, H_gsvd = gsvdnmf(X, 9=>10; alg = :cd, maxiter = 10^5, tol_final=1e-4, tol_intermediate = 1e-4); - img_tol_int = sum(abs2, X) + standard_nmf = nnmf(X, 10; alg = :cd, init=:nndsvd, tol=1e-4, maxiter = 10^5, initdata = svdX) + _, W_gsvd, H_gsvd = gsvdnmf(X, 9=>10; alg = :cd, maxiter = 10^5, tol_final=1e-4, tol_intermediate = 1e-4); @test size(W_gsvd, 2) == 10 - @test sum(abs2, X-standard_nmf.W*standard_nmf.H)/sum(abs2, X) > sum(abs2, X-W_gsvd*H_gsvd)/sum(abs2, X) @test sum(abs2, X-W_gsvd*H_gsvd)/sum(abs2, X) < 2e-10 + @test sum(abs2, X-standard_nmf.W*standard_nmf.H)/sum(abs2, X) > sum(abs2, X-W_gsvd*H_gsvd)/sum(abs2, X) X = rand(30, 20) - W_gsvd_1, H_gsvd_1 = gsvdnmf(X, 10; alg=:cd) - W_gsvd_2, H_gsvd_2 = gsvdnmf(X, 9 => 10; alg=:cd) + _, W_gsvd_1, H_gsvd_1 = gsvdnmf(X, 10; alg=:cd) + _, W_gsvd_2, H_gsvd_2 = gsvdnmf(X, 9 => 10; alg=:cd) @test sum(abs2, W_gsvd_1-W_gsvd_2) <= 1e-12 @test sum(abs2, H_gsvd_1-H_gsvd_2) <= 1e-12 end -@testset "GsvdInitialization.jl" begin +@testset "GsvdInitialization" begin W, H = rand(10, 3), rand(3, 8) X = W*H U, S, V = svd(X) @@ -54,3 +54,31 @@ end @test β.*β0 ≈ ones(3) end + +@testset "joint optimize W and alpha" begin + W = W_GT + H = H_GT + X = W*H + _, W_gsvd, H_gsvd = gsvdnmf(X, 9=>10; alg = :cd, maxiter = 10^5, tol_final=1e-4, tol_intermediate = 1e-4, initW=:joint); + @test size(W_gsvd, 2) == 10 + @test sum(abs2, X-W_gsvd*H_gsvd)/sum(abs2, X) < 2e-10 + + W, H = rand(10, 3), rand(3, 8) + X = W*H + U, S, V = svd(X) + + W0, H0 = copy(W), copy(H) + Hadd = rand(2, 8) + Wadd, a = GsvdInitialization.init_Wa(X, W0, H0, Hadd) + @test a ≈ ones(size(W0, 2)) + @test norm(Wadd) <= 1e-8 + + G = GsvdInitialization.gram_sp_C(W0, H0, Hadd)[1] + b = GsvdInitialization.gram_b(X, W0, H0, Hadd) + Wadd = rand(10, 2) + α = rand(3) + θ = vcat(vec(Wadd), α) + E = θ'*G*θ-2*b'*θ+sum(abs2, X) + @test abs(E-sum(abs2, X-[repeat(α', size(W0, 1)).*W0 Wadd]*[H0;Hadd])) <= 1e-12 + +end