diff --git a/Cargo.lock b/Cargo.lock index e83b480c9f..f276612c19 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -221,6 +221,17 @@ dependencies = [ "vad-ext", ] +[[package]] +name = "ahash" +version = "0.7.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "891477e0c6a8957309ee5c45a6368af3ae14bb510732d2684ffa19af310920f9" +dependencies = [ + "getrandom 0.2.17", + "once_cell", + "version_check", +] + [[package]] name = "ahash" version = "0.8.12" @@ -452,6 +463,15 @@ name = "anyhow" version = "1.0.100" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" +dependencies = [ + "backtrace", +] + +[[package]] +name = "anymap2" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d301b3b94cb4b2f23d7917810addbbaff90738e0ca2be692bd027e70d7e0330c" [[package]] name = "aotuv_lancer_vorbis_sys" @@ -1178,7 +1198,7 @@ dependencies = [ "futures-util", "hound", "rodio", - "rubato", + "rubato 0.16.2", "thiserror 2.0.17", "tokio", "vorbis_rs", @@ -2614,7 +2634,7 @@ version = "0.55.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b0839c297f8783316fcca9d90344424e968395413f0662a5481f79c6648bbc14" dependencies = [ - "ahash", + "ahash 0.8.12", "cached_proc_macro", "cached_proc_macro_types", "hashbrown 0.14.5", @@ -2713,7 +2733,7 @@ dependencies = [ "cudarc", "gemm 0.17.1", "half", - "memmap2", + "memmap2 0.9.9", "metal 0.27.0", "num-traits", "num_cpus", @@ -2976,7 +2996,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3e964125508474a83c95eb935697abbeb446ff4e9d62c71ce880e3986d1c606b" dependencies = [ "chinese-variant", - "enum-ordinalize", + "enum-ordinalize 4.3.2", "num-bigint", "num-traits", ] @@ -4508,6 +4528,29 @@ dependencies = [ "uuid", ] +[[package]] +name = "deep_filter" +version = "0.5.6" +source = "git+https://github.com/Rikorose/DeepFilterNet?tag=v0.5.6#978576aa8400552a4ce9730838c635aa30db5e61" +dependencies = [ + "anyhow", + "flate2", + "itertools 0.10.5", + "log", + "ndarray 0.15.6", + "num-complex", + "realfft", + "rubato 0.14.1", + "rust-ini 0.19.0", + "rustfft", + "tar", + "thiserror 1.0.69", + "tract-core", + "tract-hir", + "tract-onnx", + "tract-pulse", +] + [[package]] name = "deepgram" version = "0.7.0" @@ -4662,6 +4705,19 @@ dependencies = [ "tokio", ] +[[package]] +name = "denoise" +version = "0.1.0" +dependencies = [ + "criterion", + "data", + "deep_filter", + "hound", + "ndarray 0.15.6", + "rodio", + "thiserror 2.0.17", +] + [[package]] name = "der" version = "0.6.1" @@ -4693,6 +4749,17 @@ dependencies = [ "serde_core", ] +[[package]] +name = "derive-new" +version = "0.5.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3418329ca0ad70234b9735dc4ceed10af4df60eff9c8e7b06cb5e520d92c3535" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "derive_arbitrary" version = "1.4.2" @@ -5308,6 +5375,18 @@ dependencies = [ "zeroize", ] +[[package]] +name = "educe" +version = "0.4.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f0042ff8246a363dbe77d2ceedb073339e85a804b9a47636c6e016a9a32c05f" +dependencies = [ + "enum-ordinalize 3.1.15", + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "either" version = "1.15.0" @@ -5448,6 +5527,19 @@ dependencies = [ "syn 2.0.114", ] +[[package]] +name = "enum-ordinalize" +version = "3.1.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bf1fa3f06bbff1ea5b1a9c7b14aa992a39657db60a2759457328d7e058f49ee" +dependencies = [ + "num-bigint", + "num-traits", + "proc-macro2", + "quote", + "syn 2.0.114", +] + [[package]] name = "enum-ordinalize" version = "4.3.2" @@ -6072,7 +6164,7 @@ checksum = "b0299020c3ef3f60f526a4f64ab4a3d4ce116b1acbf24cdd22da0068e5d81dc3" dependencies = [ "fontconfig-parser", "log", - "memmap2", + "memmap2 0.9.9", "slotmap", "tinyvec", "ttf-parser 0.20.0", @@ -6086,7 +6178,7 @@ checksum = "457e789b3d1202543297a350643cf459f836cade38934e7a4cf6a39e7cde2905" dependencies = [ "fontconfig-parser", "log", - "memmap2", + "memmap2 0.9.9", "slotmap", "tinyvec", "ttf-parser 0.25.1", @@ -6816,7 +6908,7 @@ version = "0.1.0" dependencies = [ "byteorder", "dirs 6.0.0", - "memmap2", + "memmap2 0.9.9", "strum 0.26.3", "thiserror 2.0.17", ] @@ -7012,7 +7104,7 @@ dependencies = [ "bstr", "gix-chunk", "gix-hash 0.17.0", - "memmap2", + "memmap2 0.9.9", "thiserror 2.0.17", ] @@ -7283,7 +7375,7 @@ dependencies = [ "hashbrown 0.14.5", "itoa", "libc", - "memmap2", + "memmap2 0.9.9", "rustix 0.38.44", "smallvec 1.15.1", "thiserror 2.0.17", @@ -7355,7 +7447,7 @@ dependencies = [ "gix-hashtable", "gix-object", "gix-path", - "memmap2", + "memmap2 0.9.9", "smallvec 1.15.1", "thiserror 2.0.17", ] @@ -7457,7 +7549,7 @@ dependencies = [ "gix-tempfile", "gix-utils 0.2.0", "gix-validate 0.9.4", - "memmap2", + "memmap2 0.9.9", "thiserror 2.0.17", "winnow 0.7.14", ] @@ -8496,19 +8588,34 @@ dependencies = [ "byteorder", ] +[[package]] +name = "hashbrown" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e" +dependencies = [ + "ahash 0.7.8", +] + [[package]] name = "hashbrown" version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" +[[package]] +name = "hashbrown" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43a3c133739dddd0d2990f9a4bdf8eb4b21ef50e4851ca85ab661199821d510e" + [[package]] name = "hashbrown" version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" dependencies = [ - "ahash", + "ahash 0.8.12", "allocator-api2", ] @@ -9953,7 +10060,7 @@ version = "0.29.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "161c33c3ec738cfea3288c5c53dfcdb32fd4fc2954de86ea06f71b5a1a40bfcd" dependencies = [ - "ahash", + "ahash 0.8.12", "base64 0.22.1", "bytecount", "email_address", @@ -10039,7 +10146,7 @@ source = "git+https://github.com/thewh1teagle/pyannote-rs?rev=d97bd3b#d97bd3b9d5 dependencies = [ "eyre", "knf-rs-sys", - "ndarray", + "ndarray 0.16.1", ] [[package]] @@ -10124,6 +10231,7 @@ version = "2.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "558bf9508a558512042d3095138b1f7b8fe90c5467d94f9f1da28b3731c5dbd1" dependencies = [ + "serde", "static_assertions", ] @@ -10538,7 +10646,7 @@ dependencies = [ "gl", "image 0.25.9", "khronos-egl", - "memmap2", + "memmap2 0.9.9", "rustix 1.1.3", "thiserror 2.0.17", "tracing", @@ -10582,6 +10690,63 @@ dependencies = [ "rand_chacha 0.3.1", ] +[[package]] +name = "liquid" +version = "0.26.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e9338405fdbc0bce9b01695b2a2ef6b20eca5363f385d47bce48ddf8323cc25" +dependencies = [ + "doc-comment", + "liquid-core", + "liquid-derive", + "liquid-lib", + "serde", +] + +[[package]] +name = "liquid-core" +version = "0.26.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "feb8fed70857010ed9016ed2ce5a7f34e7cc51d5d7255c9c9dc2e3243e490b42" +dependencies = [ + "anymap2", + "itertools 0.13.0", + "kstring", + "liquid-derive", + "num-traits", + "pest", + "pest_derive", + "regex", + "serde", + "time", +] + +[[package]] +name = "liquid-derive" +version = "0.26.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b51f1d220e3fa869e24cfd75915efe3164bd09bb11b3165db3f37f57bf673e3" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.114", +] + +[[package]] +name = "liquid-lib" +version = "0.26.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee1794b5605e9f8864a8a4f41aa97976b42512cc81093f8c885d29fb94c6c556" +dependencies = [ + "itertools 0.13.0", + "liquid-core", + "once_cell", + "percent-encoding", + "regex", + "time", + "unicode-segmentation", +] + [[package]] name = "litemap" version = "0.7.5" @@ -10924,6 +11089,12 @@ dependencies = [ "libc", ] +[[package]] +name = "maplit" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e2e65a1a2e43cfcb47a895c4c8b10d1f4a61097f9f254f183aee60cad9c651d" + [[package]] name = "markdown" version = "1.0.0-alpha.21" @@ -11083,6 +11254,15 @@ version = "2.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" +[[package]] +name = "memmap2" +version = "0.5.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83faa42c0a078c393f6b29d5db232d8be22776a891f8f56e5284faee4a20b327" +dependencies = [ + "libc", +] + [[package]] name = "memmap2" version = "0.9.9" @@ -11183,7 +11363,7 @@ dependencies = [ "libc", "log", "mach2", - "memmap2", + "memmap2 0.9.9", "memoffset", "minidump-common", "nix 0.28.0", @@ -11492,6 +11672,20 @@ dependencies = [ "tempfile", ] +[[package]] +name = "ndarray" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "rawpointer", + "serde", +] + [[package]] name = "ndarray" version = "0.16.1" @@ -11831,6 +12025,7 @@ checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" dependencies = [ "bytemuck", "num-traits", + "serde", ] [[package]] @@ -12537,7 +12732,7 @@ dependencies = [ name = "onnx" version = "0.1.0" dependencies = [ - "ndarray", + "ndarray 0.16.1", "ort", "thiserror 2.0.17", ] @@ -12684,6 +12879,16 @@ dependencies = [ "num-traits", ] +[[package]] +name = "ordered-multimap" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ed8acf08e98e744e5384c8bc63ceb0364e68a6854187221c18df61c4797690e" +dependencies = [ + "dlv-list", + "hashbrown 0.13.2", +] + [[package]] name = "ordered-multimap" version = "0.7.3" @@ -12711,7 +12916,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fa7e49bd669d32d7bc2a15ec540a527e7764aec722a45467814005725bcd721" dependencies = [ "libloading 0.8.9", - "ndarray", + "ndarray 0.16.1", "ort-sys", "smallvec 2.0.0-alpha.10", "tracing", @@ -14660,7 +14865,7 @@ version = "0.29.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "40a64b3a635fad9000648b4d8a59c8710c523ab61a23d392a7d91d47683f5adc" dependencies = [ - "ahash", + "ahash 0.8.12", "fluent-uri", "once_cell", "parking_lot 0.12.5", @@ -15105,6 +15310,18 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rubato" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6dd52e80cfc21894deadf554a5673002938ae4625f7a283e536f9cf7c17b0d5" +dependencies = [ + "num-complex", + "num-integer", + "num-traits", + "realfft", +] + [[package]] name = "rubato" version = "0.16.2" @@ -15152,6 +15369,16 @@ dependencies = [ "walkdir", ] +[[package]] +name = "rust-ini" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e2a3bcec1f113553ef1c88aae6c020a369d03d55b58de9869a0908930385091" +dependencies = [ + "cfg-if", + "ordered-multimap 0.6.0", +] + [[package]] name = "rust-ini" version = "0.21.3" @@ -15159,7 +15386,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "796e8d2b6696392a43bea58116b667fb4c29727dc5abd27d6acf338bb4f688c7" dependencies = [ "cfg-if", - "ordered-multimap", + "ordered-multimap 0.7.3", ] [[package]] @@ -15464,6 +15691,15 @@ dependencies = [ "regex", ] +[[package]] +name = "scan_fmt" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b53b0a5db882a8e2fdaae0a43f7b39e7e9082389e978398bdf223a55b581248" +dependencies = [ + "regex", +] + [[package]] name = "scc" version = "2.4.0" @@ -16385,9 +16621,9 @@ version = "0.1.0" source = "git+https://github.com/emotechlab/silero-rs?rev=26a6460#26a646003cd8532ae2dde424ccdab1b6cdf5d7b0" dependencies = [ "anyhow", - "ndarray", + "ndarray 0.16.1", "ort", - "rubato", + "rubato 0.16.2", "thiserror 2.0.17", "tracing", ] @@ -16848,6 +17084,17 @@ dependencies = [ "float-cmp 0.9.0", ] +[[package]] +name = "string-interner" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e2531d8525b29b514d25e275a43581320d587b86db302b9a7e464bac579648" +dependencies = [ + "cfg-if", + "hashbrown 0.11.2", + "serde", +] + [[package]] name = "string_cache" version = "0.8.9" @@ -17687,7 +17934,7 @@ dependencies = [ "lru 0.12.5", "lz4_flex", "measure_time", - "memmap2", + "memmap2 0.9.9", "once_cell", "oneshot", "rayon", @@ -18249,7 +18496,7 @@ checksum = "444b091f24f2f6bdb4a305b54d3961f629c11861c685aceeea9a1972f89e43d5" dependencies = [ "dunce", "plist", - "rust-ini", + "rust-ini 0.21.3", "serde", "serde_json", "tauri", @@ -19444,9 +19691,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2cd081e2be8ae77e8f2a69bfa5c34d1931de1c79a32879ba02d93c55e084a5f" dependencies = [ "log", - "ndarray", + "ndarray 0.16.1", "ort", - "rubato", + "rubato 0.16.2", "rustfft", "thiserror 2.0.17", ] @@ -19754,7 +20001,7 @@ version = "0.21.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a620b996116a59e184c2fa2dfd8251ea34a36d0a514758c6f966386bd2e03476" dependencies = [ - "ahash", + "ahash 0.8.12", "aho-corasick", "compact_str", "dary_heap", @@ -20347,6 +20594,160 @@ dependencies = [ "tracing-log", ] +[[package]] +name = "tract-core" +version = "0.19.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dae91e4486af81c5a154dce2a1d7c0780d35c2de8bc42e94bdec995333f31b90" +dependencies = [ + "anyhow", + "bit-set 0.5.3", + "derive-new", + "downcast-rs 1.2.1", + "dyn-clone", + "educe", + "lazy_static", + "log", + "maplit", + "ndarray 0.15.6", + "num-integer", + "num-traits", + "rustfft", + "smallvec 1.15.1", + "tract-data", + "tract-linalg", +] + +[[package]] +name = "tract-data" +version = "0.19.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "027e05e3537cb13f5e84b7664de25ed326a1d42c08d9985694f48f6efe3483ee" +dependencies = [ + "anyhow", + "educe", + "half", + "itertools 0.10.5", + "lazy_static", + "maplit", + "ndarray 0.15.6", + "nom 7.1.3", + "num-complex", + "num-integer", + "num-traits", + "scan_fmt", + "smallvec 1.15.1", + "string-interner", +] + +[[package]] +name = "tract-hir" +version = "0.19.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5f72648d914f724e188cf679f3dd74f069eea36d8670633acf8889b94391a54" +dependencies = [ + "derive-new", + "educe", + "log", + "tract-core", +] + +[[package]] +name = "tract-linalg" +version = "0.19.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49fb02b3ee7b77054a3d1fecfe4bc2f7523f587a8f1814b32089f93e2a573244" +dependencies = [ + "cc", + "derive-new", + "downcast-rs 1.2.1", + "dyn-clone", + "half", + "lazy_static", + "liquid", + "liquid-core", + "log", + "num-traits", + "paste", + "scan_fmt", + "smallvec 1.15.1", + "tract-data", + "unicode-normalization", + "walkdir", +] + +[[package]] +name = "tract-nnef" +version = "0.19.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0895153ea97091967f92121715a5c80fbcc0841c53a1b2fb9ba93a89ed644357" +dependencies = [ + "byteorder", + "flate2", + "log", + "nom 7.1.3", + "tar", + "tract-core", + "walkdir", +] + +[[package]] +name = "tract-onnx" +version = "0.19.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21f752abf4627894827cdfea0ffd9089bcd31d4b5467bb32bd0908a4d3ab22b5" +dependencies = [ + "bytes", + "derive-new", + "educe", + "log", + "memmap2 0.5.10", + "num-integer", + "prost 0.11.9", + "smallvec 1.15.1", + "tract-hir", + "tract-nnef", + "tract-onnx-opl", +] + +[[package]] +name = "tract-onnx-opl" +version = "0.19.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb740c8e25f65f6e070c9438fb4c2671d72d36470c01e49b3002e61c7c01d0cb" +dependencies = [ + "educe", + "getrandom 0.2.17", + "log", + "rand 0.8.5", + "rand_distr 0.4.3", + "rustfft", + "tract-nnef", +] + +[[package]] +name = "tract-pulse" +version = "0.19.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "502f5ecdcb6c84d3c7caa1bbb2f23dbe0980055f8bd615fe3ef9ba37cb8f22dd" +dependencies = [ + "downcast-rs 1.2.1", + "lazy_static", + "log", + "tract-pulse-opl", +] + +[[package]] +name = "tract-pulse-opl" +version = "0.19.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5fcf838d40dd2b8b6454aaef8d4ce772aeca44e87c9628591d0b2919a06ed89" +dependencies = [ + "downcast-rs 1.2.1", + "lazy_static", + "tract-nnef", +] + [[package]] name = "transcribe-aws" version = "0.1.0" @@ -21079,7 +21480,7 @@ dependencies = [ "gemm 0.18.2", "half", "libloading 0.8.9", - "memmap2", + "memmap2 0.9.9", "num", "num-traits", "num_cpus", @@ -23439,7 +23840,7 @@ checksum = "8d66ca9352cbd4eecbbc40871d8a11b4ac8107cfc528a6e14d7c19c69d0e1ac9" dependencies = [ "as-raw-xcb-connection", "libc", - "memmap2", + "memmap2 0.9.9", "xkeysym", ] @@ -23760,7 +24161,7 @@ version = "0.4.0-zed" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c0b46ed118eba34d9ba53d94ddc0b665e0e06a2cf874cfa2dd5dec278148642" dependencies = [ - "ahash", + "ahash 0.8.12", "hashbrown 0.14.5", "log", "x11rb", diff --git a/Cargo.toml b/Cargo.toml index 40d624955f..b09485d6f4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,6 +51,7 @@ hypr-bundle = { path = "crates/bundle", package = "bundle" } hypr-data = { path = "crates/data", package = "data" } hypr-db-core = { path = "crates/db-core", package = "db-core" } hypr-db-user = { path = "crates/db-user", package = "db-user" } +hypr-denoise = { path = "crates/denoise", package = "denoise" } hypr-detect = { path = "crates/detect", package = "detect" } hypr-device-monitor = { path = "crates/device-monitor", package = "device-monitor" } hypr-docs = { path = "crates/docs", package = "docs" } diff --git a/crates/denoise/Cargo.toml b/crates/denoise/Cargo.toml new file mode 100644 index 0000000000..01639387c3 --- /dev/null +++ b/crates/denoise/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "denoise" +version = "0.1.0" +edition = "2024" + +[features] +default = ["default-model"] +default-model = ["deep_filter/default-model", "deep_filter/tract"] +default-model-ll = ["deep_filter/default-model-ll", "deep_filter/tract"] + +[dependencies] +deep_filter = { git = "https://github.com/Rikorose/DeepFilterNet", tag = "v0.5.6", default-features = false } +ndarray = "0.15" + +thiserror = { workspace = true } + +[dev-dependencies] +criterion = { workspace = true } +hound = { workspace = true } +hypr-data = { workspace = true } +rodio = { workspace = true } + +[[bench]] +name = "denoise_bench" +harness = false diff --git a/crates/denoise/benches/denoise_bench.rs b/crates/denoise/benches/denoise_bench.rs new file mode 100644 index 0000000000..ea5f0cc992 --- /dev/null +++ b/crates/denoise/benches/denoise_bench.rs @@ -0,0 +1,89 @@ +use std::hint::black_box; + +use criterion::{Criterion, criterion_group, criterion_main}; +use hound::WavReader; + +use denoise::{Denoiser, HOP_SIZE, SAMPLE_RATE}; + +fn load_test_audio() -> Vec { + let wav_path = hypr_data::english_1::AUDIO_PART6_48000HZ_PATH; + let reader = WavReader::open(wav_path).expect("Failed to open WAV file"); + let spec = reader.spec(); + + assert_eq!( + spec.sample_rate as usize, SAMPLE_RATE, + "Expected 48kHz audio" + ); + + let samples: Vec = reader + .into_samples::() + .map(|s| s.unwrap() as f32 / 32768.0) + .collect(); + + let num_frames = samples.len() / HOP_SIZE; + samples[..num_frames * HOP_SIZE].to_vec() +} + +fn bench_denoiser_initialization(c: &mut Criterion) { + c.bench_function("denoiser_initialization", |b| { + b.iter(|| black_box(Denoiser::new().unwrap())) + }); +} + +fn bench_denoiser_process_frame(c: &mut Criterion) { + let mut denoiser = Denoiser::new().unwrap(); + let audio = load_test_audio(); + let input = &audio[..HOP_SIZE]; + let mut output = vec![0.0f32; HOP_SIZE]; + + c.bench_function("denoiser_process_frame", |b| { + b.iter(|| { + black_box( + denoiser + .process_frame(black_box(input), black_box(&mut output)) + .unwrap(), + ) + }) + }); +} + +fn bench_denoiser_process(c: &mut Criterion) { + let mut denoiser = Denoiser::new().unwrap(); + let audio = load_test_audio(); + + let frame_counts = [10, 100]; + + for &num_frames in &frame_counts { + let samples = num_frames * HOP_SIZE; + if samples <= audio.len() { + let input = &audio[..samples]; + + c.bench_function(&format!("denoiser_process_{}_frames", num_frames), |b| { + b.iter(|| black_box(denoiser.process(black_box(input)).unwrap())) + }); + } + } +} + +fn bench_denoiser_throughput(c: &mut Criterion) { + let mut denoiser = Denoiser::new().unwrap(); + let audio = load_test_audio(); + + let mut group = c.benchmark_group("denoiser_throughput"); + group.throughput(criterion::Throughput::Elements(audio.len() as u64)); + + group.bench_function("samples_per_second", |b| { + b.iter(|| black_box(denoiser.process(black_box(&audio)).unwrap())) + }); + + group.finish(); +} + +criterion_group!( + benches, + bench_denoiser_initialization, + bench_denoiser_process_frame, + bench_denoiser_process, + bench_denoiser_throughput +); +criterion_main!(benches); diff --git a/crates/denoise/src/error.rs b/crates/denoise/src/error.rs new file mode 100644 index 0000000000..06a0c5d295 --- /dev/null +++ b/crates/denoise/src/error.rs @@ -0,0 +1,16 @@ +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum Error { + #[error("Failed to initialize denoiser: {0}")] + InitError(String), + + #[error("Failed to process audio: {0}")] + ProcessError(String), + + #[error("Invalid sample rate: expected {expected}, got {actual}")] + InvalidSampleRate { expected: usize, actual: usize }, + + #[error("Invalid frame size: expected {expected}, got {actual}")] + InvalidFrameSize { expected: usize, actual: usize }, +} diff --git a/crates/denoise/src/lib.rs b/crates/denoise/src/lib.rs new file mode 100644 index 0000000000..6433687223 --- /dev/null +++ b/crates/denoise/src/lib.rs @@ -0,0 +1,151 @@ +mod error; +pub use error::Error; + +#[cfg(any(feature = "default-model", feature = "default-model-ll"))] +use df::tract::{DfParams, DfTract, RuntimeParams}; + +pub const SAMPLE_RATE: usize = 48000; +pub const HOP_SIZE: usize = 480; + +pub struct Denoiser { + #[cfg(any(feature = "default-model", feature = "default-model-ll"))] + model: DfTract, +} + +impl Denoiser { + #[cfg(any(feature = "default-model", feature = "default-model-ll"))] + pub fn new() -> Result { + let df_params = DfParams::default(); + let runtime_params = RuntimeParams::default(); + + let model = DfTract::new(df_params, &runtime_params) + .map_err(|e| Error::InitError(e.to_string()))?; + + Ok(Self { model }) + } + + #[cfg(any(feature = "default-model", feature = "default-model-ll"))] + pub fn sample_rate(&self) -> usize { + self.model.sr + } + + #[cfg(any(feature = "default-model", feature = "default-model-ll"))] + pub fn hop_size(&self) -> usize { + self.model.hop_size + } + + #[cfg(any(feature = "default-model", feature = "default-model-ll"))] + pub fn reset(&mut self) -> Result<(), Error> { + self.model + .init() + .map_err(|e| Error::ProcessError(e.to_string())) + } + + #[cfg(any(feature = "default-model", feature = "default-model-ll"))] + pub fn process_frame(&mut self, input: &[f32], output: &mut [f32]) -> Result { + use ndarray::{ArrayView2, ArrayViewMut2}; + + if input.len() != self.model.hop_size { + return Err(Error::InvalidFrameSize { + expected: self.model.hop_size, + actual: input.len(), + }); + } + + if output.len() != self.model.hop_size { + return Err(Error::InvalidFrameSize { + expected: self.model.hop_size, + actual: output.len(), + }); + } + + let noisy = ArrayView2::from_shape((1, self.model.hop_size), input) + .map_err(|e| Error::ProcessError(e.to_string()))?; + let mut enh = ArrayViewMut2::from_shape((1, self.model.hop_size), output) + .map_err(|e| Error::ProcessError(e.to_string()))?; + + self.model + .process(noisy, enh.view_mut()) + .map_err(|e| Error::ProcessError(e.to_string())) + } + + #[cfg(any(feature = "default-model", feature = "default-model-ll"))] + pub fn process(&mut self, input: &[f32]) -> Result, Error> { + let hop_size = self.model.hop_size; + if input.len() % hop_size != 0 { + return Err(Error::InvalidFrameSize { + expected: hop_size, + actual: input.len(), + }); + } + let num_frames = input.len() / hop_size; + let mut output = vec![0.0f32; num_frames * hop_size]; + + for i in 0..num_frames { + let start = i * hop_size; + let end = start + hop_size; + self.process_frame(&input[start..end], &mut output[start..end])?; + } + + Ok(output) + } +} + +#[cfg(all(test, any(feature = "default-model", feature = "default-model-ll")))] +mod tests { + use super::*; + + #[test] + fn test_denoiser_creation() { + let denoiser = Denoiser::new(); + assert!(denoiser.is_ok()); + } + + #[test] + fn test_denoiser_sample_rate() { + let denoiser = Denoiser::new().unwrap(); + assert_eq!(denoiser.sample_rate(), SAMPLE_RATE); + } + + #[test] + fn test_denoiser_hop_size() { + let denoiser = Denoiser::new().unwrap(); + assert_eq!(denoiser.hop_size(), HOP_SIZE); + } + + #[test] + fn test_process_frame() { + let mut denoiser = Denoiser::new().unwrap(); + let input = vec![0.0f32; HOP_SIZE]; + let mut output = vec![0.0f32; HOP_SIZE]; + + let result = denoiser.process_frame(&input, &mut output); + assert!(result.is_ok()); + } + + #[test] + fn test_process() { + let mut denoiser = Denoiser::new().unwrap(); + let input = vec![0.0f32; HOP_SIZE * 10]; + + let result = denoiser.process(&input); + assert!(result.is_ok()); + assert_eq!(result.unwrap().len(), HOP_SIZE * 10); + } + + #[test] + fn test_process_invalid_length() { + let mut denoiser = Denoiser::new().unwrap(); + let input = vec![0.0f32; HOP_SIZE * 10 + 100]; + + let result = denoiser.process(&input); + assert!(result.is_err()); + match result { + Err(Error::InvalidFrameSize { expected, actual }) => { + assert_eq!(expected, HOP_SIZE); + assert_eq!(actual, HOP_SIZE * 10 + 100); + } + _ => panic!("Expected InvalidFrameSize error"), + } + } +}