From 379c51fc970a367a11bbb90867feca1a6619936d Mon Sep 17 00:00:00 2001 From: xiao312 Date: Tue, 31 Mar 2026 16:56:29 +0800 Subject: [PATCH] refactor: move augment and label into data package --- dfode_kit/cli/commands/augment.py | 3 +- dfode_kit/cli/commands/label.py | 2 +- dfode_kit/data/__init__.py | 4 + dfode_kit/data/augment.py | 167 ++++++++++++++++++ dfode_kit/data/integration.py | 8 +- .../label_data.py => data/label.py} | 24 +-- dfode_kit/data_operations/__init__.py | 49 ----- dfode_kit/data_operations/augment_data.py | 149 ---------------- dfode_kit/data_operations/contracts.py | 1 - dfode_kit/data_operations/h5_kit.py | 37 ---- tests/test_data_contracts.py | 7 +- tests/test_data_integration_shims.py | 19 +- tests/test_data_io_hdf5.py | 5 +- tests/test_data_module_exports.py | 10 ++ 14 files changed, 207 insertions(+), 278 deletions(-) create mode 100644 dfode_kit/data/augment.py rename dfode_kit/{data_operations/label_data.py => data/label.py} (57%) delete mode 100644 dfode_kit/data_operations/__init__.py delete mode 100644 dfode_kit/data_operations/augment_data.py delete mode 100644 dfode_kit/data_operations/contracts.py delete mode 100644 dfode_kit/data_operations/h5_kit.py create mode 100644 tests/test_data_module_exports.py diff --git a/dfode_kit/cli/commands/augment.py b/dfode_kit/cli/commands/augment.py index f2c7561b..a036abad 100644 --- a/dfode_kit/cli/commands/augment.py +++ b/dfode_kit/cli/commands/augment.py @@ -48,8 +48,7 @@ def add_command_parser(subparsers): def handle_command(args): import numpy as np - from dfode_kit.data_operations.augment_data import random_perturb - from dfode_kit.data.io_hdf5 import get_TPY_from_h5 + from dfode_kit.data import get_TPY_from_h5, random_perturb print('Handling augment command') print(f'Loading data from h5 file: {args.h5_file}') diff --git a/dfode_kit/cli/commands/label.py b/dfode_kit/cli/commands/label.py index 10102c05..d4b226fb 100644 --- a/dfode_kit/cli/commands/label.py +++ b/dfode_kit/cli/commands/label.py @@ -34,7 +34,7 @@ def add_command_parser(subparsers): def handle_command(args): import numpy as np - from dfode_kit.data_operations import label_npy as label_main + from dfode_kit.data import label_npy as label_main try: labeled_data = label_main( diff --git a/dfode_kit/data/__init__.py b/dfode_kit/data/__init__.py index 1b5e151f..e81f5c52 100644 --- a/dfode_kit/data/__init__.py +++ b/dfode_kit/data/__init__.py @@ -16,6 +16,8 @@ "nn_integrate", "integrate_h5", "calculate_error", + "random_perturb", + "label_npy", ] _ATTRIBUTE_MODULES = { @@ -33,6 +35,8 @@ "nn_integrate": ("dfode_kit.data.integration", "nn_integrate"), "integrate_h5": ("dfode_kit.data.integration", "integrate_h5"), "calculate_error": ("dfode_kit.data.integration", "calculate_error"), + "random_perturb": ("dfode_kit.data.augment", "random_perturb"), + "label_npy": ("dfode_kit.data.label", "label_npy"), } diff --git a/dfode_kit/data/augment.py b/dfode_kit/data/augment.py new file mode 100644 index 00000000..f4f4e85a --- /dev/null +++ b/dfode_kit/data/augment.py @@ -0,0 +1,167 @@ +import time + +import cantera as ct +import numpy as np + + +def single_step(npstate, chem, time_step=1e-6): + gas = ct.Solution(chem) + T_old, P_old, Y_old = npstate[0], npstate[1], npstate[2:] + gas.TPY = T_old, P_old, Y_old + res_1st = [T_old, P_old] + list(gas.Y) + reactor = ct.IdealGasConstPressureReactor(gas, name='R1') + sim = ct.ReactorNet([reactor]) + + sim.advance(time_step) + new_TPY = [gas.T, gas.P] + list(gas.Y) + res_1st += new_TPY + + return res_1st + + +def random_perturb( + array: np.ndarray, + mech_path: str, + dataset: int, + heat_limit: bool, + element_limit: bool, + eq_ratio: float = 1, + frozenTem: float = 310, + alpha: float = 0.1, + gamma: float = 0.1, + cq: float = 10, + inert_idx: int = -1, + time_step: float = 1e-6, +) -> np.ndarray: + array = array[array[:, 0] > frozenTem] + + gas = ct.Solution(mech_path) + n_species = gas.n_species + maxT = np.max(array[:, 0]) + minT = np.min(array[:, 0]) + maxP = np.max(array[:, 1]) + minP = np.min(array[:, 1]) + maxN2 = np.max(array[:, -1]) + minN2 = np.min(array[:, -1]) + + H_O_ratio_base = 2 * eq_ratio + + num = 0 + new_array = [] + while num < dataset: + if heat_limit: + from dfode_kit.training.formation import formation_calculate + + qdot_ = np.zeros_like(array[:, 0]) + formation = formation_calculate(mech_path) + label_array = label(array, mech_path) + for i in range(label_array.shape[0]): + qdot_[i] = ( + -( + formation + * ( + label_array[i, 4 + n_species : 4 + 2 * n_species] + - label_array[i, 2 : 2 + n_species] + ) + / time_step + ).sum() + ) + + for j in range(array.shape[0]): + test_tmp = np.copy(array[j]) + k = 0 + while True: + k += 1 + + test_r = np.copy(array[j]) + + test_tmp[0] = test_r[0] + (maxT - minT) * (2 * np.random.rand() - 1.0) * alpha + test_tmp[1] = test_r[1] + (maxP - minP) * (2 * np.random.rand() - 1.0) * alpha * 20 + test_tmp[-1] = test_r[-1] + (maxN2 - minN2) * (2 * np.random.rand() - 1) * alpha + for i in range(2, array.shape[1] - 1): + test_tmp[i] = np.abs(test_r[i]) ** (1 + (2 * np.random.rand() - 1) * alpha) + test_tmp[2:-1] = test_tmp[2:-1] / np.sum(test_tmp[2:-1]) * (1 - test_tmp[-1]) + + if heat_limit: + label_test_tmp = np.array(single_step(test_tmp, mech_path)) + qdot_new_ = ( + -( + formation + * ( + label_test_tmp[4 + n_species : 4 + 2 * n_species] + - label_test_tmp[2 : 2 + n_species] + ) + / time_step + ).sum() + ) + + if element_limit: + gas.TPY = test_tmp[0], test_tmp[1], test_tmp[2:] + H_mole_fraction = gas.elemental_mole_fraction("H") + O_mole_fraction = gas.elemental_mole_fraction("O") + H_O_ratio = H_mole_fraction / O_mole_fraction + + if heat_limit and element_limit: + condition = ( + (minT * (1 - gamma)) <= test_tmp[0] <= (maxT * (1 + gamma)) + and (H_O_ratio_base * (1 - gamma)) <= H_O_ratio <= (H_O_ratio_base * (1 + gamma)) + and (qdot_new_ > 1 / cq * qdot_[j] and qdot_new_ < cq * qdot_[j]) + ) + elif heat_limit and not element_limit: + condition = ( + (minT * (1 - gamma)) <= test_tmp[0] <= (maxT * (1 + gamma)) + and (qdot_new_ > 1 / cq * qdot_[j] and qdot_new_ < cq * qdot_[j]) + ) + elif not heat_limit and element_limit: + condition = ( + (minT * (1 - gamma)) <= test_tmp[0] <= (maxT * (1 + gamma)) + and (H_O_ratio_base * (1 - gamma)) <= H_O_ratio <= (H_O_ratio_base * (1 + gamma)) + ) + else: + condition = (minT * (1 - gamma)) <= test_tmp[0] <= (maxT * (1 + gamma)) + + if condition or k > 20: + break + + if k <= 20: + new_array.append(test_tmp) + + num = len(new_array) + print(num) + + new_array = np.array(new_array) + new_array = new_array[np.random.choice(new_array.shape[0], size=dataset)] + unique_array = np.unique(new_array, axis=0) + print(unique_array.shape) + return unique_array + + +def label( + array: np.ndarray, + mech_path: str, + time_step: float = 1e-06, +) -> np.ndarray: + from dfode_kit.data.integration import advance_reactor + + gas = ct.Solution(mech_path) + n_species = gas.n_species + + labeled_data = np.empty((array.shape[0], 2 * n_species + 4)) + + reactor = ct.Reactor(gas, name='Reactor1', energy='off') + reactor_net = ct.ReactorNet([reactor]) + reactor_net.rtol, reactor_net.atol = 1e-6, 1e-10 + + start_time = time.time() + + for i, state in enumerate(array): + gas = advance_reactor(gas, state, reactor, reactor_net, time_step) + labeled_data[i, : 2 + n_species] = state[: 2 + n_species] + labeled_data[i, 2 + n_species :] = np.array([gas.T, gas.P] + list(gas.Y)) + + end_time = time.time() + total_time = end_time - start_time + + print(f"Total time used: {total_time:.2f} seconds") + + return labeled_data diff --git a/dfode_kit/data/integration.py b/dfode_kit/data/integration.py index 5225bf6e..aedf323f 100644 --- a/dfode_kit/data/integration.py +++ b/dfode_kit/data/integration.py @@ -1,5 +1,4 @@ import h5py -import torch import numpy as np import cantera as ct @@ -28,8 +27,9 @@ def advance_reactor(gas, state, reactor, reactor_net, time_step): return gas -@torch.no_grad() def load_model(model_path, device, model_class, model_layers): + import torch + state_dict = torch.load(model_path, map_location='cpu') model = model_class(model_layers) @@ -41,8 +41,9 @@ def load_model(model_path, device, model_class, model_layers): return model -@torch.no_grad() def predict_Y(model, model_path, d_arr, mech, device): + import torch + gas = ct.Solution(mech) n_species = gas.n_species expected_dims = 2 + n_species @@ -79,7 +80,6 @@ def predict_Y(model, model_path, d_arr, mech, device): return next_Y -@torch.no_grad() def nn_integrate(orig_arr, model_path, device, model_class, model_layers, time_step, mech, frozen_temperature=510): model = load_model(model_path, device, model_class, model_layers) diff --git a/dfode_kit/data_operations/label_data.py b/dfode_kit/data/label.py similarity index 57% rename from dfode_kit/data_operations/label_data.py rename to dfode_kit/data/label.py index d0b11096..0cba89dd 100644 --- a/dfode_kit/data_operations/label_data.py +++ b/dfode_kit/data/label.py @@ -1,45 +1,39 @@ import time -import numpy as np + import cantera as ct +import numpy as np -from .h5_kit import advance_reactor def label_npy( - mech_path, + mech_path, time_step, source_path, ): - # Load the chemical mechanism + from dfode_kit.data.integration import advance_reactor + gas = ct.Solution(mech_path) n_species = gas.n_species - # Load the dataset containing initial states for the reactor test_data = np.load(source_path) print(f"Loaded dataset from: {source_path}") print(f"{test_data.shape=}") - # Prepare an array to store labeled data labeled_data = np.empty((test_data.shape[0], 2 * n_species + 4)) - # Initialize Cantera reactor reactor = ct.Reactor(gas, name='Reactor1', energy='off') reactor_net = ct.ReactorNet([reactor]) reactor_net.rtol, reactor_net.atol = 1e-6, 1e-10 - # Start timing the simulation start_time = time.time() - # Process each state in the dataset for i, state in enumerate(test_data): gas = advance_reactor(gas, state, reactor, reactor_net, time_step) - labeled_data[i, :2 + n_species] = state[:2 + n_species] - labeled_data[i, 2 + n_species:] = np.array([gas.T, gas.P] + list(gas.Y)) + labeled_data[i, : 2 + n_species] = state[: 2 + n_species] + labeled_data[i, 2 + n_species :] = np.array([gas.T, gas.P] + list(gas.Y)) - # End timing of the simulation end_time = time.time() total_time = end_time - start_time - # Print the total time used and the path of the saved data print(f"Total time used: {total_time:.2f} seconds") - - return labeled_data \ No newline at end of file + + return labeled_data diff --git a/dfode_kit/data_operations/__init__.py b/dfode_kit/data_operations/__init__.py deleted file mode 100644 index 12343642..00000000 --- a/dfode_kit/data_operations/__init__.py +++ /dev/null @@ -1,49 +0,0 @@ -from importlib import import_module - - -__all__ = [ - "touch_h5", - "get_TPY_from_h5", - "integrate_h5", - "load_model", - "nn_integrate", - "predict_Y", - "calculate_error", - "random_perturb", - "label_npy", - "SCALAR_FIELDS_GROUP", - "MECHANISM_ATTR", - "read_scalar_field_datasets", - "stack_scalar_field_datasets", - "require_h5_attr", - "require_h5_group", -] - -_ATTRIBUTE_MODULES = { - "touch_h5": ("dfode_kit.data.io_hdf5", "touch_h5"), - "get_TPY_from_h5": ("dfode_kit.data.io_hdf5", "get_TPY_from_h5"), - "integrate_h5": ("dfode_kit.data.integration", "integrate_h5"), - "load_model": ("dfode_kit.data.integration", "load_model"), - "nn_integrate": ("dfode_kit.data.integration", "nn_integrate"), - "predict_Y": ("dfode_kit.data.integration", "predict_Y"), - "calculate_error": ("dfode_kit.data.integration", "calculate_error"), - "random_perturb": ("dfode_kit.data_operations.augment_data", "random_perturb"), - "label_npy": ("dfode_kit.data_operations.label_data", "label_npy"), - "SCALAR_FIELDS_GROUP": ("dfode_kit.data.contracts", "SCALAR_FIELDS_GROUP"), - "MECHANISM_ATTR": ("dfode_kit.data.contracts", "MECHANISM_ATTR"), - "read_scalar_field_datasets": ("dfode_kit.data.contracts", "read_scalar_field_datasets"), - "stack_scalar_field_datasets": ("dfode_kit.data.contracts", "stack_scalar_field_datasets"), - "require_h5_attr": ("dfode_kit.data.contracts", "require_h5_attr"), - "require_h5_group": ("dfode_kit.data.contracts", "require_h5_group"), -} - - -def __getattr__(name): - if name not in _ATTRIBUTE_MODULES: - raise AttributeError(f"module 'dfode_kit.data_operations' has no attribute '{name}'") - - module_name, attribute_name = _ATTRIBUTE_MODULES[name] - module = import_module(module_name) - value = getattr(module, attribute_name) - globals()[name] = value - return value diff --git a/dfode_kit/data_operations/augment_data.py b/dfode_kit/data_operations/augment_data.py deleted file mode 100644 index 74296406..00000000 --- a/dfode_kit/data_operations/augment_data.py +++ /dev/null @@ -1,149 +0,0 @@ -import numpy as np -import cantera as ct -import time -from dfode_kit.data.integration import advance_reactor -from dfode_kit.training.formation import formation_calculate - -def single_step(npstate, chem, time_step=1e-6): - gas = ct.Solution(chem) - T_old, P_old, Y_old = npstate[0], npstate[1], npstate[2:] - gas.TPY = T_old, P_old, Y_old - res_1st = [T_old, P_old] + list(gas.Y) - r = ct.IdealGasConstPressureReactor(gas, name='R1') - sim = ct.ReactorNet([r]) - - - sim.advance(time_step) - new_TPY = [gas.T, gas.P] + list(gas.Y) - res_1st += new_TPY - - return res_1st - -def random_perturb( - array: np.ndarray, - mech_path: str, - dataset: int, - heat_limit: bool, - element_limit: bool, - eq_ratio: float = 1, - frozenTem: float = 310, - alpha: float = 0.1, - gamma: float = 0.1, - cq: float = 10, - inert_idx: int = -1, - time_step: float = 1e-6, -) -> np.ndarray: - - array = array[array[:, 0] > frozenTem] - - gas = ct.Solution(mech_path) - n_species = gas.n_species - maxT = np.max(array[:,0]) - minT = np.min(array[:,0]) - maxP = np.max(array[:,1]) - minP = np.min(array[:,1]) - maxN2 = np.max(array[:,-1]) - minN2 = np.min(array[:,-1]) - - H_O_ratio_base = 2 * eq_ratio - - num = 0 - new_array = [] - while num < dataset: - if heat_limit: - qdot_ = np.zeros_like(array[:, 0]) - formation = formation_calculate(mech_path) - label_array = label(array, mech_path) - for i in range(label_array.shape[0]): - qdot_[i] = (-(formation*(label_array[i, 4+n_species:4+2*n_species]-label_array[i, 2:2+n_species])/time_step).sum()) - - for j in range(array.shape[0]): - test_tmp = np.copy(array[j]) - k = 0 - while True: - k += 1 - - test_r = np.copy(array[j]) - - test_tmp[0] = test_r[0] + (maxT - minT)*(2*np.random.rand() - 1.0)*alpha - test_tmp[1] = test_r[1] + (maxP - minP)*(2*np.random.rand() - 1.0)*alpha*20 - test_tmp[-1] = test_r[-1] + (maxN2 - minN2)*(2*np.random.rand() - 1)*alpha - for i in range(2, array.shape[1] -1): - test_tmp[i] = np.abs(test_r[i])**(1 + (2*np.random.rand() -1)*alpha) - test_tmp[2: -1] = test_tmp[2:-1]/np.sum(test_tmp[2:-1])*(1 - test_tmp[-1]) - - - if heat_limit: - label_test_tmp = single_step(test_tmp, mech_path) - label_test_tmp = np.array(label_test_tmp) - # print(formation.shape) - # print(label_test_tmp.shape) - qdot_new_ = (-(formation*(label_test_tmp[4+n_species:4+2*n_species]-label_test_tmp[2:2+n_species])/time_step).sum()) - - if element_limit: - gas.TPY = test_tmp[0], test_tmp[1], test_tmp[2:] - H_mole_fraction = gas.elemental_mole_fraction("H") - O_mole_fraction = gas.elemental_mole_fraction("O") - H_O_ratio = H_mole_fraction / O_mole_fraction - - - if heat_limit and element_limit: - condition = (minT * (1 - gamma)) <= test_tmp[0] <= (maxT * (1 + gamma)) and (H_O_ratio_base * (1 - gamma)) <= H_O_ratio <= (H_O_ratio_base * (1 + gamma)) and (qdot_new_ > 1/cq*qdot_[j] and qdot_new_ < cq*qdot_[j]) - elif heat_limit and not element_limit: - condition = (minT * (1 - gamma)) <= test_tmp[0] <= (maxT * (1 + gamma)) and (qdot_new_ > 1/cq*qdot_[j] and qdot_new_ < cq*qdot_[j]) - elif not heat_limit and element_limit: - condition = (minT * (1 - gamma)) <= test_tmp[0] <= (maxT * (1 + gamma)) and (H_O_ratio_base * (1 - gamma)) <= H_O_ratio <= (H_O_ratio_base * (1 + gamma)) - else: - condition = (minT * (1 - gamma)) <= test_tmp[0] <= (maxT * (1 + gamma)) - - # print('k', k) - if condition or k > 20: - break - - - if k <= 20: - # print('j', j) - new_array.append(test_tmp) - - num = len(new_array) - print(num) - - new_array = np.array(new_array) - new_array = new_array[np.random.choice(new_array.shape[0], size=dataset)] - unique_array = np.unique(new_array, axis=0) - print(unique_array.shape) - return unique_array - -def label( - array: np.ndarray, - mech_path: str, - time_step: float = 1e-06, -) -> np.ndarray: - - gas = ct.Solution(mech_path) - n_species = gas.n_species - - labeled_data = np.empty((array.shape[0], 2 * n_species + 4)) - - # Initialize Cantera reactor - reactor = ct.Reactor(gas, name='Reactor1', energy='off') - reactor_net = ct.ReactorNet([reactor]) - reactor_net.rtol, reactor_net.atol = 1e-6, 1e-10 - - # Start timing the simulation - start_time = time.time() - - # Process each state in the dataset - for i, state in enumerate(array): - gas = advance_reactor(gas, state, reactor, reactor_net, time_step) - labeled_data[i, :2 + n_species] = state[:2 + n_species] - labeled_data[i, 2 + n_species:] = np.array([gas.T, gas.P] + list(gas.Y)) - - # End timing of the simulation - end_time = time.time() - total_time = end_time - start_time - - # Print the total time used and the path of the saved data - print(f"Total time used: {total_time:.2f} seconds") - - return labeled_data diff --git a/dfode_kit/data_operations/contracts.py b/dfode_kit/data_operations/contracts.py deleted file mode 100644 index 5a860e3e..00000000 --- a/dfode_kit/data_operations/contracts.py +++ /dev/null @@ -1 +0,0 @@ -from dfode_kit.data.contracts import * # noqa: F401,F403 diff --git a/dfode_kit/data_operations/h5_kit.py b/dfode_kit/data_operations/h5_kit.py deleted file mode 100644 index dca8954e..00000000 --- a/dfode_kit/data_operations/h5_kit.py +++ /dev/null @@ -1,37 +0,0 @@ -"""Compatibility shim for the canonical data I/O and integration modules.""" - -from importlib import import_module - - -__all__ = [ - "touch_h5", - "get_TPY_from_h5", - "advance_reactor", - "load_model", - "predict_Y", - "nn_integrate", - "integrate_h5", - "calculate_error", -] - -_ATTRIBUTE_MODULES = { - "touch_h5": ("dfode_kit.data.io_hdf5", "touch_h5"), - "get_TPY_from_h5": ("dfode_kit.data.io_hdf5", "get_TPY_from_h5"), - "advance_reactor": ("dfode_kit.data.integration", "advance_reactor"), - "load_model": ("dfode_kit.data.integration", "load_model"), - "predict_Y": ("dfode_kit.data.integration", "predict_Y"), - "nn_integrate": ("dfode_kit.data.integration", "nn_integrate"), - "integrate_h5": ("dfode_kit.data.integration", "integrate_h5"), - "calculate_error": ("dfode_kit.data.integration", "calculate_error"), -} - - -def __getattr__(name): - if name not in _ATTRIBUTE_MODULES: - raise AttributeError(f"module 'dfode_kit.data_operations.h5_kit' has no attribute '{name}'") - - module_name, attribute_name = _ATTRIBUTE_MODULES[name] - module = import_module(module_name) - value = getattr(module, attribute_name) - globals()[name] = value - return value diff --git a/tests/test_data_contracts.py b/tests/test_data_contracts.py index 28d3e4df..c777545f 100644 --- a/tests/test_data_contracts.py +++ b/tests/test_data_contracts.py @@ -27,12 +27,11 @@ def test_importing_data_contracts_does_not_require_cantera_or_torch(): assert contracts_module.SCALAR_FIELDS_GROUP == "scalar_fields" -def test_legacy_data_operations_contracts_path_re_exports_new_contracts_module(): - legacy_contracts = importlib.import_module("dfode_kit.data_operations.contracts") +def test_data_contracts_module_exports_expected_helpers(): new_contracts = importlib.import_module("dfode_kit.data.contracts") - assert legacy_contracts.SCALAR_FIELDS_GROUP == new_contracts.SCALAR_FIELDS_GROUP - assert legacy_contracts.stack_scalar_field_datasets is new_contracts.stack_scalar_field_datasets + assert new_contracts.SCALAR_FIELDS_GROUP == "scalar_fields" + assert callable(new_contracts.stack_scalar_field_datasets) def test_stack_scalar_field_datasets_uses_deterministic_numeric_order(tmp_path): diff --git a/tests/test_data_integration_shims.py b/tests/test_data_integration_shims.py index 52d5c17d..d4d7c6e3 100644 --- a/tests/test_data_integration_shims.py +++ b/tests/test_data_integration_shims.py @@ -1,17 +1,12 @@ import importlib -def test_h5_kit_shim_reexports_io_helpers(): - legacy_h5_kit = importlib.import_module("dfode_kit.data_operations.h5_kit") +def test_data_integration_module_exports_expected_helpers(): + integration = importlib.import_module("dfode_kit.data.integration") canonical_io = importlib.import_module("dfode_kit.data.io_hdf5") - assert legacy_h5_kit.touch_h5 is canonical_io.touch_h5 - assert legacy_h5_kit.get_TPY_from_h5 is canonical_io.get_TPY_from_h5 - - -def test_data_operations_package_reexports_canonical_io_helpers(): - legacy_data_ops = importlib.import_module("dfode_kit.data_operations") - canonical_io = importlib.import_module("dfode_kit.data.io_hdf5") - - assert legacy_data_ops.touch_h5 is canonical_io.touch_h5 - assert legacy_data_ops.get_TPY_from_h5 is canonical_io.get_TPY_from_h5 + assert callable(integration.advance_reactor) + assert callable(integration.integrate_h5) + assert callable(integration.calculate_error) + assert callable(canonical_io.touch_h5) + assert callable(canonical_io.get_TPY_from_h5) diff --git a/tests/test_data_io_hdf5.py b/tests/test_data_io_hdf5.py index dfdd8339..a55f1bd4 100644 --- a/tests/test_data_io_hdf5.py +++ b/tests/test_data_io_hdf5.py @@ -41,12 +41,9 @@ def test_get_tpy_from_h5_uses_contract_ordering_and_stacks_datasets(tmp_path, ca assert f"Number of datasets in {SCALAR_FIELDS_GROUP} group: 3" in capsys.readouterr().out -def test_legacy_package_exports_point_to_extracted_io_helpers(): - legacy_data_operations = importlib.import_module("dfode_kit.data_operations") +def test_root_package_exports_point_to_canonical_io_helpers(): root_package = importlib.import_module("dfode_kit") io_module = importlib.import_module("dfode_kit.data.io_hdf5") - assert legacy_data_operations.touch_h5 is io_module.touch_h5 - assert legacy_data_operations.get_TPY_from_h5 is io_module.get_TPY_from_h5 assert root_package.touch_h5 is io_module.touch_h5 assert root_package.get_TPY_from_h5 is io_module.get_TPY_from_h5 diff --git a/tests/test_data_module_exports.py b/tests/test_data_module_exports.py new file mode 100644 index 00000000..6945ab56 --- /dev/null +++ b/tests/test_data_module_exports.py @@ -0,0 +1,10 @@ +import importlib + + +def test_data_package_exports_augment_and_label_helpers(): + data_pkg = importlib.import_module("dfode_kit.data") + augment_module = importlib.import_module("dfode_kit.data.augment") + label_module = importlib.import_module("dfode_kit.data.label") + + assert data_pkg.random_perturb is augment_module.random_perturb + assert data_pkg.label_npy is label_module.label_npy