Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions dfode_kit/cli/commands/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ def add_command_parser(subparsers):
help='Requested number of augmented rows.',
)
augment_parser.add_argument('--seed', type=int, help='Random seed for reproducible augmentation.')
augment_parser.add_argument(
'--time',
action='append',
help='Select time snapshots by ordered snapshot index expression, e.g. 0, -1, 0:12, or ::10. Repeatable.',
)
augment_parser.add_argument('--from-config', type=str, help='Load an augment plan/config JSON.')
augment_parser.add_argument('--write-config', type=str, help='Write the resolved augment plan/config to JSON.')
augment_parser.add_argument('--preview', action='store_true', help='Preview the resolved plan without executing augmentation.')
Expand Down Expand Up @@ -76,6 +81,12 @@ def _print_human_plan(plan: dict):
print(f"save: {plan['save']}")
print(f"target_size: {plan['target_size']}")
print(f"seed: {plan['seed']}")
print(f"time_selectors: {plan['time_selectors']}")
print(f"resolved_snapshot_count: {plan['resolved_snapshot_count']}")
if plan['resolved_snapshot_names']:
print('resolved_snapshot_names:')
for name in plan['resolved_snapshot_names']:
print(f' - {name}')
print('resolved:')
for key in sorted(plan['resolved']):
print(f" {key}: {plan['resolved'][key]}")
Expand Down
91 changes: 86 additions & 5 deletions dfode_kit/cli/commands/augment_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
from pathlib import Path
from typing import Any

import h5py
import numpy as np

from dfode_kit.data.contracts import SCALAR_FIELDS_GROUP, ordered_group_dataset_names, require_h5_group


DEFAULT_AUGMENT_PRESET = 'random-local-combustion-v1'

Expand Down Expand Up @@ -56,6 +61,7 @@ def resolve_augment_plan(args) -> dict[str, Any]:
preset_name = args.preset or plan.get('preset', DEFAULT_AUGMENT_PRESET)
target_size = args.target_size if args.target_size is not None else plan.get('target_size')
seed = args.seed if args.seed is not None else plan.get('seed')
time_selectors = args.time if args.time is not None else plan.get('time_selectors')
else:
_validate_required_args(args, ('source', 'mech', 'preset', 'target_size'))
source = args.source
Expand All @@ -64,6 +70,7 @@ def resolve_augment_plan(args) -> dict[str, Any]:
preset_name = args.preset
target_size = args.target_size
seed = args.seed
time_selectors = args.time

if args.apply and not save:
raise ValueError('The --save path is required when using --apply.')
Expand All @@ -77,6 +84,9 @@ def resolve_augment_plan(args) -> dict[str, Any]:
if not mechanism_path.is_file():
raise ValueError(f'Mechanism file does not exist: {mechanism_path}')

ordered_names = _read_ordered_snapshot_names(source_path)
resolved_snapshot_names = _resolve_time_selectors(ordered_names, time_selectors)

