Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/numcodecs-python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ keywords = ["numcodecs", "compression", "encoding", "python", "pyo3"]
convert_case = { workspace = true }
ndarray = { workspace = true }
numcodecs = { workspace = true }
numcodecs-registry = { workspace = true }
numpy = { workspace = true }
pyo3 = { workspace = true }
pyo3-error = { workspace = true }
Expand Down
46 changes: 42 additions & 4 deletions crates/numcodecs-python/src/registry.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
use numcodecs::{DynCodec, ErasedDynCodec};
use numcodecs_registry::Registry;
use pyo3::{prelude::*, sync::PyOnceLock, types::PyDict};
use pythonize::Pythonizer;
use serde::Deserializer;
use serde_transcode::transcode;

#[expect(unused_imports)] // FIXME: use expect, only used in docs
use crate::PyCodecClassMethods;
use crate::{PyCodec, PyCodecClass};
use crate::{PyCodec, PyCodecAdapter, PyCodecClass};

/// Dynamic registry of codec classes.
pub struct PyCodecRegistry {
_private: (),
}
pub struct PyCodecRegistry;

impl PyCodecRegistry {
/// Instantiate a codec from a configuration dictionary.
Expand Down Expand Up @@ -56,3 +59,38 @@ impl PyCodecRegistry {
Ok(())
}
}

impl Registry for PyCodecRegistry {
type Error = PyErr;

fn get_codec<'de, D: Deserializer<'de>>(
&self,
config: D,
) -> Result<ErasedDynCodec, Self::Error> {
Python::attach(|py| {
let config = transcode(config, Pythonizer::new(py))?;
let config: Bound<PyDict> = config.extract()?;

let codec = Self::get_codec(config.as_borrowed())?;
let codec = PyCodecAdapter::from_codec(codec)?;

Ok(ErasedDynCodec::new(codec))
})
}

fn get_codec_typed<'de, T: DynCodec, D: Deserializer<'de>>(
&self,
config: D,
) -> Result<Option<T>, Self::Error> {
Python::attach(|py| {
let config = transcode(config, Pythonizer::new(py))?;
let config: Bound<PyDict> = config.extract()?;

let codec = Self::get_codec(config.as_borrowed())?;
// clone is necessary since we cannot move out of a PyCodec
let codec = PyCodecAdapter::with_downcast(py, &codec, |codec: &T| codec.clone());

Ok(codec)
})
}
}
1 change: 1 addition & 0 deletions crates/numcodecs-wasm-host-reproducible/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ numcodecs-wasm-host = { workspace = true }
indexmap = { workspace = true, features = ["std"] }
log = { workspace = true }
numcodecs = { workspace = true }
numcodecs-registry = { workspace = true }
polonius-the-crab = { workspace = true }
schemars = { workspace = true }
semver = { workspace = true }
Expand Down
9 changes: 9 additions & 0 deletions crates/numcodecs-wasm-host-reproducible/src/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::sync::{Arc, Mutex};
use numcodecs::{
AnyArray, AnyArrayView, AnyArrayViewMut, AnyCowArray, Codec, DynCodec, DynCodecType,
};
use numcodecs_registry::Registry;
use numcodecs_wasm_host::{CodecError, RuntimeError, WasmCodec, WasmCodecComponent};
use schemars::Schema;
use serde::Serializer;
Expand Down Expand Up @@ -323,6 +324,7 @@ where
pub fn new(
engine: E,
wasm_component: impl Into<Vec<u8>>,
registry: impl Registry,
) -> Result<Self, ReproducibleWasmCodecError>
where
E: Send + Sync,
Expand All @@ -343,6 +345,8 @@ where
}
})?;

let registry = Arc::new(registry);

let component_instantiater = Arc::new(move |component: &Component, codec_id: &str| {
let mut store = Store::new(&engine, ());

Expand All @@ -359,6 +363,11 @@ where
source: RuntimeError::from(err),
}
})?;
numcodecs_wasm_host::add_registry_to_linker(&mut linker, &mut store, registry.clone())
.map_err(|err| ReproducibleWasmCodecError::Runtime {
codec_id: Arc::from(codec_id),
source: RuntimeError::from(err),
})?;

