From 9ab7403228b8ce8edb860f75a6d4da5b2cca142f Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Fri, 15 May 2026 10:03:14 +0300 Subject: [PATCH] Upgrade burn to v0.21 --- Cargo.toml | 4 +- codecs/fourier-network/Cargo.toml | 9 +--- codecs/fourier-network/src/lib.rs | 67 ++++++++++++--------------- codecs/fourier-network/src/modules.rs | 10 ++-- 4 files changed, 39 insertions(+), 51 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index c148495ee..e926fed9d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -62,7 +62,7 @@ numcodecs-asinh = { version = "0.4", path = "codecs/asinh", default-features = f numcodecs-bit-round = { version = "0.4", path = "codecs/bit-round", default-features = false } numcodecs-ebcc = { version = "0.1.1", path = "codecs/ebcc", default-features = false } numcodecs-fixed-offset-scale = { version = "0.4", path = "codecs/fixed-offset-scale", default-features = false } -numcodecs-fourier-network = { version = "0.3", path = "codecs/fourier-network", default-features = false } +numcodecs-fourier-network = { version = "0.4", path = "codecs/fourier-network", default-features = false } numcodecs-identity = { version = "0.4", path = "codecs/identity", default-features = false } numcodecs-jpeg2000 = { version = "0.3", path = "codecs/jpeg2000", default-features = false } numcodecs-lc = { version = "0.1", path = "codecs/lc", default-features = false } @@ -86,7 +86,7 @@ numcodecs-zstd = { version = "0.4", path = "codecs/zstd", default-features = fal # crates.io third-party dependencies anyhow = { version = "1.0.93", default-features = false } -burn = { version = "0.18", default-features = false } +burn = { version = "0.21", default-features = false } clap = { version = "4.6", default-features = false } convert_case = { version = "0.8", default-features = false } ebcc = { version = "0.2", default-features = false } diff --git a/codecs/fourier-network/Cargo.toml b/codecs/fourier-network/Cargo.toml index 202c97974..f70d1471c 100644 --- a/codecs/fourier-network/Cargo.toml +++ b/codecs/fourier-network/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "numcodecs-fourier-network" -version = "0.3.0" +version = "0.4.0" edition = { workspace = true } authors = { workspace = true } repository = { workspace = true } @@ -15,20 +15,15 @@ keywords = ["fourier", "network", "numcodecs", "compression", "encoding"] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -burn = { workspace = true, features = ["std", "autodiff", "ndarray"] } +burn = { workspace = true, features = ["std", "autodiff", "flex"] } itertools = { workspace = true, features = ["use_alloc"] } log = { workspace = true } -# FIXME: bytemuck 1.24 fails to compile on 1.87 -bytemuck = { version = "=1.23.2", default-features = false } ndarray = { workspace = true, features = ["std"] } numcodecs = { workspace = true } num-traits = { workspace = true, features = ["std"] } schemars = { workspace = true, features = ["derive", "preserve_order"] } serde = { workspace = true, features = ["std", "derive"] } thiserror = { workspace = true } -# FIXME: burn-common -> cubecl-common brings in wasm-bindgen -# wasm-bindgen v0.2.115 has an unresolved import in wasm32-wasi -wasm-bindgen = { version = "=0.2.114", default-features = false } [dev-dependencies] serde_json = { workspace = true, features = ["std"] } diff --git a/codecs/fourier-network/src/lib.rs b/codecs/fourier-network/src/lib.rs index 10dddef9f..e0c6b8bfc 100644 --- a/codecs/fourier-network/src/lib.rs +++ b/codecs/fourier-network/src/lib.rs @@ -22,7 +22,7 @@ use std::{borrow::Cow, num::NonZeroUsize, ops::AddAssign}; use burn::{ - backend::{Autodiff, NdArray, ndarray::NdArrayDevice}, + backend::{Autodiff, Flex, flex::FlexDevice}, module::{Module, Param}, nn::loss::{MseLoss, Reduction}, optim::{AdamConfig, GradientsParams, Optimizer}, @@ -46,13 +46,6 @@ use schemars::{JsonSchema, Schema, SchemaGenerator, json_schema}; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use thiserror::Error; -// FIXME: bytemuck 1.24 fails to compile on 1.87 -use ::bytemuck as _; - -// FIXME: burn-common -> cubecl-common brings in wasm-bindgen -// wasm-bindgen v0.2.115 has an unresolved import in wasm32-wasi -use ::wasm_bindgen as _; - #[cfg(test)] use ::serde_json as _; @@ -109,8 +102,8 @@ impl Codec for FourierNetworkCodec { fn encode(&self, data: AnyCowArray) -> Result { match data { AnyCowArray::F32(data) => Ok(AnyArray::U8( - encode::>>( - &NdArrayDevice::Cpu, + encode::>>( + &FlexDevice, data, self.fourier_features, self.fourier_scale, @@ -123,8 +116,8 @@ impl Codec for FourierNetworkCodec { .into_dyn(), )), AnyCowArray::F64(data) => Ok(AnyArray::U8( - encode::>>( - &NdArrayDevice::Cpu, + encode::>>( + &FlexDevice, data, self.fourier_features, self.fourier_scale, @@ -162,15 +155,15 @@ impl Codec for FourierNetworkCodec { }; match decoded { - AnyArrayViewMut::F32(decoded) => decode_into::>( - &NdArrayDevice::Cpu, + AnyArrayViewMut::F32(decoded) => decode_into::>( + &FlexDevice, encoded, decoded, self.fourier_features, self.num_blocks, ), - AnyArrayViewMut::F64(decoded) => decode_into::>( - &NdArrayDevice::Cpu, + AnyArrayViewMut::F64(decoded) => decode_into::>( + &FlexDevice, encoded, decoded, self.fourier_features, @@ -362,7 +355,7 @@ pub fn encode, D: Dimension, B: AutodiffBackend::random( [data.ndim(), fourier_features.get()], @@ -612,8 +605,8 @@ mod tests { fn empty() { std::mem::drop(simple_logger::init()); - let encoded = encode::>>( - &NdArrayDevice::Cpu, + let encoded = encode::>>( + &FlexDevice, Array::::zeros((0,)), NonZeroUsize::MIN, Positive(1.0), @@ -626,8 +619,8 @@ mod tests { .unwrap(); assert!(encoded.is_empty()); let mut decoded = Array::::zeros((0,)); - decode_into::>( - &NdArrayDevice::Cpu, + decode_into::>( + &FlexDevice, encoded, decoded.view_mut(), NonZeroUsize::MIN, @@ -640,8 +633,8 @@ mod tests { fn ones() { std::mem::drop(simple_logger::init()); - let encoded = encode::>>( - &NdArrayDevice::Cpu, + let encoded = encode::>>( + &FlexDevice, Array::::zeros((1, 1, 1, 1)), NonZeroUsize::MIN, Positive(1.0), @@ -653,8 +646,8 @@ mod tests { ) .unwrap(); let mut decoded = Array::::zeros((1, 1, 1, 1)); - decode_into::>( - &NdArrayDevice::Cpu, + decode_into::>( + &FlexDevice, encoded, decoded.view_mut(), NonZeroUsize::MIN, @@ -667,8 +660,8 @@ mod tests { fn r#const() { std::mem::drop(simple_logger::init()); - let encoded = encode::>>( - &NdArrayDevice::Cpu, + let encoded = encode::>>( + &FlexDevice, Array::::from_elem((2, 1, 3), 42.0), NonZeroUsize::MIN, Positive(1.0), @@ -680,8 +673,8 @@ mod tests { ) .unwrap(); let mut decoded = Array::::zeros((2, 1, 3)); - decode_into::>( - &NdArrayDevice::Cpu, + decode_into::>( + &FlexDevice, encoded, decoded.view_mut(), NonZeroUsize::MIN, @@ -694,8 +687,8 @@ mod tests { fn const_batched() { std::mem::drop(simple_logger::init()); - let encoded = encode::>>( - &NdArrayDevice::Cpu, + let encoded = encode::>>( + &FlexDevice, Array::::from_elem((2, 1, 3), 42.0), NonZeroUsize::MIN, Positive(1.0), @@ -707,8 +700,8 @@ mod tests { ) .unwrap(); let mut decoded = Array::::zeros((2, 1, 3)); - decode_into::>( - &NdArrayDevice::Cpu, + decode_into::>( + &FlexDevice, encoded, decoded.view_mut(), NonZeroUsize::MIN, @@ -738,8 +731,8 @@ mod tests { Some(NonZeroUsize::MIN.saturating_add(1000)), // mini-batched, truncated ] { let mut decoded = Array::::zeros(data.shape()); - let encoded = encode::>>( - &NdArrayDevice::Cpu, + let encoded = encode::>>( + &FlexDevice, data.view(), fourier_features, fourier_scale, @@ -751,8 +744,8 @@ mod tests { ) .unwrap(); - decode_into::>( - &NdArrayDevice::Cpu, + decode_into::>( + &FlexDevice, encoded, decoded.view_mut(), fourier_features, diff --git a/codecs/fourier-network/src/modules.rs b/codecs/fourier-network/src/modules.rs index b0f198c09..5ce47fbc9 100644 --- a/codecs/fourier-network/src/modules.rs +++ b/codecs/fourier-network/src/modules.rs @@ -11,7 +11,7 @@ use burn::{ #[derive(Debug, Module)] pub struct Block { - bn2_1: BatchNorm, + bn2_1: BatchNorm, gu2_2: Gelu, ln2_3: Linear, } @@ -35,7 +35,7 @@ impl BlockConfig { pub fn init(&self, device: &B::Device) -> Block { Block { bn2_1: BatchNormConfig::new(self.fourier_features.get()).init(device), - gu2_2: Gelu, + gu2_2: Gelu::new(), ln2_3: LinearConfig::new(self.fourier_features.get(), self.fourier_features.get()) .init(device), } @@ -46,7 +46,7 @@ impl BlockConfig { pub struct Model { ln1: Linear, bl2: Vec>, - bn3: BatchNorm, + bn3: BatchNorm, gu4: Gelu, ln5: Linear, } @@ -88,7 +88,7 @@ impl ModelConfig { .map(|_| block.init(device)) .collect(), bn3: BatchNormConfig::new(self.fourier_features.get()).init(device), - gu4: Gelu, + gu4: Gelu::new(), ln5: LinearConfig::new(self.fourier_features.get(), 1).init(device), } } @@ -126,7 +126,7 @@ impl Record for ModelExtra { } } -#[derive(serde::Serialize, serde::Deserialize)] +#[derive(Clone, serde::Serialize, serde::Deserialize)] #[serde(bound = "")] pub struct ModelExtraItem { model: < as Module>::Record as Record>::Item,