plan = {
'schema_version': 1,
'command_type': 'augment',
Expand All @@ -87,6 +97,9 @@ def resolve_augment_plan(args) -> dict[str, Any]:
'save': str(Path(save).resolve()) if save else None,
'target_size': int(target_size),
'seed': int(seed) if seed is not None else None,
'time_selectors': list(time_selectors) if time_selectors else None,
'resolved_snapshot_names': resolved_snapshot_names,
'resolved_snapshot_count': len(resolved_snapshot_names),
'config_path': str(Path(args.from_config).resolve()) if args.from_config else None,
'notes': preset.notes,
'resolved': dict(preset.resolved),
Expand All @@ -95,17 +108,15 @@ def resolve_augment_plan(args) -> dict[str, Any]:


def apply_augment_plan(plan: dict[str, Any], quiet: bool = False) -> dict[str, Any]:
import numpy as np

from dfode_kit.data import get_TPY_from_h5, random_perturb
from dfode_kit.data import random_perturb

source_path = Path(plan['source']).resolve()
output_path = Path(plan['save']).resolve()
output_path.parent.mkdir(parents=True, exist_ok=True)

if quiet:
with redirect_stdout(io.StringIO()):
data = get_TPY_from_h5(source_path)
data = _load_selected_tpy_from_h5(source_path, plan['resolved_snapshot_names'])
augmented = random_perturb(
data,
plan['mechanism'],
Expand All @@ -117,7 +128,9 @@ def apply_augment_plan(plan: dict[str, Any], quiet: bool = False) -> dict[str, A
else:
print('Handling augment command')
print(f'Loading data from h5 file: {source_path}')
data = get_TPY_from_h5(source_path)
if plan['time_selectors']:
print(f"Selecting snapshots with --time: {plan['time_selectors']}")
data = _load_selected_tpy_from_h5(source_path, plan['resolved_snapshot_names'])
print('Data shape:', data.shape)
augmented = random_perturb(
data,
Expand All @@ -141,6 +154,8 @@ def apply_augment_plan(plan: dict[str, Any], quiet: bool = False) -> dict[str, A
'returned_count': int(augmented.shape[0]),
'feature_count': int(augmented.shape[1]) if augmented.ndim == 2 else None,
'seed': plan.get('seed'),
'resolved_snapshot_count': int(plan['resolved_snapshot_count']),
'resolved_snapshot_names': list(plan['resolved_snapshot_names']),
}


Expand All @@ -156,6 +171,72 @@ def load_plan_json(path: str | Path) -> dict[str, Any]:
return json.loads(input_path.read_text(encoding='utf-8'))


def _read_ordered_snapshot_names(source_path: str | Path) -> list[str]:
with h5py.File(source_path, 'r') as hdf5_file:
scalar_group = require_h5_group(hdf5_file, SCALAR_FIELDS_GROUP)
return ordered_group_dataset_names(scalar_group)


def _load_selected_tpy_from_h5(source_path: str | Path, dataset_names: list[str]) -> np.ndarray:
with h5py.File(source_path, 'r') as hdf5_file:
scalar_group = require_h5_group(hdf5_file, SCALAR_FIELDS_GROUP)
arrays = [scalar_group[name][:] for name in dataset_names]
if not arrays:
raise ValueError(f"No datasets selected from '{SCALAR_FIELDS_GROUP}' in {source_path}")
return np.concatenate(arrays, axis=0)


def _resolve_time_selectors(ordered_names: list[str], selectors: list[str] | None) -> list[str]:
if not ordered_names:
raise ValueError('No scalar-field snapshots are available in the source HDF5.')
if not selectors:
return list(ordered_names)

selected_indices: list[int] = []
seen = set()
for selector in selectors:
indices = _indices_from_selector(selector, len(ordered_names))
for index in indices:
if index not in seen:
seen.add(index)
selected_indices.append(index)

if not selected_indices:
raise ValueError('The provided --time selectors resolved to zero snapshots.')

return [ordered_names[index] for index in selected_indices]


def _indices_from_selector(selector: str, length: int) -> list[int]:
text = selector.strip()
if not text:
raise ValueError('Empty --time selector is not allowed.')

if ':' in text:
parts = text.split(':')
if len(parts) > 3:
raise ValueError(f'Invalid --time slice selector: {selector}')
values = []
for part in parts:
if part == '':
values.append(None)
else:
values.append(int(part))
while len(values) < 3:
values.append(None)
start, stop, step = values
if step == 0:
raise ValueError(f'Invalid --time selector with zero step: {selector}')
return list(range(length))[slice(start, stop, step)]

index = int(text)
if index < 0:
index += length
if index < 0 or index >= length:
raise ValueError(f'--time index out of range for {length} snapshots: {selector}')
return [index]


def _validate_required_args(args, names: tuple[str, ...]):
missing = [f'--{name.replace("_", "-")}' for name in names if getattr(args, name) is None]
if missing:
Expand Down
21 changes: 21 additions & 0 deletions docs/augment.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ dfode-kit augment [options]
### Optional but high-value

- `--seed`
- `--time` (repeatable snapshot index/slice selector)

## Current preset

Expand All @@ -66,6 +67,7 @@ dfode-kit augment \
--mech /path/to/gri30.yaml \
--preset random-local-combustion-v1 \
--target-size 20000 \
--time 0:12 \
--preview --json
```

Expand All @@ -90,6 +92,7 @@ dfode-kit augment \
--save /path/to/aug.npy \
--preset random-local-combustion-v1 \
--target-size 20000 \
--time ::10 \
--seed 1234 \
--apply
```
Expand All @@ -103,6 +106,22 @@ dfode-kit augment \
--apply
```

## Time snapshot selection

When `--time` is omitted, augmentation uses all snapshots in the sampled HDF5 source.

When `--time` is provided, it selects snapshots from the ordered HDF5 snapshot list by index expression.

Supported forms include:

- single index: `--time 0`
- negative index: `--time -1`
- slice: `--time 0:12`
- stride: `--time ::10`
- repeated selectors: `--time 0:5 --time -1`

Selection is applied to snapshots only; all rows from each selected snapshot are included.

## Output behavior

### `--preview --json`
Expand All @@ -122,6 +141,8 @@ In `--json` mode, the command reports a structured completion record including:
- requested row count
- returned row count
- seed
- resolved snapshot count
- resolved snapshot names

## Action rule

Expand Down
1 change: 1 addition & 0 deletions docs/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ dfode-kit augment \
--save /path/to/augmented.npy \
--preset random-local-combustion-v1 \
--target-size 20000 \
--time 0:12 \
--apply
```

Expand Down
1 change: 1 addition & 0 deletions docs/data-workflow.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ dfode-kit augment \
--save /path/to/data/ch4_phi1_aug.npy \
--preset random-local-combustion-v1 \
--target-size 20000 \
--time 0:12 \
--apply
```

Expand Down
63 changes: 62 additions & 1 deletion tests/test_augment_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
from pathlib import Path
from types import SimpleNamespace

import h5py
import numpy as np

from dfode_kit.cli.commands import augment, augment_helpers
import dfode_kit.data as data_module


class DummyArgs(SimpleNamespace):
Expand All @@ -11,7 +15,11 @@ class DummyArgs(SimpleNamespace):

def make_args(tmp_path, **overrides):
source = tmp_path / 'sample.h5'
source.write_text('stub', encoding='utf-8')
with h5py.File(source, 'w') as h5:
scalar = h5.create_group('scalar_fields')
scalar.create_dataset('0.0', data=np.array([[1.0, 2.0], [3.0, 4.0]]))
scalar.create_dataset('1.0', data=np.array([[5.0, 6.0], [7.0, 8.0]]))
scalar.create_dataset('2.0', data=np.array([[9.0, 10.0], [11.0, 12.0]]))
mech = tmp_path / 'mech.yaml'
mech.write_text('stub', encoding='utf-8')
data = {
Expand All @@ -27,6 +35,7 @@ def make_args(tmp_path, **overrides):
'preview': True,
'apply': False,
'json': True,
'time': None,
}
data.update(overrides)
return DummyArgs(**data)
Expand All @@ -42,6 +51,7 @@ def test_resolve_augment_plan_uses_minimal_contract(tmp_path):
assert plan['target_size'] == 12
assert plan['seed'] == 123
assert plan['resolved'] == {'heat_limit': False, 'element_limit': True}
assert plan['resolved_snapshot_names'] == ['0.0', '1.0', '2.0']


def test_resolve_augment_plan_from_config_allows_save_override(tmp_path):
Expand All @@ -61,6 +71,7 @@ def test_resolve_augment_plan_from_config_allows_save_override(tmp_path):
from_config=str(config_path),
preview=True,
apply=False,
time=None,
)

loaded = augment_helpers.resolve_augment_plan(from_config_args)
Expand All @@ -70,6 +81,55 @@ def test_resolve_augment_plan_from_config_allows_save_override(tmp_path):
assert loaded['seed'] == 123


def test_resolve_augment_plan_time_selectors_support_index_and_slice(tmp_path):
args = make_args(tmp_path, time=['0', '1:'])

plan = augment_helpers.resolve_augment_plan(args)

assert plan['time_selectors'] == ['0', '1:']
assert plan['resolved_snapshot_names'] == ['0.0', '1.0', '2.0']
assert plan['resolved_snapshot_count'] == 3


def test_resolve_augment_plan_time_selector_can_stride(tmp_path):
args = make_args(tmp_path, time=['::2'])

plan = augment_helpers.resolve_augment_plan(args)

assert plan['resolved_snapshot_names'] == ['0.0', '2.0']


def test_resolve_augment_plan_time_selector_out_of_range_fails(tmp_path):
args = make_args(tmp_path, time=['10'])

try:
augment_helpers.resolve_augment_plan(args)
except ValueError as exc:
assert 'out of range' in str(exc)
else:
raise AssertionError('expected ValueError')


def test_apply_augment_plan_uses_selected_snapshots_only(tmp_path, monkeypatch):
args = make_args(tmp_path, time=['1'])
plan = augment_helpers.resolve_augment_plan(args)
captured = {}

def fake_random_perturb(data, mech_path, dataset, heat_limit, element_limit, seed=None):
captured['data'] = data.copy()
return data

monkeypatch.setattr(data_module, 'random_perturb', fake_random_perturb)

result = augment_helpers.apply_augment_plan(plan, quiet=True)

assert result['resolved_snapshot_count'] == 1
assert captured['data'].shape == (2, 2)
assert captured['data'][0, 0] == 5.0
assert Path(plan['save']).exists()



def test_handle_command_json_preview_and_apply(tmp_path, monkeypatch, capsys):
args = make_args(tmp_path, preview=True, apply=True, json=True)

Expand All @@ -88,6 +148,7 @@ def test_handle_command_json_preview_and_apply(tmp_path, monkeypatch, capsys):
payload = json.loads(capsys.readouterr().out)
assert payload['command_type'] == 'augment'
assert payload['plan']['target_size'] == 12
assert payload['plan']['resolved_snapshot_count'] == 3
assert payload['apply']['returned_count'] == 9


Expand Down
Loading