let instance = linker.instantiate(&mut store, component).map_err(|err| {
ReproducibleWasmCodecError::Runtime {
Expand Down
7 changes: 6 additions & 1 deletion crates/numcodecs-wasm-host-reproducible/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use ndarray::Array;
use ndarray_rand::RandomExt;
use ndarray_rand::rand_distr::Normal;
use numcodecs::{Codec, DynCodecType};
use numcodecs_registry::EmptyRegistry;

use crate::ReproducibleWasmCodecType;

Expand Down Expand Up @@ -53,7 +54,11 @@ fn codec_roundtrip() {

let engine = wasmtime_runtime_layer::Engine::new(wasmtime::Engine::new(&config).unwrap());

let ty = match ReproducibleWasmCodecType::new(engine, include_bytes!("../tests/round.wasm")) {
let ty = match ReproducibleWasmCodecType::new(
engine,
include_bytes!("../tests/round.wasm"),
EmptyRegistry,
) {
Ok(ty) => ty,
Err(err) => panic!(
"ReproducibleWasmCodecType::new:\n===\n{err}\n===\n{err:?}\n===\n{err:#}\n===\n{err:#?}\n===\n"
Expand Down
45 changes: 31 additions & 14 deletions crates/numcodecs-wasm-host-reproducible/src/transform/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::sync::OnceLock;
use anyhow::{Context, Error, anyhow};
use instcnt::PerfWitInterfaces;
use numcodecs_wasm_host::NumcodecsWitInterfaces;
use wac_graph::AliasError;

use crate::{logging::WasiLoggingInterface, stdio::WasiSandboxedStdioInterface};

Expand All @@ -12,7 +13,11 @@ pub mod nan;
#[expect(clippy::too_many_lines)] // FIXME
pub fn transform_wasm_component(wasm_component: impl Into<Vec<u8>>) -> Result<Vec<u8>, Error> {
let NumcodecsWitInterfaces {
codec: codec_interface,
package: numcodecs_package,
codec: numcodecs_codec_interface,
codec_v0_1_1: numcodecs_v0_1_1_codec_interface,
registry: numcodecs_registry_interface,
types: numcodecs_types_interface,
..
} = NumcodecsWitInterfaces::get();

Expand All @@ -24,41 +29,46 @@ pub fn transform_wasm_component(wasm_component: impl Into<Vec<u8>>) -> Result<Ve
} = get_prepared_composition_graph()?;
let mut wac = wac.clone();

// parse and instantiate the root package, which exports numcodecs:abc/codec
let numcodecs_codec_package = wac_graph::types::Package::from_bytes(
&format!("{}", codec_interface.package().name()),
codec_interface.package().version(),
// parse and instantiate the root numcodecs:abc package, which
// - exports the numcodecs:abc/codec interface
// - imports the numcodecs:abc/registry interface
let numcodecs_package = wac_graph::types::Package::from_bytes(
&format!("{}", numcodecs_package.name()),
numcodecs_package.version(),
wasm_component,
wac.types_mut(),
)?;

let numcodecs_codec_world = &wac.types()[numcodecs_codec_package.ty()];
let numcodecs_codec_imports = extract_component_ports(&numcodecs_codec_world.imports)?;
let numcodecs_world = &wac.types()[numcodecs_package.ty()];
let numcodecs_imports = extract_component_ports(&numcodecs_world.imports)?;

let numcodecs_codec_package = wac.register_package(numcodecs_codec_package)?;
let numcodecs_codec_instance = wac.instantiate(numcodecs_codec_package);
let numcodecs_package = wac.register_package(numcodecs_package)?;
let numcodecs_instance = wac.instantiate(numcodecs_package);

// list the imports that the linker will provide
let linker_provided_imports = [
&WasiSandboxedStdioInterface::get().stdio,
&WasiLoggingInterface::get().logging,
numcodecs_registry_interface,
// numcodecs:abc/types is a types-only interface
numcodecs_types_interface,
];

// initialise the unresolved imports to the imports of the root package
let mut unresolved_imports = vecmap::VecMap::new();
for import in &numcodecs_codec_imports {
for import in &numcodecs_imports {
unresolved_imports
.entry(import.clone())
.or_insert_with(Vec::new)
.push(numcodecs_codec_instance);
.push(numcodecs_instance);
}

// track all non-root instances, which may fulfil imports
let mut package_instances = vecmap::VecMap::new();

// initialise the queue of required, still to instantiate packages
// to the imports of the root package
let mut required_packages_queue = numcodecs_codec_imports
let mut required_packages_queue = numcodecs_imports
.iter()
.map(|import| import.package().clone())
.collect::<std::collections::VecDeque<_>>();
Expand Down Expand Up @@ -148,9 +158,16 @@ pub fn transform_wasm_component(wasm_component: impl Into<Vec<u8>>) -> Result<Ve
}

// export the numcodecs:abc/codec interface
let numcodecs_codecs_str = &format!("{codec_interface}");
let numcodecs_codecs_str = &format!("{numcodecs_codec_interface}");
let numcodecs_codecs_export =
wac.alias_instance_export(numcodecs_codec_instance, numcodecs_codecs_str)?;
match wac.alias_instance_export(numcodecs_instance, numcodecs_codecs_str) {
Ok(numcodecs_codecs_export) => numcodecs_codecs_export,
Err(AliasError::InstanceMissingExport { .. }) => wac.alias_instance_export(
numcodecs_instance,
&format!("{numcodecs_v0_1_1_codec_interface}"),
)?,
Err(err) => Err(err)?,
};
wac.export(numcodecs_codecs_export, numcodecs_codecs_str)?;

// encode the WAC composition graph into a WASM component and validate it
Expand Down
1 change: 1 addition & 0 deletions crates/numcodecs-wasm-host/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ keywords = ["numcodecs", "compression", "encoding", "wasm-component", "wasm-bind
anyhow = { workspace = true }
ndarray = { workspace = true, features = ["std"] }
numcodecs = { workspace = true }
numcodecs-registry = { workspace = true }
schemars = { workspace = true }
semver = { workspace = true }
serde = { workspace = true }
Expand Down
62 changes: 55 additions & 7 deletions crates/numcodecs-wasm-host/src/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ impl WasmCodec {
}
}

fn any_array_data_ty() -> &'static VariantType {
pub(crate) fn any_array_data_ty() -> &'static VariantType {
static ANY_ARRAY_DATA_TY: OnceLock<VariantType> = OnceLock::new();

#[expect(clippy::expect_used)]
Expand All @@ -337,7 +337,7 @@ impl WasmCodec {
})
}

fn any_array_ty() -> &'static RecordType {
pub(crate) fn any_array_ty() -> &'static RecordType {
static ANY_ARRAY_TY: OnceLock<RecordType> = OnceLock::new();

#[expect(clippy::expect_used)]
Expand All @@ -359,7 +359,7 @@ impl WasmCodec {
}

#[expect(clippy::needless_pass_by_value)]
fn array_into_wasm(array: AnyArrayView) -> Result<Record, RuntimeError> {
pub(crate) fn array_into_wasm(array: AnyArrayView) -> Result<Record, RuntimeError> {
fn list_from_standard_layout<'a, T: 'static + Copy, S: Data<Elem = T>, D: Dimension>(
array: &'a ArrayBase<S, D>,
) -> List
Expand Down Expand Up @@ -451,7 +451,7 @@ impl WasmCodec {
.map_err(RuntimeError::from)
}

fn any_array_dtype_ty() -> &'static EnumType {
pub(crate) fn any_array_dtype_ty() -> &'static EnumType {
static ANY_ARRAY_DTYPE_TY: OnceLock<EnumType> = OnceLock::new();

#[expect(clippy::expect_used)]
Expand All @@ -468,7 +468,7 @@ impl WasmCodec {
})
}

fn any_array_prototype_ty() -> &'static RecordType {
pub(crate) fn any_array_prototype_ty() -> &'static RecordType {
static ANY_ARRAY_PROTOTYPE_TY: OnceLock<RecordType> = OnceLock::new();

#[expect(clippy::expect_used)]
Expand All @@ -486,7 +486,7 @@ impl WasmCodec {
})
}

fn array_prototype_into_wasm(
pub(crate) fn array_prototype_into_wasm(
dtype: AnyArrayDType,
shape: &[usize],
) -> Result<Record, RuntimeError> {
Expand Down Expand Up @@ -522,7 +522,7 @@ impl WasmCodec {
.map_err(RuntimeError::from)
}

fn with_array_view_from_wasm_record<O>(
pub(crate) fn with_array_view_from_wasm_record<O>(
record: &Record,
with: impl for<'a> FnOnce(AnyArrayView<'a>) -> Result<O, RuntimeError>,
) -> Result<O, RuntimeError> {
Expand Down Expand Up @@ -602,4 +602,52 @@ impl WasmCodec {

with(array)
}

pub(crate) fn array_prototype_from_wasm_record(
record: &Record,
) -> Result<AnyArray, RuntimeError> {
let Some(Value::Variant(dtype)) = record.field("dtype") else {
return Err(RuntimeError::from(anyhow::Error::msg(format!(
"{record:?} is missing dtype field"
))));
};
if let Some(ty) = dtype.value() {
return Err(RuntimeError::from(anyhow::Error::msg(format!(
"{record:?} has an invalid dtype variant type {ty:?}"
))));
}

let dtype = match dtype.discriminant() {
0 => AnyArrayDType::U8,
1 => AnyArrayDType::U16,
2 => AnyArrayDType::U32,
3 => AnyArrayDType::U64,
4 => AnyArrayDType::I8,
5 => AnyArrayDType::I16,
6 => AnyArrayDType::I32,
7 => AnyArrayDType::I64,
8 => AnyArrayDType::F32,
9 => AnyArrayDType::F64,
discriminant => {
return Err(RuntimeError::from(anyhow::Error::msg(format!(
"{record:?} has an invalid dtype variant [{discriminant}]"
))));
}
};

let Some(Value::List(shape)) = record.field("shape") else {
return Err(RuntimeError::from(anyhow::Error::msg(format!(
"process result record {record:?} is missing shape field"
))));
};
let shape = shape
.typed::<u32>()?
.iter()
.copied()
.map(usize::try_from)
.collect::<Result<Vec<_>, _>>()
.map_err(anyhow::Error::new)?;

Ok(AnyArray::zeros(dtype, &shape))
}
}
Loading
Loading