diff --git a/Cargo.toml b/Cargo.toml index 2a179bfc0..d873c7bd5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,6 +3,7 @@ resolver = "2" members = [ "crates/numcodecs", "crates/numcodecs-python", + "crates/numcodecs-registry", "crates/numcodecs-wasm-builder", "crates/numcodecs-wasm-logging", "crates/numcodecs-wasm-guest", @@ -21,6 +22,7 @@ members = [ "codecs/lc", "codecs/linear-quantize", "codecs/log", + "codecs/onion", "codecs/pco", "codecs/qpet-sperr", "codecs/random-projection", @@ -49,6 +51,7 @@ rust-version = "1.87" # workspace-internal numcodecs crates numcodecs = { version = "0.3.1", path = "crates/numcodecs", default-features = false } numcodecs-python = { version = "0.7.1", path = "crates/numcodecs-python", default-features = false } +numcodecs-registry = { version = "0.1", path = "crates/numcodecs-registry", default-features = false } numcodecs-wasm-builder = { version = "0.2", path = "crates/numcodecs-wasm-builder", default-features = false } numcodecs-wasm-guest = { version = "0.3", path = "crates/numcodecs-wasm-guest", default-features = false } numcodecs-wasm-host = { version = "0.2", path = "crates/numcodecs-wasm-host", default-features = false } @@ -68,6 +71,7 @@ numcodecs-jpeg2000 = { version = "0.3", path = "codecs/jpeg2000", default-featur numcodecs-lc = { version = "0.1", path = "codecs/lc", default-features = false } numcodecs-linear-quantize = { version = "0.5", path = "codecs/linear-quantize", default-features = false } numcodecs-log = { version = "0.5", path = "codecs/log", default-features = false } +numcodecs-onion = { version = "0.1", path = "codecs/onion", default-features = false } numcodecs-pco = { version = "0.4", path = "codecs/pco", default-features = false } numcodecs-qpet-sperr = { version = "0.2.2", path = "codecs/qpet-sperr", default-features = false } numcodecs-random-projection = { version = "0.4", path = "codecs/random-projection", default-features = false } @@ -90,6 +94,7 @@ burn = { version = "0.18", default-features = false } clap = { version = "4.6", default-features = false } convert_case = { version = "0.8", default-features = false } ebcc = { version = "0.3.0-alpha", default-features = false } +erased-serde = { version = "0.4", default-features = false } format_serde_error = { version = "0.3", default-features = false } indexmap = { version = "2.10", default-features = false } itertools = { version = "0.14", default-features = false } diff --git a/codecs/onion/Cargo.toml b/codecs/onion/Cargo.toml new file mode 100644 index 000000000..b905f1897 --- /dev/null +++ b/codecs/onion/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "numcodecs-onion" +version = "0.1.0" +edition = { workspace = true } +authors = { workspace = true } +repository = { workspace = true } +license = { workspace = true } +rust-version = { workspace = true } + +description = "Onion identity meta-codec implementation for the numcodecs API" +readme = "README.md" +categories = ["compression", "encoding"] +keywords = ["identity", "numcodecs", "compression", "encoding", "meta"] + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +numcodecs = { workspace = true } +numcodecs-registry = { workspace = true } +schemars = { workspace = true, features = ["derive", "preserve_order"] } +serde = { workspace = true, features = ["std", "derive"] } +thiserror = { workspace = true } + +[lints] +workspace = true diff --git a/codecs/onion/LICENSE b/codecs/onion/LICENSE new file mode 120000 index 000000000..30cff7403 --- /dev/null +++ b/codecs/onion/LICENSE @@ -0,0 +1 @@ +../../LICENSE \ No newline at end of file diff --git a/codecs/onion/README.md b/codecs/onion/README.md new file mode 100644 index 000000000..23a7a2a7e --- /dev/null +++ b/codecs/onion/README.md @@ -0,0 +1,38 @@ +[![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![PyPi Release]][pypi] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs] [![Read the Docs]][rtdocs] + +[CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main +[workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain + +[MSRV]: https://img.shields.io/badge/MSRV-1.87.0-blue +[repo]: https://github.com/juntyr/numcodecs-rs + +[Latest Version]: https://img.shields.io/crates/v/numcodecs-onion +[crates.io]: https://crates.io/crates/numcodecs-onion + +[PyPi Release]: https://img.shields.io/pypi/v/numcodecs-wasm-onion.svg +[pypi]: https://pypi.python.org/pypi/numcodecs-wasm-onion + +[Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-onion +[docs.rs]: https://docs.rs/numcodecs-onion/ + +[Rust Doc Main]: https://img.shields.io/badge/docs-main-blue +[docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_onion + +[Read the Docs]: https://img.shields.io/readthedocs/numcodecs-wasm?label=readthedocs +[rtdocs]: https://numcodecs-wasm.readthedocs.io/en/stable/api/numcodecs_wasm_onion/ + +# numcodecs-onion + +Onion identity meta-codec implementation for the [`numcodecs`] API. + +[`numcodecs`]: https://docs.rs/numcodecs/0.2/numcodecs/ + +## License + +Licensed under the Mozilla Public License, Version 2.0 ([LICENSE](LICENSE) or https://www.mozilla.org/en-US/MPL/2.0/). + +## Funding + +The `numcodecs-onion` crate has been developed as part of [ESiWACE3](https://www.esiwace.eu), the third phase of the Centre of Excellence in Simulation of Weather and Climate in Europe. + +Funded by the European Union. This work has received funding from the European High Performance Computing Joint Undertaking (JU) under grant agreement No 101093054. diff --git a/codecs/onion/src/lib.rs b/codecs/onion/src/lib.rs new file mode 100644 index 000000000..0dbb0ce0b --- /dev/null +++ b/codecs/onion/src/lib.rs @@ -0,0 +1,86 @@ +//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs] +//! +//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main +//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain +//! +//! [MSRV]: https://img.shields.io/badge/MSRV-1.87.0-blue +//! [repo]: https://github.com/juntyr/numcodecs-rs +//! +//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-onion +//! [crates.io]: https://crates.io/crates/numcodecs-onion +//! +//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-onion +//! [docs.rs]: https://docs.rs/numcodecs-onion/ +//! +//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue +//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_onion +//! +//! Onion identity meta-codec implementation for the [`numcodecs`] API. + +use numcodecs::{ + AnyArray, AnyArrayView, AnyArrayViewMut, AnyCowArray, Codec, ErasedError, StaticCodec, + StaticCodecConfig, StaticCodecVersion, +}; +use numcodecs_registry::GlobalErasedDynCodec; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +#[derive(Clone, Serialize, Deserialize, JsonSchema)] +#[serde(deny_unknown_fields)] +/// Onion identity meta-codec which wraps an existing codec and passes the +/// inputs to and outputs from it unchanged during encoding and decoding. +pub struct OnionCodec { + /// The configuration of the wrapped codec. + pub codec: GlobalErasedDynCodec, + /// The codec's encoding format version. Do not provide this parameter explicitly. + #[serde(default, rename = "_version")] + pub version: StaticCodecVersion<1, 0, 0>, +} + +impl Codec for OnionCodec { + type Error = OnionCodecError; + + fn encode(&self, data: AnyCowArray) -> Result { + self.codec + .encode(data) + .map_err(|err| OnionCodecError { error: err }) + } + + fn decode(&self, encoded: AnyCowArray) -> Result { + self.codec + .decode(encoded) + .map_err(|err| OnionCodecError { error: err }) + } + + fn decode_into( + &self, + encoded: AnyArrayView, + decoded: AnyArrayViewMut, + ) -> Result<(), Self::Error> { + self.codec + .decode_into(encoded, decoded) + .map_err(|err| OnionCodecError { error: err }) + } +} + +impl StaticCodec for OnionCodec { + const CODEC_ID: &'static str = "onion.rs"; + + type Config<'de> = Self; + + fn from_config(config: Self::Config<'_>) -> Self { + config + } + + fn get_config(&self) -> StaticCodecConfig<'_, Self> { + StaticCodecConfig::from(self) + } +} + +#[derive(Debug, Error)] +/// Errors that may occur when applying the [`OnionCodec`]. +#[error(transparent)] +pub struct OnionCodecError { + error: ErasedError, +} diff --git a/codecs/onion/tests/schema.json b/codecs/onion/tests/schema.json new file mode 100644 index 000000000..40cad458a --- /dev/null +++ b/codecs/onion/tests/schema.json @@ -0,0 +1,34 @@ +{ + "type": "object", + "additionalProperties": false, + "properties": { + "codec": { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "The `codec_id` of the codec, which is looked up in the global\nregistry." + } + }, + "required": [ + "id" + ], + "description": "The configuration of the wrapped codec.", + "additionalProperties": { + "type": "object" + } + }, + "_version": { + "type": "string", + "pattern": "^(0|[1-9]\\d*)\\.(0|[1-9]\\d*)\\.(0|[1-9]\\d*)(?:-((?:0|[1-9]\\d*|\\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\\.(?:0|[1-9]\\d*|\\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\\+([0-9a-zA-Z-]+(?:\\.[0-9a-zA-Z-]+)*))?$", + "description": "The codec's encoding format version. Do not provide this parameter explicitly.", + "default": "1.0.0" + } + }, + "required": [ + "codec" + ], + "description": "Onion identity meta-codec which wraps an existing codec and passes the\ninputs to and outputs from it unchanged during encoding and decoding.", + "title": "OnionCodec", + "$schema": "https://json-schema.org/draft/2020-12/schema" +} \ No newline at end of file diff --git a/codecs/onion/tests/schema.rs b/codecs/onion/tests/schema.rs new file mode 100644 index 000000000..02d6c09e2 --- /dev/null +++ b/codecs/onion/tests/schema.rs @@ -0,0 +1,20 @@ +#![expect(missing_docs)] + +use ::{numcodecs_registry as _, schemars as _, serde as _, thiserror as _}; + +use numcodecs::{DynCodecType, StaticCodecType}; +use numcodecs_onion::OnionCodec; + +#[test] +fn schema() { + let schema = format!( + "{:#}", + StaticCodecType::::of() + .codec_config_schema() + .to_value() + ); + + if schema != include_str!("schema.json") { + panic!("Onion schema has changed\n===\n{schema}\n==="); + } +} diff --git a/crates/numcodecs-python/Cargo.toml b/crates/numcodecs-python/Cargo.toml index c8ed8daf5..a2d431ee1 100644 --- a/crates/numcodecs-python/Cargo.toml +++ b/crates/numcodecs-python/Cargo.toml @@ -18,6 +18,7 @@ keywords = ["numcodecs", "compression", "encoding", "python", "pyo3"] convert_case = { workspace = true } ndarray = { workspace = true } numcodecs = { workspace = true } +numcodecs-registry = { workspace = true } numpy = { workspace = true } pyo3 = { workspace = true } pyo3-error = { workspace = true } diff --git a/crates/numcodecs-python/src/registry.rs b/crates/numcodecs-python/src/registry.rs index e3f121ad3..9adbaa6c3 100644 --- a/crates/numcodecs-python/src/registry.rs +++ b/crates/numcodecs-python/src/registry.rs @@ -1,13 +1,16 @@ +use numcodecs::{DynCodec, ErasedDynCodec}; +use numcodecs_registry::Registry; use pyo3::{prelude::*, sync::PyOnceLock, types::PyDict}; +use pythonize::Pythonizer; +use serde::Deserializer; +use serde_transcode::transcode; #[expect(unused_imports)] // FIXME: use expect, only used in docs use crate::PyCodecClassMethods; -use crate::{PyCodec, PyCodecClass}; +use crate::{PyCodec, PyCodecAdapter, PyCodecClass}; /// Dynamic registry of codec classes. -pub struct PyCodecRegistry { - _private: (), -} +pub struct PyCodecRegistry; impl PyCodecRegistry { /// Instantiate a codec from a configuration dictionary. @@ -56,3 +59,38 @@ impl PyCodecRegistry { Ok(()) } } + +impl Registry for PyCodecRegistry { + type Error = PyErr; + + fn get_codec<'de, D: Deserializer<'de>>( + &self, + config: D, + ) -> Result { + Python::attach(|py| { + let config = transcode(config, Pythonizer::new(py))?; + let config: Bound = config.extract()?; + + let codec = Self::get_codec(config.as_borrowed())?; + let codec = PyCodecAdapter::from_codec(codec)?; + + Ok(ErasedDynCodec::new(codec)) + }) + } + + fn get_codec_typed<'de, T: DynCodec, D: Deserializer<'de>>( + &self, + config: D, + ) -> Result, Self::Error> { + Python::attach(|py| { + let config = transcode(config, Pythonizer::new(py))?; + let config: Bound = config.extract()?; + + let codec = Self::get_codec(config.as_borrowed())?; + // clone is necessary since we cannot move out of a PyCodec + let codec = PyCodecAdapter::with_downcast(py, &codec, |codec: &T| codec.clone()); + + Ok(codec) + }) + } +} diff --git a/crates/numcodecs-python/tests/schema.rs b/crates/numcodecs-python/tests/schema.rs index 5bc4a7493..4e7abd391 100644 --- a/crates/numcodecs-python/tests/schema.rs +++ b/crates/numcodecs-python/tests/schema.rs @@ -23,7 +23,7 @@ fn collect_schemas() -> Result<(), PyErr> { println!( "{codec_id}: {:#}", - codec_ty.codec_config_schema().as_value() + codec_ty.codec_config_schema().to_value() ); } diff --git a/crates/numcodecs-registry/Cargo.toml b/crates/numcodecs-registry/Cargo.toml new file mode 100644 index 000000000..c55c037bb --- /dev/null +++ b/crates/numcodecs-registry/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "numcodecs-registry" +version = "0.1.0" +edition = { workspace = true } +authors = { workspace = true } +repository = { workspace = true } +license = { workspace = true } +rust-version = { workspace = true } + +description = "registries for numcodecs codecs" +readme = "README.md" +categories = ["compression", "encoding"] +keywords = ["numcodecs", "registry", "compression", "encoding"] + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +erased-serde = { workspace = true, features = ["std"] } +numcodecs = { workspace = true } +schemars = { workspace = true, features = ["derive"] } +serde = { workspace = true } +thiserror = { workspace = true } + +[lints] +workspace = true diff --git a/crates/numcodecs-registry/LICENSE b/crates/numcodecs-registry/LICENSE new file mode 120000 index 000000000..30cff7403 --- /dev/null +++ b/crates/numcodecs-registry/LICENSE @@ -0,0 +1 @@ +../../LICENSE \ No newline at end of file diff --git a/crates/numcodecs-registry/README.md b/crates/numcodecs-registry/README.md new file mode 100644 index 000000000..c0ada0e18 --- /dev/null +++ b/crates/numcodecs-registry/README.md @@ -0,0 +1,32 @@ +[![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs] + +[CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main +[workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain + +[MSRV]: https://img.shields.io/badge/MSRV-1.87.0-blue +[repo]: https://github.com/juntyr/numcodecs-rs + +[Latest Version]: https://img.shields.io/crates/v/numcodecs-registry +[crates.io]: https://crates.io/crates/numcodecs-registry + +[Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-registry +[docs.rs]: https://docs.rs/numcodecs-registry/ + +[Rust Doc Main]: https://img.shields.io/badge/docs-main-blue +[docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_registry + +# numcodecs-registry + +Registries for compression codecs implementing the [`numcodecs`] API. + +[`numcodecs`]: https://numcodecs.readthedocs.io/en/stable/ + +## License + +Licensed under the Mozilla Public License, Version 2.0 ([LICENSE](LICENSE) or https://www.mozilla.org/en-US/MPL/2.0/). + +## Funding + +The `numcodecs-registry` crate has been developed as part of [ESiWACE3](https://www.esiwace.eu), the third phase of the Centre of Excellence in Simulation of Weather and Climate in Europe. + +Funded by the European Union. This work has received funding from the European High Performance Computing Joint Undertaking (JU) under grant agreement No 101093054. diff --git a/crates/numcodecs-registry/src/lib.rs b/crates/numcodecs-registry/src/lib.rs new file mode 100644 index 000000000..40d5b8f27 --- /dev/null +++ b/crates/numcodecs-registry/src/lib.rs @@ -0,0 +1,320 @@ +//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs] +//! +//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main +//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain +//! +//! [MSRV]: https://img.shields.io/badge/MSRV-1.87.0-blue +//! [repo]: https://github.com/juntyr/numcodecs-rs +//! +//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-registry +//! [crates.io]: https://crates.io/crates/numcodecs-registry +//! +//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-registry +//! [docs.rs]: https://docs.rs/numcodecs-registry/ +//! +//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue +//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_registry +//! +//! Registries for compression codecs implementing the [`numcodecs`] API. +//! +//! [`numcodecs`]: https://numcodecs.readthedocs.io/en/stable/ + +use std::{ + borrow::Cow, + error::Error, + ops::{Deref, DerefMut}, + sync::Arc, +}; + +use numcodecs::{DynCodec, ErasedDynCodec, ErasedError}; +use schemars::{JsonSchema, Schema, SchemaGenerator}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +/// Registry of codec types. +pub trait Registry: 'static + Send + Sync { + /// Error type that may be returned during + /// [`get_codec`][`Registry::get_codec`] and + /// and [`register_codec`][`Codec::Registry`]. + type Error: 'static + Send + Sync + Error; + + /// Instantiate a codec of any type from its `config`uration. + /// + /// The config *must* include the `id` field with the + /// [`DynCodecType::codec_id`]. + /// + /// # Errors + /// + /// Errors if no codec with a matching `id` has been registered, or if + /// constructing the codec fails. + fn get_codec<'de, D: Deserializer<'de>>( + &self, + config: D, + ) -> Result; + + /// Instantiate a codec with a concrete type from its `config`uration. + /// + /// The config *must* include the `id` field with the + /// [`DynCodecType::codec_id`]. + /// + /// # Errors + /// + /// Errors if no codec with a matching `id` has been registered, if + /// constructing the codec fails, or if the constructed codec is not of the + /// concrete type. + fn get_codec_typed<'de, T: DynCodec, D: Deserializer<'de>>( + &self, + config: D, + ) -> Result, Self::Error> { + self.get_codec(config).map(|codec| codec.downcast().ok()) + } +} + +impl Registry for Box { + type Error = R::Error; + + fn get_codec<'de, D: Deserializer<'de>>( + &self, + config: D, + ) -> Result { + R::get_codec(self, config) + } + + fn get_codec_typed<'de, T: DynCodec, D: Deserializer<'de>>( + &self, + config: D, + ) -> Result, Self::Error> { + R::get_codec_typed(self, config) + } +} + +impl Registry for Arc { + type Error = R::Error; + + fn get_codec<'de, D: Deserializer<'de>>( + &self, + config: D, + ) -> Result { + R::get_codec(self, config) + } + + fn get_codec_typed<'de, T: DynCodec, D: Deserializer<'de>>( + &self, + config: D, + ) -> Result, Self::Error> { + R::get_codec_typed(self, config) + } +} + +/// Type-erased [`Registry`]. +pub struct ErasedRegistry { + registry: Box, +} + +impl ErasedRegistry { + /// Erase the type information of the concrete `registry`. + pub fn new(registry: T) -> Self { + Self { + registry: Box::new(registry), + } + } +} + +impl Registry for ErasedRegistry { + type Error = ErasedError; + + fn get_codec<'de, D: Deserializer<'de>>( + &self, + config: D, + ) -> Result { + self.registry + .erased_get_codec(&mut ::erase(config)) + } +} + +trait ErasedRegistryDispatch: 'static + Send + Sync { + fn erased_get_codec( + &self, + config: &mut dyn erased_serde::Deserializer, + ) -> Result; +} + +impl ErasedRegistryDispatch for T { + fn erased_get_codec( + &self, + config: &mut dyn erased_serde::Deserializer, + ) -> Result { + match self.get_codec(config) { + Ok(codec) => Ok(codec), + Err(err) => Err(ErasedError::new(err)), + } + } +} + +/// Global registry singleton. +/// +/// If the global registry is used, its backing registry must be provided +/// exactly once using [`export_global`]. +/// +/// The global registry must not be used to provide the backing of itself, +/// which would result in an infinite loop at runtime. +pub struct GlobalRegistry; + +impl GlobalRegistry { + fn get() -> &'static ErasedRegistry { + #[expect(unsafe_code)] + unsafe extern "C" { + #[expect(improper_ctypes)] + safe fn _numcodecs_registry_get_global_registry() -> &'static ErasedRegistry; + } + + _numcodecs_registry_get_global_registry() + } +} + +impl Registry for GlobalRegistry { + type Error = ErasedError; + + fn get_codec<'de, D: Deserializer<'de>>( + &self, + config: D, + ) -> Result { + Self::get().get_codec(config) + } + + fn get_codec_typed<'de, T: DynCodec, D: Deserializer<'de>>( + &self, + config: D, + ) -> Result, Self::Error> { + Self::get().get_codec_typed(config) + } +} + +#[macro_export] +/// `export_global!(registry: ty = expr)` exports the provided registry as the +/// global registry singleton. +/// +/// This macro must only be used at most once in every binary or shared +/// library. +macro_rules! export_global { + (registry: $ty:ty = $init:expr) => { + const _: () = { + use std::sync::LazyLock; + + use $crate::ErasedRegistry; + + static _GLOBAL_REGISTRY: LazyLock = + LazyLock::new(|| ErasedRegistry::new($init)); + + #[allow(improper_ctypes, unsafe_code)] + #[unsafe(no_mangle)] + extern "C" fn _numcodecs_registry_get_global_registry() -> &'static ErasedRegistry { + LazyLock::force(&_GLOBAL_REGISTRY) + } + }; + }; +} + +#[derive(Debug, thiserror::Error)] +#[error("codec not found")] +/// Codec was not found in the registry +pub struct CodecNotFoundError; + +/// Empty registry that contains no codecs +pub struct EmptyRegistry; + +impl Registry for EmptyRegistry { + type Error = CodecNotFoundError; + + fn get_codec<'de, D: Deserializer<'de>>( + &self, + _config: D, + ) -> Result { + Err(CodecNotFoundError) + } + + fn get_codec_typed<'de, T: DynCodec, D: Deserializer<'de>>( + &self, + _config: D, + ) -> Result, Self::Error> { + Err(CodecNotFoundError) + } +} + +#[derive(Clone)] +/// Wrapper around an [`ErasedDynCodec`] that can be used inside a meta-codec +/// configuration to (de)serialize a wrapped inner codec. +pub struct GlobalErasedDynCodec { + codec: ErasedDynCodec, +} + +impl GlobalErasedDynCodec { + #[must_use] + /// Wrap an existing `codec`. + pub const fn new(codec: ErasedDynCodec) -> Self { + Self { codec } + } + + #[must_use] + /// Extract the inner codec. + pub fn into_inner(this: Self) -> ErasedDynCodec { + this.codec + } +} + +impl Deref for GlobalErasedDynCodec { + type Target = ErasedDynCodec; + + fn deref(&self) -> &Self::Target { + &self.codec + } +} + +impl DerefMut for GlobalErasedDynCodec { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.codec + } +} + +impl Serialize for GlobalErasedDynCodec { + fn serialize(&self, serializer: S) -> Result { + self.codec.get_config(serializer) + } +} + +impl<'de> Deserialize<'de> for GlobalErasedDynCodec { + fn deserialize>(deserializer: D) -> Result { + Ok(Self { + codec: GlobalRegistry + .get_codec(deserializer) + .map_err(serde::de::Error::custom)?, + }) + } +} + +impl JsonSchema for GlobalErasedDynCodec { + fn inline_schema() -> bool { + false + } + + fn schema_name() -> Cow<'static, str> { + Cow::Borrowed("NumcodecsCodecConfig") + } + + fn schema_id() -> Cow<'static, str> { + Cow::Borrowed(concat!(module_path!(), "::", "GlobalErasedDynCodec")) + } + + fn json_schema(generator: &mut SchemaGenerator) -> Schema { + #[derive(JsonSchema)] + #[schemars(extend("additionalProperties" = {"type": "object"}))] + /// The configuration for a codec. + struct NumcodecsCodecConfig { + /// The `codec_id` of the codec, which is looked up in the global + /// registry. + #[expect(dead_code)] + id: String, + } + + NumcodecsCodecConfig::json_schema(generator) + } +} diff --git a/crates/numcodecs-wasm-guest/Cargo.toml b/crates/numcodecs-wasm-guest/Cargo.toml index 31df535e3..7582f84d0 100644 --- a/crates/numcodecs-wasm-guest/Cargo.toml +++ b/crates/numcodecs-wasm-guest/Cargo.toml @@ -19,8 +19,10 @@ wit-bindgen = { workspace = true, features = ["macros", "realloc"] } [target.'cfg(target_arch = "wasm32")'.dependencies] format_serde_error = { workspace = true, features = ["serde_json"] } ndarray = { workspace = true, features = ["std"] } +numcodecs-registry = { workspace = true } schemars = { workspace = true } serde = { workspace = true } +serde-transcode = { workspace = true } serde_json = { workspace = true, features = ["std"] } thiserror = { workspace = true } diff --git a/crates/numcodecs-wasm-guest/src/convert.rs b/crates/numcodecs-wasm-guest/src/convert.rs index 08a2440a4..a845cf46c 100644 --- a/crates/numcodecs-wasm-guest/src/convert.rs +++ b/crates/numcodecs-wasm-guest/src/convert.rs @@ -6,43 +6,47 @@ use thiserror::Error; use crate::wit; -pub fn from_wit_any_array(array: wit::AnyArray) -> Result { +pub fn from_wit_any_array( + array: wit::types::AnyArray, +) -> Result { let shape = u32_as_usize_vec(array.shape); let array = match array.data { - wit::AnyArrayData::U8(data) => AnyArray::U8(Array::from_shape_vec(shape, data)?), - wit::AnyArrayData::U16(data) => AnyArray::U16(Array::from_shape_vec(shape, data)?), - wit::AnyArrayData::U32(data) => AnyArray::U32(Array::from_shape_vec(shape, data)?), - wit::AnyArrayData::U64(data) => AnyArray::U64(Array::from_shape_vec(shape, data)?), - wit::AnyArrayData::I8(data) => AnyArray::I8(Array::from_shape_vec(shape, data)?), - wit::AnyArrayData::I16(data) => AnyArray::I16(Array::from_shape_vec(shape, data)?), - wit::AnyArrayData::I32(data) => AnyArray::I32(Array::from_shape_vec(shape, data)?), - wit::AnyArrayData::I64(data) => AnyArray::I64(Array::from_shape_vec(shape, data)?), - wit::AnyArrayData::F32(data) => AnyArray::F32(Array::from_shape_vec(shape, data)?), - wit::AnyArrayData::F64(data) => AnyArray::F64(Array::from_shape_vec(shape, data)?), + wit::types::AnyArrayData::U8(data) => AnyArray::U8(Array::from_shape_vec(shape, data)?), + wit::types::AnyArrayData::U16(data) => AnyArray::U16(Array::from_shape_vec(shape, data)?), + wit::types::AnyArrayData::U32(data) => AnyArray::U32(Array::from_shape_vec(shape, data)?), + wit::types::AnyArrayData::U64(data) => AnyArray::U64(Array::from_shape_vec(shape, data)?), + wit::types::AnyArrayData::I8(data) => AnyArray::I8(Array::from_shape_vec(shape, data)?), + wit::types::AnyArrayData::I16(data) => AnyArray::I16(Array::from_shape_vec(shape, data)?), + wit::types::AnyArrayData::I32(data) => AnyArray::I32(Array::from_shape_vec(shape, data)?), + wit::types::AnyArrayData::I64(data) => AnyArray::I64(Array::from_shape_vec(shape, data)?), + wit::types::AnyArrayData::F32(data) => AnyArray::F32(Array::from_shape_vec(shape, data)?), + wit::types::AnyArrayData::F64(data) => AnyArray::F64(Array::from_shape_vec(shape, data)?), }; Ok(array) } -pub fn zeros_from_wit_any_array_prototype(prototype: wit::AnyArrayPrototype) -> AnyArray { +pub fn zeros_from_wit_any_array_prototype(prototype: wit::types::AnyArrayPrototype) -> AnyArray { let shape = u32_as_usize_vec(prototype.shape); match prototype.dtype { - wit::AnyArrayDtype::U8 => AnyArray::U8(Array::zeros(shape)), - wit::AnyArrayDtype::U16 => AnyArray::U16(Array::zeros(shape)), - wit::AnyArrayDtype::U32 => AnyArray::U32(Array::zeros(shape)), - wit::AnyArrayDtype::U64 => AnyArray::U64(Array::zeros(shape)), - wit::AnyArrayDtype::I8 => AnyArray::I8(Array::zeros(shape)), - wit::AnyArrayDtype::I16 => AnyArray::I16(Array::zeros(shape)), - wit::AnyArrayDtype::I32 => AnyArray::I32(Array::zeros(shape)), - wit::AnyArrayDtype::I64 => AnyArray::I64(Array::zeros(shape)), - wit::AnyArrayDtype::F32 => AnyArray::F32(Array::zeros(shape)), - wit::AnyArrayDtype::F64 => AnyArray::F64(Array::zeros(shape)), + wit::types::AnyArrayDtype::U8 => AnyArray::U8(Array::zeros(shape)), + wit::types::AnyArrayDtype::U16 => AnyArray::U16(Array::zeros(shape)), + wit::types::AnyArrayDtype::U32 => AnyArray::U32(Array::zeros(shape)), + wit::types::AnyArrayDtype::U64 => AnyArray::U64(Array::zeros(shape)), + wit::types::AnyArrayDtype::I8 => AnyArray::I8(Array::zeros(shape)), + wit::types::AnyArrayDtype::I16 => AnyArray::I16(Array::zeros(shape)), + wit::types::AnyArrayDtype::I32 => AnyArray::I32(Array::zeros(shape)), + wit::types::AnyArrayDtype::I64 => AnyArray::I64(Array::zeros(shape)), + wit::types::AnyArrayDtype::F32 => AnyArray::F32(Array::zeros(shape)), + wit::types::AnyArrayDtype::F64 => AnyArray::F64(Array::zeros(shape)), } } -pub fn into_wit_any_array(array: AnyArray) -> Result { +pub fn into_wit_any_array( + array: AnyArray, +) -> Result { fn array_into_standard_layout_vec(array: ArrayD) -> Vec { if array.is_standard_layout() { array.into_raw_vec_and_offset().0 @@ -54,16 +58,32 @@ pub fn into_wit_any_array(array: AnyArray) -> Result wit::AnyArrayData::U8(array_into_standard_layout_vec(array)), - AnyArray::U16(array) => wit::AnyArrayData::U16(array_into_standard_layout_vec(array)), - AnyArray::U32(array) => wit::AnyArrayData::U32(array_into_standard_layout_vec(array)), - AnyArray::U64(array) => wit::AnyArrayData::U64(array_into_standard_layout_vec(array)), - AnyArray::I8(array) => wit::AnyArrayData::I8(array_into_standard_layout_vec(array)), - AnyArray::I16(array) => wit::AnyArrayData::I16(array_into_standard_layout_vec(array)), - AnyArray::I32(array) => wit::AnyArrayData::I32(array_into_standard_layout_vec(array)), - AnyArray::I64(array) => wit::AnyArrayData::I64(array_into_standard_layout_vec(array)), - AnyArray::F32(array) => wit::AnyArrayData::F32(array_into_standard_layout_vec(array)), - AnyArray::F64(array) => wit::AnyArrayData::F64(array_into_standard_layout_vec(array)), + AnyArray::U8(array) => wit::types::AnyArrayData::U8(array_into_standard_layout_vec(array)), + AnyArray::U16(array) => { + wit::types::AnyArrayData::U16(array_into_standard_layout_vec(array)) + } + AnyArray::U32(array) => { + wit::types::AnyArrayData::U32(array_into_standard_layout_vec(array)) + } + AnyArray::U64(array) => { + wit::types::AnyArrayData::U64(array_into_standard_layout_vec(array)) + } + AnyArray::I8(array) => wit::types::AnyArrayData::I8(array_into_standard_layout_vec(array)), + AnyArray::I16(array) => { + wit::types::AnyArrayData::I16(array_into_standard_layout_vec(array)) + } + AnyArray::I32(array) => { + wit::types::AnyArrayData::I32(array_into_standard_layout_vec(array)) + } + AnyArray::I64(array) => { + wit::types::AnyArrayData::I64(array_into_standard_layout_vec(array)) + } + AnyArray::F32(array) => { + wit::types::AnyArrayData::F32(array_into_standard_layout_vec(array)) + } + AnyArray::F64(array) => { + wit::types::AnyArrayData::F64(array_into_standard_layout_vec(array)) + } array => { return Err(AnyArrayConversionError::UnsupportedDtype { dtype: array.dtype(), @@ -71,7 +91,29 @@ pub fn into_wit_any_array(array: AnyArray) -> Result Result { + let dtype = match dtype { + AnyArrayDType::U8 => wit::types::AnyArrayDtype::U8, + AnyArrayDType::U16 => wit::types::AnyArrayDtype::U16, + AnyArrayDType::U32 => wit::types::AnyArrayDtype::U32, + AnyArrayDType::U64 => wit::types::AnyArrayDtype::U64, + AnyArrayDType::I8 => wit::types::AnyArrayDtype::I8, + AnyArrayDType::I16 => wit::types::AnyArrayDtype::I16, + AnyArrayDType::I32 => wit::types::AnyArrayDtype::I32, + AnyArrayDType::I64 => wit::types::AnyArrayDtype::I64, + AnyArrayDType::F32 => wit::types::AnyArrayDtype::F32, + AnyArrayDType::F64 => wit::types::AnyArrayDtype::F64, + dtype => { + return Err(AnyArrayConversionError::UnsupportedDtype { dtype }); + } + }; + + Ok(dtype) } #[derive(Debug, Error)] @@ -86,10 +128,10 @@ pub enum AnyArrayConversionError { } #[must_use] -pub fn into_wit_error(err: T) -> wit::Error { +pub fn into_wit_error(err: T) -> wit::types::Error { let mut source: Option<&dyn Error> = err.source(); - let mut error = wit::Error { + let mut error = wit::types::Error { message: format!("{err}"), chain: if source.is_some() { Vec::with_capacity(4) @@ -108,11 +150,11 @@ pub fn into_wit_error(err: T) -> wit::Error { #[expect(clippy::cast_possible_truncation)] #[must_use] -fn usize_as_u32_slice(slice: &[usize]) -> Vec { +pub(crate) fn usize_as_u32_slice(slice: &[usize]) -> Vec { slice.iter().map(|x| *x as u32).collect() } #[must_use] -fn u32_as_usize_vec(vec: Vec) -> Vec { +pub(crate) fn u32_as_usize_vec(vec: Vec) -> Vec { vec.into_iter().map(|x| x as usize).collect() } diff --git a/crates/numcodecs-wasm-guest/src/external.rs b/crates/numcodecs-wasm-guest/src/external.rs new file mode 100644 index 000000000..033c9ef6d --- /dev/null +++ b/crates/numcodecs-wasm-guest/src/external.rs @@ -0,0 +1,203 @@ +use std::sync::Arc; + +use numcodecs::{ + self, AnyArray, AnyArrayView, AnyArrayViewMut, AnyCowArray, Codec, DynCodec, DynCodecType, + ErasedDynCodec, +}; +use numcodecs_registry::{self, Registry, export_global}; +use schemars::Schema; +use serde::{self, Deserializer, Serializer}; +use serde_transcode::transcode; + +use crate::{convert, wit}; + +pub struct ExternalCodec { + codec: wit::registry::ErasedDynCodec, + ty: ExternalCodecType, +} + +impl Clone for ExternalCodec { + fn clone(&self) -> Self { + Self { + codec: self.codec.clone(), + ty: ExternalCodecType { + ty: self.ty.ty.clone(), + codec_id: self.ty.codec_id.clone(), + schema: self.ty.schema.clone(), + }, + } + } +} + +impl Codec for ExternalCodec { + type Error = ExternalError; + + fn encode(&self, data: AnyCowArray) -> Result { + match self.codec.encode( + &convert::into_wit_any_array(data.into_owned()).map_err(ExternalError::from_error)?, + ) { + Ok(encoded) => convert::from_wit_any_array(encoded).map_err(ExternalError::from_error), + Err(err) => Err(ExternalError::new(err)), + } + } + + fn decode(&self, encoded: AnyCowArray) -> Result { + match self.codec.decode( + &convert::into_wit_any_array(encoded.into_owned()) + .map_err(ExternalError::from_error)?, + ) { + Ok(decoded) => convert::from_wit_any_array(decoded).map_err(ExternalError::from_error), + Err(err) => Err(ExternalError::new(err)), + } + } + + fn decode_into( + &self, + encoded: AnyArrayView, + mut decoded: AnyArrayViewMut, + ) -> Result<(), Self::Error> { + match self.codec.decode_into( + &convert::into_wit_any_array(encoded.into_owned()) + .map_err(ExternalError::from_error)?, + &wit::types::AnyArrayPrototype { + dtype: convert::into_wit_any_array_dtype(decoded.dtype()) + .map_err(ExternalError::from_error)?, + shape: convert::usize_as_u32_slice(decoded.shape()), + }, + ) { + Ok(dec) => match convert::from_wit_any_array(dec) { + Ok(dec) => decoded.assign(&dec).map_err(ExternalError::from_error), + Err(err) => Err(ExternalError::from_error(err)), + }, + Err(err) => Err(ExternalError::new(err)), + } + } +} + +impl DynCodec for ExternalCodec { + type Type = ExternalCodecType; + + fn get_config(&self, serializer: S) -> Result { + let config = self + .codec + .get_config() + .map_err(ExternalError::new) + .map_err(serde::ser::Error::custom)?; + transcode(&mut serde_json::Deserializer::from_str(&config), serializer) + } + + fn ty(&self) -> Self::Type { + ExternalCodecType { + ty: self.ty.ty.clone(), + codec_id: self.ty.codec_id.clone(), + schema: self.ty.schema.clone(), + } + } +} + +pub struct ExternalCodecType { + ty: Arc, + codec_id: Arc, + schema: Arc, +} + +impl DynCodecType for ExternalCodecType { + type Codec = ExternalCodec; + + fn codec_id(&self) -> &str { + &*self.codec_id + } + + fn codec_config_schema(&self) -> Schema { + (*self.schema).clone() + } + + fn codec_from_config<'de, D: Deserializer<'de>>( + &self, + config: D, + ) -> Result { + let mut config_bytes = Vec::new(); + transcode(config, &mut serde_json::Serializer::new(&mut config_bytes)) + .map_err(serde::de::Error::custom)?; + let config = String::from_utf8(config_bytes).map_err(serde::de::Error::custom)?; + + let codec = self + .ty + .codec_from_config(&config) + .map_err(ExternalError::new) + .map_err(serde::de::Error::custom)?; + + Ok(ExternalCodec { + codec, + ty: ExternalCodecType { + ty: self.ty.clone(), + codec_id: self.codec_id.clone(), + schema: self.schema.clone(), + }, + }) + } +} + +#[derive(Debug, thiserror::Error)] +#[error("{msg}")] +pub struct ExternalError { + msg: String, + source: Option>, +} + +impl ExternalError { + fn new(error: wit::types::Error) -> Self { + let mut root = Self { + msg: error.message, + source: None, + }; + + let mut err = &mut root; + + for msg in error.chain { + err = &mut *err.source.insert(Box::new(Self { msg, source: None })); + } + + root + } + + fn from_error(err: impl std::error::Error) -> Self { + Self::new(convert::into_wit_error(err)) + } +} + +pub struct ExternalRegistry; + +impl Registry for ExternalRegistry { + type Error = ExternalError; + + fn get_codec<'de, D: Deserializer<'de>>( + &self, + config: D, + ) -> Result { + let mut config_bytes = Vec::new(); + transcode(config, &mut serde_json::Serializer::new(&mut config_bytes)) + .map_err(ExternalError::from_error)?; + let config = String::from_utf8(config_bytes).map_err(ExternalError::from_error)?; + + let codec = wit::registry::get_codec(&config).map_err(ExternalError::new)?; + let ty = codec.ty(); + + let codec_id = ty.codec_id(); + let schema: Schema = + serde_json::from_str(&ty.codec_config_schema()).map_err(ExternalError::from_error)?; + + let codec = ExternalCodec { + codec, + ty: ExternalCodecType { + ty: std::sync::Arc::new(ty), + codec_id: codec_id.into(), + schema: std::sync::Arc::new(schema), + }, + }; + + Ok(ErasedDynCodec::new(codec)) + } +} + +export_global! { registry: ExternalRegistry = ExternalRegistry } diff --git a/crates/numcodecs-wasm-guest/src/lib.rs b/crates/numcodecs-wasm-guest/src/lib.rs index 3462f2fa6..89f1f9fbe 100644 --- a/crates/numcodecs-wasm-guest/src/lib.rs +++ b/crates/numcodecs-wasm-guest/src/lib.rs @@ -34,13 +34,12 @@ use ::{ #[cfg(target_arch = "wasm32")] mod convert; +#[cfg(target_arch = "wasm32")] +mod external; #[cfg(target_arch = "wasm32")] -use crate::{ - bindings::exports::numcodecs::abc::codec as wit, - convert::{ - from_wit_any_array, into_wit_any_array, into_wit_error, zeros_from_wit_any_array_prototype, - }, +use crate::convert::{ + from_wit_any_array, into_wit_any_array, into_wit_error, zeros_from_wit_any_array_prototype, }; #[doc(hidden)] @@ -55,6 +54,14 @@ pub mod bindings { }); } +#[cfg(target_arch = "wasm32")] +mod wit { + pub use crate::bindings::{ + exports::numcodecs::abc::codec, + numcodecs::abc::{registry, types}, + }; +} + #[macro_export] /// Export a [`StaticCodec`] type using the WASM component model. /// @@ -96,27 +103,31 @@ macro_rules! export_codec { #[cfg(target_arch = "wasm32")] #[doc(hidden)] -impl wit::Guest for T { +impl wit::codec::Guest for T { type Codec = Self; fn codec_id() -> String { String::from(::CODEC_ID) } - fn codec_config_schema() -> wit::JsonSchema { + fn codec_config_schema() -> wit::types::JsonSchema { schema_for!(::Config<'static>) - .as_value() + .to_value() .to_string() } } #[cfg(target_arch = "wasm32")] -impl wit::GuestCodec for T { - fn from_config(config: String) -> Result { +impl wit::codec::GuestCodec for T { + fn from_config(config: String) -> Result { let err = match ::Config::deserialize( &mut serde_json::Deserializer::from_str(&config), ) { - Ok(config) => return Ok(wit::Codec::new(::from_config(config))), + Ok(config) => { + return Ok(wit::codec::Codec::new(::from_config( + config, + ))); + } Err(err) => err, }; @@ -124,7 +135,10 @@ impl wit::GuestCodec for T { Err(into_wit_error(err)) } - fn encode(&self, data: wit::AnyArray) -> Result { + fn encode( + &self, + data: wit::types::AnyArray, + ) -> Result { let data = match from_wit_any_array(data) { Ok(data) => data, Err(err) => return Err(into_wit_error(err)), @@ -139,7 +153,10 @@ impl wit::GuestCodec for T { } } - fn decode(&self, encoded: wit::AnyArray) -> Result { + fn decode( + &self, + encoded: wit::types::AnyArray, + ) -> Result { let encoded = match from_wit_any_array(encoded) { Ok(encoded) => encoded, Err(err) => return Err(into_wit_error(err)), @@ -156,9 +173,9 @@ impl wit::GuestCodec for T { fn decode_into( &self, - encoded: wit::AnyArray, - decoded: wit::AnyArrayPrototype, - ) -> Result { + encoded: wit::types::AnyArray, + decoded: wit::types::AnyArrayPrototype, + ) -> Result { let encoded = match from_wit_any_array(encoded) { Ok(encoded) => encoded, Err(err) => return Err(into_wit_error(err)), @@ -175,7 +192,7 @@ impl wit::GuestCodec for T { } } - fn get_config(&self) -> Result { + fn get_config(&self) -> Result { match serde_json::to_string(&::get_config(self)) { Ok(config) => Ok(config), Err(err) => Err(into_wit_error(err)), diff --git a/crates/numcodecs-wasm-host-reproducible/Cargo.toml b/crates/numcodecs-wasm-host-reproducible/Cargo.toml index 75af6bd19..830e0f59e 100644 --- a/crates/numcodecs-wasm-host-reproducible/Cargo.toml +++ b/crates/numcodecs-wasm-host-reproducible/Cargo.toml @@ -26,6 +26,7 @@ numcodecs-wasm-host = { workspace = true } indexmap = { workspace = true, features = ["std"] } log = { workspace = true } numcodecs = { workspace = true } +numcodecs-registry = { workspace = true } polonius-the-crab = { workspace = true } schemars = { workspace = true } semver = { workspace = true } diff --git a/crates/numcodecs-wasm-host-reproducible/src/codec.rs b/crates/numcodecs-wasm-host-reproducible/src/codec.rs index 1328c22bd..9b4747a28 100644 --- a/crates/numcodecs-wasm-host-reproducible/src/codec.rs +++ b/crates/numcodecs-wasm-host-reproducible/src/codec.rs @@ -4,6 +4,7 @@ use std::sync::{Arc, Mutex}; use numcodecs::{ AnyArray, AnyArrayView, AnyArrayViewMut, AnyCowArray, Codec, DynCodec, DynCodecType, }; +use numcodecs_registry::Registry; use numcodecs_wasm_host::{CodecError, RuntimeError, WasmCodec, WasmCodecComponent}; use schemars::Schema; use serde::Serializer; @@ -323,6 +324,7 @@ where pub fn new( engine: E, wasm_component: impl Into>, + registry: impl Registry, ) -> Result where E: Send + Sync, @@ -343,6 +345,8 @@ where } })?; + let registry = Arc::new(registry); + let component_instantiater = Arc::new(move |component: &Component, codec_id: &str| { let mut store = Store::new(&engine, ()); @@ -359,6 +363,11 @@ where source: RuntimeError::from(err), } })?; + numcodecs_wasm_host::add_registry_to_linker(&mut linker, &mut store, registry.clone()) + .map_err(|err| ReproducibleWasmCodecError::Runtime { + codec_id: Arc::from(codec_id), + source: RuntimeError::from(err), + })?; let instance = linker.instantiate(&mut store, component).map_err(|err| { ReproducibleWasmCodecError::Runtime { diff --git a/crates/numcodecs-wasm-host-reproducible/src/tests.rs b/crates/numcodecs-wasm-host-reproducible/src/tests.rs index d22978864..8addebfb9 100644 --- a/crates/numcodecs-wasm-host-reproducible/src/tests.rs +++ b/crates/numcodecs-wasm-host-reproducible/src/tests.rs @@ -2,6 +2,7 @@ use ndarray::Array; use ndarray_rand::RandomExt; use ndarray_rand::rand_distr::Normal; use numcodecs::{Codec, DynCodecType}; +use numcodecs_registry::EmptyRegistry; use crate::ReproducibleWasmCodecType; @@ -53,7 +54,11 @@ fn codec_roundtrip() { let engine = wasmtime_runtime_layer::Engine::new(wasmtime::Engine::new(&config).unwrap()); - let ty = match ReproducibleWasmCodecType::new(engine, include_bytes!("../tests/round.wasm")) { + let ty = match ReproducibleWasmCodecType::new( + engine, + include_bytes!("../tests/round.wasm"), + EmptyRegistry, + ) { Ok(ty) => ty, Err(err) => panic!( "ReproducibleWasmCodecType::new:\n===\n{err}\n===\n{err:?}\n===\n{err:#}\n===\n{err:#?}\n===\n" diff --git a/crates/numcodecs-wasm-host-reproducible/src/transform/mod.rs b/crates/numcodecs-wasm-host-reproducible/src/transform/mod.rs index bcf0a663d..73e8f3c85 100644 --- a/crates/numcodecs-wasm-host-reproducible/src/transform/mod.rs +++ b/crates/numcodecs-wasm-host-reproducible/src/transform/mod.rs @@ -12,7 +12,10 @@ pub mod nan; #[expect(clippy::too_many_lines)] // FIXME pub fn transform_wasm_component(wasm_component: impl Into>) -> Result, Error> { let NumcodecsWitInterfaces { + package, codec: codec_interface, + registry: registry_interface, + types: types_interface, .. } = NumcodecsWitInterfaces::get(); @@ -24,33 +27,37 @@ pub fn transform_wasm_component(wasm_component: impl Into>) -> Result>) -> Result>(); @@ -150,7 +157,7 @@ pub fn transform_wasm_component(wasm_component: impl Into>) -> Result &'static VariantType { + pub(crate) fn any_array_data_ty() -> &'static VariantType { static ANY_ARRAY_DATA_TY: OnceLock = OnceLock::new(); #[expect(clippy::expect_used)] @@ -337,7 +337,7 @@ impl WasmCodec { }) } - fn any_array_ty() -> &'static RecordType { + pub(crate) fn any_array_ty() -> &'static RecordType { static ANY_ARRAY_TY: OnceLock = OnceLock::new(); #[expect(clippy::expect_used)] @@ -359,7 +359,7 @@ impl WasmCodec { } #[expect(clippy::needless_pass_by_value)] - fn array_into_wasm(array: AnyArrayView) -> Result { + pub(crate) fn array_into_wasm(array: AnyArrayView) -> Result { fn list_from_standard_layout<'a, T: 'static + Copy, S: Data, D: Dimension>( array: &'a ArrayBase, ) -> List @@ -451,7 +451,7 @@ impl WasmCodec { .map_err(RuntimeError::from) } - fn any_array_dtype_ty() -> &'static EnumType { + pub(crate) fn any_array_dtype_ty() -> &'static EnumType { static ANY_ARRAY_DTYPE_TY: OnceLock = OnceLock::new(); #[expect(clippy::expect_used)] @@ -468,7 +468,7 @@ impl WasmCodec { }) } - fn any_array_prototype_ty() -> &'static RecordType { + pub(crate) fn any_array_prototype_ty() -> &'static RecordType { static ANY_ARRAY_PROTOTYPE_TY: OnceLock = OnceLock::new(); #[expect(clippy::expect_used)] @@ -486,7 +486,7 @@ impl WasmCodec { }) } - fn array_prototype_into_wasm( + pub(crate) fn array_prototype_into_wasm( dtype: AnyArrayDType, shape: &[usize], ) -> Result { @@ -522,7 +522,7 @@ impl WasmCodec { .map_err(RuntimeError::from) } - fn with_array_view_from_wasm_record( + pub(crate) fn with_array_view_from_wasm_record( record: &Record, with: impl for<'a> FnOnce(AnyArrayView<'a>) -> Result, ) -> Result { @@ -602,4 +602,52 @@ impl WasmCodec { with(array) } + + pub(crate) fn array_prototype_from_wasm_record( + record: &Record, + ) -> Result { + let Some(Value::Variant(dtype)) = record.field("dtype") else { + return Err(RuntimeError::from(anyhow::Error::msg(format!( + "{record:?} is missing dtype field" + )))); + }; + if let Some(ty) = dtype.value() { + return Err(RuntimeError::from(anyhow::Error::msg(format!( + "{record:?} has an invalid dtype variant type {ty:?}" + )))); + } + + let dtype = match dtype.discriminant() { + 0 => AnyArrayDType::U8, + 1 => AnyArrayDType::U16, + 2 => AnyArrayDType::U32, + 3 => AnyArrayDType::U64, + 4 => AnyArrayDType::I8, + 5 => AnyArrayDType::I16, + 6 => AnyArrayDType::I32, + 7 => AnyArrayDType::I64, + 8 => AnyArrayDType::F32, + 9 => AnyArrayDType::F64, + discriminant => { + return Err(RuntimeError::from(anyhow::Error::msg(format!( + "{record:?} has an invalid dtype variant [{discriminant}]" + )))); + } + }; + + let Some(Value::List(shape)) = record.field("shape") else { + return Err(RuntimeError::from(anyhow::Error::msg(format!( + "process result record {record:?} is missing shape field" + )))); + }; + let shape = shape + .typed::()? + .iter() + .copied() + .map(usize::try_from) + .collect::, _>>() + .map_err(anyhow::Error::new)?; + + Ok(AnyArray::zeros(dtype, &shape)) + } } diff --git a/crates/numcodecs-wasm-host/src/lib.rs b/crates/numcodecs-wasm-host/src/lib.rs index fa57892c2..bb0620450 100644 --- a/crates/numcodecs-wasm-host/src/lib.rs +++ b/crates/numcodecs-wasm-host/src/lib.rs @@ -24,9 +24,11 @@ mod codec; mod component; mod error; +mod registry; mod wit; pub use codec::WasmCodec; pub use component::WasmCodecComponent; pub use error::{CodecError, RuntimeError}; +pub use registry::add_registry_to_linker; pub use wit::NumcodecsWitInterfaces; diff --git a/crates/numcodecs-wasm-host/src/registry.rs b/crates/numcodecs-wasm-host/src/registry.rs new file mode 100644 index 000000000..96d73893a --- /dev/null +++ b/crates/numcodecs-wasm-host/src/registry.rs @@ -0,0 +1,562 @@ +use std::{error::Error, sync::Arc}; + +use numcodecs::{Codec, DynCodec, DynCodecType, ErasedDynCodec, ErasedDynCodecType}; +use numcodecs_registry::Registry; +use wasm_component_layer::{ + AsContext, AsContextMut, Func, FuncType, Linker, List, ListType, Record, RecordType, + ResourceOwn, ResourceType, ResultType, ResultValue, TypeIdentifier, Value, ValueType, +}; + +use crate::{WasmCodec, wit::NumcodecsWitInterfaces}; + +/// Adds the `registry` to the `linker` to define the `numcodecs:abc/registry` +/// interface. +/// +/// # Errors +/// +/// Errors if adding the `registry` to the `linker` fails. +#[expect(clippy::too_many_lines)] // FIXME +pub fn add_registry_to_linker( + linker: &mut Linker, + mut ctx: impl AsContextMut, + registry: impl Registry, +) -> Result<(), anyhow::Error> { + let NumcodecsWitInterfaces { + registry: numcodecs_registry_interface, + types: numcodecs_types_interface, + .. + } = NumcodecsWitInterfaces::get(); + + let registry = Arc::new(registry); + + let numcodecs_types_error_record = RecordType::new( + Some(TypeIdentifier::new( + "error", + Some(numcodecs_types_interface.clone()), + )), + [ + ("message", ValueType::String), + ("chain", ValueType::List(ListType::new(ValueType::String))), + ], + )?; + + let numcodecs_registry_instance = + linker.define_instance(numcodecs_registry_interface.clone())?; + + let numcodecs_registry_codec_resource = ResourceType::with_destructor( + ctx.as_context_mut(), + Some(TypeIdentifier::new( + "erased-dyn-codec", + Some(numcodecs_registry_interface.clone()), + )), + |_ctx, codec: ErasedDynCodec| { + std::mem::drop(codec); + Ok(()) + }, + )?; + + numcodecs_registry_instance.define_resource( + "erased-dyn-codec", + numcodecs_registry_codec_resource.clone(), + )?; + + let numcodecs_registry_codec_type_resource = ResourceType::with_destructor( + ctx.as_context_mut(), + Some(TypeIdentifier::new( + "erased-dyn-codec-type", + Some(numcodecs_registry_interface.clone()), + )), + |_ctx, codec_ty: ErasedDynCodecType| { + std::mem::drop(codec_ty); + Ok(()) + }, + )?; + + numcodecs_registry_instance.define_resource( + "erased-dyn-codec-type", + numcodecs_registry_codec_type_resource.clone(), + )?; + + let any_array_record = WasmCodec::any_array_ty().clone(); + + let any_array_result = ResultType::new( + Some(ValueType::Record(any_array_record.clone())), + Some(ValueType::Record(numcodecs_types_error_record.clone())), + ); + + let my_any_array_result = any_array_result.clone(); + let my_numcodecs_types_error_record = numcodecs_types_error_record.clone(); + let codec_encode = Func::new( + ctx.as_context_mut(), + FuncType::new( + [ + ValueType::Borrow(numcodecs_registry_codec_resource.clone()), + ValueType::Record(any_array_record.clone()), + ], + [ValueType::Result(any_array_result.clone())], + ), + move |ctx, args, results| { + let [Value::Borrow(codec), Value::Record(data)] = args else { + anyhow::bail!( + "invalid numcodecs:abc/registry#[method]erased-dyn-codec.encode arguments" + ); + }; + + let [result] = results else { + anyhow::bail!( + "invalid numcodecs:abc/registry#[method]erased-dyn-codec.encode results" + ); + }; + + let encoded = WasmCodec::with_array_view_from_wasm_record(data, |data| { + let ctx = ctx.as_context(); + let codec: &ErasedDynCodec = codec.rep(&ctx)?; + + let encoded = codec.encode(data.cow()).map_err(anyhow::Error::new)?; + Ok(encoded) + }); + + let encoded = match encoded { + Ok(encoded) => Ok(WasmCodec::array_into_wasm(encoded.view())?), + Err(err) => Err(into_wit_error(err, &my_numcodecs_types_error_record)?), + }; + + let res = match encoded { + Ok(encoded) => Ok(Some(Value::Record(encoded))), + Err(err) => Err(Some(Value::Record(err))), + }; + + *result = Value::Result(ResultValue::new(my_any_array_result.clone(), res)?); + + Ok(()) + }, + ); + numcodecs_registry_instance.define_func("[method]erased-dyn-codec.encode", codec_encode)?; + + let my_any_array_result = any_array_result.clone(); + let my_numcodecs_types_error_record = numcodecs_types_error_record.clone(); + let codec_decode = Func::new( + ctx.as_context_mut(), + FuncType::new( + [ + ValueType::Borrow(numcodecs_registry_codec_resource.clone()), + ValueType::Record(any_array_record.clone()), + ], + [ValueType::Result(any_array_result.clone())], + ), + move |ctx, args, results| { + let [Value::Borrow(codec), Value::Record(encoded)] = args else { + anyhow::bail!( + "invalid numcodecs:abc/registry#[method]erased-dyn-codec.decode arguments" + ); + }; + + let [result] = results else { + anyhow::bail!( + "invalid numcodecs:abc/registry#[method]erased-dyn-codec.decode results" + ); + }; + + let decoded = WasmCodec::with_array_view_from_wasm_record(encoded, |encoded| { + let ctx = ctx.as_context(); + let codec: &ErasedDynCodec = codec.rep(&ctx)?; + + let decoded = codec.decode(encoded.cow()).map_err(anyhow::Error::new)?; + Ok(decoded) + }); + + let decoded = match decoded { + Ok(decoded) => Ok(WasmCodec::array_into_wasm(decoded.view())?), + Err(err) => Err(into_wit_error(err, &my_numcodecs_types_error_record)?), + }; + + let res = match decoded { + Ok(decoded) => Ok(Some(Value::Record(decoded))), + Err(err) => Err(Some(Value::Record(err))), + }; + + *result = Value::Result(ResultValue::new(my_any_array_result.clone(), res)?); + + Ok(()) + }, + ); + numcodecs_registry_instance.define_func("[method]erased-dyn-codec.decode", codec_decode)?; + + let any_array_prototype_record = WasmCodec::any_array_prototype_ty().clone(); + + let my_any_array_result = any_array_result.clone(); + let my_numcodecs_types_error_record = numcodecs_types_error_record.clone(); + let codec_decode_into = Func::new( + ctx.as_context_mut(), + FuncType::new( + [ + ValueType::Borrow(numcodecs_registry_codec_resource.clone()), + ValueType::Record(any_array_record), + ValueType::Record(any_array_prototype_record), + ], + [ValueType::Result(any_array_result)], + ), + move |ctx, args, results| { + let [ + Value::Borrow(codec), + Value::Record(encoded), + Value::Record(decoded), + ] = args + else { + anyhow::bail!( + "invalid numcodecs:abc/registry#[method]erased-dyn-codec.decode-into arguments" + ); + }; + + let [result] = results else { + anyhow::bail!( + "invalid numcodecs:abc/registry#[method]erased-dyn-codec.decode-into results" + ); + }; + + let mut decoded = WasmCodec::array_prototype_from_wasm_record(decoded)?; + + let res = WasmCodec::with_array_view_from_wasm_record(encoded, |encoded| { + let ctx = ctx.as_context(); + let codec: &ErasedDynCodec = codec.rep(&ctx)?; + + codec + .decode_into(encoded, decoded.view_mut()) + .map_err(anyhow::Error::new)?; + Ok(()) + }); + + let decoded = match res { + Ok(()) => Ok(WasmCodec::array_into_wasm(decoded.view())?), + Err(err) => Err(into_wit_error(err, &my_numcodecs_types_error_record)?), + }; + + let res = match decoded { + Ok(decoded) => Ok(Some(Value::Record(decoded))), + Err(err) => Err(Some(Value::Record(err))), + }; + + *result = Value::Result(ResultValue::new(my_any_array_result.clone(), res)?); + + Ok(()) + }, + ); + numcodecs_registry_instance + .define_func("[method]erased-dyn-codec.decode-into", codec_decode_into)?; + + let my_numcodecs_registry_codec_resource = numcodecs_registry_codec_resource.clone(); + let codec_clone = Func::new( + ctx.as_context_mut(), + FuncType::new( + [ValueType::Borrow(numcodecs_registry_codec_resource.clone())], + [ValueType::Own(numcodecs_registry_codec_resource.clone())], + ), + move |ctx, args, results| { + let [Value::Borrow(codec)] = args else { + anyhow::bail!( + "invalid numcodecs:abc/registry#[method]erased-dyn-codec.clone arguments" + ); + }; + + let [result] = results else { + anyhow::bail!( + "invalid numcodecs:abc/registry#[method]erased-dyn-codec.clone results" + ); + }; + + let codec = { + let ctx = ctx.as_context(); + let codec: &ErasedDynCodec = codec.rep(&ctx)?; + codec.clone() + }; + + *result = Value::Own(ResourceOwn::new( + ctx, + codec, + my_numcodecs_registry_codec_resource.clone(), + )?); + + Ok(()) + }, + ); + numcodecs_registry_instance.define_func("[method]erased-dyn-codec.clone", codec_clone)?; + + let string_result = ResultType::new( + Some(ValueType::String), + Some(ValueType::Record(numcodecs_types_error_record.clone())), + ); + + let my_numcodecs_types_error_record = numcodecs_types_error_record.clone(); + let codec_get_config = Func::new( + ctx.as_context_mut(), + FuncType::new( + [ValueType::Borrow(numcodecs_registry_codec_resource.clone())], + [ValueType::Result(string_result.clone())], + ), + move |ctx, args, results| { + let [Value::Borrow(codec)] = args else { + anyhow::bail!( + "invalid numcodecs:abc/registry#[method]erased-dyn-codec.get-config arguments" + ); + }; + + let [result] = results else { + anyhow::bail!( + "invalid numcodecs:abc/registry#[method]erased-dyn-codec.get-config results" + ); + }; + + let config = { + let ctx = ctx.as_context(); + let codec: &ErasedDynCodec = codec.rep(&ctx)?; + + let mut config_bytes = Vec::new(); + match codec.get_config(&mut serde_json::Serializer::new(&mut config_bytes)) { + Ok(()) => match String::from_utf8(config_bytes) { + Ok(config) => Ok(config), + Err(err) => Err(into_wit_error(err, &my_numcodecs_types_error_record)?), + }, + Err(err) => Err(into_wit_error(err, &my_numcodecs_types_error_record)?), + } + }; + + let res = match config { + Ok(config) => Ok(Some(Value::String(Arc::from(config)))), + Err(err) => Err(Some(Value::Record(err))), + }; + + *result = Value::Result(ResultValue::new(string_result.clone(), res)?); + + Ok(()) + }, + ); + numcodecs_registry_instance + .define_func("[method]erased-dyn-codec.get-config", codec_get_config)?; + + let my_numcodecs_registry_codec_type_resource = numcodecs_registry_codec_type_resource.clone(); + let codec_ty = Func::new( + ctx.as_context_mut(), + FuncType::new( + [ValueType::Borrow(numcodecs_registry_codec_resource.clone())], + [ValueType::Own( + numcodecs_registry_codec_type_resource.clone(), + )], + ), + move |ctx, args, results| { + let [Value::Borrow(codec)] = args else { + anyhow::bail!( + "invalid numcodecs:abc/registry#[method]erased-dyn-codec.ty arguments" + ); + }; + + let [result] = results else { + anyhow::bail!("invalid numcodecs:abc/registry#[method]erased-dyn-codec.ty results"); + }; + + let ty = { + let ctx = ctx.as_context(); + let codec: &ErasedDynCodec = codec.rep(&ctx)?; + codec.ty() + }; + + *result = Value::Own(ResourceOwn::new( + ctx, + ty, + my_numcodecs_registry_codec_type_resource.clone(), + )?); + + Ok(()) + }, + ); + numcodecs_registry_instance.define_func("[method]erased-dyn-codec.ty", codec_ty)?; + + let codec_type_id = Func::new( + ctx.as_context_mut(), + FuncType::new( + [ValueType::Borrow( + numcodecs_registry_codec_type_resource.clone(), + )], + [ValueType::String], + ), + move |ctx, args, results| { + let [Value::Borrow(ty)] = args else { + anyhow::bail!( + "invalid numcodecs:abc/registry#[method]erased-dyn-codec-type.codec-id arguments" + ); + }; + + let [result] = results else { + anyhow::bail!( + "invalid numcodecs:abc/registry#[method]erased-dyn-codectype.codec-id results" + ); + }; + + let ctx = ctx.as_context(); + let ty: &ErasedDynCodecType = ty.rep(&ctx)?; + + *result = Value::String(Arc::from(ty.codec_id())); + + Ok(()) + }, + ); + numcodecs_registry_instance + .define_func("[method]erased-dyn-codec-type.codec-id", codec_type_id)?; + + let codec_type_schema = Func::new( + ctx.as_context_mut(), + FuncType::new( + [ValueType::Borrow( + numcodecs_registry_codec_type_resource.clone(), + )], + [ValueType::String], + ), + move |ctx, args, results| { + let [Value::Borrow(ty)] = args else { + anyhow::bail!( + "invalid numcodecs:abc/registry#[method]erased-dyn-codec-type.codec-config-schema arguments" + ); + }; + + let [result] = results else { + anyhow::bail!( + "invalid numcodecs:abc/registry#[method]erased-dyn-codectype.codec-config-schema results" + ); + }; + + let ctx = ctx.as_context(); + let ty: &ErasedDynCodecType = ty.rep(&ctx)?; + + *result = Value::String(Arc::from(ty.codec_config_schema().to_value().to_string())); + + Ok(()) + }, + ); + numcodecs_registry_instance.define_func( + "[method]erased-dyn-codec-type.codec-config-schema", + codec_type_schema, + )?; + + let codec_result = ResultType::new( + Some(ValueType::Own(numcodecs_registry_codec_resource.clone())), + Some(ValueType::Record(numcodecs_types_error_record.clone())), + ); + + let my_numcodecs_registry_codec_resource = numcodecs_registry_codec_resource.clone(); + let my_numcodecs_types_error_record = numcodecs_types_error_record.clone(); + let my_codec_result = codec_result.clone(); + let codec_from_config = Func::new( + ctx.as_context_mut(), + FuncType::new( + [ + ValueType::Borrow(numcodecs_registry_codec_type_resource), + ValueType::String, + ], + [ValueType::Result(my_codec_result.clone())], + ), + move |ctx, args, results| { + let [Value::Borrow(ty), Value::String(config)] = args else { + anyhow::bail!( + "invalid numcodecs:abc/registry#[method]erased-dyn-codec-type.codec-from-config arguments" + ); + }; + + let [result] = results else { + anyhow::bail!( + "invalid numcodecs:abc/registry#[method]erased-dyn-codectype.codec-from-config results" + ); + }; + + let res = { + let ctx = ctx.as_context(); + let ty: &ErasedDynCodecType = ty.rep(&ctx)?; + ty.codec_from_config(&mut serde_json::Deserializer::from_str(config)) + }; + + let res = match res { + Ok(codec) => Ok(Some(Value::Own(ResourceOwn::new( + ctx, + codec, + my_numcodecs_registry_codec_resource.clone(), + )?))), + Err(err) => Err(Some(Value::Record(into_wit_error( + err, + &my_numcodecs_types_error_record, + )?))), + }; + + *result = Value::Result(ResultValue::new(my_codec_result.clone(), res)?); + + Ok(()) + }, + ); + numcodecs_registry_instance.define_func( + "[method]erased-dyn-codec-type.codec-from-config", + codec_from_config, + )?; + + let my_numcodecs_registry_codec_resource = numcodecs_registry_codec_resource; + let my_numcodecs_types_error_record = numcodecs_types_error_record; + let my_codec_result = codec_result; + let get_codec = Func::new( + ctx, + FuncType::new( + [ValueType::String], + [ValueType::Result(my_codec_result.clone())], + ), + move |ctx, args, results| { + let [Value::String(config)] = args else { + anyhow::bail!("invalid numcodecs:abc/registry#get-codec arguments"); + }; + + let [result] = results else { + anyhow::bail!("invalid numcodecs:abc/registry#get-codec results"); + }; + + let res = match registry.get_codec(&mut serde_json::Deserializer::from_str(config)) { + Ok(codec) => Ok(Some(Value::Own(ResourceOwn::new( + ctx, + codec, + my_numcodecs_registry_codec_resource.clone(), + )?))), + Err(err) => Err(Some(Value::Record(into_wit_error( + err, + &my_numcodecs_types_error_record, + )?))), + }; + + *result = Value::Result(ResultValue::new(my_codec_result.clone(), res)?); + + Ok(()) + }, + ); + numcodecs_registry_instance.define_func("get-codec", get_codec)?; + + Ok(()) +} + +fn into_wit_error(err: T, ty: &RecordType) -> Result { + let mut source: Option<&dyn Error> = err.source(); + + let message = Value::String(Arc::from(format!("{err}"))); + let mut chain = if source.is_some() { + Vec::with_capacity(4) + } else { + Vec::new() + }; + + while let Some(err) = source.take() { + chain.push(Value::String(Arc::from(format!("{err}")))); + source = err.source(); + } + + Record::new( + ty.clone(), + [ + ("message", message), + ( + "chain", + Value::List(List::new(ListType::new(ValueType::String), chain)?), + ), + ], + ) +} diff --git a/crates/numcodecs-wasm-host/src/wit.rs b/crates/numcodecs-wasm-host/src/wit.rs index d8fdc0a18..1204ad78d 100644 --- a/crates/numcodecs-wasm-host/src/wit.rs +++ b/crates/numcodecs-wasm-host/src/wit.rs @@ -8,8 +8,14 @@ use crate::error::{CodecError, RuntimeError}; /// WebAssembly Interface Type (WIT) interfaces for `numcodecs` #[non_exhaustive] pub struct NumcodecsWitInterfaces { + /// The `numcodecs:abc` package + pub package: PackageIdentifier, /// The `numcodecs:abc/codec` interface pub codec: InterfaceIdentifier, + /// The `numcodecs:abc/registry` interface + pub registry: InterfaceIdentifier, + /// The `numcodecs:abc/types` interface + pub types: InterfaceIdentifier, } impl NumcodecsWitInterfaces { @@ -18,14 +24,17 @@ impl NumcodecsWitInterfaces { pub fn get() -> &'static Self { static NUMCODECS_WIT_INTERFACES: OnceLock = OnceLock::new(); - NUMCODECS_WIT_INTERFACES.get_or_init(|| Self { - codec: InterfaceIdentifier::new( - PackageIdentifier::new( - PackageName::new("numcodecs", "abc"), - Some(Version::new(0, 1, 1)), - ), - "codec", - ), + NUMCODECS_WIT_INTERFACES.get_or_init(|| { + let package = PackageIdentifier::new( + PackageName::new("numcodecs", "abc"), + Some(Version::new(0, 1, 1)), + ); + Self { + package: package.clone(), + codec: InterfaceIdentifier::new(package.clone(), "codec"), + registry: InterfaceIdentifier::new(package.clone(), "registry"), + types: InterfaceIdentifier::new(package, "types"), + } }) } } diff --git a/crates/numcodecs/Cargo.toml b/crates/numcodecs/Cargo.toml index daf48e7b5..70bd6b1b2 100644 --- a/crates/numcodecs/Cargo.toml +++ b/crates/numcodecs/Cargo.toml @@ -15,6 +15,7 @@ keywords = ["numcodecs", "compression", "encoding"] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +erased-serde = { workspace = true, features = ["std"] } ndarray = { workspace = true } schemars = { workspace = true, features = ["derive"] } semver = { workspace = true, features = ["std", "serde"] } diff --git a/crates/numcodecs/src/codec.rs b/crates/numcodecs/src/codec.rs index ee87e241d..fc96940a3 100644 --- a/crates/numcodecs/src/codec.rs +++ b/crates/numcodecs/src/codec.rs @@ -1,4 +1,4 @@ -use std::{borrow::Cow, error::Error, fmt, marker::PhantomData}; +use std::{any::Any, borrow::Cow, error::Error, fmt, marker::PhantomData}; use schemars::{JsonSchema, Schema, SchemaGenerator, generate::SchemaSettings, json_schema}; use semver::{Version, VersionReq}; @@ -169,6 +169,300 @@ impl DynCodecType for StaticCodecType { } } +/// Type-erased [`Error`] type. +pub struct ErasedError { + error: Box, +} + +impl ErasedError { + /// Erase the type information of the concrete `err`or. + pub fn new(err: T) -> Self { + Self { + error: Box::new(err), + } + } +} + +impl fmt::Debug for ErasedError { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + fmt::Debug::fmt(&self.error, fmt) + } +} + +impl fmt::Display for ErasedError { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + fmt::Display::fmt(&self.error, fmt) + } +} + +impl Error for ErasedError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + self.error.source() + } +} + +/// Type-erased dynamically typed compression codec. +pub struct ErasedDynCodec { + codec: Box, +} + +impl ErasedDynCodec { + /// Erase the type information of the concrete `codec`. + pub fn new(codec: T) -> Self { + Self { + codec: Box::new(codec), + } + } + + /// Try to downcast into a concretely-typed codec. + /// + /// # Errors + /// + /// Returns `self` if the type-erased codec is not of the concrete type. + pub fn downcast(self) -> Result { + if self.codec.erased_as_any().is::() { + let raw = Box::into_raw(self.codec); + #[expect(unsafe_code)] + // SAFETY: we have checked that self.codec is of type T + let codec = unsafe { Box::from_raw(raw.cast::()) }; + Ok(*codec) + } else { + Err(self) + } + } + + /// Try to downcast to a concretely-typed codec reference. + #[must_use] + pub fn downcast_ref(&self) -> Option<&T> { + self.codec.erased_as_any().downcast_ref() + } + + /// Try to downcast to a concretely-typed mutable codec reference. + #[must_use] + pub fn downcast_mut(&mut self) -> Option<&mut T> { + self.codec.erased_as_any_mut().downcast_mut() + } +} + +impl Clone for ErasedDynCodec { + fn clone(&self) -> Self { + Self { + codec: self.codec.erased_clone(), + } + } +} + +impl Codec for ErasedDynCodec { + type Error = ErasedError; + + fn encode(&self, data: AnyCowArray) -> Result { + self.codec.erased_encode(data) + } + + fn decode(&self, encoded: AnyCowArray) -> Result { + self.codec.erased_decode(encoded) + } + + fn decode_into( + &self, + encoded: AnyArrayView, + decoded: AnyArrayViewMut, + ) -> Result<(), Self::Error> { + self.codec.erased_decode_into(encoded, decoded) + } +} + +impl DynCodec for ErasedDynCodec { + type Type = ErasedDynCodecType; + + fn ty(&self) -> Self::Type { + ErasedDynCodecType { + ty: self.codec.erased_ty(), + } + } + + fn get_config(&self, serializer: S) -> Result { + erased_serde::serialize(self.codec.erased_as_serialize(), serializer) + } +} + +/// Type-erased dynamically typed compression codec type. +pub struct ErasedDynCodecType { + ty: Box, +} + +impl ErasedDynCodecType { + /// Erase the type information of the concrete codec `ty`pe. + pub fn new(ty: T) -> Self { + Self { ty: Box::new(ty) } + } + + /// Try to downcast into a concretely-typed codec type. + /// + /// # Errors + /// + /// Returns `self` if the type-erased codec type is not of the concrete + /// type. + pub fn downcast(self) -> Result { + if self.ty.erased_as_any().is::() { + let raw = Box::into_raw(self.ty); + #[expect(unsafe_code)] + // SAFETY: we have checked that self.ty is of type T + let ty = unsafe { Box::from_raw(raw.cast::()) }; + Ok(*ty) + } else { + Err(self) + } + } + + /// Try to downcast to a concretely-typed codec type reference. + #[must_use] + pub fn downcast_ref(&self) -> Option<&T> { + self.ty.erased_as_any().downcast_ref() + } + + /// Try to downcast to a concretely-typed mutable codec type reference. + #[must_use] + pub fn downcast_mut(&mut self) -> Option<&mut T> { + self.ty.erased_as_any_mut().downcast_mut() + } +} + +impl DynCodecType for ErasedDynCodecType { + type Codec = ErasedDynCodec; + + fn codec_id(&self) -> &str { + self.ty.erased_codec_id() + } + + fn codec_config_schema(&self) -> Schema { + self.ty.erased_codec_config_schema() + } + + fn codec_from_config<'de, D: Deserializer<'de>>( + &self, + config: D, + ) -> Result { + match self + .ty + .erased_codec_from_config(&mut ::erase(config)) + { + Ok(codec) => Ok(ErasedDynCodec { codec }), + Err(err) => Err(serde::de::Error::custom(err)), // TODO: improve + } + } +} + +trait ErasedDynCodecDispatch: 'static + Send + Sync { + fn erased_encode(&self, data: AnyCowArray) -> Result; + fn erased_decode(&self, encoded: AnyCowArray) -> Result; + fn erased_decode_into( + &self, + encoded: AnyArrayView, + decoded: AnyArrayViewMut, + ) -> Result<(), ErasedError>; + + fn erased_clone(&self) -> Box; + + fn erased_ty(&self) -> Box; + + fn erased_as_any(&self) -> &dyn Any; + fn erased_as_any_mut(&mut self) -> &mut dyn Any; + + fn erased_as_serialize(&self) -> &dyn erased_serde::Serialize; +} + +trait ErasedDynCodecTypeDispatch: 'static + Send + Sync { + fn erased_codec_id(&self) -> &str; + fn erased_codec_config_schema(&self) -> Schema; + fn erased_codec_from_config( + &self, + config: &mut dyn erased_serde::Deserializer, + ) -> Result, erased_serde::Error>; + + fn erased_as_any(&self) -> &dyn Any; + fn erased_as_any_mut(&mut self) -> &mut dyn Any; +} + +impl ErasedDynCodecDispatch for T { + fn erased_encode(&self, data: AnyCowArray) -> Result { + Codec::encode(self, data).map_err(ErasedError::new) + } + + fn erased_decode(&self, encoded: AnyCowArray) -> Result { + Codec::decode(self, encoded).map_err(ErasedError::new) + } + + fn erased_decode_into( + &self, + encoded: AnyArrayView, + decoded: AnyArrayViewMut, + ) -> Result<(), ErasedError> { + Codec::decode_into(self, encoded, decoded).map_err(ErasedError::new) + } + + fn erased_clone(&self) -> Box { + Box::new(Clone::clone(self)) + } + + fn erased_ty(&self) -> Box { + Box::new(DynCodec::ty(self)) + } + + fn erased_as_any(&self) -> &dyn Any { + self + } + + fn erased_as_any_mut(&mut self) -> &mut dyn Any { + self + } + + fn erased_as_serialize(&self) -> &dyn erased_serde::Serialize { + #[repr(transparent)] + struct SerializeDynCodec(T); + + impl Serialize for SerializeDynCodec { + fn serialize(&self, serializer: S) -> Result { + DynCodec::get_config(&self.0, serializer) + } + } + + #[expect(unsafe_code)] + // SAFETY: SerializeDynCodec is a transparent newtype around Self + unsafe { + &*std::ptr::from_ref(self).cast::>() + } + } +} + +impl ErasedDynCodecTypeDispatch for T { + fn erased_codec_id(&self) -> &str { + DynCodecType::codec_id(self) + } + + fn erased_codec_config_schema(&self) -> Schema { + DynCodecType::codec_config_schema(self) + } + + fn erased_codec_from_config( + &self, + config: &mut dyn erased_serde::Deserializer, + ) -> Result, erased_serde::Error> { + match DynCodecType::codec_from_config(self, config) { + Ok(codec) => Ok(Box::new(codec)), + Err(err) => Err(err), + } + } + + fn erased_as_any(&self) -> &dyn Any { + self + } + + fn erased_as_any_mut(&mut self) -> &mut dyn Any { + self + } +} + /// Utility struct to serialize a [`StaticCodec`]'s [`StaticCodec::Config`] /// together with its [`StaticCodec::CODEC_ID`] #[derive(Serialize, Deserialize)] diff --git a/crates/numcodecs/src/lib.rs b/crates/numcodecs/src/lib.rs index b4db09d2b..d6d48221a 100644 --- a/crates/numcodecs/src/lib.rs +++ b/crates/numcodecs/src/lib.rs @@ -27,6 +27,7 @@ pub use array::{ AnyArrayViewMut, AnyCowArray, AnyRawData, ArrayDType, ArrayDataMutExt, }; pub use codec::{ - Codec, DynCodec, DynCodecType, StaticCodec, StaticCodecConfig, StaticCodecType, - StaticCodecVersion, codec_from_config_with_id, serialize_codec_config_with_id, + Codec, DynCodec, DynCodecType, ErasedDynCodec, ErasedDynCodecType, ErasedError, StaticCodec, + StaticCodecConfig, StaticCodecType, StaticCodecVersion, codec_from_config_with_id, + serialize_codec_config_with_id, }; diff --git a/py/numcodecs-wasm-template/pyproject.toml b/py/numcodecs-wasm-template/pyproject.toml index f317b8209..5e44ec501 100644 --- a/py/numcodecs-wasm-template/pyproject.toml +++ b/py/numcodecs-wasm-template/pyproject.toml @@ -16,7 +16,7 @@ license = { file = "LICENSE" } requires-python = ">=3.10" dependencies = [ - "numcodecs-wasm~=0.2.4", # wasi 0.2.6 + "numcodecs-wasm @ file:///Users/junityre/numcodecs-rs/py/numcodecs-wasm/dist/numcodecs_wasm-0.2.4-cp310-abi3-macosx_11_0_arm64.whl", # wasi 0.2.6 ] [project.entry-points."numcodecs.codecs"] diff --git a/py/numcodecs-wasm/Cargo.toml b/py/numcodecs-wasm/Cargo.toml index 2f3a65124..4d8ea9f21 100644 --- a/py/numcodecs-wasm/Cargo.toml +++ b/py/numcodecs-wasm/Cargo.toml @@ -29,6 +29,7 @@ anyhow = { workspace = true } # FIXME: https://github.com/bytecodealliance/rustix/issues/1620 memfd = { version = "0.6.5", default-features = false } numcodecs-python = { workspace = true } +numcodecs-registry = { workspace = true } numcodecs-wasm-host-reproducible = { workspace = true } pyo3 = { workspace = true, features = ["macros", "abi3-py310"] } pyo3-error = { workspace = true } diff --git a/py/numcodecs-wasm/src/lib.rs b/py/numcodecs-wasm/src/lib.rs index 66010b32e..02de59600 100644 --- a/py/numcodecs-wasm/src/lib.rs +++ b/py/numcodecs-wasm/src/lib.rs @@ -32,7 +32,8 @@ fn create_codec_class<'py>( ) -> Result, PyErr> { let engine = default_engine(py)?; - let codec_ty = ReproducibleWasmCodecType::new(engine, wasm) + // TODO: we should allow restricting the codecs that the reproducible codec can 'see' + let codec_ty = ReproducibleWasmCodecType::new(engine, wasm, numcodecs_python::PyCodecRegistry) .map_err(|err| pyo3_error::PyErrChain::new(py, err))?; let codec_class = numcodecs_python::export_codec_class(py, codec_ty, module.as_borrowed())?; @@ -60,3 +61,7 @@ fn read_codec_instruction_counter<'py>( Ok(instruction_counter.0) } + +numcodecs_registry::export_global! { + registry: numcodecs_python::PyCodecRegistry = numcodecs_python::PyCodecRegistry +} diff --git a/wit/codecs.wit b/wit/codecs.wit index 863623366..0e60b92d9 100644 --- a/wit/codecs.wit +++ b/wit/codecs.wit @@ -1,6 +1,8 @@ package numcodecs:abc@0.1.1; -interface codec { +// TODO: major version bump + since annotations + +interface types { type json = string; type json-schema = json; type usize = u32; @@ -47,6 +49,13 @@ interface codec { message: string, chain: list, } +} + +interface codec { + use types.{ + any-array, any-array-prototype, any-array-data, any-array-dtype, error, + json, json-schema, usize, + }; resource codec { from-config: static func(config: json) -> result; @@ -65,3 +74,35 @@ interface codec { codec-config-schema: func() -> json-schema; } + +interface registry { + use types.{ + any-array, any-array-prototype, any-array-data, any-array-dtype, error, + json, json-schema, usize, + }; + + resource erased-dyn-codec { + encode: func(data: any-array) -> result; + + decode: func(encoded: any-array) -> result; + + @since(version = 0.1.1) + decode-into: func(encoded: any-array, decoded: any-array-prototype) -> result; + + clone: func() -> erased-dyn-codec; + + get-config: func() -> result; + + ty: func() -> erased-dyn-codec-type; + } + + resource erased-dyn-codec-type { + codec-id: func() -> string; + + codec-config-schema: func() -> json-schema; + + codec-from-config: func(config: json) -> result; + } + + get-codec: func(config: json) -> result; +} diff --git a/wit/world.wit b/wit/world.wit index 44c9b58a9..1b5e1704d 100644 --- a/wit/world.wit +++ b/wit/world.wit @@ -2,8 +2,10 @@ package numcodecs:abc@0.1.1; world imports { import codec; + export registry; } world exports { + import registry; export codec; }