From 28a85af84282e20062f53bc326033b3c8a21975c Mon Sep 17 00:00:00 2001 From: bilaljo Date: Sun, 3 Mar 2024 20:01:44 +0100 Subject: [PATCH] Restructured entire code base to modern C++ --- CMakeLists.txt | 24 +-- setup.py | 3 - src/fcwt/api.cpp | 389 ++++++++++++++++++++++++++++++++++++++++++++ src/fcwt/api.h | 112 +++++++++++++ src/fcwt/morlet.cpp | 48 ++++++ src/fcwt/morlet.h | 38 +++++ src/fcwt/scales.cpp | 93 +++++++++++ src/fcwt/scales.h | 38 +++++ src/fcwt/wavelet.h | 43 +++++ src/main.cpp | 49 +++--- src/main.h | 7 +- 11 files changed, 800 insertions(+), 44 deletions(-) create mode 100644 src/fcwt/api.cpp create mode 100644 src/fcwt/api.h create mode 100644 src/fcwt/morlet.cpp create mode 100644 src/fcwt/morlet.h create mode 100644 src/fcwt/scales.cpp create mode 100644 src/fcwt/scales.h create mode 100644 src/fcwt/wavelet.h diff --git a/CMakeLists.txt b/CMakeLists.txt index f1405d9..afd3f6c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -45,7 +45,7 @@ endif() project(fCWT VERSION 2.0 DESCRIPTION "Highly optimized implementation of the Continuous Wavelet Transform") # specify the C++ standard -set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED True) set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${PROJECT_BINARY_DIR}") @@ -53,9 +53,9 @@ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${PROJECT_BINARY_DIR}") set(CMAKE_RUNTIME_OUTPUT_DIRECTORY "${PROJECT_BINARY_DIR}") if(UNIX) - SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -mavx -O2") + SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -mavx -g3") elseif(WIN32) - SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /std:c++17 /arch:AVX /O2 /W1") + SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /std:c++20 /arch:AVX /O2 /W1") endif() #set cmake linker flags @@ -64,27 +64,31 @@ link_libraries("-lm") link_directories(${PROJECT_SOURCE_DIR}) set(libraryheaders - "./src/fcwt/fcwt.h" + "src/fcwt/api.h" + "src/fcwt/morlet.h" + "src/fcwt/wavelet.h" + "src/fcwt/scales.h" ) set(implementationheaders "./src/main.h" ) set(benchmarkheaders "./src/benchmark.h" - "./src/fcwt/fcwt.h" + "src/fcwt/api.h" "./src/rwave-bench.h" "./src/wavelib-bench.h" ) set(librarysources - "./src/fcwt/fcwt.cpp" + "src/fcwt/api.cpp" + "src/fcwt/scales.cpp" + "src/fcwt/morlet.cpp" ) set(implementationsources - "./src/main.cpp" + "./src/main.cpp" ) set(benchmarksources "./src/benchmark.cpp" - "./src/fcwt/fcwt.cpp" "./src/rwave-bench.cpp" "./src/wavelib-bench.cpp" ) @@ -326,7 +330,7 @@ if(BUILD_MATLAB) matlab_add_mex( NAME fCWTmex - SRC src/MEX/fcwtmex.cpp src/fcwt/fcwt.h src/fcwt/fcwt.cpp + SRC src/MEX/fcwtmex.cpp src/fcwt/api.h src/fcwt/api.cpp OUTPUT_NAME ${FCWT_MATLAB_DIR}/fCWT LINK_TO ${FFTW} ${FFTW_OMP} R2018a @@ -337,7 +341,7 @@ if(BUILD_MATLAB) matlab_add_mex( NAME fCWTmexplan - SRC src/MEX/fcwtplan.cpp src/fcwt/fcwt.h src/fcwt/fcwt.cpp + SRC src/MEX/fcwtplan.cpp src/fcwt/api.h src/fcwt/api.cpp OUTPUT_NAME ${FCWT_MATLAB_DIR}/fCWT_create_plan LINK_TO ${FFTW} ${FFTW_OMP} R2018a diff --git a/setup.py b/setup.py index 84395cd..0db1358 100644 --- a/setup.py +++ b/setup.py @@ -5,11 +5,8 @@ """ from setuptools import Extension, setup, find_packages -import distutils.command.build import sysconfig import numpy -import os -import shutil # Obtain the numpy include directory. This logic works across numpy versions. diff --git a/src/fcwt/api.cpp b/src/fcwt/api.cpp new file mode 100644 index 0000000..be344f7 --- /dev/null +++ b/src/fcwt/api.cpp @@ -0,0 +1,389 @@ +// +// fcwt.cpp +// fCWT +// +// Created by Lukas Arts on 21/12/2020. +// Copyright © 2021 Lukas Arts. +/*Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// +// fcwt.cpp +// fCWT-testing +// +// Created by Lukas Arts on 21/12/2020. +// Copyright © 2020 Lukas Arts. +/*Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "api.h" +#include +#include +#include +#include + + +void fcwt::API::daughter_wavelet_multiplication(fftwf_complex *input, fftwf_complex *output, float const *mother, + const float scale, int isize, bool imaginary, bool doublesided) const { + const auto isizef = static_cast(isize); + const float endpointf = std::min(isizef / 2.0f,((isizef * 2.0f / scale))); + const float step = static_cast(scale) / 2.0f; + int endpoint = static_cast(endpointf); + const int endpoint4 = endpoint >> 2; + + #ifdef AVX + //has avx instructions + __m256* O8 = (__m256*)output; + __m256* I8 = (__m256*)input; + __m256 step4 = _mm256_set1_ps(step); + __m256 offset = _mm256_set_ps(3,3,2,2,1,1,0,0); + __m256 maximum = _mm256_set1_ps(isizef-1); + + int athreads = std::min(threads, std::max(1,endpoint4/16)); + int batchsize = (endpoint4/athreads); + int s4 = (isize>>2)-1; + + #ifndef SINGLE_THREAD + #pragma omp parallel for + #endif + for(int i = 0; i(q4 * 4); + + __m256 qq = _mm256_set1_ps(q); + + U256f tmp = {_mm256_min_ps(maximum,_mm256_mul_ps(step4,_mm256_add_ps(qq,offset)))}; + //U256f tmp = {_mm256_mul_ps(step4,_mm256_add_ps(qq,offset))}; + + __m256 wav = _mm256_set_ps( + mother[static_cast(tmp.a[7])] * static_cast(1 - 2 * imaginary), + mother[static_cast(tmp.a[6])], + mother[static_cast(tmp.a[5])] * static_cast(1 - 2 * imaginary), + mother[static_cast(tmp.a[4])], + mother[static_cast(tmp.a[3])] * static_cast(1 - 2 * imaginary), + mother[static_cast(tmp.a[2])], + mother[static_cast(tmp.a[1])] * static_cast(1 - 2 * imaginary), + mother[static_cast(tmp.a[0])]); + + if(imaginary) { + __m256 tmp2 = _mm256_mul_ps(I8[q4],wav); + O8[q4] = _mm256_shuffle_ps(tmp2, tmp2, 177); + } else { + O8[q4] = _mm256_mul_ps(I8[q4],wav); + } + } + + if (doublesided) { + for(int q4 = start; q4 < end; q4++) { + auto q = static_cast(q4 * 4); + + __m256 qq = _mm256_set1_ps(q); + U256f tmp = {_mm256_mul_ps(step4,_mm256_add_ps(qq,offset))}; + + __m256 wav = _mm256_set_ps( + mother[static_cast(tmp.a[0])] * static_cast(1 - 2 * imaginary), + mother[static_cast(tmp.a[1])], + mother[static_cast(tmp.a[2])] * static_cast(1 - 2 * imaginary), + mother[static_cast(tmp.a[3])], + mother[static_cast(tmp.a[4])] * static_cast(1 - 2 * imaginary), + mother[static_cast(tmp.a[5])], + mother[static_cast(tmp.a[6])] * static_cast(1 - 2 * imaginary), + mother[static_cast(tmp.a[7])]); + + if (imaginary) { + __m256 tmp2 = _mm256_mul_ps(I8[s4-q4],wav); + O8[s4-q4] = _mm256_shuffle_ps(tmp2, tmp2, 177); + } else { + O8[s4-q4] = _mm256_mul_ps(I8[s4-q4],wav); + } + } + } + } + #else + int athreads = min(threads,max(1,endpoint/16)); + int batchsize = (endpoint/athreads); + float maximum = isizef-1; + int s1 = isize-1; + + #ifndef SINGLE_THREAD + #pragma omp parallel for + #endif + for(int i=0; i *out, Wavelet *wav, int size, int newsize, float scale, bool lastscale) { + + if (lastscale) { + #ifdef _WIN32 + fftwf_complex *lastscalemem = (fftwf_complex*)_aligned_malloc(newsize*sizeof(fftwf_complex), 32); + #else + fftwf_complex *lastscalemem = (fftwf_complex*)aligned_alloc(32, newsize*sizeof(fftwf_complex)); + #endif + memset(lastscalemem,0,sizeof(fftwf_complex)*newsize); + + fftbased(p, Ihat, O1, (float*)lastscalemem, wav->mother.data(), newsize, scale, wav->imag_frequency, wav->doublesided); + if(use_normalization) fft_normalize((std::complex*)lastscalemem, newsize); + memcpy(out, (std::complex*)lastscalemem, sizeof(std::complex)*size); + } else { + if(!out) { + std::cout << "OUT NOT A POINTER" << std::endl; + } + fftbased(p, Ihat, O1, (float*)out, wav->mother.data(), newsize, scale, wav->imag_frequency, wav->doublesided); + if(use_normalization) fft_normalize(out, newsize); + } +} + +void fcwt::API::fftbased(fftwf_plan p, fftwf_complex *Ihat, fftwf_complex *O1, float *out, float* mother, int size, float scale, bool imaginary, bool doublesided) { + + void *pt = out; + + //Perform daughter wavelet generation and multiplication with the Fourier transformed input signal + daughter_wavelet_multiplication(Ihat,O1,mother,scale,size,imaginary,doublesided); + + std::size_t space = 16; + std::align(16,sizeof(fftwf_complex),pt,space); + + fftwf_execute_dft(p,O1,(fftwf_complex*)pt); +} + +void fcwt::API::fft_normalize(std::complex* out, int size) { + + int nbatch = threads; + int batchsize = (int)ceil((float)size/((float)threads)); + + //#pragma omp parallel for + for(int i=0; i* poutput, Scales *scales, bool complexinput) { + + fftwf_complex *Ihat, *O1; + size = psize; + + //Find nearest power of 2 + const int nt = find2power(size); + const int newsize = 1 << nt; + + //Initialize intermediate result + #ifdef _WIN32 + Ihat = (fftwf_complex*)_aligned_malloc(newsize*sizeof(fftwf_complex), 32); + O1 = (fftwf_complex*)_aligned_malloc(newsize*sizeof(fftwf_complex), 32); + #else + Ihat = (fftwf_complex*)aligned_alloc(32, newsize*sizeof(fftwf_complex)); + O1 = (fftwf_complex*)aligned_alloc(32, newsize*sizeof(fftwf_complex)); + #endif + + //Copy input to new input buffer + memset(Ihat,0,sizeof(fftwf_complex)*newsize); + memset(O1,0,sizeof(fftwf_complex)*newsize); + + #ifndef SINGLE_THREAD + //Initialize FFTW plans + omp_set_num_threads(threads); + + //Initialize FFTW plans + fftwf_init_threads(); + + fftwf_plan_with_nthreads(threads); + #endif + + fftwf_plan pinv; + fftwf_plan p; + + // //Load optimization schemes if necessary + load_FFT_optimization_plan(); + + // //Perform forward FFT on input signal + float *input; + if(complexinput) { + input = (float*)calloc(newsize,sizeof(std::complex)); + memcpy(input,pinput,sizeof(std::complex)*size); + p = fftwf_plan_dft_1d(newsize, (fftwf_complex*)input, Ihat, FFTW_FORWARD, FFTW_ESTIMATE); + } else { + input = (float*)malloc(newsize*sizeof(float)); + memset(input,0,newsize*sizeof(float)); + memcpy(input,pinput,sizeof(float)*size); + p = fftwf_plan_dft_r2c_1d(newsize, input, Ihat, FFTW_ESTIMATE); + } + + fftwf_execute(p); + fftwf_destroy_plan(p); + free(input); + + pinv = fftwf_plan_dft_1d(newsize, O1, (fftwf_complex*)poutput, FFTW_BACKWARD, FFTW_ESTIMATE); + + //Generate mother wavelet function + wavelet->generate(newsize); + + for(int i=1; i<(newsize>>1); i++) { + Ihat[newsize-i][0] = Ihat[i][0]; + Ihat[newsize-i][1] = -Ihat[i][1]; + } + + std::complex *out = poutput; + + for(int i = 0; i < scales->nscales; i++) { + //FFT-base convolution in the frequency domain + convolve(pinv, Ihat, O1, out, wavelet, size, newsize, scales->scales[i], i==(scales->nscales-1)); + out = out + size; + } + + // //Cleanup + fftwf_destroy_plan(pinv); + #ifdef _WIN32 + _aligned_free(Ihat); + _aligned_free(O1); + #else + free(Ihat); + free(O1); + #endif +} + +void fcwt::API::cwt(float *pinput, int psize, std::complex* poutput, Scales *scales) { + cwt(pinput,psize,poutput,scales,false); +} + +void fcwt::API::cwt(std::complex *pinput, int psize, std::complex* poutput, Scales *scales) { + cwt((float*)pinput,psize,poutput,scales,true); +} + +void fcwt::API::cwt(float *pinput, int psize, Scales *scales, std::complex* poutput, int pn1, int pn2) { + assert((psize*scales->nscales) == (pn1*pn2)); + cwt(pinput,psize,poutput,scales); +} + +void fcwt::API::cwt(std::complex *pinput, int psize, Scales* scales, std::complex* poutput, int pn1, int pn2) { + assert((psize*scales->nscales) == (pn1*pn2)); + cwt(pinput,psize,poutput,scales); +} diff --git a/src/fcwt/api.h b/src/fcwt/api.h new file mode 100644 index 0000000..d63af19 --- /dev/null +++ b/src/fcwt/api.h @@ -0,0 +1,112 @@ +// +// fcwt.h +// fCWT +// +// Created by Lukas Arts on 21/12/2020. +// Copyright © 2021 Lukas Arts. +/*Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef FCWT_H +#define FCWT_H + + + +#include +#include +#include +#include +#include +#include + +#include + +#ifndef SINGLE_THREAD + #include +#endif +#ifdef _WIN32 + #include +#else + #include +#endif +#include "fftw3.h" +#include +//check if avx is supported and include the header +#if defined(__AVX__) + #include + #define AVX + union U256f { + __m256 v; + float a[8]; + }; +#endif + +#define PI 3.14159265358979323846264338327950288419716939937510582097494459072381640628620899862803482534211706798f +#define sqrt2PI 2.50662827463100050241576528f +#define IPI4 0.75112554446f + + +#include "wavelet.h" +#include "scales.h" + +namespace fcwt { + class API { + public: + FCWT_LIBRARY_API API(Wavelet *pwav, int pthreads=1, bool puse_optimalization_schemes=false, bool puse_normalization=false): + wavelet(pwav), + threads(pthreads), + use_optimalization_schemes(puse_optimalization_schemes), + use_normalization(puse_normalization) {}; + + void FCWT_LIBRARY_API create_FFT_optimization_plan(int maxsize, int flags) const; + void FCWT_LIBRARY_API create_FFT_optimization_plan(int pmaxsize, std::string poptimizationflags); + void FCWT_LIBRARY_API cwt(float *pinput, int psize, std::complex* poutput, Scales *scales); + void FCWT_LIBRARY_API cwt(std::complex *pinput, int psize, std::complex* poutput, Scales *scales); + void FCWT_LIBRARY_API cwt(float *pinput, int psize, Scales *scales, std::complex* poutput, int pn1, int pn2); + void FCWT_LIBRARY_API cwt(float *pinput, int psize, std::complex* poutput, Scales *scales, bool complexinput); + void FCWT_LIBRARY_API cwt(std::complex *pinput, int psize, Scales* scales, std::complex* poutput, int pn1, int pn2); + Wavelet *wavelet; + + private: + + void convolve(fftwf_plan p, fftwf_complex *Ihat, fftwf_complex *O1, std::complex *out, Wavelet *wav, int size, int newsize, float scale, bool lastscale); + + void fftbased(fftwf_plan p, fftwf_complex *Ihat, fftwf_complex *O1, float *out, float* mother, int size, float scale, bool imaginary, bool doublesided); + + void fft_normalize(std::complex* out, int size); + + void load_FFT_optimization_plan(); + + void daughter_wavelet_multiplication(fftwf_complex *input, fftwf_complex *output, float const *mother, + float scale, int isize, bool imaginary, bool doublesided) const; + + static int find2power(const int n) + { + int m = 0; + int m2 = 1 << m; /* 2 to the power of m */ + while (m2 - n < 0) { + m++; + m2 <<= 1; /* m2 = m2*2 */ + } + return(m); + } + + int threads; + int size; + float fs, f0, f1, fn; + bool use_optimalization_schemes; + bool use_normalization; + }; + +} + +#endif \ No newline at end of file diff --git a/src/fcwt/morlet.cpp b/src/fcwt/morlet.cpp new file mode 100644 index 0000000..48cbe8c --- /dev/null +++ b/src/fcwt/morlet.cpp @@ -0,0 +1,48 @@ +#include "morlet.h" + +fcwt::Morlet::Morlet(const float bandwidth): fb(bandwidth) { + fb2 = 2.0f * fb * fb; + ifb = 1.0f / fb; + imag_frequency = false; + doublesided = false; +} + +int fcwt::Morlet::getSupport(const float scale) const noexcept { + return static_cast(fb * scale * 3.0f); +} + +void fcwt::Morlet::generate(const int size) noexcept { + // Frequency domain, because we only need size. Default scale is always 2 + width = size; + const float toradians = 2 * PI / static_cast(size); + const float norm = std::sqrt(2 * PI ) * IPI4; + mother.resize(width); + for(int w = 0; w < width; w++) { + float tmp1 = (2.0f * (static_cast(w) * toradians) * fb - 2.0f * PI * fb); + tmp1 = -(tmp1 * tmp1) / 2; + mother[w] = norm * std::exp(tmp1); + } +} + +void fcwt::Morlet::generate(std::vector>& pwav, const int size, const float scale) noexcept { + // Time domain because we know size from scale + width = getSupport(scale); + std::vector res(width * 2 + 1, 0); + + const float norm = static_cast(size) * ifb * IPI4; + + //cout << scale << " ["; + for(int t = 0; t < width * 2 + 1; t++) { + const float tmp1 = static_cast(t - width) / scale; + const float tmp2 = std::exp(-tmp1 * tmp1 / fb2); + pwav[t].real(norm * tmp2 * std::cos(tmp1 * 2.0f * PI) / scale); + pwav[t].imag(norm * tmp2 * std::sin(tmp1 * 2.0f * PI) / scale); + //cout << real[t]*real[t]+imag[t]*imag[t] << ","; + } + + //cout << "]" << endl; +} + +void fcwt::Morlet::getWavelet(const float scale, std::vector>& pwav, const int pn) noexcept { + generate(pwav, pn, scale); +} diff --git a/src/fcwt/morlet.h b/src/fcwt/morlet.h new file mode 100644 index 0000000..8784576 --- /dev/null +++ b/src/fcwt/morlet.h @@ -0,0 +1,38 @@ +// +// Created by jonas on 29.02.24. +// + +#ifndef MORLET_H +#define MORLET_H + +#include +#include + +#include "wavelet.h" +#include "api.h" + +namespace fcwt { + class Morlet: public Wavelet { + public: + FCWT_LIBRARY_API explicit Morlet(float bandwidth); // frequency domain + + ~Morlet() override = default; + + // Frequency domain + void generate(int size) noexcept override; + + // Time domain + void generate(std::vector> &pwav, int size, float scale) noexcept override; + + [[nodiscard]] int getSupport(float scale) const noexcept override; + + void getWavelet(float scale, std::vector>& pwav, int pn) noexcept override; + + float fb; + + private: + float ifb, fb2; + }; +} + +#endif //MORLET_H diff --git a/src/fcwt/scales.cpp b/src/fcwt/scales.cpp new file mode 100644 index 0000000..9f077f4 --- /dev/null +++ b/src/fcwt/scales.cpp @@ -0,0 +1,93 @@ +#include "scales.h" + +#include +#include + +fcwt::Scales::Scales(const ScaleType st, const int fs, const float f0, const float f1, const int fn): fs(fs), nscales(fn) { + scales.resize(fn); + switch(st) { + case ScaleType::FCWT_LINSCALES: + calculate_logscale_array(2.0f, fs, f0, f1, fn); + break; + case ScaleType::FCWT_LOGSCALES: + calculate_linscale_array(fs, f0, f1, fn); + break; + default: + calculate_linfreq_array(fs, f0, f1, fn); + } +} + +bool fcwt::Scales::check_nyquist_satisfied(const float f, const int fs) noexcept { + if (f > static_cast(fs) / 2) { [[unlikely]] + std::cerr << "Max frequency cannot be higher than the Nyquist frequency fs/2\n"; + return false; + } + return true; +} + +void fcwt::Scales::getScales(const std::vector& pfreqs) noexcept { + scales = pfreqs; +}; + +void fcwt::Scales::getFrequencies(std::vector& pfreqs) const noexcept { + std::ranges::transform(scales, pfreqs.begin(), + [&](const float scale) {return static_cast(fs) / scale;}); +}; + +void fcwt::Scales::calculate_logscale_array(const float base, const int fs, const float f0, const float f1, + const int fn) noexcept { + + // If a signal has fs=100hz and you want to measure [0.1-50]Hz, you need scales 2 to 1000; + const float nf0 = f0; + const float nf1 = f1; + const float s0 = static_cast(fs) / nf1; + const float s1 = static_cast(fs) / nf0; + + if (!check_nyquist_satisfied(f1, fs)) { + // Cannot pass the nyquist frequency + return; + } + + const float power[2] = {std::log(s0) / std::log(base), std::log(s1) / std::log(base)}; + const float dpower = power[1] - power[0]; + + for(int i = 0; i < fn; i++) { + const float log_power = power[0] + (dpower / static_cast(fn - 1)) * static_cast(i); + scales[i] = std::pow(base, log_power); + } +} + +void fcwt::Scales::calculate_linfreq_array(const int fs, const float f0, const float f1, const int fn) noexcept { + + const float nf0 = f0; + const float nf1 = f1; + // If a signal has fs=100hz and you want to measure [0.1-50] Hz, you need scales 2 to 1000; + + if (!check_nyquist_satisfied(f1, fs)) { + // Cannot pass the nyquist frequency + return; + } + + const float df = nf1 - nf0; + + for(int i=0; i < fn; i++) { + scales[fn - i- 1] = static_cast(fs) / (nf0 + df / static_cast(fn) * static_cast(i)); + } +} + +void fcwt::Scales::calculate_linscale_array(const int fs, const float f0, const float f1, const int fn) noexcept { + // If a signal has fs=100hz and you want to measure [0.1-50]Hz, you need scales 2 to 1000; + const float s0 = static_cast(fs) / f1; + const float s1 = static_cast(fs) / f0; + + if (!check_nyquist_satisfied(f1, fs)) { + // Cannot pass the nyquist frequency + return; + } + + const float ds = s1 - s0; + + for (int i = 0; i < fn; i++) { + scales[i] = s0 + ds/ static_cast(fn) * static_cast(i); + } +} diff --git a/src/fcwt/scales.h b/src/fcwt/scales.h new file mode 100644 index 0000000..575f0ef --- /dev/null +++ b/src/fcwt/scales.h @@ -0,0 +1,38 @@ +#ifndef SCALES_H +#define SCALES_H + +#include "wavelet.h" + +namespace fcwt { + enum class ScaleType { + FCWT_LINSCALES, + FCWT_LOGSCALES, + FCWT_LINFREQS + }; + + class Scales { + public: + FCWT_LIBRARY_API Scales(ScaleType st, int fs, float f0, float f1, int fn); + + void FCWT_LIBRARY_API getScales(const std::vector& pfreqs) noexcept; + + void FCWT_LIBRARY_API getFrequencies(std::vector& pfreqs) const noexcept; + + std::vector scales; + + int fs; + + int nscales; + + private: + static bool check_nyquist_satisfied(float f, int fs) noexcept; + + void calculate_logscale_array(float base, int fs, float f0, float f1, int fn) noexcept; + + void calculate_linscale_array(int fs, float f0, float f1, int fn) noexcept; + + void calculate_linfreq_array(int fs, float f0, float f1, int fn) noexcept; + }; +} // fcwt + +#endif //SCALES_H diff --git a/src/fcwt/wavelet.h b/src/fcwt/wavelet.h new file mode 100644 index 0000000..1c1adf2 --- /dev/null +++ b/src/fcwt/wavelet.h @@ -0,0 +1,43 @@ +#ifndef WAVELET_H +#define WAVELET_H + +#include +#include +#ifdef _WIN32 + #ifdef FCWT_LIBRARY_DLL_BUILDING + #define FCWT_LIBRARY_API __declspec(dllexport) + #else + #if FCWT_LIBRARY_DLL + #define FCWT_LIBRARY_API __declspec(dllimport) + #else /* static or header-only library on Windows */ + #define FCWT_LIBRARY_API + #endif + #endif +#else /* Unix */ + #define FCWT_LIBRARY_API +#endif + +namespace fcwt { + class Wavelet { + public: + virtual ~Wavelet() = default; + + virtual void generate(std::vector> &pwav, int size, float scale) = 0; + + virtual void generate(int size) = 0; + + [[nodiscard]] virtual int getSupport(float scale) const noexcept = 0; + + virtual void getWavelet(float scale, std::vector>& pwav, int pn) = 0; + + int width = 0; + + bool imag_frequency = false; + + bool doublesided = false; + + std::vector mother; + }; +}; + +#endif //WAVELET_H diff --git a/src/main.cpp b/src/main.cpp index 02f4c40..fae4b11 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -19,9 +19,8 @@ limitations under the License. #include "main.h" -using namespace std; -int main(int argc, char * argv[]) { +int main2(int argc, char * argv[]) { int n = 1000; //signal length const int fs = 1000; //sampling frequency @@ -39,34 +38,27 @@ int main(int argc, char * argv[]) { std::vector sig(n); //input: n complex numbers - std::vector> sigc(n); + std::vector> sigc(n); //output: n x scales x 2 (complex numbers consist of two parts) - std::vector> tfm(n*fn); + std::vector> tfm(n * fn * 2); //initialize with 1 Hz cosine wave for(auto& el : sig) { - el = cos(twopi*((float)(&el - &sig[0])/(float)fs)); + el = std::cos(twopi* (static_cast(&el - &sig[0]) / static_cast(fs))); } //initialize with 1 Hz cosine wave for(auto& el : sigc) { - el = complex(cos(twopi*((float)(&el - &sigc[0])/(float)fs)), 0.0f); + el = std::complex(cos(twopi*((float)(&el - &sigc[0])/(float)fs)), 0.0f); } - + //Start timing - auto start = chrono::high_resolution_clock::now(); - - //Create a wavelet object - Wavelet *wavelet; - - //Initialize a Morlet wavelet having sigma=1.0; - Morlet morl(1.0f); - wavelet = &morl; + const auto start = std::chrono::high_resolution_clock::now(); - //Other wavelets are also possible - //DOG dog(int order); - //Paul paul(int order); + // Initialize a Morlet wavelet having sigma=1.0; + fcwt::Morlet morl(1.0f); + fcwt::Wavelet *wavelet = &morl; //Create the continuous wavelet transform object //constructor(wavelet, nthreads, optplan) @@ -75,19 +67,18 @@ int main(int argc, char * argv[]) { //wavelet - pointer to wavelet object //nthreads - number of threads to use //optplan - use FFTW optimization plans if true - FCWT fcwt(wavelet, nthreads, true, false); + fcwt::API fcwt(wavelet, nthreads, true, false); //Generate frequencies //constructor(wavelet, dist, fs, f0, f1, fn) // //Arguments - //wavelet - pointer to wavelet object //dist - FCWT_LOGSCALES | FCWT_LINSCALES for logarithmic or linear distribution of scales across frequency range //fs - sample frequency //f0 - beginning of frequency range //f1 - end of frequency range //fn - number of wavelets to generate across frequency range - Scales scs(wavelet, FCWT_LINFREQS, fs, f0, f1, fn); + fcwt::Scales scs(fcwt::ScaleType::FCWT_LINFREQS, fs, f0, f1, fn); //Perform a CWT //cwt(input, length, output, scales) @@ -97,18 +88,18 @@ int main(int argc, char * argv[]) { //length - integer signal length //output - floating pointer to output array //scales - pointer to scales object - fcwt.cwt(&sigc[0], n, &tfm[0], &scs); - + fcwt.cwt(sigc.data(), n, tfm.data(), &scs); + //End timing - auto finish = chrono::high_resolution_clock::now(); + const auto finish = std::chrono::high_resolution_clock::now(); //Calculate total duration - chrono::duration elapsed = finish - start; + const std::chrono::duration elapsed = finish - start; - cout << "=== fCWT example ===" << endl; - cout << "Calculate CWT of a 100k sample sinusodial signal using a [" << f0 << "-" << f1 << "] Hz linear frequency range and " << fn << " wavelets." << endl; - cout << "====================" << endl; - cout << "fCWT finished in " << elapsed.count() << "s" << endl; + std::cout << "=== fCWT example ===\n"; + std::cout << "Calculate CWT of a 100k sample sinusodial signal using a [" << f0 << "-" << f1 << "] Hz linear frequency range and " << fn << " wavelets.\n"; + std::cout << "====================\n"; + std::cout << "fCWT finished in " << elapsed.count() << " s\n"; return 0; } diff --git a/src/main.h b/src/main.h index 4cbf079..b77dfc4 100644 --- a/src/main.h +++ b/src/main.h @@ -41,10 +41,13 @@ limitations under the License. #define AVX #endif -using namespace std; #define PI 3.14159265358979323846264338327950288419716939937510582097494459072381640628620899862803482534211706798f -#include "fcwt/fcwt.h" +#include "fcwt/api.h" +#include "fcwt/wavelet.h" +#include "fcwt/morlet.h" +#include "fcwt/scales.h" + #include "rwave-bench.h" #include "wavelib-bench.h"