Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ numcodecs-asinh = { version = "0.4", path = "codecs/asinh", default-features = f
numcodecs-bit-round = { version = "0.4", path = "codecs/bit-round", default-features = false }
numcodecs-ebcc = { version = "0.1.1", path = "codecs/ebcc", default-features = false }
numcodecs-fixed-offset-scale = { version = "0.4", path = "codecs/fixed-offset-scale", default-features = false }
numcodecs-fourier-network = { version = "0.3", path = "codecs/fourier-network", default-features = false }
numcodecs-fourier-network = { version = "0.4", path = "codecs/fourier-network", default-features = false }
numcodecs-identity = { version = "0.4", path = "codecs/identity", default-features = false }
numcodecs-jpeg2000 = { version = "0.3", path = "codecs/jpeg2000", default-features = false }
numcodecs-lc = { version = "0.1", path = "codecs/lc", default-features = false }
Expand All @@ -86,7 +86,7 @@ numcodecs-zstd = { version = "0.4", path = "codecs/zstd", default-features = fal

# crates.io third-party dependencies
anyhow = { version = "1.0.93", default-features = false }
burn = { version = "0.18", default-features = false }
burn = { version = "0.21", default-features = false }
clap = { version = "4.6", default-features = false }
convert_case = { version = "0.8", default-features = false }
ebcc = { version = "0.2", default-features = false }
Expand Down
9 changes: 2 additions & 7 deletions codecs/fourier-network/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "numcodecs-fourier-network"
version = "0.3.0"
version = "0.4.0"
edition = { workspace = true }
authors = { workspace = true }
repository = { workspace = true }
Expand All @@ -15,20 +15,15 @@ keywords = ["fourier", "network", "numcodecs", "compression", "encoding"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
burn = { workspace = true, features = ["std", "autodiff", "ndarray"] }
burn = { workspace = true, features = ["std", "autodiff", "flex"] }
itertools = { workspace = true, features = ["use_alloc"] }
log = { workspace = true }
# FIXME: bytemuck 1.24 fails to compile on 1.87
bytemuck = { version = "=1.23.2", default-features = false }
ndarray = { workspace = true, features = ["std"] }
numcodecs = { workspace = true }
num-traits = { workspace = true, features = ["std"] }
schemars = { workspace = true, features = ["derive", "preserve_order"] }
serde = { workspace = true, features = ["std", "derive"] }
thiserror = { workspace = true }
# FIXME: burn-common -> cubecl-common brings in wasm-bindgen
# wasm-bindgen v0.2.115 has an unresolved import in wasm32-wasi
wasm-bindgen = { version = "=0.2.114", default-features = false }

[dev-dependencies]
serde_json = { workspace = true, features = ["std"] }
Expand Down
67 changes: 30 additions & 37 deletions codecs/fourier-network/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
use std::{borrow::Cow, num::NonZeroUsize, ops::AddAssign};

use burn::{
backend::{Autodiff, NdArray, ndarray::NdArrayDevice},
backend::{Autodiff, Flex, flex::FlexDevice},
module::{Module, Param},
nn::loss::{MseLoss, Reduction},
optim::{AdamConfig, GradientsParams, Optimizer},
Expand All @@ -46,13 +46,6 @@ use schemars::{JsonSchema, Schema, SchemaGenerator, json_schema};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use thiserror::Error;

// FIXME: bytemuck 1.24 fails to compile on 1.87
use ::bytemuck as _;

// FIXME: burn-common -> cubecl-common brings in wasm-bindgen
// wasm-bindgen v0.2.115 has an unresolved import in wasm32-wasi
use ::wasm_bindgen as _;

#[cfg(test)]
use ::serde_json as _;

Expand Down Expand Up @@ -109,8 +102,8 @@ impl Codec for FourierNetworkCodec {
fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
match data {
AnyCowArray::F32(data) => Ok(AnyArray::U8(
encode::<f32, _, _, Autodiff<NdArray<f32>>>(
&NdArrayDevice::Cpu,
encode::<f32, _, _, Autodiff<Flex<f32>>>(
&FlexDevice,
data,
self.fourier_features,
self.fourier_scale,
Expand All @@ -123,8 +116,8 @@ impl Codec for FourierNetworkCodec {
.into_dyn(),
)),
AnyCowArray::F64(data) => Ok(AnyArray::U8(
encode::<f64, _, _, Autodiff<NdArray<f64>>>(
&NdArrayDevice::Cpu,
encode::<f64, _, _, Autodiff<Flex<f64>>>(
&FlexDevice,
data,
self.fourier_features,
self.fourier_scale,
Expand Down Expand Up @@ -162,15 +155,15 @@ impl Codec for FourierNetworkCodec {
};

match decoded {
AnyArrayViewMut::F32(decoded) => decode_into::<f32, _, _, NdArray<f32>>(
&NdArrayDevice::Cpu,
AnyArrayViewMut::F32(decoded) => decode_into::<f32, _, _, Flex<f32>>(
&FlexDevice,
encoded,
decoded,
self.fourier_features,
self.num_blocks,
),
AnyArrayViewMut::F64(decoded) => decode_into::<f64, _, _, NdArray<f64>>(
&NdArrayDevice::Cpu,
AnyArrayViewMut::F64(decoded) => decode_into::<f64, _, _, Flex<f64>>(
&FlexDevice,
encoded,
decoded,
self.fourier_features,
Expand Down Expand Up @@ -362,7 +355,7 @@ pub fn encode<T: FloatExt, S: Data<Elem = T>, D: Dimension, B: AutodiffBackend<F
return Err(FourierNetworkCodecError::NonFiniteData);
}

B::seed(seed);
B::seed(device, seed);

let b_t = Tensor::<B, 2, Float>::random(
[data.ndim(), fourier_features.get()],
Expand Down Expand Up @@ -612,8 +605,8 @@ mod tests {
fn empty() {
std::mem::drop(simple_logger::init());

let encoded = encode::<f32, _, _, Autodiff<NdArray<f32>>>(
&NdArrayDevice::Cpu,
let encoded = encode::<f32, _, _, Autodiff<Flex<f32>>>(
&FlexDevice,
Array::<f32, _>::zeros((0,)),
NonZeroUsize::MIN,
Positive(1.0),
Expand All @@ -626,8 +619,8 @@ mod tests {
.unwrap();
assert!(encoded.is_empty());
let mut decoded = Array::<f32, _>::zeros((0,));
decode_into::<f32, _, _, NdArray<f32>>(
&NdArrayDevice::Cpu,
decode_into::<f32, _, _, Flex<f32>>(
&FlexDevice,
encoded,
decoded.view_mut(),
NonZeroUsize::MIN,
Expand All @@ -640,8 +633,8 @@ mod tests {
fn ones() {
std::mem::drop(simple_logger::init());

let encoded = encode::<f32, _, _, Autodiff<NdArray<f32>>>(
&NdArrayDevice::Cpu,
let encoded = encode::<f32, _, _, Autodiff<Flex<f32>>>(
&FlexDevice,
Array::<f32, _>::zeros((1, 1, 1, 1)),
NonZeroUsize::MIN,
Positive(1.0),
Expand All @@ -653,8 +646,8 @@ mod tests {
)
.unwrap();
let mut decoded = Array::<f32, _>::zeros((1, 1, 1, 1));
decode_into::<f32, _, _, NdArray<f32>>(
&NdArrayDevice::Cpu,
decode_into::<f32, _, _, Flex<f32>>(
&FlexDevice,
encoded,
decoded.view_mut(),
NonZeroUsize::MIN,
Expand All @@ -667,8 +660,8 @@ mod tests {
fn r#const() {
std::mem::drop(simple_logger::init());

let encoded = encode::<f32, _, _, Autodiff<NdArray<f32>>>(
&NdArrayDevice::Cpu,
let encoded = encode::<f32, _, _, Autodiff<Flex<f32>>>(
&FlexDevice,
Array::<f32, _>::from_elem((2, 1, 3), 42.0),
NonZeroUsize::MIN,
Positive(1.0),
Expand All @@ -680,8 +673,8 @@ mod tests {
)
.unwrap();
let mut decoded = Array::<f32, _>::zeros((2, 1, 3));
decode_into::<f32, _, _, NdArray<f32>>(
&NdArrayDevice::Cpu,
decode_into::<f32, _, _, Flex<f32>>(
&FlexDevice,
encoded,
decoded.view_mut(),
NonZeroUsize::MIN,
Expand All @@ -694,8 +687,8 @@ mod tests {
fn const_batched() {
std::mem::drop(simple_logger::init());

let encoded = encode::<f32, _, _, Autodiff<NdArray<f32>>>(
&NdArrayDevice::Cpu,
let encoded = encode::<f32, _, _, Autodiff<Flex<f32>>>(
&FlexDevice,
Array::<f32, _>::from_elem((2, 1, 3), 42.0),
NonZeroUsize::MIN,
Positive(1.0),
Expand All @@ -707,8 +700,8 @@ mod tests {
)
.unwrap();
let mut decoded = Array::<f32, _>::zeros((2, 1, 3));
decode_into::<f32, _, _, NdArray<f32>>(
&NdArrayDevice::Cpu,
decode_into::<f32, _, _, Flex<f32>>(
&FlexDevice,
encoded,
decoded.view_mut(),
NonZeroUsize::MIN,
Expand Down Expand Up @@ -738,8 +731,8 @@ mod tests {
Some(NonZeroUsize::MIN.saturating_add(1000)), // mini-batched, truncated
] {
let mut decoded = Array::<f64, _>::zeros(data.shape());
let encoded = encode::<f64, _, _, Autodiff<NdArray<f64>>>(
&NdArrayDevice::Cpu,
let encoded = encode::<f64, _, _, Autodiff<Flex<f64>>>(
&FlexDevice,
data.view(),
fourier_features,
fourier_scale,
Expand All @@ -751,8 +744,8 @@ mod tests {
)
.unwrap();

decode_into::<f64, _, _, NdArray<f64>>(
&NdArrayDevice::Cpu,
decode_into::<f64, _, _, Flex<f64>>(
&FlexDevice,
encoded,
decoded.view_mut(),
fourier_features,
Expand Down
10 changes: 5 additions & 5 deletions codecs/fourier-network/src/modules.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use burn::{

#[derive(Debug, Module)]
pub struct Block<B: Backend> {
bn2_1: BatchNorm<B, 0>,
bn2_1: BatchNorm<B>,
gu2_2: Gelu,
ln2_3: Linear<B>,
}
Expand All @@ -35,7 +35,7 @@ impl BlockConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> Block<B> {
Block {
bn2_1: BatchNormConfig::new(self.fourier_features.get()).init(device),
gu2_2: Gelu,
gu2_2: Gelu::new(),
ln2_3: LinearConfig::new(self.fourier_features.get(), self.fourier_features.get())
.init(device),
}
Expand All @@ -46,7 +46,7 @@ impl BlockConfig {
pub struct Model<B: Backend> {
ln1: Linear<B>,
bl2: Vec<Block<B>>,
bn3: BatchNorm<B, 0>,
bn3: BatchNorm<B>,
gu4: Gelu,
ln5: Linear<B>,
}
Expand Down Expand Up @@ -88,7 +88,7 @@ impl ModelConfig {
.map(|_| block.init(device))
.collect(),
bn3: BatchNormConfig::new(self.fourier_features.get()).init(device),
gu4: Gelu,
gu4: Gelu::new(),
ln5: LinearConfig::new(self.fourier_features.get(), 1).init(device),
}
}
Expand Down Expand Up @@ -126,7 +126,7 @@ impl<B: Backend> Record<B> for ModelExtra<B> {
}
}

#[derive(serde::Serialize, serde::Deserialize)]
#[derive(Clone, serde::Serialize, serde::Deserialize)]
#[serde(bound = "")]
pub struct ModelExtraItem<B: Backend, S: PrecisionSettings> {
model: <<Model<B> as Module<B>>::Record as Record<B>>::Item<S>,
Expand Down
Loading