diff --git a/include/flucoma/algorithms/public/NMF.hpp b/include/flucoma/algorithms/public/NMF.hpp index 9cc6e1e0..6c6a4d16 100644 --- a/include/flucoma/algorithms/public/NMF.hpp +++ b/include/flucoma/algorithms/public/NMF.hpp @@ -11,10 +11,11 @@ under the European Union’s Horizon 2020 research and innovation programme #pragma once #include "../util/AlgorithmUtils.hpp" +#include "../util/EigenRandom.hpp" #include "../util/FluidEigenMappings.hpp" #include "../../data/FluidIndex.hpp" -#include "../../data/TensorTypes.hpp" #include "../../data/FluidMemory.hpp" +#include "../../data/TensorTypes.hpp" #include #include @@ -42,17 +43,16 @@ class NMF // processFrame computes activations of a dictionary W in a given frame void processFrame(const RealVectorView x, const RealMatrixView W0, - RealVectorView out, index nIterations, - RealVectorView v, Allocator& alloc) + RealVectorView out, index nIterations, RealVectorView v, + index randomSeed, Allocator& alloc) { using namespace Eigen; using namespace _impl; index rank = W0.extent(0); FluidEigenMap W = asEigen(W0); - + ScopedEigenMap h(rank, alloc); - h = VectorXd::Random(rank) * 0.5 + VectorXd::Constant(rank, 0.5); - + h = EigenRandom(rank, RandomSeed{randomSeed}, Range{0.0, 1.0}); ScopedEigenMap v0(x.size(), alloc); v0 = asEigen(x); W = W.array().max(epsilon).matrix(); @@ -90,7 +90,7 @@ class NMF void process(const RealMatrixView X, RealMatrixView W1, RealMatrixView H1, RealMatrixView V1, index rank, index nIterations, bool updateW, - bool updateH = false, + bool updateH = false, index randomSeed = -1, RealMatrixView W0 = RealMatrixView(nullptr, 0, 0, 0), RealMatrixView H0 = RealMatrixView(nullptr, 0, 0, 0)) { @@ -101,8 +101,8 @@ class NMF MatrixXd W; if (W0.extent(0) == 0 && W0.extent(1) == 0) { - W = MatrixXd::Random(nBins, rank) * 0.5 + - MatrixXd::Constant(nBins, rank, 0.5); + W = EigenRandom(nBins, rank, RandomSeed{randomSeed}, + Range{0.0, 1.0}); } else { @@ -113,8 +113,8 @@ class NMF MatrixXd H; if (H0.extent(0) == 0 && H0.extent(1) == 0) { - H = MatrixXd::Random(rank, nFrames) * 0.5 + - MatrixXd::Constant(rank, nFrames, 0.5); + H = EigenRandom(rank, nFrames, RandomSeed{randomSeed}, + Range{0.0, 1.0}); } else { diff --git a/include/flucoma/clients/nrt/NMFClient.hpp b/include/flucoma/clients/nrt/NMFClient.hpp index 0abca69c..4feaaa94 100644 --- a/include/flucoma/clients/nrt/NMFClient.hpp +++ b/include/flucoma/clients/nrt/NMFClient.hpp @@ -18,8 +18,8 @@ under the European Union’s Horizon 2020 research and innovation programme #include "../../algorithms/public/NMF.hpp" #include "../../algorithms/public/RatioMask.hpp" #include "../../algorithms/public/STFT.hpp" -#include "../../data/FluidTensor.hpp" #include "../../data/FluidMemory.hpp" +#include "../../data/FluidTensor.hpp" #include //for max_element #include #include //for ostringstream @@ -47,6 +47,7 @@ enum NMFParamIndex { kEnvelopesUpdate, kRank, kIterations, + kRandomSeed, kFFT }; @@ -57,7 +58,7 @@ constexpr auto BufNMFParams = defineParameters( LongParam("startChan", "Start Channel", 0, Min(0)), LongParam("numChans", "Number Channels", -1), BufferParam("resynth", "Resynthesis Buffer"), - LongParam("resynthMode","Resynthesise components", 0,Min(0),Max(1)), + LongParam("resynthMode", "Resynthesise components", 0, Min(0), Max(1)), BufferParam("bases", "Bases Buffer"), EnumParam("basesMode", "Bases Buffer Update Mode", 0, "None", "Seed", "Fixed"), @@ -66,6 +67,7 @@ constexpr auto BufNMFParams = defineParameters( "Fixed"), LongParam("components", "Number of Components", 1, Min(1)), LongParam("iterations", "Number of Iterations", 100, Min(1)), + LongParam("seed", "Random Seed", -1), FFTParam("fftSettings", "FFT Settings", 1024, -1, -1)); class NMFClient : public FluidBaseClient, public OfflineIn, public OfflineOut @@ -98,7 +100,7 @@ class NMFClient : public FluidBaseClient, public OfflineIn, public OfflineOut index nFrames = get(); index nChannels = get(); auto rangeCheck = bufferRangeCheck(get().get(), get(), - nFrames, get(), nChannels); + nFrames, get(), nChannels); if (!rangeCheck.ok()) return rangeCheck; @@ -264,8 +266,9 @@ class NMFClient : public FluidBaseClient, public OfflineIn, public OfflineOut : true; }); nmf.process(magnitude, outputFilters, outputEnvelopes, outputMags, - get(), get() * needsAnalysis, !fixFilters, !fixEnvelopes, - seededFilters, seededEnvelopes); + get(), get() * needsAnalysis, !fixFilters, + !fixEnvelopes, get(), seededFilters, + seededEnvelopes); if (c.task() && c.task()->cancelled()) return {Result::Status::kCancelled, ""}; diff --git a/include/flucoma/clients/rt/NMFFilterClient.hpp b/include/flucoma/clients/rt/NMFFilterClient.hpp index 85c8c998..6a1a455f 100644 --- a/include/flucoma/clients/rt/NMFFilterClient.hpp +++ b/include/flucoma/clients/rt/NMFFilterClient.hpp @@ -22,12 +22,13 @@ namespace fluid { namespace client { namespace nmffilter { -enum NMFFilterIndex { kFilterbuf, kMaxRank, kIterations, kFFT }; +enum NMFFilterIndex { kFilterbuf, kMaxRank, kIterations, kRandomSeed, kFFT }; constexpr auto NMFFilterParams = defineParameters( InputBufferParam("bases", "Bases Buffer"), LongParamRuntimeMax("maxComponents", "Maximum Number of Components", 20, Min(1)), LongParam("iterations", "Number of Iterations", 10, Min(1)), + LongParam("seed", "Random Seed", -1), FFTParam("fftSettings", "FFT Settings", 1024, -1, -1)); class NMFFilterClient : public FluidBaseClient, public AudioIn, public AudioOut @@ -103,7 +104,8 @@ class NMFFilterClient : public FluidBaseClient, public AudioIn, public AudioOut [&](ComplexMatrixView in, ComplexMatrixView out) { algorithm::STFT::magnitude(in, tmpMagnitude); mNMF.processFrame(tmpMagnitude.row(0), tmpFilt, tmpOut, - get(), tmpEstimate.row(0), c.allocator()); + get(), tmpEstimate.row(0), + get(), c.allocator()); mMask.init(tmpEstimate); for (index i = 0; i < rank; ++i) { diff --git a/include/flucoma/clients/rt/NMFMatchClient.hpp b/include/flucoma/clients/rt/NMFMatchClient.hpp index 0d577472..89f3d7a6 100644 --- a/include/flucoma/clients/rt/NMFMatchClient.hpp +++ b/include/flucoma/clients/rt/NMFMatchClient.hpp @@ -25,6 +25,7 @@ enum NMFMatchParamIndex { kFilterbuf, kMaxRank, kIterations, + kRandomSeed, kFFT }; @@ -33,6 +34,7 @@ constexpr auto NMFMatchParams = defineParameters( LongParamRuntimeMax("maxComponents", "Maximum Number of Components", 20, Min(1)), LongParam("iterations", "Number of Iterations", 10, Min(1)), + LongParam("seed", "Random Seed", -1), FFTParam("fftSettings", "FFT Settings", 1024, -1, -1)); class NMFMatchClient : public FluidBaseClient, public AudioIn, public ControlOut @@ -104,14 +106,16 @@ class NMFMatchClient : public FluidBaseClient, public AudioIn, public ControlOut for (index i = 0; i < filter.rows(); ++i) filter.row(i) <<= filterBuffer.samps(i); - mSTFTProcessor.processInput(get(), input, c, [&](ComplexMatrixView in) { - algorithm::STFT::magnitude(in, mags); - mNMF.processFrame(mags.row(0), filter, activations, - 10, FluidTensorView{nullptr, 0, 0}, c.allocator()); - }); output[0](Slice(0,rank)) <<= activations; output[0](Slice(rank,get().max() - rank)).fill(0); + mSTFTProcessor.processInput( + get(), input, c, [&](ComplexMatrixView in) { + algorithm::STFT::magnitude(in, mags); + mNMF.processFrame(mags.row(0), filter, activations, 10, + FluidTensorView{nullptr, 0, 0}, + get(), c.allocator()); + }); } } diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 1d6a50b8..80c3d9e0 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -116,6 +116,7 @@ add_test_executable(TestTransientSlice algorithms/public/TestTransientSlice.cpp) add_test_executable(TestMLP algorithms/public/TestMLP.cpp) add_test_executable(TestKMeans algorithms/public/TestKMeans.cpp) +add_test_executable(TestNMF algorithms/public/TestNMF.cpp) add_test_executable(TestUMAP algorithms/public/TestUMAP.cpp) add_test_executable(TestDataSampler data/detail/TestDataSampler.cpp) @@ -157,6 +158,7 @@ catch_discover_tests(TestBufferedProcess WORKING_DIRECTORY "${CMAKE_BINARY_DIR}" catch_discover_tests(TestMLP WORKING_DIRECTORY "${CMAKE_BINARY_DIR}") catch_discover_tests(TestKMeans WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) catch_discover_tests(TestEigenRandom WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) +catch_discover_tests(TestNMF WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) catch_discover_tests(TestRTPGHI WORKING_DIRECTORY "${CMAKE_BINARY_DIR}") catch_discover_tests(TestUMAP WORKING_DIRECTORY "${CMAKE_BINARY_DIR}") diff --git a/tests/algorithms/public/TestNMF.cpp b/tests/algorithms/public/TestNMF.cpp new file mode 100644 index 00000000..4c139e70 --- /dev/null +++ b/tests/algorithms/public/TestNMF.cpp @@ -0,0 +1,74 @@ +#define CATCH_CONFIG_MAIN +#include +#include +#include +#include +#include +#include + +namespace fluid { + +TEST_CASE("NMF is repeatable with user-supplied random seed") +{ + + using algorithm::NMF; + using Tensor = FluidTensor; + NMF algo; + + Tensor input{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}; + + std::vector Vs(4, Tensor(3, 3)); + std::vector Ws(4, Tensor(2, 3)); + std::vector Hs(4, Tensor(3, 2)); + + algo.process(input, Ws[0], Hs[0], Vs[0], 2, 1, true, true, 42); + algo.process(input, Ws[1], Hs[1], Vs[1], 2, 1, true, true, 42); + algo.process(input, Ws[2], Hs[2], Vs[2], 2, 1, true, true, 5063); + algo.process(input, Ws[3], Hs[3], Vs[3], 2, 1, true, true, 5063); + + using Catch::Matchers::RangeEquals; + + SECTION("Calls with the same seed have the same output") + { + REQUIRE_THAT(Ws[1], RangeEquals(Ws[0])); + REQUIRE_THAT(Hs[1], RangeEquals(Hs[0])); + REQUIRE_THAT(Vs[1], RangeEquals(Vs[0])); + REQUIRE_THAT(Ws[3], RangeEquals(Ws[2])); + REQUIRE_THAT(Hs[3], RangeEquals(Hs[2])); + REQUIRE_THAT(Vs[3], RangeEquals(Vs[2])); + } + SECTION("Calls with different seeds have different outputs") + { + REQUIRE_THAT(Ws[1], !RangeEquals(Ws[2])); + REQUIRE_THAT(Hs[1], !RangeEquals(Hs[2])); + REQUIRE_THAT(Vs[1], !RangeEquals(Vs[2])); + } +} + +TEST_CASE("NMF processFrame() is repeatable with user-supplied random seed") +{ + using fluid::algorithm::NMF; + using Tensor = fluid::FluidTensor; + using Vector = fluid::FluidTensor; + NMF algo; + + Vector input{{1, 0, 1, 0}}; + Tensor bases{{0, 0, 1, 0}, {1, 0, 0, 0}}; + Vector v(4); + + std::vector outputs(3, Vector(2)); + + index nIter{0}; + algo.processFrame(input, bases, outputs[0], nIter, v, 42, + FluidDefaultAllocator()); + algo.processFrame(input, bases, outputs[1], nIter, v, 42, + FluidDefaultAllocator()); + algo.processFrame(input, bases, outputs[2], nIter, v, 7863, + FluidDefaultAllocator()); + + using Catch::Matchers::RangeEquals; + + REQUIRE_THAT(outputs[1], RangeEquals(outputs[0])); + REQUIRE_THAT(outputs[1], !RangeEquals(outputs[2])); +} +} // namespace fluid \ No newline at end of file