diff --git a/dfode_kit/cli/commands/augment.py b/dfode_kit/cli/commands/augment.py index acbc733..e59eeaf 100644 --- a/dfode_kit/cli/commands/augment.py +++ b/dfode_kit/cli/commands/augment.py @@ -26,6 +26,11 @@ def add_command_parser(subparsers): help='Requested number of augmented rows.', ) augment_parser.add_argument('--seed', type=int, help='Random seed for reproducible augmentation.') + augment_parser.add_argument( + '--time', + action='append', + help='Select time snapshots by ordered snapshot index expression, e.g. 0, -1, 0:12, or ::10. Repeatable.', + ) augment_parser.add_argument('--from-config', type=str, help='Load an augment plan/config JSON.') augment_parser.add_argument('--write-config', type=str, help='Write the resolved augment plan/config to JSON.') augment_parser.add_argument('--preview', action='store_true', help='Preview the resolved plan without executing augmentation.') @@ -76,6 +81,12 @@ def _print_human_plan(plan: dict): print(f"save: {plan['save']}") print(f"target_size: {plan['target_size']}") print(f"seed: {plan['seed']}") + print(f"time_selectors: {plan['time_selectors']}") + print(f"resolved_snapshot_count: {plan['resolved_snapshot_count']}") + if plan['resolved_snapshot_names']: + print('resolved_snapshot_names:') + for name in plan['resolved_snapshot_names']: + print(f' - {name}') print('resolved:') for key in sorted(plan['resolved']): print(f" {key}: {plan['resolved'][key]}") diff --git a/dfode_kit/cli/commands/augment_helpers.py b/dfode_kit/cli/commands/augment_helpers.py index 1822b2f..4d368d5 100644 --- a/dfode_kit/cli/commands/augment_helpers.py +++ b/dfode_kit/cli/commands/augment_helpers.py @@ -7,6 +7,11 @@ from pathlib import Path from typing import Any +import h5py +import numpy as np + +from dfode_kit.data.contracts import SCALAR_FIELDS_GROUP, ordered_group_dataset_names, require_h5_group + DEFAULT_AUGMENT_PRESET = 'random-local-combustion-v1' @@ -56,6 +61,7 @@ def resolve_augment_plan(args) -> dict[str, Any]: preset_name = args.preset or plan.get('preset', DEFAULT_AUGMENT_PRESET) target_size = args.target_size if args.target_size is not None else plan.get('target_size') seed = args.seed if args.seed is not None else plan.get('seed') + time_selectors = args.time if args.time is not None else plan.get('time_selectors') else: _validate_required_args(args, ('source', 'mech', 'preset', 'target_size')) source = args.source @@ -64,6 +70,7 @@ def resolve_augment_plan(args) -> dict[str, Any]: preset_name = args.preset target_size = args.target_size seed = args.seed + time_selectors = args.time if args.apply and not save: raise ValueError('The --save path is required when using --apply.') @@ -77,6 +84,9 @@ def resolve_augment_plan(args) -> dict[str, Any]: if not mechanism_path.is_file(): raise ValueError(f'Mechanism file does not exist: {mechanism_path}') + ordered_names = _read_ordered_snapshot_names(source_path) + resolved_snapshot_names = _resolve_time_selectors(ordered_names, time_selectors) + plan = { 'schema_version': 1, 'command_type': 'augment', @@ -87,6 +97,9 @@ def resolve_augment_plan(args) -> dict[str, Any]: 'save': str(Path(save).resolve()) if save else None, 'target_size': int(target_size), 'seed': int(seed) if seed is not None else None, + 'time_selectors': list(time_selectors) if time_selectors else None, + 'resolved_snapshot_names': resolved_snapshot_names, + 'resolved_snapshot_count': len(resolved_snapshot_names), 'config_path': str(Path(args.from_config).resolve()) if args.from_config else None, 'notes': preset.notes, 'resolved': dict(preset.resolved), @@ -95,9 +108,7 @@ def resolve_augment_plan(args) -> dict[str, Any]: def apply_augment_plan(plan: dict[str, Any], quiet: bool = False) -> dict[str, Any]: - import numpy as np - - from dfode_kit.data import get_TPY_from_h5, random_perturb + from dfode_kit.data import random_perturb source_path = Path(plan['source']).resolve() output_path = Path(plan['save']).resolve() @@ -105,7 +116,7 @@ def apply_augment_plan(plan: dict[str, Any], quiet: bool = False) -> dict[str, A if quiet: with redirect_stdout(io.StringIO()): - data = get_TPY_from_h5(source_path) + data = _load_selected_tpy_from_h5(source_path, plan['resolved_snapshot_names']) augmented = random_perturb( data, plan['mechanism'], @@ -117,7 +128,9 @@ def apply_augment_plan(plan: dict[str, Any], quiet: bool = False) -> dict[str, A else: print('Handling augment command') print(f'Loading data from h5 file: {source_path}') - data = get_TPY_from_h5(source_path) + if plan['time_selectors']: + print(f"Selecting snapshots with --time: {plan['time_selectors']}") + data = _load_selected_tpy_from_h5(source_path, plan['resolved_snapshot_names']) print('Data shape:', data.shape) augmented = random_perturb( data, @@ -141,6 +154,8 @@ def apply_augment_plan(plan: dict[str, Any], quiet: bool = False) -> dict[str, A 'returned_count': int(augmented.shape[0]), 'feature_count': int(augmented.shape[1]) if augmented.ndim == 2 else None, 'seed': plan.get('seed'), + 'resolved_snapshot_count': int(plan['resolved_snapshot_count']), + 'resolved_snapshot_names': list(plan['resolved_snapshot_names']), } @@ -156,6 +171,72 @@ def load_plan_json(path: str | Path) -> dict[str, Any]: return json.loads(input_path.read_text(encoding='utf-8')) +def _read_ordered_snapshot_names(source_path: str | Path) -> list[str]: + with h5py.File(source_path, 'r') as hdf5_file: + scalar_group = require_h5_group(hdf5_file, SCALAR_FIELDS_GROUP) + return ordered_group_dataset_names(scalar_group) + + +def _load_selected_tpy_from_h5(source_path: str | Path, dataset_names: list[str]) -> np.ndarray: + with h5py.File(source_path, 'r') as hdf5_file: + scalar_group = require_h5_group(hdf5_file, SCALAR_FIELDS_GROUP) + arrays = [scalar_group[name][:] for name in dataset_names] + if not arrays: + raise ValueError(f"No datasets selected from '{SCALAR_FIELDS_GROUP}' in {source_path}") + return np.concatenate(arrays, axis=0) + + +def _resolve_time_selectors(ordered_names: list[str], selectors: list[str] | None) -> list[str]: + if not ordered_names: + raise ValueError('No scalar-field snapshots are available in the source HDF5.') + if not selectors: + return list(ordered_names) + + selected_indices: list[int] = [] + seen = set() + for selector in selectors: + indices = _indices_from_selector(selector, len(ordered_names)) + for index in indices: + if index not in seen: + seen.add(index) + selected_indices.append(index) + + if not selected_indices: + raise ValueError('The provided --time selectors resolved to zero snapshots.') + + return [ordered_names[index] for index in selected_indices] + + +def _indices_from_selector(selector: str, length: int) -> list[int]: + text = selector.strip() + if not text: + raise ValueError('Empty --time selector is not allowed.') + + if ':' in text: + parts = text.split(':') + if len(parts) > 3: + raise ValueError(f'Invalid --time slice selector: {selector}') + values = [] + for part in parts: + if part == '': + values.append(None) + else: + values.append(int(part)) + while len(values) < 3: + values.append(None) + start, stop, step = values + if step == 0: + raise ValueError(f'Invalid --time selector with zero step: {selector}') + return list(range(length))[slice(start, stop, step)] + + index = int(text) + if index < 0: + index += length + if index < 0 or index >= length: + raise ValueError(f'--time index out of range for {length} snapshots: {selector}') + return [index] + + def _validate_required_args(args, names: tuple[str, ...]): missing = [f'--{name.replace("_", "-")}' for name in names if getattr(args, name) is None] if missing: diff --git a/docs/augment.md b/docs/augment.md index 291f2da..c34a91b 100644 --- a/docs/augment.md +++ b/docs/augment.md @@ -49,6 +49,7 @@ dfode-kit augment [options] ### Optional but high-value - `--seed` +- `--time` (repeatable snapshot index/slice selector) ## Current preset @@ -66,6 +67,7 @@ dfode-kit augment \ --mech /path/to/gri30.yaml \ --preset random-local-combustion-v1 \ --target-size 20000 \ + --time 0:12 \ --preview --json ``` @@ -90,6 +92,7 @@ dfode-kit augment \ --save /path/to/aug.npy \ --preset random-local-combustion-v1 \ --target-size 20000 \ + --time ::10 \ --seed 1234 \ --apply ``` @@ -103,6 +106,22 @@ dfode-kit augment \ --apply ``` +## Time snapshot selection + +When `--time` is omitted, augmentation uses all snapshots in the sampled HDF5 source. + +When `--time` is provided, it selects snapshots from the ordered HDF5 snapshot list by index expression. + +Supported forms include: + +- single index: `--time 0` +- negative index: `--time -1` +- slice: `--time 0:12` +- stride: `--time ::10` +- repeated selectors: `--time 0:5 --time -1` + +Selection is applied to snapshots only; all rows from each selected snapshot are included. + ## Output behavior ### `--preview --json` @@ -122,6 +141,8 @@ In `--json` mode, the command reports a structured completion record including: - requested row count - returned row count - seed +- resolved snapshot count +- resolved snapshot names ## Action rule diff --git a/docs/cli.md b/docs/cli.md index 1f98ab3..2eb78ce 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -83,6 +83,7 @@ dfode-kit augment \ --save /path/to/augmented.npy \ --preset random-local-combustion-v1 \ --target-size 20000 \ + --time 0:12 \ --apply ``` diff --git a/docs/data-workflow.md b/docs/data-workflow.md index 2821255..0efac01 100644 --- a/docs/data-workflow.md +++ b/docs/data-workflow.md @@ -73,6 +73,7 @@ dfode-kit augment \ --save /path/to/data/ch4_phi1_aug.npy \ --preset random-local-combustion-v1 \ --target-size 20000 \ + --time 0:12 \ --apply ``` diff --git a/tests/test_augment_cli.py b/tests/test_augment_cli.py index 0387f38..7ef17e0 100644 --- a/tests/test_augment_cli.py +++ b/tests/test_augment_cli.py @@ -2,7 +2,11 @@ from pathlib import Path from types import SimpleNamespace +import h5py +import numpy as np + from dfode_kit.cli.commands import augment, augment_helpers +import dfode_kit.data as data_module class DummyArgs(SimpleNamespace): @@ -11,7 +15,11 @@ class DummyArgs(SimpleNamespace): def make_args(tmp_path, **overrides): source = tmp_path / 'sample.h5' - source.write_text('stub', encoding='utf-8') + with h5py.File(source, 'w') as h5: + scalar = h5.create_group('scalar_fields') + scalar.create_dataset('0.0', data=np.array([[1.0, 2.0], [3.0, 4.0]])) + scalar.create_dataset('1.0', data=np.array([[5.0, 6.0], [7.0, 8.0]])) + scalar.create_dataset('2.0', data=np.array([[9.0, 10.0], [11.0, 12.0]])) mech = tmp_path / 'mech.yaml' mech.write_text('stub', encoding='utf-8') data = { @@ -27,6 +35,7 @@ def make_args(tmp_path, **overrides): 'preview': True, 'apply': False, 'json': True, + 'time': None, } data.update(overrides) return DummyArgs(**data) @@ -42,6 +51,7 @@ def test_resolve_augment_plan_uses_minimal_contract(tmp_path): assert plan['target_size'] == 12 assert plan['seed'] == 123 assert plan['resolved'] == {'heat_limit': False, 'element_limit': True} + assert plan['resolved_snapshot_names'] == ['0.0', '1.0', '2.0'] def test_resolve_augment_plan_from_config_allows_save_override(tmp_path): @@ -61,6 +71,7 @@ def test_resolve_augment_plan_from_config_allows_save_override(tmp_path): from_config=str(config_path), preview=True, apply=False, + time=None, ) loaded = augment_helpers.resolve_augment_plan(from_config_args) @@ -70,6 +81,55 @@ def test_resolve_augment_plan_from_config_allows_save_override(tmp_path): assert loaded['seed'] == 123 +def test_resolve_augment_plan_time_selectors_support_index_and_slice(tmp_path): + args = make_args(tmp_path, time=['0', '1:']) + + plan = augment_helpers.resolve_augment_plan(args) + + assert plan['time_selectors'] == ['0', '1:'] + assert plan['resolved_snapshot_names'] == ['0.0', '1.0', '2.0'] + assert plan['resolved_snapshot_count'] == 3 + + +def test_resolve_augment_plan_time_selector_can_stride(tmp_path): + args = make_args(tmp_path, time=['::2']) + + plan = augment_helpers.resolve_augment_plan(args) + + assert plan['resolved_snapshot_names'] == ['0.0', '2.0'] + + +def test_resolve_augment_plan_time_selector_out_of_range_fails(tmp_path): + args = make_args(tmp_path, time=['10']) + + try: + augment_helpers.resolve_augment_plan(args) + except ValueError as exc: + assert 'out of range' in str(exc) + else: + raise AssertionError('expected ValueError') + + +def test_apply_augment_plan_uses_selected_snapshots_only(tmp_path, monkeypatch): + args = make_args(tmp_path, time=['1']) + plan = augment_helpers.resolve_augment_plan(args) + captured = {} + + def fake_random_perturb(data, mech_path, dataset, heat_limit, element_limit, seed=None): + captured['data'] = data.copy() + return data + + monkeypatch.setattr(data_module, 'random_perturb', fake_random_perturb) + + result = augment_helpers.apply_augment_plan(plan, quiet=True) + + assert result['resolved_snapshot_count'] == 1 + assert captured['data'].shape == (2, 2) + assert captured['data'][0, 0] == 5.0 + assert Path(plan['save']).exists() + + + def test_handle_command_json_preview_and_apply(tmp_path, monkeypatch, capsys): args = make_args(tmp_path, preview=True, apply=True, json=True) @@ -88,6 +148,7 @@ def test_handle_command_json_preview_and_apply(tmp_path, monkeypatch, capsys): payload = json.loads(capsys.readouterr().out) assert payload['command_type'] == 'augment' assert payload['plan']['target_size'] == 12 + assert payload['plan']['resolved_snapshot_count'] == 3 assert payload['apply']['returned_count'] == 9