From 001dcc518e3b682014da1d20de937060155c224e Mon Sep 17 00:00:00 2001 From: Bhupendra Raut Date: Tue, 16 Jun 2026 11:42:38 -0500 Subject: [PATCH 1/8] FIX: CI by adding duckdb --- environment.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/environment.yml b/environment.yml index 01e0dc3..f7c5cd1 100644 --- a/environment.yml +++ b/environment.yml @@ -27,6 +27,7 @@ dependencies: - pyproj - contextily - pyarrow + - duckdb - pydata-sphinx-theme - pydantic - pip: From 1e682a6c99c4ac56693eee78200bfff3ddb084a4 Mon Sep 17 00:00:00 2001 From: Bhupendra Raut Date: Fri, 26 Jun 2026 00:46:21 -0500 Subject: [PATCH 2/8] ENH: tracking module; schema --- src/adapt/configuration/schemas/internal.py | 7 + src/adapt/configuration/schemas/param.py | 39 + .../schemas/radar_catalog_schema.sql | 12 +- src/adapt/consumers/live/dashboard.py | 2 +- src/adapt/execution/nodes/tracking.py | 7 + src/adapt/modules/tracking/README.md | 115 ++- src/adapt/modules/tracking/config.py | 7 + src/adapt/modules/tracking/events.py | 205 +++++ src/adapt/modules/tracking/graph.py | 173 ++++ src/adapt/modules/tracking/identity.py | 67 ++ .../modules/tracking/matching/__init__.py | 0 .../modules/tracking/matching/hungarian.py | 91 ++ .../modules/tracking/matching/overlap.py | 66 ++ src/adapt/modules/tracking/models.py | 67 ++ src/adapt/modules/tracking/module.py | 862 ++++++++---------- src/adapt/modules/tracking/motion.py | 80 ++ src/adapt/modules/tracking/projection.py | 30 + src/adapt/persistence/track_store.py | 31 +- tests/persistence/test_track_store.py | 11 +- 19 files changed, 1357 insertions(+), 515 deletions(-) create mode 100644 src/adapt/modules/tracking/events.py create mode 100644 src/adapt/modules/tracking/graph.py create mode 100644 src/adapt/modules/tracking/identity.py create mode 100644 src/adapt/modules/tracking/matching/__init__.py create mode 100644 src/adapt/modules/tracking/matching/hungarian.py create mode 100644 src/adapt/modules/tracking/matching/overlap.py create mode 100644 src/adapt/modules/tracking/models.py create mode 100644 src/adapt/modules/tracking/motion.py create mode 100644 src/adapt/modules/tracking/projection.py diff --git a/src/adapt/configuration/schemas/internal.py b/src/adapt/configuration/schemas/internal.py index 98c9a48..c902830 100644 --- a/src/adapt/configuration/schemas/internal.py +++ b/src/adapt/configuration/schemas/internal.py @@ -140,6 +140,13 @@ class InternalCellUidConfig(AdaptBaseModel): core_reflectivity_threshold: float = Field(default=40.0, ge=0.0) max_gap_minutes: float = Field(default=10.0, gt=0.0) expected_speed_ms: float = Field(default=30.0, gt=0.0) + max_tracking_gap_minutes: float = Field(default=20.0, gt=0.0) + projection_horizon_minutes: float = Field(default=20.0, gt=0.0) + projection_interval_minutes: float = Field(default=1.0, gt=0.0) + max_speed_ms: float = Field(default=40.0, gt=0.0) + max_speed_multiplier: float = Field(default=3.0, gt=0.0) + overlap_match_threshold: float = Field(default=0.3, ge=0.0, le=1.0) + heading_change_penalty_weight: float = Field(default=0.0, ge=0.0) cell_uid: InternalCellUidConfig diff --git a/src/adapt/configuration/schemas/param.py b/src/adapt/configuration/schemas/param.py index f4ef1a7..0dc4c96 100644 --- a/src/adapt/configuration/schemas/param.py +++ b/src/adapt/configuration/schemas/param.py @@ -227,6 +227,45 @@ class CellUidConfig(AdaptBaseModel): gt=0.0, description="Maximum expected cell propagation speed (m/s); scales D_pos with dt", ) + max_tracking_gap_minutes: float = Field( + 20.0, + gt=0.0, + description="Hard limit: scan gaps above this terminate all tracks and restart " + "(no matching attempted across the gap)", + ) + projection_horizon_minutes: float = Field( + 20.0, + gt=0.0, + description="How far ahead (minutes) registration-based projected hulls are consumed", + ) + projection_interval_minutes: float = Field( + 1.0, + gt=0.0, + description="Spacing (minutes) between registration projected hulls", + ) + max_speed_ms: float = Field( + 40.0, + gt=0.0, + description="Hard physical cap (m/s); candidate pairs above this are rejected pre-matching", + ) + max_speed_multiplier: float = Field( + 3.0, + gt=0.0, + description="Hard acceleration cap: reject if candidate speed exceeds this times the " + "track's previous speed", + ) + overlap_match_threshold: float = Field( + 0.3, + ge=0.0, + le=1.0, + description="Min projected-hull overlap for a deterministic unique-overlap direct match " + "(uniqueness dominates the threshold)", + ) + heading_change_penalty_weight: float = Field( + 0.0, + ge=0.0, + description="Optional cost penalty per radian of heading change (0 = diagnostic only)", + ) cell_uid: CellUidConfig = Field(default_factory=CellUidConfig) # type: ignore[arg-type] diff --git a/src/adapt/configuration/schemas/radar_catalog_schema.sql b/src/adapt/configuration/schemas/radar_catalog_schema.sql index bd13c70..82301cb 100644 --- a/src/adapt/configuration/schemas/radar_catalog_schema.sql +++ b/src/adapt/configuration/schemas/radar_catalog_schema.sql @@ -176,7 +176,17 @@ CREATE TABLE IF NOT EXISTS cell_events ( target_cell_label INTEGER, cost REAL, is_dominant INTEGER NOT NULL DEFAULT 0, - event_group_id TEXT NOT NULL + event_group_id TEXT NOT NULL, + -- Per-accepted-match diagnostics (NULL for INITIATION / TERMINATION) + candidate_overlap REAL, + candidate_iou REAL, + candidate_centroid_distance_m REAL, + candidate_speed_ms REAL, + candidate_heading_change_deg REAL, + candidate_area_ratio REAL, + candidate_reflectivity_difference REAL, + candidate_final_cost REAL, + match_method TEXT ); CREATE INDEX IF NOT EXISTS idx_ce_source ON cell_events(run_id, source_cell_uid); diff --git a/src/adapt/consumers/live/dashboard.py b/src/adapt/consumers/live/dashboard.py index 6fc5c36..000bd60 100644 --- a/src/adapt/consumers/live/dashboard.py +++ b/src/adapt/consumers/live/dashboard.py @@ -2089,7 +2089,7 @@ def _draw_scan(self, ds, fig, ax=None): marker="*", markersize=8, linestyle="None", - label="qurrent centroid", + label="Current Centroid", ), ] diff --git a/src/adapt/execution/nodes/tracking.py b/src/adapt/execution/nodes/tracking.py index 5247875..40fc4ea 100644 --- a/src/adapt/execution/nodes/tracking.py +++ b/src/adapt/execution/nodes/tracking.py @@ -91,6 +91,13 @@ def build_config(cls, cfg) -> TrackingConfig: labels_var=cfg.global_.var_names.cell_labels, max_gap_minutes=cfg.tracker.max_gap_minutes, expected_speed_ms=cfg.tracker.expected_speed_ms, + max_tracking_gap_minutes=cfg.tracker.max_tracking_gap_minutes, + projection_horizon_minutes=cfg.tracker.projection_horizon_minutes, + projection_interval_minutes=cfg.tracker.projection_interval_minutes, + max_speed_ms=cfg.tracker.max_speed_ms, + max_speed_multiplier=cfg.tracker.max_speed_multiplier, + overlap_match_threshold=cfg.tracker.overlap_match_threshold, + heading_change_penalty_weight=cfg.tracker.heading_change_penalty_weight, ) def __init__(self) -> None: diff --git a/src/adapt/modules/tracking/README.md b/src/adapt/modules/tracking/README.md index 922564a..4825468 100644 --- a/src/adapt/modules/tracking/README.md +++ b/src/adapt/modules/tracking/README.md @@ -18,21 +18,45 @@ The tracking module performs tracking-only association of segmented radar cells ## Architecture +The scientific layer is decomposed into focused, single-responsibility files under +`adapt/modules/tracking/`. `RadarCellTracker` (in `module.py`) is orchestration only; +it delegates to: + +| File | Responsibility | +|------|----------------| +| `module.py` | `RadarCellTracker` — per-scan flow, state, delegation | +| `graph.py` | `TrackingGraph` (the only `networkx` home) | +| `projection.py` | `select_registration_labels` — minute-resolution registration hull | +| `matching/overlap.py` | `OverlapMatcher` — deterministic unique-overlap matching | +| `matching/hungarian.py` | `MatchingEngine` — cost matrix + Hungarian (the only `scipy` home) | +| `motion.py` | `MotionValidator` + heading-change helpers | +| `models.py` | `MatchMethod` / `TrackingError` enums, `MatchDiagnostics`, `TrackMotionState` | +| `identity.py` | stable `cell_uid` generation | +| `events.py` | lineage event-row builders + diagnostics assembly | +| `config.py` | frozen `TrackingConfig` | + +The node layer (`adapt/execution/nodes/tracking.py`) keeps the `BaseModule` wrapper, +`registry.register`, `build_config`, contracts, and persistence specs — no engine +imports ever live under `modules/tracking/`. + +### Matching hierarchy + +Each frame pair is resolved in this order (registration-driven, optimisation last): + ``` -┌──────────────────────────┐ -│ RadarCellTracker │ Scientific Implementation -│ - Tracking graph │ -│ - Cost function │ -│ - Hungarian assignment │ -│ - Event emission │ -└──────────────────────────┘ - ↓ -┌──────────────────────────┐ -│ TrackingModule │ Pipeline Integration -│ - BaseModule wrapper │ -│ - Context management │ -│ - Contract validation │ -└──────────────────────────┘ +scan-gap classification (physical time; hard reset on excess gap / non-monotonic time) + ↓ +registration projected hulls (minute nearest the real gap) + ↓ +hard physical-motion rejection (speed / acceleration caps — before matching) + ↓ +deterministic unique-overlap matching (skips Hungarian) + ↓ +Hungarian assignment (residual ambiguity only; soft heading-consistency penalty) + ↓ +split / merge detection + ↓ +initiation / termination ``` ## Usage @@ -138,25 +162,38 @@ Normalized adjacency pairs in track identity space: ### Cost Function -The matching cost combines multiple terms: +The Hungarian matching cost (`matching/hungarian.py`) combines four terms: ``` -cost = 0.4 * D_pos + 0.3 * (1 - IoU) + 0.15 * |log(A2/A1)| + 0.1 * |Z2 - Z1| + 0.05 * core_penalty +cost = 0.4 * D_pos + 0.3 * (1 - IoU) + 0.15 * |log(A2/A1)| + 0.1 * |Z2 - Z1| / 50 ``` -Where: -- `D_pos`: Normalized centroid distance -- `IoU`: Intersection-over-union of masks -- `A2/A1`: Area ratio -- `Z2 - Z1`: Reflectivity difference +Where `D_pos` is the centroid distance normalised by `expected_speed_ms * dt`, +`IoU` is the projected-hull/current-cell overlap, `A2/A1` is the area ratio, and +`Z2 - Z1` is the mean-reflectivity difference. When `heading_change_penalty_weight` +> 0, a soft `weight * heading_change` (radians) term is added for tracks with an +established velocity (crossing-track prevention). ### Assignment -Uses the Hungarian algorithm (`scipy.optimize.linear_sum_assignment`) to find optimal cell-to-cell assignments while minimizing total cost. +Hungarian assignment (`scipy.optimize.linear_sum_assignment`) is applied **only to +residual ambiguity** — pairs left after deterministic unique-overlap matching and +hard physical-motion rejection. See the matching hierarchy above. ### Search Region -Candidates are filtered by non-zero overlap with the projected previous-cell labels in the current scan coordinates (`cell_projections[0]`). +Candidates are filtered by non-zero overlap with the registration projected hull — +the minute-resolution `registration_minutes` frame nearest the real scan gap, +falling back to `cell_projections[0]` when minute frames are absent +(`projection.select_registration_labels`). + +### Diagnostics + +Every accepted match records a `MatchDiagnostics` row persisted to `cell_events`: +`candidate_overlap`, `candidate_iou`, `candidate_centroid_distance_m`, +`candidate_speed_ms`, `candidate_heading_change_deg`, `candidate_area_ratio`, +`candidate_reflectivity_difference`, `candidate_final_cost`, and `match_method` +(`OVERLAP` / `HUNGARIAN` / `SPLIT` / `MERGE`). ## Testing @@ -192,13 +229,41 @@ Typical performance for 50 cells per scan: - **Memory usage**: ~10 MB for 100 scans - **Graph size**: Linear with total cell-observations +## Tracking-Assisted Segmentation Correction (design only — not implemented) + +Projected hulls carry motion-coherence information that can flag segmentation +errors. This is a **designed extension point**, intentionally left unimplemented +(YAGNI) until a concrete use case exists. No hooks or dead code are present today. + +**Motivating cases** + +- *Over-split.* Segmentation fragments one storm into several cells, but a single + continuing parent's projected hull covers all fragments coherently → the + fragments should be re-merged into one tracked object. +- *Over-merge.* Segmentation fuses two storms into one cell, but two distinct + parents project into separable sub-regions → the cell should be re-split. + +**Where it would hook in.** A correction stage would sit between *registration +projected hulls* and *deterministic unique-overlap matching* in the hierarchy +above — i.e. it adjusts the current-frame label field *before* matching, so all +downstream stages operate on the corrected segmentation. It would consume the +same `select_registration_labels` hull plus the per-parent overlap structure +already computed by `OverlapMatcher`. + +**Proposed shape when built.** A registered, swappable +`SegmentationCorrector` strategy (Open/Closed, like detection/tracking backends) +implemented under `modules/`, communicating via a `contracts/` interface, returning +a corrected label array + provenance describing each merge/split it applied. It +must be deterministic and produce no side effects. The correction provenance would +travel as additional diagnostic rows so every change is traceable. + ## Future Enhancements Potential improvements identified during development: 1. **Advanced Split/Merge Logic**: Implement temporary merge identity restoration -2. **Multi-Step Prediction**: Use multiple projection steps for better matching -3. **Track Smoothing**: Apply Kalman filtering to motion vectors +2. **Track Smoothing**: Apply Kalman filtering to motion vectors +3. **Motion-coherent segmentation correction**: see the design section above 4. **Parallel Processing**: Process multiple files concurrently 5. **Persistence**: Save/load tracking graph for resumable processing diff --git a/src/adapt/modules/tracking/config.py b/src/adapt/modules/tracking/config.py index 3fbe6be..5c0bc1d 100644 --- a/src/adapt/modules/tracking/config.py +++ b/src/adapt/modules/tracking/config.py @@ -27,3 +27,10 @@ class TrackingConfig(BaseModel): labels_var: str max_gap_minutes: float expected_speed_ms: float + max_tracking_gap_minutes: float + projection_horizon_minutes: float + projection_interval_minutes: float + max_speed_ms: float + max_speed_multiplier: float + overlap_match_threshold: float + heading_change_penalty_weight: float diff --git a/src/adapt/modules/tracking/events.py b/src/adapt/modules/tracking/events.py new file mode 100644 index 0000000..5035994 --- /dev/null +++ b/src/adapt/modules/tracking/events.py @@ -0,0 +1,205 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Lineage event-row builders. + +Pure functions that turn graph nodes into the explicit event dicts persisted as +``cell_events`` (CONTINUE / SPLIT / MERGE / INITIATION / TERMINATION). Each takes +the graph and the track_index→(cell_uid, signature) identity map; no tracker state. +""" + +import pandas as pd + +from adapt.modules.tracking.graph import TrackingGraph +from adapt.modules.tracking.models import MatchDiagnostics, MatchMethod +from adapt.utils.time import normalize_time_scalar + +__all__ = [ + "DIAGNOSTIC_COLUMNS", + "EVENT_COLUMNS", + "build_cell_events_dataframe", + "event_continue", + "event_initiation", + "event_merge", + "event_split", + "event_termination", +] + +# Per-accepted-match explainability columns (null for INITIATION / TERMINATION). +DIAGNOSTIC_COLUMNS = [ + "candidate_overlap", + "candidate_iou", + "candidate_centroid_distance_m", + "candidate_speed_ms", + "candidate_heading_change_deg", + "candidate_area_ratio", + "candidate_reflectivity_difference", + "candidate_final_cost", + "match_method", +] + +EVENT_COLUMNS = [ + "time", + "event_type", + "source_cell_uid", + "target_cell_uid", + "source_cell_label", + "target_cell_label", + "cost", + "is_dominant", + "event_group_id", + *DIAGNOSTIC_COLUMNS, +] + +Identity = dict[int, tuple[str, str]] + + +def _diagnostic_fields(diag: MatchDiagnostics | None) -> dict: + """Map a MatchDiagnostics into the ``candidate_*`` / ``match_method`` columns.""" + if diag is None: + return dict.fromkeys(DIAGNOSTIC_COLUMNS) + return { + "candidate_overlap": diag.overlap, + "candidate_iou": diag.iou, + "candidate_centroid_distance_m": diag.centroid_distance_m, + "candidate_speed_ms": diag.speed_ms, + "candidate_heading_change_deg": diag.heading_change_deg, + "candidate_area_ratio": diag.area_ratio, + "candidate_reflectivity_difference": diag.reflectivity_difference, + "candidate_final_cost": diag.final_cost, + "match_method": diag.match_method, + } + + +def _uid(graph: TrackingGraph, identity: Identity, node_id: int) -> str: + track_index = int(graph.get_node_attr(node_id, "track_index")) + if track_index not in identity: + raise ValueError(f"Missing cell identity for track_index={track_index}") + return identity[track_index][0] + + +def _time_key(time_val) -> str: + """Stable ISO8601 time key for event grouping.""" + tv = normalize_time_scalar(time_val) + return pd.Timestamp(tv).isoformat() + + +def event_continue( + graph: TrackingGraph, + identity: Identity, + time, + prev_node_id: int, + curr_node_id: int, + cost: float, + diagnostics: MatchDiagnostics | None = None, +) -> dict: + target_cell_uid = _uid(graph, identity, curr_node_id) + return { + "time": time, + "event_type": "CONTINUE", + "source_cell_uid": _uid(graph, identity, prev_node_id), + "target_cell_uid": target_cell_uid, + "source_cell_label": int(graph.get_node_attr(prev_node_id, "cell_id")), + "target_cell_label": int(graph.get_node_attr(curr_node_id, "cell_id")), + "cost": float(cost), + "is_dominant": True, + "event_group_id": f"{_time_key(time)}:CONTINUE:{target_cell_uid}", + **_diagnostic_fields(diagnostics), + } + + +def event_split( + graph: TrackingGraph, + identity: Identity, + time, + parent_node_id: int, + child_node_id: int, +) -> dict: + parent_uid = _uid(graph, identity, parent_node_id) + return { + "time": time, + "event_type": "SPLIT", + "source_cell_uid": parent_uid, + "target_cell_uid": _uid(graph, identity, child_node_id), + "source_cell_label": int(graph.get_node_attr(parent_node_id, "cell_id")), + "target_cell_label": int(graph.get_node_attr(child_node_id, "cell_id")), + "cost": None, + "is_dominant": False, + "event_group_id": f"{_time_key(time)}:SPLIT:{parent_uid}", + **_diagnostic_fields(MatchDiagnostics(match_method=MatchMethod.SPLIT)), + } + + +def event_merge( + graph: TrackingGraph, + identity: Identity, + time, + source_node_id: int, + target_node_id: int, +) -> dict: + target_uid = _uid(graph, identity, target_node_id) + return { + "time": time, + "event_type": "MERGE", + "source_cell_uid": _uid(graph, identity, source_node_id), + "target_cell_uid": target_uid, + "source_cell_label": int(graph.get_node_attr(source_node_id, "cell_id")), + "target_cell_label": int(graph.get_node_attr(target_node_id, "cell_id")), + "cost": None, + "is_dominant": False, + "event_group_id": f"{_time_key(time)}:MERGE:{target_uid}", + **_diagnostic_fields(MatchDiagnostics(match_method=MatchMethod.MERGE)), + } + + +def event_initiation(graph: TrackingGraph, identity: Identity, time, node_id: int) -> dict: + target_uid = _uid(graph, identity, node_id) + return { + "time": time, + "event_type": "INITIATION", + "source_cell_uid": None, + "target_cell_uid": target_uid, + "source_cell_label": None, + "target_cell_label": int(graph.get_node_attr(node_id, "cell_id")), + "cost": None, + "is_dominant": False, + "event_group_id": f"{_time_key(time)}:INITIATION:{target_uid}", + } + + +def event_termination( + graph: TrackingGraph, + identity: Identity, + time, + source_node_id: int, + target_node_id: int | None, +) -> dict: + source_uid = _uid(graph, identity, source_node_id) + target_uid = _uid(graph, identity, target_node_id) if target_node_id is not None else None + return { + "time": time, + "event_type": "TERMINATION", + "source_cell_uid": source_uid, + "target_cell_uid": target_uid, + "source_cell_label": int(graph.get_node_attr(source_node_id, "cell_id")), + "target_cell_label": ( + int(graph.get_node_attr(target_node_id, "cell_id")) + if target_node_id is not None + else None + ), + "cost": None, + "is_dominant": False, + "event_group_id": f"{_time_key(time)}:TERMINATION:{source_uid}", + } + + +def build_cell_events_dataframe(events: list[dict]) -> pd.DataFrame: + if not events: + return pd.DataFrame(columns=EVENT_COLUMNS) + df = pd.DataFrame(events) + for col in EVENT_COLUMNS: + if col not in df.columns: + df[col] = None + df = df[EVENT_COLUMNS] + df["time"] = df["time"].apply(lambda t: pd.Timestamp(normalize_time_scalar(t))) + return df diff --git a/src/adapt/modules/tracking/graph.py b/src/adapt/modules/tracking/graph.py new file mode 100644 index 0000000..7861015 --- /dev/null +++ b/src/adapt/modules/tracking/graph.py @@ -0,0 +1,173 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Directed graph storing cell tracking history and lineage. + +Nodes are cell observations at a time; edges are temporal relationships +(CONTINUE, SPLIT, MERGE). This is the only home for ``networkx`` in the package. +""" + +import networkx as nx + +__all__ = ["TrackingGraph"] + + +class TrackingGraph: + """Directed graph storing cell tracking history and lineage. + + Nodes represent cell observations at specific times. + Edges represent temporal relationships (CONTINUE, SPLIT, MERGE). + + Node attributes: + - node_id: unique identifier (int) + - time: observation timestamp + - cell_id: cell label from segmentation + - track_index: tracking index this cell belongs to (starts at 1; 0 = background sentinel) + - area: cell area in km² + - centroid_x, centroid_y: cell center coordinates + - mean_reflectivity: average dBZ + - max_reflectivity: peak dBZ + - core_area: area with Z > threshold dBZ + + Edge attributes: + - edge_type: "CONTINUE", "SPLIT", "MERGE" + - cost: assignment cost (for diagnostics) + """ + + def __init__(self): + """Initialize empty tracking graph.""" + self.graph = nx.DiGraph() + self._node_counter = 0 + self._track_counter = 0 # Will yield 1, 2, 3, ... (0 is background sentinel) + + def add_observation( + self, + time, + cell_id: int, + track_index: int, + area: float, + centroid_x: float, + centroid_y: float, + mean_reflectivity: float, + max_reflectivity: float, + core_area: float, + cell_uid: str, + track_signature: str, + ) -> int: + node_id = self._node_counter + self._node_counter += 1 + + self.graph.add_node( + node_id, + time=time, + cell_id=cell_id, + track_index=track_index, + area=area, + centroid_x=centroid_x, + centroid_y=centroid_y, + mean_reflectivity=mean_reflectivity, + max_reflectivity=max_reflectivity, + core_area=core_area, + cell_uid=cell_uid, + track_signature=track_signature, + ) + return node_id + + def add_edge(self, from_node: int, to_node: int, edge_type: str, cost: float = 0.0): + """Add a temporal relationship edge. + + Parameters + ---------- + from_node : int + Source node ID (earlier time) + to_node : int + Target node ID (later time) + edge_type : str + Edge type: "CONTINUE", "SPLIT", or "MERGE" + cost : float, optional + Assignment cost for diagnostics (default: 0.0) + """ + self.graph.add_edge(from_node, to_node, edge_type=edge_type, cost=cost) + + def get_new_track_index(self) -> int: + """Allocate a new unique track index (starts at 1; 0 is background sentinel).""" + self._track_counter += 1 + return self._track_counter + + def get_node_attr(self, node_id: int, attr: str): + """Get a node attribute value. + + Parameters + ---------- + node_id : int + Node identifier + attr : str + Attribute name + + Returns + ------- + Any + Attribute value, or None if not present + """ + return self.graph.nodes[node_id].get(attr) + + def get_nodes_at_time(self, time) -> list[int]: + """Get all node IDs for a given timestamp. + + Parameters + ---------- + time : datetime-like + Timestamp to query + + Returns + ------- + List[int] + List of node IDs at this time + """ + return [n for n, d in self.graph.nodes(data=True) if d.get("time") == time] + + def get_track_nodes(self, track_index: int) -> list[int]: + """Get all nodes belonging to a track, sorted by time.""" + nodes = [ + (n, d["time"]) + for n, d in self.graph.nodes(data=True) + if d.get("track_index") == track_index + ] + nodes.sort(key=lambda x: x[1]) + return [n for n, _ in nodes] + + def get_predecessors(self, node_id: int) -> list[tuple[int, str]]: + """Get predecessor nodes with their edge types. + + Parameters + ---------- + node_id : int + Node identifier + + Returns + ------- + List[Tuple[int, str]] + List of (predecessor_node_id, edge_type) tuples + """ + return [ + (pred, self.graph.edges[pred, node_id]["edge_type"]) + for pred in self.graph.predecessors(node_id) + ] + + def get_successors(self, node_id: int) -> list[tuple[int, str]]: + """Get successor nodes with their edge types. + + Parameters + ---------- + node_id : int + Node identifier + + Returns + ------- + List[Tuple[int, str]] + List of (successor_node_id, edge_type) tuples + """ + return [ + (succ, self.graph.edges[node_id, succ]["edge_type"]) + for succ in self.graph.successors(node_id) + ] diff --git a/src/adapt/modules/tracking/identity.py b/src/adapt/modules/tracking/identity.py new file mode 100644 index 0000000..61c2139 --- /dev/null +++ b/src/adapt/modules/tracking/identity.py @@ -0,0 +1,67 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Stable cell-uid generation. + +A cell's birth state (quantized time, location, intensity, area) is hashed into a +short base36 token. Quantization makes the token robust to small input variation; +the hash makes it stable and reproducible. Pure functions only — no I/O, no state. +""" + +import hashlib +import string + +BASE36_UPPER = string.digits + string.ascii_uppercase + + +def _quantize(value: float, step: float) -> int: + # this is for creating stable hashes that are robust to small variations in the input values + if step <= 0: + raise ValueError("step must be positive") + return int(round(value / step)) + + +def _encode_base36(value: int) -> str: + if value < 0: + raise ValueError("value must be non-negative") + if value == 0: + return "0" + chars: list[str] = [] + while value: + value, remainder = divmod(value, 36) + chars.append(BASE36_UPPER[remainder]) + return "".join(reversed(chars)) + + +def _encode_base36_fixed(value: int, width: int) -> str: + token = _encode_base36(value) + return token.rjust(width, "0") + + +def _track_signature_from_birth( + scan_start_time_epoch_s: float, + centroid_lat_deg: float, + centroid_lon_deg: float, + max_dbz: float, + max_zdr: float, + area40_km2: float, + *, + time_step_s: int, + latlon_step_deg: float, + area_step_km2: float, + signature_version: str = "v1", +) -> str: + tq = _quantize(scan_start_time_epoch_s, time_step_s) + latq = _quantize(centroid_lat_deg, latlon_step_deg) + lonq = _quantize(centroid_lon_deg, latlon_step_deg) + dbzq = int(round(max_dbz)) + zdrq = int(round(max_zdr * 10.0)) + a40q = _quantize(area40_km2, area_step_km2) + return f"{signature_version}|{tq}|{latq}|{lonq}|{dbzq}|{zdrq}|{a40q}" + + +def _cell_uid_from_signature(signature: str, width: int) -> str: + digest = hashlib.blake2b(signature.encode("utf-8"), digest_size=8).digest() + value64 = int.from_bytes(digest, byteorder="big", signed=False) + modulus = 36**width + return _encode_base36_fixed(value64 % modulus, width=width) diff --git a/src/adapt/modules/tracking/matching/__init__.py b/src/adapt/modules/tracking/matching/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/adapt/modules/tracking/matching/hungarian.py b/src/adapt/modules/tracking/matching/hungarian.py new file mode 100644 index 0000000..cff6156 --- /dev/null +++ b/src/adapt/modules/tracking/matching/hungarian.py @@ -0,0 +1,91 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Hungarian (optimal-assignment) matcher for residual ambiguity. + +Builds a (n_prev × n_curr) cost matrix from projected hulls and solves it with +``scipy.optimize.linear_sum_assignment``. This is the only home for ``scipy`` in +the package. Deterministic overlap resolution lives in ``overlap.py``; this +matcher is the fallback for genuinely contested cells. +""" + +import numpy as np + +__all__ = ["MatchingEngine"] + + +class MatchingEngine: + """Cost matrix builder using projected masks (cell_projections[0] is already the hull).""" + + def __init__(self, config): + self.core_threshold = config.core_reflectivity_threshold + self.expected_speed_ms = config.expected_speed_ms + + def compute_cost_matrix( + self, + prev_node_ids: list[int], + graph, + proj_labels: np.ndarray, + curr_cells: list[dict], + dummy_cost: float, + dt_s: float, + ) -> np.ndarray: + """Build (n_prev × n_curr) cost matrix. + + Uses cell_projections[0] directly as the projected hull — no recomputation. + Pairs with no spatial overlap receive dummy_cost. + D_pos is normalised by expected_speed_ms * dt_s so displacement cost scales + correctly with scan interval. + """ + n_prev = len(prev_node_ids) + n_curr = len(curr_cells) + cost_matrix = np.full((n_prev, n_curr), dummy_cost, dtype=float) + + for prev_idx, prev_node in enumerate(prev_node_ids): + prev_cell_id = graph.get_node_attr(prev_node, "cell_id") + proj_mask = proj_labels == prev_cell_id + if not np.any(proj_mask): + continue # cell left the frame or is dormant (no projection) + for curr_idx, curr_cell in enumerate(curr_cells): + if np.any(proj_mask & curr_cell["mask"]): + cost_matrix[prev_idx, curr_idx] = self._compute_cost( + prev_node, graph, proj_mask, curr_cell, dt_s + ) + + return cost_matrix + + def _compute_cost( + self, + prev_node: int, + graph, + proj_mask: np.ndarray, + curr_cell: dict, + dt_s: float, + ) -> float: + """4-term cost: 0.4*Dpos + 0.3*(1-IoU) + 0.15*|log(A2/A1)| + 0.1*|Z2-Z1|/50 + + D_pos is normalised by max_displacement = expected_speed_ms * dt_s (metres), + then capped at 1.0 so it stays in [0, 1] regardless of cadence. + """ + prev_cx = graph.get_node_attr(prev_node, "centroid_x") + prev_cy = graph.get_node_attr(prev_node, "centroid_y") + prev_area = graph.get_node_attr(prev_node, "area") + prev_refl = graph.get_node_attr(prev_node, "mean_reflectivity") + + curr_mask = curr_cell["mask"] + dist = np.sqrt( + (curr_cell["centroid_x"] - prev_cx) ** 2 + (curr_cell["centroid_y"] - prev_cy) ** 2 + ) + max_displacement = self.expected_speed_ms * dt_s # metres + D_pos = min(float(dist) / max_displacement, 1.0) + + union = np.sum(proj_mask | curr_mask) + IoU = float(np.sum(proj_mask & curr_mask)) / union if union > 0 else 0.0 + + curr_area = curr_cell["area"] + area_diff = ( + float(np.abs(np.log(curr_area / prev_area))) if prev_area > 0 and curr_area > 0 else 1.0 + ) + refl_diff = float(np.abs(curr_cell["mean_reflectivity"] - prev_refl)) / 50.0 + + return 0.4 * D_pos + 0.3 * (1.0 - IoU) + 0.15 * area_diff + 0.1 * refl_diff diff --git a/src/adapt/modules/tracking/matching/overlap.py b/src/adapt/modules/tracking/matching/overlap.py new file mode 100644 index 0000000..38d7e0c --- /dev/null +++ b/src/adapt/modules/tracking/matching/overlap.py @@ -0,0 +1,66 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Deterministic overlap-first matcher. + +Resolves the unambiguous cases before any optimisation: when a projected parent +hull and a current child overlap *mutually uniquely* above a moderate threshold, +they are a direct match and never enter Hungarian assignment. Uniqueness — no +competing candidate on either side — matters more than the exact overlap value. +Pure geometry over boolean masks; no graph, no state. +""" + +import numpy as np + +__all__ = ["OverlapMatcher", "overlap_fraction"] + + +def overlap_fraction(hull: np.ndarray, mask: np.ndarray) -> float: + """Fraction of a projected hull covered by a current cell mask.""" + denom = float(np.sum(hull)) + if denom == 0.0: + return 0.0 + return float(np.sum(hull & mask)) / denom + + +class OverlapMatcher: + """Find mutually-unique parent↔child overlaps above a threshold.""" + + def __init__(self, overlap_threshold: float): + self.overlap_threshold = overlap_threshold + + def unique_matches( + self, + prev_hulls: list[np.ndarray], + curr_masks: list[np.ndarray], + allowed: np.ndarray, + ) -> list[tuple[int, int]]: + """Return ``(prev_idx, curr_idx)`` pairs that overlap mutually uniquely. + + ``allowed[i, j]`` gates a pair (e.g. it overlaps at all and is physically + plausible). A pair is returned only when prev ``i`` links to exactly one + curr and that curr links back to exactly one prev — both above threshold. + """ + n_prev = len(prev_hulls) + n_curr = len(curr_masks) + links_per_prev: list[list[int]] = [[] for _ in range(n_prev)] + links_per_curr: list[list[int]] = [[] for _ in range(n_curr)] + + for i, hull in enumerate(prev_hulls): + if not hull.any(): + continue + for j, mask in enumerate(curr_masks): + if not allowed[i, j]: + continue + if overlap_fraction(hull, mask) >= self.overlap_threshold: + links_per_prev[i].append(j) + links_per_curr[j].append(i) + + matches: list[tuple[int, int]] = [] + for i, js in enumerate(links_per_prev): + if len(js) != 1: + continue + j = js[0] + if len(links_per_curr[j]) == 1: # j links back only to i + matches.append((i, j)) + return matches diff --git a/src/adapt/modules/tracking/models.py b/src/adapt/modules/tracking/models.py new file mode 100644 index 0000000..7cdf47d --- /dev/null +++ b/src/adapt/modules/tracking/models.py @@ -0,0 +1,67 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Value types for the tracker: match methods, error/diagnostic codes, motion state. + +Plain enums and frozen dataclasses — no logic, no I/O. Shared by the matching, +motion, and event layers so a decision and its explanation travel together. +""" + +from dataclasses import dataclass +from enum import StrEnum + +__all__ = [ + "MatchDiagnostics", + "MatchMethod", + "TrackMotionState", + "TrackingError", +] + + +class MatchMethod(StrEnum): + """How an accepted match was decided.""" + + OVERLAP = "OVERLAP" + HUNGARIAN = "HUNGARIAN" + SPLIT = "SPLIT" + MERGE = "MERGE" + + +class TrackingError(StrEnum): + """Structured tracking diagnostic codes (logged; not pipeline-fatal).""" + + # Scan-cadence / gap problems + NON_MONOTONIC_TIME = "NON_MONOTONIC_TIME" + TRACK_GAP_EXCEEDED = "TRACK_GAP_EXCEEDED" + IRREGULAR_SCAN_CADENCE = "IRREGULAR_SCAN_CADENCE" + # Physical-constraint rejections + VELOCITY_EXCEEDED = "VELOCITY_EXCEEDED" + ACCELERATION_EXCEEDED = "ACCELERATION_EXCEEDED" + + +@dataclass(frozen=True) +class MatchDiagnostics: + """Per-accepted-match explainability record (persisted with the event row).""" + + overlap: float | None = None + iou: float | None = None + centroid_distance_m: float | None = None + speed_ms: float | None = None + heading_change_deg: float | None = None + area_ratio: float | None = None + reflectivity_difference: float | None = None + final_cost: float | None = None + match_method: str | None = None + + +@dataclass(frozen=True) +class TrackMotionState: + """Per-track velocity carried forward for acceleration and heading checks. + + ``speed`` is in m/s; ``heading`` is in radians measured as ``atan2(vy, vx)``. + ``has_velocity`` is False until a track has been observed across two scans. + """ + + speed: float = 0.0 + heading: float = 0.0 + has_velocity: bool = False diff --git a/src/adapt/modules/tracking/module.py b/src/adapt/modules/tracking/module.py index d2330b3..d3b5970 100644 --- a/src/adapt/modules/tracking/module.py +++ b/src/adapt/modules/tracking/module.py @@ -19,8 +19,10 @@ 1. **tracked_cells**: Per-observation rows for the current scan 2. **cell_events**: Explicit lineage/event rows for the current scan -Tracking state is stored in a directed graph structure with nodes representing cell observations -and edges representing temporal relationships. +Tracking state is stored in a directed graph structure (``graph.TrackingGraph``) with nodes +representing cell observations and edges representing temporal relationships. Cost-matrix +matching lives in ``matching.hungarian.MatchingEngine``; uid generation in ``identity``; event +rows in ``events``. ``RadarCellTracker`` here is orchestration only. What is different from TINT: - No centroid-only matching (uses full mask overlap + motion prediction) @@ -33,329 +35,57 @@ Journal of Applied Meteorology and Climatology, 60(4), 513-526. """ -import hashlib import logging -import string +import math -import networkx as nx import numpy as np import pandas as pd import xarray as xr from scipy.optimize import linear_sum_assignment +from adapt.modules.tracking.events import ( + build_cell_events_dataframe, + event_continue, + event_initiation, + event_merge, + event_split, + event_termination, +) +from adapt.modules.tracking.graph import TrackingGraph +from adapt.modules.tracking.identity import ( + _cell_uid_from_signature, + _track_signature_from_birth, +) +from adapt.modules.tracking.matching.hungarian import MatchingEngine +from adapt.modules.tracking.matching.overlap import OverlapMatcher, overlap_fraction +from adapt.modules.tracking.models import ( + MatchDiagnostics, + MatchMethod, + TrackingError, + TrackMotionState, +) +from adapt.modules.tracking.motion import ( + MotionValidator, + heading_change_degrees, + heading_change_radians, +) +from adapt.modules.tracking.projection import select_registration_labels from adapt.utils.time import normalize_time_scalar -__all__ = ["RadarCellTracker"] +# Beyond this ratio between consecutive scan intervals the cadence is flagged +# irregular (diagnostic only — no track reset). +_CADENCE_IRREGULAR_RATIO = 2.0 -logger = logging.getLogger(__name__) - -BASE36_UPPER = string.digits + string.ascii_uppercase - - -def _quantize(value: float, step: float) -> int: - # this is for creating stable hashes that are robust to small variations in the input values - if step <= 0: - raise ValueError("step must be positive") - return int(round(value / step)) - - -def _encode_base36(value: int) -> str: - if value < 0: - raise ValueError("value must be non-negative") - if value == 0: - return "0" - chars: list[str] = [] - while value: - value, remainder = divmod(value, 36) - chars.append(BASE36_UPPER[remainder]) - return "".join(reversed(chars)) - - -def _encode_base36_fixed(value: int, width: int) -> str: - token = _encode_base36(value) - return token.rjust(width, "0") - - -def _track_signature_from_birth( - scan_start_time_epoch_s: float, - centroid_lat_deg: float, - centroid_lon_deg: float, - max_dbz: float, - max_zdr: float, - area40_km2: float, - *, - time_step_s: int, - latlon_step_deg: float, - area_step_km2: float, - signature_version: str = "v1", -) -> str: - tq = _quantize(scan_start_time_epoch_s, time_step_s) - latq = _quantize(centroid_lat_deg, latlon_step_deg) - lonq = _quantize(centroid_lon_deg, latlon_step_deg) - dbzq = int(round(max_dbz)) - zdrq = int(round(max_zdr * 10.0)) - a40q = _quantize(area40_km2, area_step_km2) - return f"{signature_version}|{tq}|{latq}|{lonq}|{dbzq}|{zdrq}|{a40q}" - - -def _cell_uid_from_signature(signature: str, width: int) -> str: - digest = hashlib.blake2b(signature.encode("utf-8"), digest_size=8).digest() - value64 = int.from_bytes(digest, byteorder="big", signed=False) - modulus = 36**width - return _encode_base36_fixed(value64 % modulus, width=width) - - -# ============================================================================= -# Tracking Graph Structure -# ============================================================================= - - -class TrackingGraph: - """Directed graph storing cell tracking history and lineage. - - Nodes represent cell observations at specific times. - Edges represent temporal relationships (CONTINUE, SPLIT, MERGE). - - Node attributes: - - node_id: unique identifier (int) - - time: observation timestamp - - cell_id: cell label from segmentation - - track_index: tracking index this cell belongs to (starts at 1; 0 = background sentinel) - - area: cell area in km² - - centroid_x, centroid_y: cell center coordinates - - mean_reflectivity: average dBZ - - max_reflectivity: peak dBZ - - core_area: area with Z > threshold dBZ - - Edge attributes: - - edge_type: "CONTINUE", "SPLIT", "MERGE" - - cost: assignment cost (for diagnostics) - """ - - def __init__(self): - """Initialize empty tracking graph.""" - self.graph = nx.DiGraph() - self._node_counter = 0 - self._track_counter = 0 # Will yield 1, 2, 3, ... (0 is background sentinel) - - def add_observation( - self, - time, - cell_id: int, - track_index: int, - area: float, - centroid_x: float, - centroid_y: float, - mean_reflectivity: float, - max_reflectivity: float, - core_area: float, - cell_uid: str, - track_signature: str, - ) -> int: - node_id = self._node_counter - self._node_counter += 1 - - self.graph.add_node( - node_id, - time=time, - cell_id=cell_id, - track_index=track_index, - area=area, - centroid_x=centroid_x, - centroid_y=centroid_y, - mean_reflectivity=mean_reflectivity, - max_reflectivity=max_reflectivity, - core_area=core_area, - cell_uid=cell_uid, - track_signature=track_signature, - ) - return node_id - - def add_edge(self, from_node: int, to_node: int, edge_type: str, cost: float = 0.0): - """Add a temporal relationship edge. - - Parameters - ---------- - from_node : int - Source node ID (earlier time) - to_node : int - Target node ID (later time) - edge_type : str - Edge type: "CONTINUE", "SPLIT", or "MERGE" - cost : float, optional - Assignment cost for diagnostics (default: 0.0) - """ - self.graph.add_edge(from_node, to_node, edge_type=edge_type, cost=cost) - - def get_new_track_index(self) -> int: - """Allocate a new unique track index (starts at 1; 0 is background sentinel).""" - self._track_counter += 1 - return self._track_counter - - def get_node_attr(self, node_id: int, attr: str): - """Get a node attribute value. - - Parameters - ---------- - node_id : int - Node identifier - attr : str - Attribute name - - Returns - ------- - Any - Attribute value, or None if not present - """ - return self.graph.nodes[node_id].get(attr) - - def get_nodes_at_time(self, time) -> list[int]: - """Get all node IDs for a given timestamp. - - Parameters - ---------- - time : datetime-like - Timestamp to query - - Returns - ------- - List[int] - List of node IDs at this time - """ - return [n for n, d in self.graph.nodes(data=True) if d.get("time") == time] - - def get_track_nodes(self, track_index: int) -> list[int]: - """Get all nodes belonging to a track, sorted by time.""" - nodes = [ - (n, d["time"]) - for n, d in self.graph.nodes(data=True) - if d.get("track_index") == track_index - ] - nodes.sort(key=lambda x: x[1]) - return [n for n, _ in nodes] - - def get_predecessors(self, node_id: int) -> list[tuple[int, str]]: - """Get predecessor nodes with their edge types. - - Parameters - ---------- - node_id : int - Node identifier - - Returns - ------- - List[Tuple[int, str]] - List of (predecessor_node_id, edge_type) tuples - """ - return [ - (pred, self.graph.edges[pred, node_id]["edge_type"]) - for pred in self.graph.predecessors(node_id) - ] - - def get_successors(self, node_id: int) -> list[tuple[int, str]]: - """Get successor nodes with their edge types. - - Parameters - ---------- - node_id : int - Node identifier - - Returns - ------- - List[Tuple[int, str]] - List of (successor_node_id, edge_type) tuples - """ - return [ - (succ, self.graph.edges[node_id, succ]["edge_type"]) - for succ in self.graph.successors(node_id) - ] - - -# ============================================================================= -# Matching Engine -# ============================================================================= - - -class MatchingEngine: - """Cost matrix builder using projected masks (cell_projections[0] is already the hull).""" - - def __init__(self, config): - self.core_threshold = config.core_reflectivity_threshold - self.expected_speed_ms = config.expected_speed_ms - - def compute_cost_matrix( - self, - prev_node_ids: list[int], - graph: "TrackingGraph", - proj_labels: np.ndarray, - curr_cells: list[dict], - dummy_cost: float, - dt_s: float, - ) -> np.ndarray: - """Build (n_prev × n_curr) cost matrix. - - Uses cell_projections[0] directly as the projected hull — no recomputation. - Pairs with no spatial overlap receive dummy_cost. - D_pos is normalised by expected_speed_ms * dt_s so displacement cost scales - correctly with scan interval. - """ - n_prev = len(prev_node_ids) - n_curr = len(curr_cells) - cost_matrix = np.full((n_prev, n_curr), dummy_cost, dtype=float) - - for prev_idx, prev_node in enumerate(prev_node_ids): - prev_cell_id = graph.get_node_attr(prev_node, "cell_id") - proj_mask = proj_labels == prev_cell_id - if not np.any(proj_mask): - continue # cell left the frame or is dormant (no projection) - for curr_idx, curr_cell in enumerate(curr_cells): - if np.any(proj_mask & curr_cell["mask"]): - cost_matrix[prev_idx, curr_idx] = self._compute_cost( - prev_node, graph, proj_mask, curr_cell, dt_s - ) - - return cost_matrix - - def _compute_cost( - self, - prev_node: int, - graph: "TrackingGraph", - proj_mask: np.ndarray, - curr_cell: dict, - dt_s: float, - ) -> float: - """4-term cost: 0.4*Dpos + 0.3*(1-IoU) + 0.15*|log(A2/A1)| + 0.1*|Z2-Z1|/50 - - D_pos is normalised by max_displacement = expected_speed_ms * dt_s (metres), - then capped at 1.0 so it stays in [0, 1] regardless of cadence. - """ - prev_cx = graph.get_node_attr(prev_node, "centroid_x") - prev_cy = graph.get_node_attr(prev_node, "centroid_y") - prev_area = graph.get_node_attr(prev_node, "area") - prev_refl = graph.get_node_attr(prev_node, "mean_reflectivity") - - curr_mask = curr_cell["mask"] - dist = np.sqrt( - (curr_cell["centroid_x"] - prev_cx) ** 2 + (curr_cell["centroid_y"] - prev_cy) ** 2 - ) - max_displacement = self.expected_speed_ms * dt_s # metres - D_pos = min(float(dist) / max_displacement, 1.0) - - union = np.sum(proj_mask | curr_mask) - IoU = float(np.sum(proj_mask & curr_mask)) / union if union > 0 else 0.0 - - curr_area = curr_cell["area"] - area_diff = ( - float(np.abs(np.log(curr_area / prev_area))) if prev_area > 0 and curr_area > 0 else 1.0 - ) - refl_diff = float(np.abs(curr_cell["mean_reflectivity"] - prev_refl)) / 50.0 - - return 0.4 * D_pos + 0.3 * (1.0 - IoU) + 0.15 * area_diff + 0.1 * refl_diff +# Re-exported for the public import surface (tests + tooling import these from here). +__all__ = [ + "MatchingEngine", + "RadarCellTracker", + "TrackingGraph", + "_cell_uid_from_signature", + "_track_signature_from_birth", +] - -# ============================================================================= -# Core Tracking Algorithm -# ============================================================================= +logger = logging.getLogger(__name__) class RadarCellTracker: @@ -386,12 +116,19 @@ def __init__(self, config): self.uid_width = config.uid_width self.max_gap_s: float = config.max_gap_minutes * 60.0 + self.max_tracking_gap_minutes: float = config.max_tracking_gap_minutes + self.max_tracking_gap_s: float = config.max_tracking_gap_minutes * 60.0 self.graph = TrackingGraph() self.matcher = MatchingEngine(config) + self.overlap_matcher = OverlapMatcher(config.overlap_match_threshold) + self.motion = MotionValidator(config.max_speed_ms, config.max_speed_multiplier) + self.heading_penalty_weight = config.heading_change_penalty_weight self._previous_scan: tuple | None = None # (time, ds, node_ids) self._cell_identity: dict[int, tuple[str, str]] = {} self._dormant_nodes: dict[int, float] = {} # node_id → last_seen_epoch_s + self._track_motion: dict[int, TrackMotionState] = {} # track_index → kinematics + self._prev_dt_s: float | None = None # last good scan interval (cadence check) logger.info( "RadarCellTracker initialized: match=%.2f keep=%.2f unmatch=%.2f overlap=%.2f" @@ -427,29 +164,33 @@ def track( node_ids = self._initialize_tracks(current_time, cells_current) self._previous_scan = (current_time, ds_projected, node_ids) for node_id in node_ids: - events.append(self._event_initiation(current_time, node_id)) + events.append( + event_initiation(self.graph, self._cell_identity, current_time, node_id) + ) else: prev_time, ds_prev, prev_node_ids = self._previous_scan prev_time_epoch = self._to_epoch_seconds(prev_time) curr_time_epoch = self._to_epoch_seconds(current_time) dt_s = curr_time_epoch - prev_time_epoch - if dt_s <= 0: - raise ValueError(f"Non-monotonic scan times: prev={prev_time}, curr={current_time}") - events = self._track_frame_pair( - prev_time, - ds_prev, - prev_node_ids, - current_time, - ds_projected, - cells_current, - dt_s, - ) + if self._gap_forces_reset(dt_s, prev_time, current_time): + events = self._reset_tracks(prev_node_ids, current_time, cells_current) + else: + self._prev_dt_s = dt_s + events = self._track_frame_pair( + prev_time, + ds_prev, + prev_node_ids, + current_time, + ds_projected, + cells_current, + dt_s, + ) current_node_ids = self.graph.get_nodes_at_time(current_time) self._previous_scan = (current_time, ds_projected, current_node_ids) current_node_ids = self.graph.get_nodes_at_time(current_time) tracked_cells_df = self._build_tracked_cells_current(current_time, current_node_ids) - cell_events_df = self._build_cell_events_dataframe(events) + cell_events_df = build_cell_events_dataframe(events) return tracked_cells_df, cell_events_df def get_cell_identity(self, track_index: int) -> tuple[str, str]: @@ -457,6 +198,65 @@ def get_cell_identity(self, track_index: int) -> tuple[str, str]: raise ValueError(f"Missing cell identity for track_index={track_index}") return self._cell_identity[track_index] + # ------------------------------------------------------------------ + # Scan-gap classification (physical time) + # ------------------------------------------------------------------ + + def _gap_forces_reset(self, dt_s: float, prev_time, curr_time) -> bool: + """Classify the inter-scan interval; log a structured code. + + Returns True when tracks must be terminated and restarted (non-monotonic + time or a gap above the hard limit) — never raises, never matches across + the gap. An irregular-but-monotonic cadence is a diagnostic warning only. + """ + if dt_s <= 0: + logger.error( + "tracking_error code=%s prev=%s curr=%s dt_s=%.1f", + TrackingError.NON_MONOTONIC_TIME.value, + prev_time, + curr_time, + dt_s, + ) + return True + if dt_s > self.max_tracking_gap_s: + logger.error( + "tracking_error code=%s dt_minutes=%.1f limit_minutes=%.1f", + TrackingError.TRACK_GAP_EXCEEDED.value, + dt_s / 60.0, + self.max_tracking_gap_minutes, + ) + return True + if self._prev_dt_s is not None and ( + dt_s > _CADENCE_IRREGULAR_RATIO * self._prev_dt_s + or dt_s * _CADENCE_IRREGULAR_RATIO < self._prev_dt_s + ): + logger.warning( + "tracking_error code=%s dt_s=%.1f prev_dt_s=%.1f", + TrackingError.IRREGULAR_SCAN_CADENCE.value, + dt_s, + self._prev_dt_s, + ) + return False + + def _reset_tracks( + self, prev_node_ids: list[int], curr_time, cells_current: list[dict] + ) -> list[dict]: + """Terminate every active and dormant track, then start fresh tracks.""" + events: list[dict] = [] + for node_id in prev_node_ids: + events.append( + event_termination(self.graph, self._cell_identity, curr_time, node_id, None) + ) + for node_id in list(self._dormant_nodes): + events.append( + event_termination(self.graph, self._cell_identity, curr_time, node_id, None) + ) + self._dormant_nodes.clear() + self._initialize_tracks(curr_time, cells_current) + for node_id in self.graph.get_nodes_at_time(curr_time): + events.append(event_initiation(self.graph, self._cell_identity, curr_time, node_id)) + return events + # ------------------------------------------------------------------ # Cell extraction # ------------------------------------------------------------------ @@ -471,12 +271,6 @@ def _to_epoch_seconds(time_val) -> float: ts = ts.tz_localize("UTC") return float(ts.timestamp()) - @staticmethod - def _time_key(time_val) -> str: - """Stable ISO8601 time key for event grouping.""" - tv = normalize_time_scalar(time_val) - return pd.Timestamp(tv).isoformat() - def _extract_cells_from_analyzer( self, ds: xr.Dataset, cell_stats_df: pd.DataFrame ) -> list[dict]: @@ -593,6 +387,174 @@ def _add_cell_node( # Frame-pair matching # ------------------------------------------------------------------ + def _reject_unphysical( + self, + all_prev_ids: list[int], + curr_cells: list[dict], + raw: np.ndarray, + dummy_cost: float, + dt_s: float, + ) -> None: + """Set candidate pairs that violate hard kinematic limits to ``dummy_cost``. + + Mutates ``raw`` in place so rejected pairs cannot be assigned. Only pairs + that actually overlap (cost below ``dummy_cost``) are checked. + """ + for i, prev_node in enumerate(all_prev_ids): + prev_cx = self.graph.get_node_attr(prev_node, "centroid_x") + prev_cy = self.graph.get_node_attr(prev_node, "centroid_y") + track_index = int(self.graph.get_node_attr(prev_node, "track_index") or 0) + prev_state = self._track_motion.get(track_index) + previous_speed = prev_state.speed if prev_state and prev_state.has_velocity else None + for j, curr_cell in enumerate(curr_cells): + if raw[i, j] >= dummy_cost: + continue # no overlap — not a candidate + decision = self.motion.check( + prev_cx, + prev_cy, + curr_cell["centroid_x"], + curr_cell["centroid_y"], + dt_s, + previous_speed, + ) + if not decision.ok: + raw[i, j] = dummy_cost + logger.debug( + "tracking_error code=%s track=%d speed=%.1fm/s", + decision.code.value, + track_index, + decision.speed_ms, + ) + + def _apply_motion_penalty( + self, + all_prev_ids: list[int], + curr_cells: list[dict], + raw: np.ndarray, + dummy_cost: float, + ) -> None: + """Bias Hungarian away from heading-inconsistent matches (crossing tracks). + + Adds ``heading_change_penalty_weight × heading_change`` (radians) to each + candidate pair whose track has an established velocity. Soft penalty — never + a rejection — and a no-op when the weight is 0. + """ + if self.heading_penalty_weight <= 0.0: + return + for i, prev_node in enumerate(all_prev_ids): + track_index = int(self.graph.get_node_attr(prev_node, "track_index") or 0) + state = self._track_motion.get(track_index) + if state is None or not state.has_velocity: + continue + prev_cx = float(self.graph.get_node_attr(prev_node, "centroid_x")) + prev_cy = float(self.graph.get_node_attr(prev_node, "centroid_y")) + for j, curr_cell in enumerate(curr_cells): + if raw[i, j] >= dummy_cost: + continue # not a candidate + cand_heading = math.atan2( + curr_cell["centroid_y"] - prev_cy, curr_cell["centroid_x"] - prev_cx + ) + raw[i, j] += self.heading_penalty_weight * heading_change_radians( + state.heading, cand_heading + ) + + def _update_motion( + self, prev_node: int, curr_cell: dict, track_index: int, dt_s: float + ) -> None: + """Record a continuing track's kinematics for prediction and accel checks.""" + prev_cx = float(self.graph.get_node_attr(prev_node, "centroid_x")) + prev_cy = float(self.graph.get_node_attr(prev_node, "centroid_y")) + curr_cx = float(curr_cell["centroid_x"]) + curr_cy = float(curr_cell["centroid_y"]) + vx = (curr_cx - prev_cx) / dt_s + vy = (curr_cy - prev_cy) / dt_s + self._track_motion[track_index] = TrackMotionState( + speed=math.hypot(vx, vy), + heading=math.atan2(vy, vx), + has_velocity=True, + ) + + def _record_continue( + self, + i: int, + c: int, + all_prev_ids: list[int], + curr_cells: list[dict], + curr_time, + cost: float, + dt_s: float, + proj_labels: np.ndarray, + method: MatchMethod, + matched_prev: dict[int, int], + matched_curr: dict[int, int], + ) -> dict: + """Create the CONTINUE node/edge for prev row ``i`` ↔ curr col ``c``. + + Shared by the overlap-first and Hungarian paths. Records diagnostics, + updates the matched maps, the track's motion state, and the dormant set; + returns the event row. + """ + prev_node = all_prev_ids[i] + track_index = int(self.graph.get_node_attr(prev_node, "track_index") or 0) + # Diagnostics use the track's prior motion — compute before _update_motion. + diagnostics = self._match_diagnostics( + prev_node, curr_cells[c], proj_labels, cost, dt_s, method + ) + curr_node = self._add_cell_node(curr_time, curr_cells[c], track_index) + self.graph.add_edge(prev_node, curr_node, edge_type="CONTINUE", cost=cost) + matched_prev[i] = curr_node + matched_curr[c] = curr_node + self._update_motion(prev_node, curr_cells[c], track_index, dt_s) + if prev_node in self._dormant_nodes: # dormant node re-acquired + del self._dormant_nodes[prev_node] + return event_continue( + self.graph, self._cell_identity, curr_time, prev_node, curr_node, cost, diagnostics + ) + + def _match_diagnostics( + self, + prev_node: int, + curr_cell: dict, + proj_labels: np.ndarray, + cost: float, + dt_s: float, + method: MatchMethod, + ) -> MatchDiagnostics: + """Assemble the per-match explainability record for an accepted CONTINUE.""" + prev_cx = float(self.graph.get_node_attr(prev_node, "centroid_x")) + prev_cy = float(self.graph.get_node_attr(prev_node, "centroid_y")) + prev_area = float(self.graph.get_node_attr(prev_node, "area")) + prev_refl = float(self.graph.get_node_attr(prev_node, "mean_reflectivity")) + track_index = int(self.graph.get_node_attr(prev_node, "track_index") or 0) + + proj_mask = proj_labels == self.graph.get_node_attr(prev_node, "cell_id") + curr_mask = curr_cell["mask"] + union = float(np.sum(proj_mask | curr_mask)) + iou = float(np.sum(proj_mask & curr_mask)) / union if union > 0 else 0.0 + + curr_cx = float(curr_cell["centroid_x"]) + curr_cy = float(curr_cell["centroid_y"]) + dist = math.hypot(curr_cx - prev_cx, curr_cy - prev_cy) + + heading_change_deg = None + prev_state = self._track_motion.get(track_index) + if prev_state is not None and prev_state.has_velocity: + cand_heading = math.atan2(curr_cy - prev_cy, curr_cx - prev_cx) + heading_change_deg = heading_change_degrees(prev_state.heading, cand_heading) + + curr_area = float(curr_cell["area"]) + return MatchDiagnostics( + overlap=overlap_fraction(proj_mask, curr_mask), + iou=iou, + centroid_distance_m=dist, + speed_ms=dist / dt_s, + heading_change_deg=heading_change_deg, + area_ratio=(curr_area / prev_area if prev_area > 0 else None), + reflectivity_difference=abs(float(curr_cell["mean_reflectivity"]) - prev_refl), + final_cost=cost, + match_method=method, + ) + def _track_frame_pair( self, prev_time, @@ -619,7 +581,9 @@ def _track_frame_pair( # Expire dormant nodes that have already exceeded the gap limit for node_id, last_seen in list(self._dormant_nodes.items()): if curr_time_epoch - last_seen > self.max_gap_s: - events.append(self._event_termination(curr_time, node_id, target_node_id=None)) + events.append( + event_termination(self.graph, self._cell_identity, curr_time, node_id, None) + ) del self._dormant_nodes[node_id] # Move active prev nodes into dormant; preserve last_seen as prev scan time for node_id in prev_node_ids: @@ -628,15 +592,19 @@ def _track_frame_pair( # Current cells become new tracks self._initialize_tracks(curr_time, curr_cells) for node_id in self.graph.get_nodes_at_time(curr_time): - events.append(self._event_initiation(curr_time, node_id)) + events.append(event_initiation(self.graph, self._cell_identity, curr_time, node_id)) return events - proj_labels = ds_curr["cell_projections"].values[0] # registration frame: prev → curr + # Registration hull at the minute nearest the real gap (falls back to the + # whole-step cell_projections[0] when minute frames are absent). + proj_labels = select_registration_labels(ds_curr, dt_s) # ── Expire dormant nodes beyond the gap limit ───────────────────── for node_id, last_seen in list(self._dormant_nodes.items()): if curr_time_epoch - last_seen > self.max_gap_s: - events.append(self._event_termination(curr_time, node_id, target_node_id=None)) + events.append( + event_termination(self.graph, self._cell_identity, curr_time, node_id, None) + ) del self._dormant_nodes[node_id] # ── Include surviving dormant nodes in the matching pool ────────── @@ -648,12 +616,14 @@ def _track_frame_pair( if n_all == 0: self._initialize_tracks(curr_time, curr_cells) for node_id in self.graph.get_nodes_at_time(curr_time): - events.append(self._event_initiation(curr_time, node_id)) + events.append(event_initiation(self.graph, self._cell_identity, curr_time, node_id)) return events if n_curr == 0: # All cells dissipated — terminate every node including dormant for d_node in all_prev_ids: - events.append(self._event_termination(curr_time, d_node, target_node_id=None)) + events.append( + event_termination(self.graph, self._cell_identity, curr_time, d_node, None) + ) self._dormant_nodes.clear() return events @@ -669,6 +639,44 @@ def _track_frame_pair( dt_s, ) + # ── Step 1b: hard physical-motion rejection (before matching) ───── + self._reject_unphysical(all_prev_ids, curr_cells, raw, dummy_cost, dt_s) + + # ── Step 1b′: soft heading-consistency penalty (crossing prevention) ─ + self._apply_motion_penalty(all_prev_ids, curr_cells, raw, dummy_cost) + + matched_prev: dict[int, int] = {} # row_idx → new curr node_id + matched_curr: dict[int, int] = {} # curr_idx → new curr node_id + n_continue = 0 + + # ── Step 1c: deterministic overlap-first matching ───────────────── + # Mutually-unique parent↔child overlaps are matched directly and pulled + # out of the Hungarian pool — uniqueness dominates the overlap threshold. + prev_hulls = [ + proj_labels == self.graph.get_node_attr(node_id, "cell_id") for node_id in all_prev_ids + ] + curr_masks = [cell["mask"] for cell in curr_cells] + allowed = raw < dummy_cost + for i, c in self.overlap_matcher.unique_matches(prev_hulls, curr_masks, allowed): + events.append( + self._record_continue( + i, + c, + all_prev_ids, + curr_cells, + curr_time, + float(raw[i, c]), + dt_s, + proj_labels, + MatchMethod.OVERLAP, + matched_prev, + matched_curr, + ) + ) + n_continue += 1 + raw[i, :] = dummy_cost # remove this parent and child from the pool + raw[:, c] = dummy_cost + # ── Step 2: pre-clamp ───────────────────────────────────────────── raw[raw < self.match_cost] = 0.0 raw[raw > self.unmatch_cost] = dummy_cost @@ -678,33 +686,32 @@ def _track_frame_pair( square = np.full((n, n), dummy_cost, dtype=float) square[:n_all, :n_curr] = raw - # ── Step 4: Hungarian ───────────────────────────────────────────── + # ── Step 4: Hungarian (residual ambiguity only) ────────────────── row_ind, col_ind = linear_sum_assignment(square) # ── Step 5: post-filter → CONTINUE / dissipated / born ─────────── - matched_prev: dict[int, int] = {} # row_idx → new curr node_id - matched_curr: dict[int, int] = {} # curr_idx → new curr node_id - n_continue = 0 - for r, c in zip(row_ind, col_ind, strict=False): if r >= n_all or c >= n_curr: continue # dummy slot + if r in matched_prev or c in matched_curr: + continue # already resolved by overlap-first if square[r, c] <= self.keep_cost: - prev_node = all_prev_ids[r] - track_index = self.graph.get_node_attr(prev_node, "track_index") - curr_node = self._add_cell_node(curr_time, curr_cells[c], int(track_index or 0)) - self.graph.add_edge( - prev_node, curr_node, edge_type="CONTINUE", cost=float(square[r, c]) - ) - matched_prev[r] = curr_node - matched_curr[c] = curr_node - n_continue += 1 events.append( - self._event_continue(curr_time, prev_node, curr_node, float(square[r, c])) + self._record_continue( + r, + c, + all_prev_ids, + curr_cells, + curr_time, + float(square[r, c]), + dt_s, + proj_labels, + MatchMethod.HUNGARIAN, + matched_prev, + matched_curr, + ) ) - # Dormant node re-acquired: remove from dormant set - if prev_node in self._dormant_nodes: - del self._dormant_nodes[prev_node] + n_continue += 1 # Dissipated = unmatched nodes; split by active vs dormant all_dissipated = [all_prev_ids[i] for i in range(n_all) if i not in matched_prev] @@ -741,7 +748,9 @@ def _track_frame_pair( ) self.graph.add_edge(best_parent, child_node, edge_type="SPLIT", cost=0.0) split_born.add(b_idx) - events.append(self._event_split(curr_time, best_parent, child_node)) + events.append( + event_split(self.graph, self._cell_identity, curr_time, best_parent, child_node) + ) logger.debug( "SPLIT: track %d → new track %d (overlap=%.2f)", parent_track_index, @@ -770,7 +779,9 @@ def _track_frame_pair( self.graph.add_edge(d_node, best_target, edge_type="MERGE", cost=0.0) n_merge += 1 merged_nodes[d_node] = best_target - events.append(self._event_merge(curr_time, d_node, best_target)) + events.append( + event_merge(self.graph, self._cell_identity, curr_time, d_node, best_target) + ) logger.debug( "MERGE: track %d → track %d (overlap=%.2f)", self.graph.get_node_attr(d_node, "track_index"), @@ -789,15 +800,19 @@ def _track_frame_pair( curr_time, curr_cells[b_idx], new_index, cell_uid, track_signature ) n_births += 1 - events.append(self._event_initiation(curr_time, node_id)) + events.append(event_initiation(self.graph, self._cell_identity, curr_time, node_id)) for d_node in dissipated_active: if d_node in merged_nodes: events.append( - self._event_termination(curr_time, d_node, target_node_id=merged_nodes[d_node]) + event_termination( + self.graph, self._cell_identity, curr_time, d_node, merged_nodes[d_node] + ) ) else: - events.append(self._event_termination(curr_time, d_node, target_node_id=None)) + events.append( + event_termination(self.graph, self._cell_identity, curr_time, d_node, None) + ) n_split = len(split_born) n_dissipated = len(dissipated_active) - n_merge @@ -846,126 +861,3 @@ def _build_tracked_cells_current(self, time, node_ids: list[int]) -> pd.DataFram df["time"] = pd.to_datetime(df["time"]) df = df.sort_values(["cell_uid", "cell_label"]).reset_index(drop=True) return df - - @staticmethod - def _build_cell_events_dataframe(events: list[dict]) -> pd.DataFrame: - cols = [ - "time", - "event_type", - "source_cell_uid", - "target_cell_uid", - "source_cell_label", - "target_cell_label", - "cost", - "is_dominant", - "event_group_id", - ] - if not events: - return pd.DataFrame(columns=cols) - df = pd.DataFrame(events) - for col in cols: - if col not in df.columns: - df[col] = None - df = df[cols] - df["time"] = df["time"].apply(lambda t: pd.Timestamp(normalize_time_scalar(t))) - return df - - # ------------------------------------------------------------------ - # Event builders - # ------------------------------------------------------------------ - - def _event_continue(self, time, prev_node_id: int, curr_node_id: int, cost: float) -> dict: - source_cell_uid = self.get_cell_identity( - int(self.graph.get_node_attr(prev_node_id, "track_index")) - )[0] - target_cell_uid = self.get_cell_identity( - int(self.graph.get_node_attr(curr_node_id, "track_index")) - )[0] - return { - "time": time, - "event_type": "CONTINUE", - "source_cell_uid": source_cell_uid, - "target_cell_uid": target_cell_uid, - "source_cell_label": int(self.graph.get_node_attr(prev_node_id, "cell_id")), - "target_cell_label": int(self.graph.get_node_attr(curr_node_id, "cell_id")), - "cost": float(cost), - "is_dominant": True, - "event_group_id": f"{self._time_key(time)}:CONTINUE:{target_cell_uid}", - } - - def _event_split(self, time, parent_node_id: int, child_node_id: int) -> dict: - parent_uid = self.get_cell_identity( - int(self.graph.get_node_attr(parent_node_id, "track_index")) - )[0] - child_uid = self.get_cell_identity( - int(self.graph.get_node_attr(child_node_id, "track_index")) - )[0] - return { - "time": time, - "event_type": "SPLIT", - "source_cell_uid": parent_uid, - "target_cell_uid": child_uid, - "source_cell_label": int(self.graph.get_node_attr(parent_node_id, "cell_id")), - "target_cell_label": int(self.graph.get_node_attr(child_node_id, "cell_id")), - "cost": None, - "is_dominant": False, - "event_group_id": f"{self._time_key(time)}:SPLIT:{parent_uid}", - } - - def _event_merge(self, time, source_node_id: int, target_node_id: int) -> dict: - source_path = int(self.graph.get_node_attr(source_node_id, "track_index")) - target_path = int(self.graph.get_node_attr(target_node_id, "track_index")) - target_uid = self.get_cell_identity(target_path)[0] - return { - "time": time, - "event_type": "MERGE", - "source_cell_uid": self.get_cell_identity(source_path)[0], - "target_cell_uid": target_uid, - "source_cell_label": int(self.graph.get_node_attr(source_node_id, "cell_id")), - "target_cell_label": int(self.graph.get_node_attr(target_node_id, "cell_id")), - "cost": None, - "is_dominant": False, - "event_group_id": f"{self._time_key(time)}:MERGE:{target_uid}", - } - - def _event_initiation(self, time, node_id: int) -> dict: - target_uid = self.get_cell_identity(int(self.graph.get_node_attr(node_id, "track_index")))[ - 0 - ] - return { - "time": time, - "event_type": "INITIATION", - "source_cell_uid": None, - "target_cell_uid": target_uid, - "source_cell_label": None, - "target_cell_label": int(self.graph.get_node_attr(node_id, "cell_id")), - "cost": None, - "is_dominant": False, - "event_group_id": f"{self._time_key(time)}:INITIATION:{target_uid}", - } - - def _event_termination(self, time, source_node_id: int, target_node_id: int | None) -> dict: - source_path = int(self.graph.get_node_attr(source_node_id, "track_index")) - target_path = ( - int(self.graph.get_node_attr(target_node_id, "track_index")) - if target_node_id is not None - else None - ) - source_uid = self.get_cell_identity(source_path)[0] - return { - "time": time, - "event_type": "TERMINATION", - "source_cell_uid": source_uid, - "target_cell_uid": ( - self.get_cell_identity(target_path)[0] if target_path is not None else None - ), - "source_cell_label": int(self.graph.get_node_attr(source_node_id, "cell_id")), - "target_cell_label": ( - int(self.graph.get_node_attr(target_node_id, "cell_id")) - if target_node_id is not None - else None - ), - "cost": None, - "is_dominant": False, - "event_group_id": f"{self._time_key(time)}:TERMINATION:{source_uid}", - } diff --git a/src/adapt/modules/tracking/motion.py b/src/adapt/modules/tracking/motion.py new file mode 100644 index 0000000..40a9f97 --- /dev/null +++ b/src/adapt/modules/tracking/motion.py @@ -0,0 +1,80 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Physical motion checks for candidate matching. + +``MotionValidator`` applies hard kinematic limits that reject impossible +candidate pairs *before* matching; the heading helpers quantify direction change +for the soft crossing-prevention penalty. Pure geometry — no graph, no I/O. +""" + +import math +from dataclasses import dataclass + +from adapt.modules.tracking.models import TrackingError + +__all__ = [ + "MotionDecision", + "MotionValidator", + "heading_change_degrees", + "heading_change_radians", +] + + +def heading_change_radians(prev_heading: float, curr_heading: float) -> float: + """Smallest absolute heading change (radians, in [0, π]) between two directions.""" + delta = curr_heading - prev_heading + return abs(math.atan2(math.sin(delta), math.cos(delta))) + + +def heading_change_degrees(prev_heading: float, curr_heading: float) -> float: + """Smallest absolute heading change (degrees) between two directions (radians).""" + return math.degrees(heading_change_radians(prev_heading, curr_heading)) + + +@dataclass(frozen=True) +class MotionDecision: + """Outcome of a physical-motion check for one candidate pair.""" + + ok: bool + speed_ms: float + code: TrackingError | None = None + + +class MotionValidator: + """Reject candidate pairs that violate hard kinematic limits. + + A pair is rejected when its implied speed exceeds ``max_speed_ms`` (absolute + cap) or ``max_speed_multiplier × previous_speed`` (acceleration cap). Rejected + pairs never reach overlap or Hungarian matching. + """ + + def __init__(self, max_speed_ms: float, max_speed_multiplier: float): + self.max_speed_ms = max_speed_ms + self.max_speed_multiplier = max_speed_multiplier + + @staticmethod + def speed_ms(prev_x: float, prev_y: float, curr_x: float, curr_y: float, dt_s: float) -> float: + if dt_s <= 0: + raise ValueError("dt_s must be positive for a speed estimate") + return math.hypot(curr_x - prev_x, curr_y - prev_y) / dt_s + + def check( + self, + prev_x: float, + prev_y: float, + curr_x: float, + curr_y: float, + dt_s: float, + previous_speed: float | None = None, + ) -> MotionDecision: + speed = self.speed_ms(prev_x, prev_y, curr_x, curr_y, dt_s) + if speed > self.max_speed_ms: + return MotionDecision(False, speed, TrackingError.VELOCITY_EXCEEDED) + if ( + previous_speed is not None + and previous_speed > 0.0 + and speed > self.max_speed_multiplier * previous_speed + ): + return MotionDecision(False, speed, TrackingError.ACCELERATION_EXCEEDED) + return MotionDecision(True, speed, None) diff --git a/src/adapt/modules/tracking/projection.py b/src/adapt/modules/tracking/projection.py new file mode 100644 index 0000000..6a09677 --- /dev/null +++ b/src/adapt/modules/tracking/projection.py @@ -0,0 +1,30 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Registration-based projected hulls. + +The projection module already advects the previous scan's cell labels forward at +minute resolution (``registration_minutes`` + ``interpolation_fraction``). The +tracker consumes the minute frame closest to the actual scan gap so matching uses +a hull registered to the real elapsed time, not a fixed single-step projection. +Falls back to ``cell_projections[0]`` (the whole-step registration) when the +minute frames are absent. Pure xarray access — no state. +""" + +import numpy as np +import xarray as xr + +__all__ = ["select_registration_labels"] + + +def select_registration_labels(ds: xr.Dataset, dt_s: float) -> np.ndarray: + """Previous-scan cell labels registered to the current frame. + + Picks the ``registration_minutes`` frame whose minute is closest to ``dt_s``; + falls back to ``cell_projections[0]`` if no minute frames are present. + """ + if "registration_minutes" in ds.data_vars and ds["registration_minutes"].sizes.get("minute", 0): + minutes = np.asarray(ds["minute"].values, dtype=float) + nearest = int(np.argmin(np.abs(minutes - dt_s / 60.0))) + return np.asarray(ds["registration_minutes"].values[nearest]) + return np.asarray(ds["cell_projections"].values[0]) diff --git a/src/adapt/persistence/track_store.py b/src/adapt/persistence/track_store.py index 651107f..f89dfc8 100644 --- a/src/adapt/persistence/track_store.py +++ b/src/adapt/persistence/track_store.py @@ -526,8 +526,21 @@ def _insert_cell_events( "is_dominant", "event_group_id", ] - placeholders = ", ".join("?" * len(cols)) - sql = f"INSERT INTO cell_events ({', '.join(cols)}) VALUES ({placeholders})" + # Per-match diagnostics carried through verbatim when the tracker emits them. + diagnostic_cols = [ + "candidate_overlap", + "candidate_iou", + "candidate_centroid_distance_m", + "candidate_speed_ms", + "candidate_heading_change_deg", + "candidate_area_ratio", + "candidate_reflectivity_difference", + "candidate_final_cost", + "match_method", + ] + all_cols = cols + diagnostic_cols + placeholders = ", ".join("?" * len(all_cols)) + sql = f"INSERT INTO cell_events ({', '.join(all_cols)}) VALUES ({placeholders})" def _src_time(etype: str) -> str | None: return None if etype == "INITIATION" else source_iso @@ -535,11 +548,16 @@ def _src_time(etype: str) -> str | None: def _tgt_time(etype: str) -> str | None: return None if etype == "TERMINATION" else target_iso + def _num(ev: pd.Series, col: str) -> float | None: + val = ev.get(col) + return float(val) if pd.notna(val) else None + rows = [] for _, ev in cell_events_df.iterrows(): etype = str(ev["event_type"]) source_uid = _source_uid(ev) target_uid = _target_uid(ev) + method = ev.get("match_method") rows.append( ( run_id, @@ -561,6 +579,15 @@ def _tgt_time(etype: str) -> str | None: float(ev["cost"]) if pd.notna(ev.get("cost")) else None, int(bool(ev.get("is_dominant", False))), str(ev["event_group_id"]), + _num(ev, "candidate_overlap"), + _num(ev, "candidate_iou"), + _num(ev, "candidate_centroid_distance_m"), + _num(ev, "candidate_speed_ms"), + _num(ev, "candidate_heading_change_deg"), + _num(ev, "candidate_area_ratio"), + _num(ev, "candidate_reflectivity_difference"), + _num(ev, "candidate_final_cost"), + str(method) if pd.notna(method) else None, ) ) conn.executemany(sql, rows) diff --git a/tests/persistence/test_track_store.py b/tests/persistence/test_track_store.py index b4b135a..64ea046 100644 --- a/tests/persistence/test_track_store.py +++ b/tests/persistence/test_track_store.py @@ -65,7 +65,16 @@ target_cell_label INTEGER, cost REAL, is_dominant INTEGER NOT NULL DEFAULT 0, - event_group_id TEXT NOT NULL + event_group_id TEXT NOT NULL, + candidate_overlap REAL, + candidate_iou REAL, + candidate_centroid_distance_m REAL, + candidate_speed_ms REAL, + candidate_heading_change_deg REAL, + candidate_area_ratio REAL, + candidate_reflectivity_difference REAL, + candidate_final_cost REAL, + match_method TEXT ); CREATE TABLE IF NOT EXISTS cell_tracks ( From 0c555ea0842393541466b44c019a4131eb3a369c Mon Sep 17 00:00:00 2001 From: Bhupendra Raut Date: Fri, 26 Jun 2026 00:56:47 -0500 Subject: [PATCH 3/8] CLEAN: some Pylint checks --- src/adapt/api/client.py | 2 +- src/adapt/cli.py | 2 -- src/adapt/configuration/schemas/initialization.py | 4 +--- src/adapt/consumers/live/dashboard.py | 1 + src/adapt/modules/projection/module.py | 1 - 5 files changed, 3 insertions(+), 7 deletions(-) diff --git a/src/adapt/api/client.py b/src/adapt/api/client.py index d3344d5..bb0a31c 100644 --- a/src/adapt/api/client.py +++ b/src/adapt/api/client.py @@ -494,7 +494,7 @@ def _bundle_from_scan_record( run_id=str(scan_record.get("run_id", "")), n_cells=int(scan_record.get("num_cells") or 0), max_reflectivity=float(scan_record.get("max_reflectivity") or 0.0), - has_tracks=bool(scan_record.get("has_tracks") or False), + has_tracks=bool(scan_record.get("has_tracks")), ) seg = self._load_item_file(radar, scan_record.get("segmentation2d_item_id")) cells = self._load_item_file(radar, scan_record.get("analysis2d_item_id")) diff --git a/src/adapt/cli.py b/src/adapt/cli.py index 7382419..e718783 100644 --- a/src/adapt/cli.py +++ b/src/adapt/cli.py @@ -340,8 +340,6 @@ def _build_dashboard_parser(sub: argparse.ArgumentParser) -> None: def _dashboard_cmd(args: argparse.Namespace) -> None: """Launch the Adapt GUI dashboard.""" - import os - try: os.getcwd() except FileNotFoundError: diff --git a/src/adapt/configuration/schemas/initialization.py b/src/adapt/configuration/schemas/initialization.py index 89bd9d7..7a19ae4 100644 --- a/src/adapt/configuration/schemas/initialization.py +++ b/src/adapt/configuration/schemas/initialization.py @@ -63,8 +63,6 @@ def write_default_config(path: Path, extensions: list[str] | None = None) -> Non own params under ``module_params``. Public — called by both ``init_runtime_config`` (auto-bootstrap) and ``adapt config``. """ - from datetime import datetime as _dt - from adapt.configuration.schemas import yaml_writer from adapt.configuration.schemas.assemble import ( assemble_default_config, @@ -76,7 +74,7 @@ def write_default_config(path: Path, extensions: list[str] | None = None) -> Non # config.yaml` works without --base-dir; the user can edit or override it. data = {"base_dir": str(path.parent.resolve()), **data} descriptions = assemble_descriptions(extensions) - header = _CONFIG_HEADER.format(timestamp=_dt.now().strftime("%Y-%m-%d %H:%M:%S")) + header = _CONFIG_HEADER.format(timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S")) path.parent.mkdir(parents=True, exist_ok=True) path.write_text(yaml_writer.dump(data, descriptions, header=header)) diff --git a/src/adapt/consumers/live/dashboard.py b/src/adapt/consumers/live/dashboard.py index 000bd60..639203a 100644 --- a/src/adapt/consumers/live/dashboard.py +++ b/src/adapt/consumers/live/dashboard.py @@ -863,6 +863,7 @@ def _create_config_from_wizard(self, path: str, wizard_win, info_var) -> None: capture_output=True, text=True, timeout=15, + check=False, ) if result.returncode != 0: messagebox.showerror( diff --git a/src/adapt/modules/projection/module.py b/src/adapt/modules/projection/module.py index acdedf2..754e2ee 100644 --- a/src/adapt/modules/projection/module.py +++ b/src/adapt/modules/projection/module.py @@ -606,7 +606,6 @@ def _fill_concave_hull(self, label_mask, alpha=0.1): # Create output mask filled = np.zeros_like(label_mask, dtype=np.uint8) - H, W = label_mask.shape # Filter triangles by circumradius (alpha shape) for simplex in tri.simplices: From ee4565b43b759375134d691285c1a0f8a19ede3e Mon Sep 17 00:00:00 2001 From: Bhupendra Raut Date: Sun, 28 Jun 2026 00:55:09 -0500 Subject: [PATCH 4/8] FIX:retries,utf-8 for all, dedup SQLite stores --- pyproject.toml | 2 +- .../configuration/schemas/initialization.py | 8 +- src/adapt/configuration/schemas/internal.py | 2 + src/adapt/configuration/schemas/param.py | 6 + src/adapt/configuration/schemas/user.py | 4 + src/adapt/execution/nodes/ingest.py | 1 + src/adapt/execution/pipeline_builder.py | 2 +- src/adapt/modules/acquisition/module.py | 42 +- src/adapt/modules/ingest/config.py | 1 + src/adapt/modules/ingest/module.py | 73 ++-- src/adapt/persistence/catalog.py | 124 +----- src/adapt/persistence/registry.py | 126 +----- src/adapt/persistence/repository.py | 2 +- src/adapt/runtime/postprocessor.py | 2 +- .../acquisition/test_downloader_failures.py | 37 +- .../modules/ingest/test_loader_netcdf_save.py | 48 +++ .../modules/tracking/test_tracker_quality.py | 370 ++++++++++++++++++ tests/persistence/test_sqlite_store.py | 36 ++ tests/unit/test_base_module_persistence.py | 51 +++ 19 files changed, 650 insertions(+), 287 deletions(-) create mode 100644 tests/modules/ingest/test_loader_netcdf_save.py create mode 100644 tests/modules/tracking/test_tracker_quality.py create mode 100644 tests/persistence/test_sqlite_store.py create mode 100644 tests/unit/test_base_module_persistence.py diff --git a/pyproject.toml b/pyproject.toml index dc1698a..6584b5c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,7 +80,7 @@ namespaces = false [tool.setuptools.package-data] "adapt.data" = ["user_config.py"] -"adapt.schemas" = ["*.sql"] +"adapt.configuration.schemas" = ["*.sql"] "adapt.config" = ["*.yaml"] [tool.setuptools_scm] diff --git a/src/adapt/configuration/schemas/initialization.py b/src/adapt/configuration/schemas/initialization.py index 7a19ae4..e239043 100644 --- a/src/adapt/configuration/schemas/initialization.py +++ b/src/adapt/configuration/schemas/initialization.py @@ -97,7 +97,7 @@ def _load_user_config_dict(config_path: str) -> dict: raise ImportError( "PyYAML is required for YAML config files: pip install pyyaml" ) from err - with open(path) as f: + with open(path, encoding="utf-8") as f: data = yaml.safe_load(f) return data or {} @@ -185,7 +185,7 @@ def _persist_runtime_config( config_dict["run_id"] = run_id config_dict["created_at"] = datetime.now(UTC).isoformat() - with open(config_file, "w") as f: + with open(config_file, "w", encoding="utf-8") as f: json.dump(config_dict, f, indent=2, default=str) @@ -213,7 +213,7 @@ def _find_matching_run_id(new_config_dict: dict) -> str | None: for cfg_file in candidates: try: - with open(cfg_file) as f: + with open(cfg_file, encoding="utf-8") as f: saved = json.load(f) if _config_fingerprint(saved) == target: return saved.get("run_id") @@ -236,7 +236,7 @@ def _load_saved_runtime_config(base_dir: str, run_id: str) -> InternalConfig: if not cfg_path.exists(): raise FileNotFoundError(f"Saved runtime config not found for run_id '{run_id}': {cfg_path}") - with open(cfg_path) as f: + with open(cfg_path, encoding="utf-8") as f: cfg_dict = json.load(f) # Non-schema metadata persisted for audit only. diff --git a/src/adapt/configuration/schemas/internal.py b/src/adapt/configuration/schemas/internal.py index c902830..2357bf4 100644 --- a/src/adapt/configuration/schemas/internal.py +++ b/src/adapt/configuration/schemas/internal.py @@ -36,6 +36,7 @@ class InternalDownloaderConfig(AdaptBaseModel): latest_files: int latest_minutes: int poll_interval_sec: int + max_fetch_retries: int start_time: str | None end_time: str | None min_file_size: int @@ -50,6 +51,7 @@ class InternalRegridderConfig(AdaptBaseModel): min_radius: float weighting_function: Literal["cressman", "barnes", "nearest"] save_netcdf: bool + netcdf_save_retries: int class InternalSegmenterConfig(AdaptBaseModel): diff --git a/src/adapt/configuration/schemas/param.py b/src/adapt/configuration/schemas/param.py index 0dc4c96..5d95208 100644 --- a/src/adapt/configuration/schemas/param.py +++ b/src/adapt/configuration/schemas/param.py @@ -35,6 +35,9 @@ class DownloaderConfig(AdaptBaseModel): latest_files: int = Field(5, ge=1, description="Number of latest files to keep") latest_minutes: int = Field(60, ge=1, description="Time window in minutes") poll_interval_sec: int = Field(300, ge=1, description="Polling interval in seconds") + max_fetch_retries: int = Field( + 3, ge=1, description="AWS scan-fetch attempts before giving up for this poll" + ) start_time: str | None = None end_time: str | None = None min_file_size: int = Field( @@ -55,6 +58,9 @@ class RegridderConfig(AdaptBaseModel): min_radius: float = Field(1750.0, gt=0) weighting_function: Literal["cressman", "barnes", "nearest"] = "cressman" save_netcdf: bool = True + netcdf_save_retries: int = Field( + 3, ge=1, description="NetCDF write attempts before raising (when save_netcdf is set)" + ) class SegmenterConfig(AdaptBaseModel): diff --git a/src/adapt/configuration/schemas/user.py b/src/adapt/configuration/schemas/user.py index b1dc919..25aed14 100644 --- a/src/adapt/configuration/schemas/user.py +++ b/src/adapt/configuration/schemas/user.py @@ -121,6 +121,7 @@ class UserDownloaderConfig(_UserSection): latest_files: int | None = None latest_minutes: int | None = None poll_interval_sec: int | None = None + max_fetch_retries: int | None = None start_time: str | None = None end_time: str | None = None @@ -164,6 +165,7 @@ class UserConfig(AdaptBaseModel): latest_files: int | None = Field(None, alias="LATEST_FILES") latest_minutes: int | None = Field(None, alias="LATEST_MINUTES") poll_interval_sec: int | None = Field(None, alias="POLL_INTERVAL_SEC") + max_fetch_retries: int | None = Field(None, alias="MAX_FETCH_RETRIES") # Historical settings start_time: str | None = Field(None, alias="START_TIME") @@ -292,6 +294,8 @@ def to_internal_overrides(self) -> dict: downloader["latest_minutes"] = self.latest_minutes if self.poll_interval_sec is not None: downloader["poll_interval_sec"] = self.poll_interval_sec + if self.max_fetch_retries is not None: + downloader["max_fetch_retries"] = self.max_fetch_retries # Map base_dir to downloader.output_dir for convenience if self.base_dir is not None: diff --git a/src/adapt/execution/nodes/ingest.py b/src/adapt/execution/nodes/ingest.py index 4f6e759..ade28ee 100644 --- a/src/adapt/execution/nodes/ingest.py +++ b/src/adapt/execution/nodes/ingest.py @@ -62,6 +62,7 @@ def build_config(cls, cfg) -> IngestConfig: min_radius=cfg.regridder.min_radius, weighting_function=cfg.regridder.weighting_function, save_netcdf=cfg.regridder.save_netcdf, + netcdf_save_retries=cfg.regridder.netcdf_save_retries, radar=cfg.downloader.radar, z_level=cfg.global_.z_level, z_coord=cfg.global_.coord_names.z, diff --git a/src/adapt/execution/pipeline_builder.py b/src/adapt/execution/pipeline_builder.py index cd5fd66..6687863 100644 --- a/src/adapt/execution/pipeline_builder.py +++ b/src/adapt/execution/pipeline_builder.py @@ -39,7 +39,7 @@ def _ensure_modules_registered(extensions: list[str] | None = None) -> None: If any declared core module or extension fails to import. """ try: - with open(_DEFAULTS_YAML) as f: + with open(_DEFAULTS_YAML, encoding="utf-8") as f: cfg = yaml.safe_load(f) except OSError as e: raise RuntimeError(f"Cannot read pipeline module list '{_DEFAULTS_YAML}': {e}") from e diff --git a/src/adapt/modules/acquisition/module.py b/src/adapt/modules/acquisition/module.py index ac6eee3..a94a200 100644 --- a/src/adapt/modules/acquisition/module.py +++ b/src/adapt/modules/acquisition/module.py @@ -125,6 +125,7 @@ def __init__( # Legacy support: if output_dir provided but not output_dirs, use old behavior self.output_dir = Path(output_dir) if output_dir else None self.poll_interval_sec = config.downloader.poll_interval_sec + self.max_fetch_retries = config.downloader.max_fetch_retries self.latest_files = config.downloader.latest_files self.latest_minutes = config.downloader.latest_minutes self.start_time = config.downloader.start_time @@ -382,19 +383,38 @@ def _parse_time_range(self) -> tuple: return start, end def _fetch_scans(self, start: datetime, end: datetime) -> list: - """Fetch available scans from AWS.""" - try: - scans = self.conn.get_avail_scans_in_range(start, end, self.radar) - scans = sorted(scans, key=lambda s: s.scan_time) + """Fetch available scans from AWS, retrying transient failures. - # Filter out MDM files - scans = [s for s in scans if not s.key.endswith("_MDM")] + Retries up to ``max_fetch_retries`` times with linear backoff. After the + final attempt fails, returns ``[]`` so a transient AWS error does not + stop the downloader — the next poll retries from scratch. + """ + for attempt in range(1, self.max_fetch_retries + 1): + try: + scans = self.conn.get_avail_scans_in_range(start, end, self.radar) + scans = sorted(scans, key=lambda s: s.scan_time) - logger.debug("Found %d scans for %s", len(scans), self.radar) - return scans - except Exception as e: - logger.error("Failed to fetch scans: %s", e) - return [] + # Filter out MDM files + scans = [s for s in scans if not s.key.endswith("_MDM")] + + logger.debug("Found %d scans for %s", len(scans), self.radar) + return scans + except Exception as e: + if attempt < self.max_fetch_retries: + logger.warning( + "Fetch scans attempt %d/%d failed: %s; retrying", + attempt, + self.max_fetch_retries, + e, + ) + self._sleep(attempt) + else: + logger.error( + "Failed to fetch scans after %d attempts: %s", + self.max_fetch_retries, + e, + ) + return [] # ======================================================================== # Radar availability helpers diff --git a/src/adapt/modules/ingest/config.py b/src/adapt/modules/ingest/config.py index 8329297..84cf6e7 100644 --- a/src/adapt/modules/ingest/config.py +++ b/src/adapt/modules/ingest/config.py @@ -20,6 +20,7 @@ class IngestConfig(BaseModel): min_radius: float weighting_function: str save_netcdf: bool + netcdf_save_retries: int radar: str z_level: float z_coord: str diff --git a/src/adapt/modules/ingest/module.py b/src/adapt/modules/ingest/module.py index 8f55fae..fd43517 100644 --- a/src/adapt/modules/ingest/module.py +++ b/src/adapt/modules/ingest/module.py @@ -107,6 +107,7 @@ def __init__(self, config): self.min_radius = config.min_radius self.weighting_function = config.weighting_function self.save_netcdf = config.save_netcdf + self.netcdf_save_retries = config.netcdf_save_retries def read(self, filepath: Path | str) -> object: """Read a NEXRAD archive file into a Py-ART Radar object. @@ -252,30 +253,51 @@ def regrid( ds.attrs["radar_longitude"] = float(radar.longitude["data"][0]) ds.attrs["radar_altitude"] = float(radar.altitude["data"][0]) - # NetCDF save is best-effort: failure is logged but does not abort. + # When requested, the NetCDF file is a downstream input (registered as a + # gridded3d artifact), so a persistent write failure must raise. self._write_netcdf(ds, output_dir, source_filepath) return ds def _write_netcdf(self, ds, output_dir, source_filepath): - """Internal writer for netcdf output.""" - try: - if output_dir is None: - output_dir = "." + """Write the regridded grid to NetCDF, retrying then raising on failure. - output_dir_path = Path(output_dir) - output_dir_path.mkdir(parents=True, exist_ok=True) - - nc_filename = Path(source_filepath).stem + ".nc" - - nc_path = output_dir_path / nc_filename - - encoding = {var: {"zlib": True, "complevel": 9} for var in ds.data_vars} - ds.to_netcdf(nc_path, encoding=encoding, compute=True) - - logger.info("Saved regridded NetCDF: %s", nc_path) - - except Exception as e: - logger.warning("Failed to save NetCDF: %s", e) + The saved file is a real downstream input (the ingest node registers it + as a ``gridded3d`` artifact and ``cell_volume_stats`` reads it), so a + persistent write failure raises rather than being swallowed — a missing + file must never be reported as produced. Retries ``netcdf_save_retries`` + times before re-raising. + """ + if output_dir is None: + output_dir = "." + + output_dir_path = Path(output_dir) + output_dir_path.mkdir(parents=True, exist_ok=True) + + nc_path = output_dir_path / (Path(source_filepath).stem + ".nc") + encoding = {var: {"zlib": True, "complevel": 9} for var in ds.data_vars} + + for attempt in range(1, self.netcdf_save_retries + 1): + try: + ds.to_netcdf(nc_path, encoding=encoding, compute=True) + logger.info("Saved regridded NetCDF: %s", nc_path) + return + except Exception as e: + if attempt < self.netcdf_save_retries: + logger.warning( + "NetCDF save attempt %d/%d failed for %s: %s; retrying", + attempt, + self.netcdf_save_retries, + nc_path, + e, + ) + else: + logger.error( + "NetCDF save failed after %d attempts for %s: %s", + self.netcdf_save_retries, + nc_path, + e, + ) + raise def load_and_regrid( self, @@ -312,14 +334,19 @@ def load_and_regrid( Cartesian grid xarray.Dataset if successful, None if: - File does not exist or cannot be read - Regridding fails - - NetCDF save fails (if save_netcdf=True) + + Raises + ------ + Exception + If save_netcdf=True and the NetCDF write fails on every attempt + (after netcdf_save_retries). An enabled save must succeed because the + file is a registered downstream artifact. Notes ----- - Preferred method over separate read() + regrid() calls - - Two failure points: read stage and regrid stage - - NetCDF save is optional; set save_netcdf=False for memory-only - processing (avoids disk I/O) + - Set save_netcdf=False for memory-only processing (avoids disk I/O); + when set, the write is mandatory and raises on persistent failure - Returns same xarray.Dataset regardless of save_netcdf setting Examples diff --git a/src/adapt/persistence/catalog.py b/src/adapt/persistence/catalog.py index 8aa80f5..f9fa81c 100644 --- a/src/adapt/persistence/catalog.py +++ b/src/adapt/persistence/catalog.py @@ -17,20 +17,20 @@ import json import logging -import sqlite3 -import threading from datetime import UTC, datetime from pathlib import Path from typing import Any import pandas as pd +from adapt.persistence.sqlite_store import SqliteStore + __all__ = ["RadarCatalog"] logger = logging.getLogger(__name__) -class RadarCatalog: +class RadarCatalog(SqliteStore): """Radar-level catalog manager. Manages catalog.db at {radar_dir}/catalog.db. @@ -61,117 +61,9 @@ def __init__(self, radar_dir: str | Path): """ self.radar_dir = Path(radar_dir).resolve() self.radar = self.radar_dir.name - self.db_path = self.radar_dir / "catalog.db" - - # Thread safety - self._lock = threading.RLock() - self._conn: sqlite3.Connection | None = None - - # Initialize database - self._init_database() - + super().__init__(self.radar_dir / "catalog.db", "radar_catalog_schema.sql", checkpoint=True) logger.info(f"RadarCatalog initialized for {self.radar} at {self.db_path}") - def _get_connection(self) -> sqlite3.Connection: - """Get thread-safe database connection.""" - if self._conn is None: - self._conn = sqlite3.connect( - str(self.db_path), check_same_thread=False, isolation_level="DEFERRED" - ) - self._conn.row_factory = sqlite3.Row - # Enable WAL mode for concurrent access - self._conn.execute("PRAGMA journal_mode=WAL") - self._conn.execute("PRAGMA foreign_keys=ON") - return self._conn - - def _init_database(self) -> None: - """Initialize database schema from SQL file.""" - schema_path = ( - Path(__file__).resolve().parents[1] - / "configuration" - / "schemas" - / "radar_catalog_schema.sql" - ) - - if not schema_path.exists(): - # Fallback to embedded schema - self._create_schema_inline() - return - - with open(schema_path) as f: - schema_sql = f.read() - - conn = self._get_connection() - with self._lock: - conn.executescript(schema_sql) - conn.commit() - # Checkpoint WAL so readonly readers (immutable=1) see the schema. - conn.execute("PRAGMA wal_checkpoint(PASSIVE)") - - logger.debug(f"Radar catalog schema initialized from {schema_path}") - - def _create_schema_inline(self) -> None: - """Create schema inline (fallback).""" - conn = self._get_connection() - with self._lock: - conn.execute("PRAGMA journal_mode=WAL") - conn.execute("PRAGMA foreign_keys=ON") - - # Items table - conn.execute(""" - CREATE TABLE IF NOT EXISTS items ( - item_id TEXT PRIMARY KEY, - run_id TEXT NOT NULL, - item_type TEXT NOT NULL, - scan_time TEXT NOT NULL, - file_path TEXT NOT NULL, - parent_ids TEXT, - processing_stage TEXT NOT NULL, - status TEXT NOT NULL, - error_message TEXT, - metadata TEXT, - file_size_bytes INTEGER, - file_hash TEXT, - created_at TEXT NOT NULL, - updated_at TEXT NOT NULL - ) - """) - conn.execute("CREATE INDEX IF NOT EXISTS idx_items_run ON items(run_id)") - conn.execute("CREATE INDEX IF NOT EXISTS idx_items_type ON items(item_type)") - conn.execute("CREATE INDEX IF NOT EXISTS idx_items_scan_time ON items(scan_time DESC)") - conn.execute( - "CREATE INDEX IF NOT EXISTS idx_items_type_time ON items(item_type, scan_time DESC)" - ) - - # Progress table - conn.execute(""" - CREATE TABLE IF NOT EXISTS progress ( - run_id TEXT PRIMARY KEY, - latest_downloaded_time TEXT, - latest_gridded_time TEXT, - latest_segmented_time TEXT, - latest_analyzed_time TEXT, - num_items_complete INTEGER DEFAULT 0, - num_items_failed INTEGER DEFAULT 0, - queue_depth INTEGER DEFAULT 0, - last_updated TEXT NOT NULL - ) - """) - - # Schemas table - conn.execute(""" - CREATE TABLE IF NOT EXISTS schemas ( - item_type TEXT PRIMARY KEY, - columns_json TEXT NOT NULL, - schema_version INTEGER DEFAULT 1, - updated_at TEXT NOT NULL - ) - """) - - conn.commit() - # Checkpoint WAL so readonly readers (immutable=1) see the schema. - conn.execute("PRAGMA wal_checkpoint(PASSIVE)") - # ========================================================================= # Item Management # ========================================================================= @@ -791,11 +683,3 @@ def get_latest_scan(self, run_id: str | None = None) -> dict | None: """).fetchone() return dict(row) if row else None - - def close(self) -> None: - """Close database connection.""" - if self._conn: - with self._lock: - self._conn.close() - self._conn = None - logger.debug(f"Radar catalog connection closed for {self.radar}") diff --git a/src/adapt/persistence/registry.py b/src/adapt/persistence/registry.py index e2c0da2..7dab94a 100644 --- a/src/adapt/persistence/registry.py +++ b/src/adapt/persistence/registry.py @@ -16,13 +16,14 @@ """ import logging -import sqlite3 import threading from datetime import UTC, datetime from pathlib import Path import pandas as pd +from adapt.persistence.sqlite_store import SqliteStore + __all__ = ["RepositoryRegistry"] logger = logging.getLogger(__name__) @@ -32,7 +33,7 @@ _cache_lock = threading.Lock() -class RepositoryRegistry: +class RepositoryRegistry(SqliteStore): """Root-level registry for Adapt repository. Manages adapt_registry.db at {root_dir}/adapt_registry.db. @@ -57,15 +58,7 @@ def __init__(self, root_dir: str | Path): Root directory for the Adapt repository """ self.root_dir = Path(root_dir).resolve() - self.db_path = self.root_dir / "adapt_registry.db" - - # Thread safety - self._lock = threading.RLock() - self._conn: sqlite3.Connection | None = None - - # Initialize database - self._init_database() - + super().__init__(self.root_dir / "adapt_registry.db", "registry_schema.sql") logger.debug("RepositoryRegistry initialized at %s", self.db_path) @classmethod @@ -89,109 +82,6 @@ def get_instance(cls, root_dir: str | Path) -> "RepositoryRegistry": _registry_cache[root_path] = cls(root_dir) return _registry_cache[root_path] - def _get_connection(self) -> sqlite3.Connection: - """Get thread-safe database connection.""" - if self._conn is None: - self._conn = sqlite3.connect( - str(self.db_path), check_same_thread=False, isolation_level="DEFERRED" - ) - self._conn.row_factory = sqlite3.Row - # Enable WAL mode for concurrent access - self._conn.execute("PRAGMA journal_mode=WAL") - self._conn.execute("PRAGMA foreign_keys=ON") - return self._conn - - def _init_database(self) -> None: - """Initialize database schema from SQL file.""" - schema_path = Path(__file__).parent / "schemas" / "registry_schema.sql" - - if not schema_path.exists(): - # Fallback to embedded schema if file not found - self._create_schema_inline() - return - - with open(schema_path) as f: - schema_sql = f.read() - - conn = self._get_connection() - with self._lock: - conn.executescript(schema_sql) - conn.commit() - - logger.debug(f"Registry schema initialized from {schema_path}") - - def _create_schema_inline(self) -> None: - """Create schema inline (fallback).""" - conn = self._get_connection() - with self._lock: - conn.execute("PRAGMA journal_mode=WAL") - conn.execute("PRAGMA foreign_keys=ON") - - # Runs table - conn.execute(""" - CREATE TABLE IF NOT EXISTS runs ( - run_id TEXT PRIMARY KEY, - radar TEXT NOT NULL, - start_time TEXT NOT NULL, - end_time TEXT, - status TEXT NOT NULL, - mode TEXT, - config_path TEXT, - repository_version TEXT NOT NULL, - created_at TEXT NOT NULL - ) - """) - conn.execute("CREATE INDEX IF NOT EXISTS idx_runs_start_time ON runs(start_time DESC)") - conn.execute("CREATE INDEX IF NOT EXISTS idx_runs_radar ON runs(radar)") - conn.execute("CREATE INDEX IF NOT EXISTS idx_runs_status ON runs(status)") - - # Radars table - conn.execute(""" - CREATE TABLE IF NOT EXISTS radars ( - radar TEXT PRIMARY KEY, - catalog_path TEXT NOT NULL, - data_path TEXT NOT NULL, - location_lat REAL, - location_lon REAL, - created_at TEXT NOT NULL, - last_updated TEXT NOT NULL - ) - """) - conn.execute( - "CREATE INDEX IF NOT EXISTS idx_radars_updated ON radars(last_updated DESC)" - ) - - # Item types table - conn.execute(""" - CREATE TABLE IF NOT EXISTS item_types ( - item_type TEXT PRIMARY KEY, - description TEXT NOT NULL, - storage_format TEXT NOT NULL, - dimensionality TEXT NOT NULL, - created_at TEXT NOT NULL - ) - """) - - # Prepopulate item types - now = datetime.now(UTC).isoformat() - item_types_data = [ - ("gridded3d", "Gridded reflectivity volume", "netcdf", "3d", now), - ("segmentation2d", "Cell segmentation masks", "netcdf", "2d", now), - ("projection2d", "Cell motion projections", "netcdf", "2d", now), - ("analysis2d", "Cell-level analysis metrics", "parquet", "table", now), - ] - - conn.executemany( - """ - INSERT OR IGNORE INTO item_types - (item_type, description, storage_format, dimensionality, created_at) - VALUES (?, ?, ?, ?, ?) - """, - item_types_data, - ) - - conn.commit() - # ========================================================================= # Radar Management # ========================================================================= @@ -470,11 +360,3 @@ def get_item_type_info(self, item_type: str) -> dict | None: ).fetchone() return dict(row) if row else None - - def close(self) -> None: - """Close database connection.""" - if self._conn: - with self._lock: - self._conn.close() - self._conn = None - logger.debug("Registry connection closed") diff --git a/src/adapt/persistence/repository.py b/src/adapt/persistence/repository.py index 34e8ee4..93b8250 100644 --- a/src/adapt/persistence/repository.py +++ b/src/adapt/persistence/repository.py @@ -177,7 +177,7 @@ def _register_in_new_catalog(self) -> None: config_file = self.catalog.radar_dir / f"config_run_{self.run_id}.json" if not config_file.exists(): config_json = self.config.model_dump_json() - with open(config_file, "w") as f: + with open(config_file, "w", encoding="utf-8") as f: f.write(config_json) logger.debug(f"Saved runtime config: {config_file}") config_path = str(config_file.relative_to(self.base_dir)) diff --git a/src/adapt/runtime/postprocessor.py b/src/adapt/runtime/postprocessor.py index 5638fce..1b04e7d 100644 --- a/src/adapt/runtime/postprocessor.py +++ b/src/adapt/runtime/postprocessor.py @@ -53,7 +53,7 @@ def _ensure_postprocess_modules_registered(extensions: list[str] | None = None) """ module_paths: list[str] = [] if _POSTPROCESS_DEFAULTS_YAML.exists(): - with open(_POSTPROCESS_DEFAULTS_YAML) as f: + with open(_POSTPROCESS_DEFAULTS_YAML, encoding="utf-8") as f: cfg = yaml.safe_load(f) or {} module_paths = cfg.get("postprocess", {}).get("modules", []) or [] diff --git a/tests/modules/acquisition/test_downloader_failures.py b/tests/modules/acquisition/test_downloader_failures.py index 7749c89..2705971 100644 --- a/tests/modules/acquisition/test_downloader_failures.py +++ b/tests/modules/acquisition/test_downloader_failures.py @@ -27,13 +27,44 @@ def iter_success(self): assert downloads == [] -def test_fetch_scans_exception_returns_empty(tmp_path, make_config): +def test_fetch_scans_retries_then_returns_empty(tmp_path, make_config): + """Persistent fetch failure is retried max_fetch_retries times, then [] (no crash).""" + calls = {"fetch": 0} + class ExplodingConn: def get_avail_scans_in_range(self, *a): + calls["fetch"] += 1 raise RuntimeError("AWS down") - config = make_config() - d = AwsNexradDownloader(config, output_dir=tmp_path, conn=ExplodingConn()) + sleeps: list = [] + config = make_config(max_fetch_retries=4) + d = AwsNexradDownloader( + config, output_dir=tmp_path, conn=ExplodingConn(), sleeper=sleeps.append + ) scans = d._fetch_scans(datetime.now(UTC), datetime.now(UTC)) + assert scans == [] + assert calls["fetch"] == 4 # one attempt per retry + assert sleeps == [1, 2, 3] # backoff between attempts, none after the last + + +def test_fetch_scans_recovers_on_retry(tmp_path, fake_scan, make_config): + """A transient failure followed by success returns the scans.""" + calls = {"fetch": 0} + good = [fake_scan("KLOT20250305_120000", datetime.now(UTC))] + + class FlakyConn: + def get_avail_scans_in_range(self, *a): + calls["fetch"] += 1 + if calls["fetch"] < 3: + raise RuntimeError("transient AWS error") + return good + + config = make_config(max_fetch_retries=3) + d = AwsNexradDownloader(config, output_dir=tmp_path, conn=FlakyConn(), sleeper=lambda _: None) + + scans = d._fetch_scans(datetime.now(UTC), datetime.now(UTC)) + + assert scans == good + assert calls["fetch"] == 3 diff --git a/tests/modules/ingest/test_loader_netcdf_save.py b/tests/modules/ingest/test_loader_netcdf_save.py new file mode 100644 index 0000000..bc9297e --- /dev/null +++ b/tests/modules/ingest/test_loader_netcdf_save.py @@ -0,0 +1,48 @@ +"""NetCDF save retry/raise behaviour for RadarDataLoader._write_netcdf. + +The saved NetCDF is a registered downstream artifact, so a persistent write +failure must raise (after retries) rather than be silently swallowed. +""" + +import pytest + +from adapt.modules.ingest.module import RadarDataLoader + +pytestmark = pytest.mark.unit + + +class _FakeDataset: + """Minimal stand-in exposing the attributes _write_netcdf touches.""" + + def __init__(self, fail_times: int): + self.data_vars = {"reflectivity": None} + self._fail_times = fail_times + self.calls = 0 + + def to_netcdf(self, *args, **kwargs): + self.calls += 1 + if self.calls <= self._fail_times: + raise OSError("disk full") + + +def test_write_netcdf_raises_after_retries(tmp_path, make_ingest_config): + """Persistent write failure raises after exactly netcdf_save_retries attempts.""" + config = make_ingest_config(regridder={"netcdf_save_retries": 3}) + loader = RadarDataLoader(config) + ds = _FakeDataset(fail_times=99) + + with pytest.raises(OSError, match="disk full"): + loader._write_netcdf(ds, str(tmp_path), "KLOT_20250305_120000.gz") + + assert ds.calls == 3 + + +def test_write_netcdf_recovers_on_retry(tmp_path, make_ingest_config): + """A transient failure followed by success does not raise.""" + config = make_ingest_config(regridder={"netcdf_save_retries": 3}) + loader = RadarDataLoader(config) + ds = _FakeDataset(fail_times=1) + + loader._write_netcdf(ds, str(tmp_path), "KLOT_20250305_120000.gz") + + assert ds.calls == 2 diff --git a/tests/modules/tracking/test_tracker_quality.py b/tests/modules/tracking/test_tracker_quality.py new file mode 100644 index 0000000..dc67f15 --- /dev/null +++ b/tests/modules/tracking/test_tracker_quality.py @@ -0,0 +1,370 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Phase-B tracking-quality behaviours: hard gap limits, physical motion +constraints, deterministic overlap-first matching, and persisted diagnostics. + +Synthetic inputs with analytically known outcomes; no stored fixtures. +""" + +import shutil +import tempfile +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest +import xarray as xr + +from adapt.configuration.schemas.param import ParamConfig +from adapt.configuration.schemas.resolve import resolve_config +from adapt.configuration.schemas.user import UserConfig +from adapt.execution.nodes.tracking import TrackingModule +from adapt.modules.tracking.module import RadarCellTracker +from adapt.modules.tracking.projection import select_registration_labels + +pytestmark = pytest.mark.unit + + +def _make_config(**overrides): + d = tempfile.mkdtemp() + try: + param = ParamConfig() + param.tracker.split_overlap_threshold = 0.4 + for key, val in overrides.items(): + setattr(param.tracker, key, val) + user = UserConfig(base_dir=str(Path(d)), radar="TEST_RADAR") + internal = resolve_config(param, user, None) + return TrackingModule.build_config(internal) + finally: + shutil.rmtree(d, ignore_errors=True) + + +def _synthetic_ds(time, labels, refl=None, proj_labels=None): + H, W = labels.shape + if refl is None: + refl = np.zeros((H, W), dtype=np.float32) + refl[labels > 0] = 40.0 + if proj_labels is None: + proj_labels = labels + projections = np.stack([proj_labels.astype(np.int32)], axis=0) + ds = xr.Dataset( + { + "cell_labels": (["y", "x"], labels.astype(np.int32)), + "reflectivity": (["y", "x"], refl.astype(np.float32)), + "cell_projections": (["frame_offset", "y", "x"], projections), + "heading_x": (["y", "x"], np.zeros((H, W), dtype=np.float32)), + "heading_y": (["y", "x"], np.zeros((H, W), dtype=np.float32)), + }, + coords={"y": np.arange(H) * 1000.0, "x": np.arange(W) * 1000.0, "frame_offset": [0]}, + ) + return ds.assign_coords(time=time) + + +def _cell_stats(time, rows): + return pd.DataFrame( + [ + { + "time": time, + "time_volume_start": time, + "cell_label": r["id"], + "cell_area_sqkm": r["area"], + "area_40dbz_km2": r.get("area40", r["area"]), + "cell_centroid_geom_x": r["cx"], + "cell_centroid_geom_y": r["cy"], + "cell_centroid_mass_lat": r.get("lat", 35.0), + "cell_centroid_mass_lon": r.get("lon", -97.0), + "radar_reflectivity_mean": r["mean_refl"], + "radar_reflectivity_max": r["max_refl"], + "radar_differential_reflectivity_max": r.get("max_zdr", 1.0), + } + for r in rows + ] + ) + + +def _one_cell_scan(time, x_pix, proj_labels=None): + """A single 2x2 cell placed at column x_pix on an 8x8 grid (1000 m pixels).""" + labels = np.zeros((8, 8), dtype=np.int32) + labels[2:4, x_pix : x_pix + 2] = 1 + cx = (x_pix + 0.5) * 1000.0 + stats = _cell_stats( + time, + [{"id": 1, "area": 4.0, "cx": cx, "cy": 2500.0, "mean_refl": 40.0, "max_refl": 45.0}], + ) + ds = _synthetic_ds(time, labels, proj_labels=proj_labels if proj_labels is not None else labels) + return ds, stats + + +# --------------------------------------------------------------------------- +# B1 — hard scan-gap limits +# --------------------------------------------------------------------------- + + +def test_gap_exceeded_terminates_and_restarts(): + """dt above max_tracking_gap_minutes terminates active tracks and starts fresh.""" + cfg = _make_config(max_tracking_gap_minutes=10.0) + tracker = RadarCellTracker(cfg) + + t0 = np.datetime64("2024-01-01T12:00:00") + t1 = np.datetime64("2024-01-01T12:20:00") # 20 min > 10-min hard limit + + ds0, stats0 = _one_cell_scan(t0, 2) + _, events0 = tracker.track(ds0, stats0) + uid0 = str(events0[events0["event_type"] == "INITIATION"].iloc[0]["target_cell_uid"]) + + ds1, stats1 = _one_cell_scan(t1, 2) + tracked1, events1 = tracker.track(ds1, stats1) + + assert (events1["event_type"] == "CONTINUE").sum() == 0, "No match may cross the hard gap" + assert uid0 in set(events1[events1["event_type"] == "TERMINATION"]["source_cell_uid"]), ( + "Active track must be terminated when the gap is exceeded" + ) + assert (events1["event_type"] == "INITIATION").sum() == 1, "A fresh track must start" + new_uid = str(events1[events1["event_type"] == "INITIATION"].iloc[0]["target_cell_uid"]) + assert new_uid != uid0 + + +def test_non_monotonic_time_resets_without_crash(): + """A backwards scan time must not raise; it resets tracks instead.""" + cfg = _make_config() + tracker = RadarCellTracker(cfg) + + t0 = np.datetime64("2024-01-01T12:05:00") + t_back = np.datetime64("2024-01-01T12:00:00") # earlier than t0 + + ds0, stats0 = _one_cell_scan(t0, 2) + _, events0 = tracker.track(ds0, stats0) + uid0 = str(events0[events0["event_type"] == "INITIATION"].iloc[0]["target_cell_uid"]) + + ds1, stats1 = _one_cell_scan(t_back, 2) + _, events1 = tracker.track(ds1, stats1) # must not raise + + assert (events1["event_type"] == "CONTINUE").sum() == 0 + assert uid0 in set(events1[events1["event_type"] == "TERMINATION"]["source_cell_uid"]) + assert (events1["event_type"] == "INITIATION").sum() == 1 + + +def test_normal_gap_still_continues(): + """A within-limit gap with an exact projection still produces CONTINUE.""" + cfg = _make_config(max_tracking_gap_minutes=20.0, match_cost_threshold=0.0) + tracker = RadarCellTracker(cfg) + + t0 = np.datetime64("2024-01-01T12:00:00") + t1 = np.datetime64("2024-01-01T12:05:00") # 5 min < 20-min limit + + ds0, stats0 = _one_cell_scan(t0, 2) + tracker.track(ds0, stats0) + ds1, stats1 = _one_cell_scan(t1, 2, proj_labels=ds0["cell_labels"].values) + _, events1 = tracker.track(ds1, stats1) + + assert (events1["event_type"] == "CONTINUE").sum() == 1 + + +# --------------------------------------------------------------------------- +# B3 — physical motion constraints (hard reject before matching) +# --------------------------------------------------------------------------- + + +def test_velocity_exceeded_rejects_match(): + """An over-speed candidate is rejected even with a perfect projection overlap.""" + # x=2 → x=6 is 4000 m in 300 s = 13.3 m/s; cap at 5 m/s rejects it. + cfg = _make_config(max_speed_ms=5.0, match_cost_threshold=0.0, max_tracking_gap_minutes=60.0) + tracker = RadarCellTracker(cfg) + + t0 = np.datetime64("2024-01-01T12:00:00") + t1 = np.datetime64("2024-01-01T12:05:00") + + ds0, stats0 = _one_cell_scan(t0, 2) + _, events0 = tracker.track(ds0, stats0) + uid0 = str(events0[events0["event_type"] == "INITIATION"].iloc[0]["target_cell_uid"]) + + # Projection predicts the jumped position exactly → overlap exists, but speed cap bites. + ds1, stats1 = _one_cell_scan(t1, 6) + _, events1 = tracker.track(ds1, stats1) + + assert (events1["event_type"] == "CONTINUE").sum() == 0 + assert uid0 in set(events1[events1["event_type"] == "TERMINATION"]["source_cell_uid"]) + assert (events1["event_type"] == "INITIATION").sum() == 1 + + +def test_acceleration_exceeded_rejects_match(): + """A candidate far faster than the track's own prior speed is rejected.""" + # scan0→1: x2→x3 (3.33 m/s sets prior). scan1→2: x3→x6 (10 m/s) > 2×3.33. + cfg = _make_config( + max_speed_ms=40.0, + max_speed_multiplier=2.0, + match_cost_threshold=0.0, + max_tracking_gap_minutes=60.0, + ) + tracker = RadarCellTracker(cfg) + + t0 = np.datetime64("2024-01-01T12:00:00") + t1 = np.datetime64("2024-01-01T12:05:00") + t2 = np.datetime64("2024-01-01T12:10:00") + + ds0, stats0 = _one_cell_scan(t0, 2) + tracker.track(ds0, stats0) + ds1, stats1 = _one_cell_scan(t1, 3) + _, events1 = tracker.track(ds1, stats1) + assert (events1["event_type"] == "CONTINUE").sum() == 1, "slow step must continue" + + ds2, stats2 = _one_cell_scan(t2, 6) + _, events2 = tracker.track(ds2, stats2) + assert (events2["event_type"] == "CONTINUE").sum() == 0, "accelerating step must be rejected" + + +# --------------------------------------------------------------------------- +# B4 — deterministic overlap-first matching +# --------------------------------------------------------------------------- + + +def test_unique_overlap_matches_despite_high_cost(): + """A unique projected-hull overlap continues a track even when the cost-based + Hungarian post-filter (keep_cost) would have rejected it.""" + # keep_cost is tiny so a non-zero-cost pair fails the Hungarian filter, but the + # projection covers the child fully (overlap 1.0 ≥ 0.3) → overlap-first matches. + cfg = _make_config(keep_cost_threshold=0.01, match_cost_threshold=0.0) + tracker = RadarCellTracker(cfg) + + t0 = np.datetime64("2024-01-01T12:00:00") + t1 = np.datetime64("2024-01-01T12:05:00") + + ds0, stats0 = _one_cell_scan(t0, 2) + _, events0 = tracker.track(ds0, stats0) + uid0 = str(events0[events0["event_type"] == "INITIATION"].iloc[0]["target_cell_uid"]) + + # Cell shifts one pixel and brightens → small but non-zero 4-term cost. + labels1 = np.zeros((8, 8), dtype=np.int32) + labels1[2:4, 3:5] = 1 + stats1 = _cell_stats( + t1, + [{"id": 1, "area": 4.0, "cx": 3500.0, "cy": 2500.0, "mean_refl": 50.0, "max_refl": 55.0}], + ) + ds1 = _synthetic_ds(t1, labels1, proj_labels=labels1) # projection predicts new position + tracked1, events1 = tracker.track(ds1, stats1) + + assert (events1["event_type"] == "CONTINUE").sum() == 1, "overlap-first must continue the track" + assert str(tracked1.iloc[0]["cell_uid"]) == uid0, "continued cell keeps its uid" + + +# --------------------------------------------------------------------------- +# B6 — per-match diagnostics on accepted matches +# --------------------------------------------------------------------------- + + +def test_continue_event_carries_diagnostics(): + """An accepted CONTINUE records overlap, iou, distance, speed, cost, method.""" + cfg = _make_config(match_cost_threshold=0.0, max_tracking_gap_minutes=60.0) + tracker = RadarCellTracker(cfg) + + t0 = np.datetime64("2024-01-01T12:00:00") + t1 = np.datetime64("2024-01-01T12:05:00") + + ds0, stats0 = _one_cell_scan(t0, 2) + tracker.track(ds0, stats0) + # cell moves one pixel; projection predicts it → unique overlap match + ds1, stats1 = _one_cell_scan(t1, 3) + _, events1 = tracker.track(ds1, stats1) + + cont = events1[events1["event_type"] == "CONTINUE"] + assert len(cont) == 1 + row = cont.iloc[0] + assert row["match_method"] in {"OVERLAP", "HUNGARIAN"} + assert 0.0 <= float(row["candidate_overlap"]) <= 1.0 + assert 0.0 <= float(row["candidate_iou"]) <= 1.0 + # cell moved 1000 m in 300 s ≈ 3.33 m/s + assert float(row["candidate_centroid_distance_m"]) == pytest.approx(1000.0, abs=1e-6) + assert float(row["candidate_speed_ms"]) == pytest.approx(1000.0 / 300.0, abs=1e-6) + assert pd.notna(row["candidate_final_cost"]) + + +def test_initiation_event_has_null_diagnostics(): + """INITIATION rows carry no candidate diagnostics.""" + cfg = _make_config() + tracker = RadarCellTracker(cfg) + ds0, stats0 = _one_cell_scan(np.datetime64("2024-01-01T12:00:00"), 2) + _, events0 = tracker.track(ds0, stats0) + init = events0[events0["event_type"] == "INITIATION"].iloc[0] + assert pd.isna(init["candidate_overlap"]) + assert pd.isna(init["match_method"]) + + +# --------------------------------------------------------------------------- +# B5 — motion-state crossing prevention (heading-consistency penalty) +# --------------------------------------------------------------------------- + + +def test_heading_penalty_breaks_ambiguous_match_toward_consistent_track(): + """With an established +x velocity, a heading penalty steers an ambiguous + match to the heading-consistent candidate instead of the reversed one.""" + cfg = _make_config( + heading_change_penalty_weight=0.5, + match_cost_threshold=0.0, + max_tracking_gap_minutes=60.0, + ) + tracker = RadarCellTracker(cfg) + + t0 = np.datetime64("2024-01-01T12:00:00") + t1 = np.datetime64("2024-01-01T12:05:00") + t2 = np.datetime64("2024-01-01T12:10:00") + + # scan0→1 establish a +x velocity (heading 0) for track label 1. + ds0, stats0 = _one_cell_scan(t0, 2) + tracker.track(ds0, stats0) + ds1, stats1 = _one_cell_scan(t1, 4) # perfect projection → CONTINUE, vx>0 + tracker.track(ds1, stats1) + + # scan2: the registration hull (label 1) fills the whole row band, so it + # overlaps P and Q identically (equal IoU). P and Q are equidistant from the + # prev centroid (x=4) — 2000 m each. Q is made *cheaper* on base cost (its + # reflectivity matches the track, P's differs) so that WITHOUT the heading + # penalty Q wins. Only the +x/−x heading asymmetry can flip it back to P. + labels2 = np.zeros((8, 8), dtype=np.int32) + labels2[2:4, 6:8] = 1 # P ahead (centroid x = 6500, +2000, +x consistent) + labels2[2:4, 2:4] = 2 # Q behind (centroid x = 2500, −2000, −x reversed) + proj2 = np.zeros((8, 8), dtype=np.int32) + proj2[2:4, 0:8] = 1 # symmetric hull spanning the full row band + stats2 = _cell_stats( + t2, + [ + {"id": 1, "area": 4.0, "cx": 6500.0, "cy": 2500.0, "mean_refl": 42.0, "max_refl": 45.0}, + {"id": 2, "area": 4.0, "cx": 2500.0, "cy": 2500.0, "mean_refl": 40.0, "max_refl": 45.0}, + ], + ) + ds2 = _synthetic_ds(t2, labels2, proj_labels=proj2) + _, events2 = tracker.track(ds2, stats2) + + cont = events2[events2["event_type"] == "CONTINUE"] + assert len(cont) == 1, "the track should continue to exactly one cell" + assert int(cont.iloc[0]["target_cell_label"]) == 1, "heading penalty must steer to the +x cell" + + +# --------------------------------------------------------------------------- +# B2 — registration-based multi-step projection +# --------------------------------------------------------------------------- + + +def test_registration_selects_nearest_minute(): + """The minute frame closest to the real gap is chosen for the hull.""" + f1 = np.full((4, 4), 11, dtype=np.int32) + f2 = np.full((4, 4), 22, dtype=np.int32) + f3 = np.full((4, 4), 33, dtype=np.int32) + ds = xr.Dataset( + { + "registration_minutes": (["minute", "y", "x"], np.stack([f1, f2, f3])), + "cell_projections": (["frame_offset", "y", "x"], np.zeros((1, 4, 4), dtype=np.int32)), + }, + coords={"minute": [1, 2, 3]}, + ) + out = select_registration_labels(ds, dt_s=130.0) # 2.17 min → nearest minute 2 + assert int(out[0, 0]) == 22 + + +def test_registration_falls_back_to_cell_projections(): + """Without minute frames the whole-step cell_projections[0] is used.""" + ds = xr.Dataset( + {"cell_projections": (["frame_offset", "y", "x"], np.full((1, 4, 4), 7, dtype=np.int32))} + ) + out = select_registration_labels(ds, dt_s=300.0) + assert int(out[0, 0]) == 7 diff --git a/tests/persistence/test_sqlite_store.py b/tests/persistence/test_sqlite_store.py new file mode 100644 index 0000000..620f7e1 --- /dev/null +++ b/tests/persistence/test_sqlite_store.py @@ -0,0 +1,36 @@ +"""Tests for the shared SqliteStore base behaviour.""" + +import pytest + +from adapt.persistence.sqlite_store import SqliteStore + +pytestmark = pytest.mark.unit + + +def test_missing_schema_file_raises(tmp_path): + """A nonexistent schema file fails loudly — there is no inline fallback.""" + with pytest.raises(FileNotFoundError, match="nonexistent_schema.sql"): + SqliteStore(tmp_path / "x.db", "nonexistent_schema.sql") + + +def test_registry_schema_loads_from_file(tmp_path): + """The real registry schema loads, including the schema_registry table.""" + store = SqliteStore(tmp_path / "reg.db", "registry_schema.sql") + try: + tables = { + row[0] + for row in store._get_connection().execute( + "SELECT name FROM sqlite_master WHERE type='table'" + ) + } + finally: + store.close() + + assert {"runs", "radars", "item_types", "schema_registry"} <= tables + + +def test_close_is_idempotent(tmp_path): + """close() may be called repeatedly without error.""" + store = SqliteStore(tmp_path / "reg.db", "registry_schema.sql") + store.close() + store.close() diff --git a/tests/unit/test_base_module_persistence.py b/tests/unit/test_base_module_persistence.py new file mode 100644 index 0000000..4c341e5 --- /dev/null +++ b/tests/unit/test_base_module_persistence.py @@ -0,0 +1,51 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Tests for BaseModule.persistence declaration. + +Modules that persist outputs declare a tuple of persistence specs (from +adapt.contracts); modules that persist nothing leave the default empty tuple. +""" + +import pytest + +pytestmark = pytest.mark.unit + +from adapt.modules.base import BaseModule # noqa: E402 + + +class _NoPersistenceModule(BaseModule): + name = "no_persistence" + + def run(self, context: dict) -> dict: + return {} + + +class TestBaseModulePersistence: + def test_default_persistence_is_empty(self): + assert _NoPersistenceModule.persistence == () + + def test_pure_compute_modules_declare_no_persistence(self): + from adapt.execution.nodes.detection import DetectModule + from adapt.execution.nodes.projection import ProjectionModule + + assert DetectModule.persistence == () + assert ProjectionModule.persistence == () + + def test_persisting_modules_declare_specs(self): + from adapt.contracts import ( + NetcdfArtifact, + ParquetArtifact, + RegisterFileArtifact, + SqliteTable, + TrackTablesWrite, + ) + from adapt.execution.nodes.analysis import AnalysisModule + from adapt.execution.nodes.cell_volume_stats import CellVolumeStatsModule + from adapt.execution.nodes.ingest import LoadModule + from adapt.execution.nodes.tracking import TrackingModule + + assert {type(s) for s in LoadModule.persistence} == {RegisterFileArtifact} + assert {type(s) for s in AnalysisModule.persistence} == {ParquetArtifact} + assert {type(s) for s in TrackingModule.persistence} == {NetcdfArtifact, TrackTablesWrite} + assert {type(s) for s in CellVolumeStatsModule.persistence} == {SqliteTable} From 1032e60970a9e6f570240927eaaae398316e6950 Mon Sep 17 00:00:00 2001 From: Bhupendra Raut Date: Sun, 28 Jun 2026 14:33:24 -0500 Subject: [PATCH 5/8] REF:better logging;better status updates --- src/adapt/__init__.py | 9 + src/adapt/configuration/schemas/internal.py | 15 +- src/adapt/contracts/execution_history.py | 83 +++++ src/adapt/contracts/observability.py | 68 ++++ src/adapt/execution/graph/executor.py | 20 +- src/adapt/modules/acquisition/module.py | 9 +- src/adapt/modules/ingest/module.py | 3 + src/adapt/modules/projection/module.py | 12 +- src/adapt/persistence/execution_history.py | 284 ++++++++++++++++ src/adapt/persistence/repository.py | 8 +- src/adapt/persistence/sqlite_store.py | 70 ++++ src/adapt/runtime/diagnostics.py | 29 ++ src/adapt/runtime/history_handler.py | 75 +++++ src/adapt/runtime/logging_setup.py | 116 +++++++ src/adapt/runtime/observability.py | 304 ++++++++++++++++++ src/adapt/runtime/orchestrator.py | 143 ++++++++ src/adapt/runtime/processor.py | 88 ++++- src/adapt/runtime/provenance.py | 56 ++++ src/adapt/runtime/run_reporter.py | 98 ++++++ .../test_observability_config.py | 17 + tests/contracts/test_observability.py | 67 ++++ tests/graph/test_executor_observability.py | 121 +++++++ .../test_integration_execution_history.py | 112 +++++++ .../acquisition/test_downloader_quiet.py | 39 +++ tests/modules/ingest/test_ingest_quiet.py | 33 ++ .../test_projector_internal_utils.py | 23 ++ tests/persistence/test_data_repository.py | 34 ++ tests/persistence/test_execution_history.py | 150 +++++++++ tests/runtime/test_diagnostics.py | 62 ++++ tests/runtime/test_history_handler.py | 64 ++++ tests/runtime/test_logging_setup.py | 143 ++++++++ tests/runtime/test_observability.py | 260 +++++++++++++++ tests/runtime/test_orchestrator.py | 41 +++ tests/runtime/test_processor_observability.py | 180 +++++++++++ tests/runtime/test_provenance.py | 25 ++ tests/runtime/test_run_reporter.py | 147 +++++++++ tests/runtime/test_thirdparty_quiet.py | 43 +++ tests/test_architecture.py | 23 ++ 38 files changed, 3061 insertions(+), 13 deletions(-) create mode 100644 src/adapt/contracts/execution_history.py create mode 100644 src/adapt/contracts/observability.py create mode 100644 src/adapt/persistence/execution_history.py create mode 100644 src/adapt/persistence/sqlite_store.py create mode 100644 src/adapt/runtime/diagnostics.py create mode 100644 src/adapt/runtime/history_handler.py create mode 100644 src/adapt/runtime/logging_setup.py create mode 100644 src/adapt/runtime/observability.py create mode 100644 src/adapt/runtime/provenance.py create mode 100644 src/adapt/runtime/run_reporter.py create mode 100644 tests/configuration/test_observability_config.py create mode 100644 tests/contracts/test_observability.py create mode 100644 tests/graph/test_executor_observability.py create mode 100644 tests/integration/test_integration_execution_history.py create mode 100644 tests/modules/acquisition/test_downloader_quiet.py create mode 100644 tests/modules/ingest/test_ingest_quiet.py create mode 100644 tests/persistence/test_execution_history.py create mode 100644 tests/runtime/test_diagnostics.py create mode 100644 tests/runtime/test_history_handler.py create mode 100644 tests/runtime/test_logging_setup.py create mode 100644 tests/runtime/test_observability.py create mode 100644 tests/runtime/test_processor_observability.py create mode 100644 tests/runtime/test_provenance.py create mode 100644 tests/runtime/test_run_reporter.py create mode 100644 tests/runtime/test_thirdparty_quiet.py diff --git a/src/adapt/__init__.py b/src/adapt/__init__.py index 3854017..8e19fe7 100644 --- a/src/adapt/__init__.py +++ b/src/adapt/__init__.py @@ -11,6 +11,15 @@ Authors: Bhupendra Raut and Sid Gupta """ +import os as _os + +# Quiet third-party import-time chatter before any submodule (and its transitive +# deps) load. Py-ART prints a citation banner on import unless PYART_QUIET is set, +# and it is pulled in transitively by nexradaws via the acquisition source — earlier +# than any Adapt module could set this. The package root is the one place guaranteed +# to run first. setdefault preserves a user-provided override. +_os.environ.setdefault("PYART_QUIET", "1") + import importlib.metadata as _importlib_metadata # Get the version diff --git a/src/adapt/configuration/schemas/internal.py b/src/adapt/configuration/schemas/internal.py index 2357bf4..fe9c278 100644 --- a/src/adapt/configuration/schemas/internal.py +++ b/src/adapt/configuration/schemas/internal.py @@ -178,9 +178,22 @@ class InternalOutputConfig(AdaptBaseModel): class InternalLoggingConfig(AdaptBaseModel): - """Runtime logging configuration.""" + """Runtime logging + observability configuration. + + ``level`` governs the full file/JSON log; ``console_level`` keeps the console + quiet independently. The remaining toggles enable/disable the observability + subsystem and its pillars; the orchestrator translates these into an + ``ObsSettings`` when it builds the provider. + """ level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] + console_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "WARNING" + enabled: bool = True + traces: bool = True + metrics: bool = True + json_logs: bool = False + console_logs: bool = True + progress_every: float = Field(default=30.0, gt=0.0) class InternalProcessorConfig(AdaptBaseModel): diff --git a/src/adapt/contracts/execution_history.py b/src/adapt/contracts/execution_history.py new file mode 100644 index 0000000..856c1cc --- /dev/null +++ b/src/adapt/contracts/execution_history.py @@ -0,0 +1,83 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Execution-history data contracts: the DTOs the runtime hands to persistence. + +Pure frozen types, stdlib only. The runtime (which holds both telemetry and the +repository) builds these; ``adapt.persistence.ExecutionHistory`` consumes them. +``SpanRecord`` (from ``contracts.observability``) is reused for per-module rows. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime + + +@dataclass(frozen=True, slots=True) +class RunProvenance: + """Environment snapshot for reproducibility. ``git_commit`` is None outside a checkout.""" + + git_commit: str | None + hostname: str + username: str + python_version: str + platform: str + software_version: str + + +@dataclass(frozen=True, slots=True) +class RunStart: + """Everything known at run start — drives both run_history and the console header.""" + + run_id: str + pipeline: str + pipeline_version: str + site: str # radar id + dataset: str # dataset id (= radar today) + instrument: str + mode: str + start_time: datetime + configuration_hash: str + configuration_file: str + provenance: RunProvenance + enabled_modules: tuple[str, ...] + + +@dataclass(frozen=True, slots=True) +class RunSummary: + """End-of-run aggregates — drives run_history finalize and the console summary.""" + + run_id: str + status: str # success | failed | cancelled + end_time: datetime + duration_seconds: float + files_processed: int + scans_processed: int + objects_detected: int + warnings: int + errors: int + average_scan_time: float + maximum_scan_time: float + slowest_stages: tuple[tuple[str, float], ...] # (module, total_seconds) desc + + +@dataclass(frozen=True, slots=True) +class WarningEvent: + scan_id: str + module: str + category: str + message: str + logger: str + timestamp: datetime + + +@dataclass(frozen=True, slots=True) +class ErrorEvent: + scan_id: str + module: str + exception_type: str + message: str + traceback: str + logger: str + timestamp: datetime diff --git a/src/adapt/contracts/observability.py b/src/adapt/contracts/observability.py new file mode 100644 index 0000000..c3111ff --- /dev/null +++ b/src/adapt/contracts/observability.py @@ -0,0 +1,68 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Observability contract: the injectable seam shared by execution and runtime. + +Pure types only — zero logic, stdlib imports. ``ObsContext`` is the correlation +identity carried out-of-band in contextvars (never in the science context dict). +``SpanRecord`` is a finished-span snapshot handed to persistence (execution +history). ``Observability`` is the Protocol that ``GraphExecutor`` and the +runtime depend on, so the concrete provider can be injected — and later +relocated — without changing any call site. +""" + +from __future__ import annotations + +from contextlib import AbstractContextManager +from dataclasses import dataclass, field +from typing import Protocol, runtime_checkable + + +@dataclass(frozen=True, slots=True) +class ObsContext: + """Immutable correlation identity. Travels via contextvars, never mutated. + + These ids never enter the science ``context`` dict, so they cannot collide + with module inputs/outputs or affect determinism. + """ + + pipeline_id: str + trace_id: str + span_id: str = "" + scan_id: str = "" + dataset_id: str = "" # radar / site id, e.g. "KDIX" + experiment_id: str = "" + worker_id: str = "" + stage: str = "" + + +@dataclass(frozen=True, slots=True) +class SpanRecord: + """Finished-span snapshot emitted to consumers (e.g. module_history).""" + + name: str + trace_id: str + span_id: str + parent_span_id: str + start: float + finish: float + duration_s: float + error: str # "" when no error + metadata: dict[str, str] = field(default_factory=dict) + + +@runtime_checkable +class Observability(Protocol): + """The injected observability provider. Real or disabled, same interface.""" + + metrics: object + + def span(self, name: str, **ctx: object) -> AbstractContextManager: ... + + def bind(self, **ctx: str) -> AbstractContextManager: ... + + def current(self) -> ObsContext: ... + + def drain_spans(self) -> list[SpanRecord]: ... + + def install_logging(self, log_path: object) -> None: ... diff --git a/src/adapt/execution/graph/executor.py b/src/adapt/execution/graph/executor.py index ae4a411..e049cf6 100644 --- a/src/adapt/execution/graph/executor.py +++ b/src/adapt/execution/graph/executor.py @@ -16,7 +16,9 @@ """ import logging +from contextlib import nullcontext +from adapt.contracts.observability import Observability from adapt.contracts.pipeline import require from adapt.execution.graph.node import Node @@ -38,8 +40,11 @@ class GraphExecutor: result_context = executor.run(initial_context={}) """ - def __init__(self, nodes: list[Node]) -> None: + def __init__(self, nodes: list[Node], observability: Observability | None = None) -> None: self.nodes = nodes + # Optional, injected by the runtime composition root. Absent -> no telemetry + # (the documented default for this optional subsystem); the node still runs. + self._obs = observability def run(self, context: dict) -> dict: """Execute all nodes in dependency order. @@ -94,7 +99,18 @@ def run(self, context: dict) -> dict: ) validator(context[key]) - outputs = node.module.run(context) + # One span per node — the single auto-instrumentation seam. The span's + # exit captures any propagating exception and records errors_total. + span_cm = self._obs.span(node.name) if self._obs is not None else nullcontext() + with span_cm: + try: + outputs = node.module.run(context) + except Exception as exc: + # Name the failing stage on the exception so the scan-level + # handler reports *which* stage broke. The span still records + # the error; we re-raise unchanged (no swallowing, no wrapping). + exc.add_note(f"adapt: failing stage = {node.name}") + raise # Validate outputs declared by the module if outputs: diff --git a/src/adapt/modules/acquisition/module.py b/src/adapt/modules/acquisition/module.py index a94a200..093bfbd 100644 --- a/src/adapt/modules/acquisition/module.py +++ b/src/adapt/modules/acquisition/module.py @@ -7,6 +7,8 @@ in realtime or historical batches. Deduplicates files to avoid re-downloading. """ +import contextlib +import io import logging import threading import time @@ -549,7 +551,12 @@ def _download_scan(self, scan, local_path: Path) -> bool: temp_dir = base_dir / "_temp" temp_dir.mkdir(exist_ok=True) - results = self.conn.download([scan], temp_dir, keep_aws_folders=False) + # nexradaws prints "Downloaded ..." / "n out of m files downloaded..." + # straight to stdout with no quiet option. Contain it to this call (we log + # a controlled "Downloaded: " below); logging uses stderr, so this + # narrow stdout redirect never swallows our own output. + with contextlib.redirect_stdout(io.StringIO()): + results = self.conn.download([scan], temp_dir, keep_aws_folders=False) success = list(results.iter_success()) if success: diff --git a/src/adapt/modules/ingest/module.py b/src/adapt/modules/ingest/module.py index fd43517..b1aa976 100644 --- a/src/adapt/modules/ingest/module.py +++ b/src/adapt/modules/ingest/module.py @@ -24,6 +24,9 @@ import pyart import xarray as xr +# NB: the Py-ART citation banner is suppressed at the package root (adapt/__init__ +# sets PYART_QUIET) because nexradaws imports pyart before this module ever loads. + __all__ = ["RadarDataLoader"] logger = logging.getLogger(__name__) diff --git a/src/adapt/modules/projection/module.py b/src/adapt/modules/projection/module.py index 754e2ee..b6c5814 100644 --- a/src/adapt/modules/projection/module.py +++ b/src/adapt/modules/projection/module.py @@ -600,6 +600,14 @@ def _fill_concave_hull(self, label_mask, alpha=0.1): # Swap to (x, y) for Delaunay points = points[:, [1, 0]] + # A collinear / single-row / single-column point set is < 2-D, so Delaunay + # would raise QhullError ("input is less than 3-dimensional"). This is an + # expected case for thin cells, not an error: detect it up front and fall back + # to dilation silently instead of logging a verbose qhull dump per cell. + if np.linalg.matrix_rank(points - points.mean(axis=0)) < 2: + kernel = np.ones((3, 3), dtype=np.uint8) + return binary_dilation(label_mask, structure=kernel).astype(np.uint8) + try: # Compute Delaunay triangulation tri = Delaunay(points) @@ -633,6 +641,8 @@ def _fill_concave_hull(self, label_mask, alpha=0.1): return filled.astype(np.uint8) except Exception as e: - logger.warning(f"Concave hull failed: {e}, falling back to dilation") + # Unexpected failure (degenerate input is handled above). Keep it to one + # concise line — never dump the full multi-line qhull diagnostic. + logger.warning("Concave hull fill failed (%s); using dilation", type(e).__name__) kernel = np.ones((3, 3), dtype=np.uint8) return binary_dilation(label_mask, structure=kernel).astype(np.uint8) diff --git a/src/adapt/persistence/execution_history.py b/src/adapt/persistence/execution_history.py new file mode 100644 index 0000000..43c1b36 --- /dev/null +++ b/src/adapt/persistence/execution_history.py @@ -0,0 +1,284 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Execution History — a permanent, queryable scientific metadata product. + +Records every run (run_history), each module execution (module_history), and +structured warnings/errors into the radar catalog database, so questions like +"which runs failed?", "which module got slower?", "what config produced these +outputs?" are answered with SQL, never by parsing log files. + +Owns its own connection + tables (idempotent ``CREATE TABLE IF NOT EXISTS``), +like FileProcessingTracker. Consumes only ``contracts`` DTOs — never the +observability implementation. +""" + +from __future__ import annotations + +import json +import sqlite3 +import threading +from datetime import UTC, datetime +from pathlib import Path + +from adapt.contracts.execution_history import ErrorEvent, RunStart, RunSummary, WarningEvent +from adapt.contracts.observability import SpanRecord + +__all__ = ["ExecutionHistory"] + +_SCHEMA = """ +CREATE TABLE IF NOT EXISTS run_history ( + run_id TEXT PRIMARY KEY, + pipeline TEXT NOT NULL, pipeline_version TEXT, git_commit TEXT, + hostname TEXT, username TEXT, + start_time TEXT NOT NULL, end_time TEXT, duration_seconds REAL, + configuration_hash TEXT, configuration_file TEXT, + dataset TEXT, site TEXT, instrument TEXT, + files_processed INTEGER DEFAULT 0, scans_processed INTEGER DEFAULT 0, + objects_detected INTEGER DEFAULT 0, + warnings INTEGER DEFAULT 0, errors INTEGER DEFAULT 0, + status TEXT NOT NULL, + average_scan_time REAL, maximum_scan_time REAL, + slowest_stage TEXT, slowest_stage_duration REAL, + software_version TEXT, python_version TEXT, platform TEXT, + created_at TEXT NOT NULL, updated_at TEXT NOT NULL +); +CREATE TABLE IF NOT EXISTS module_history ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + run_id TEXT NOT NULL, scan_id TEXT, module TEXT NOT NULL, + duration_seconds REAL NOT NULL, status TEXT NOT NULL, error TEXT, + trace_id TEXT, span_id TEXT, recorded_at TEXT NOT NULL +); +CREATE TABLE IF NOT EXISTS warning_history ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + run_id TEXT NOT NULL, scan_id TEXT, module TEXT, + category TEXT, message TEXT NOT NULL, logger TEXT, timestamp TEXT NOT NULL +); +CREATE TABLE IF NOT EXISTS error_history ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + run_id TEXT NOT NULL, scan_id TEXT, module TEXT, + exception_type TEXT, message TEXT NOT NULL, traceback TEXT, + logger TEXT, timestamp TEXT NOT NULL +); +CREATE INDEX IF NOT EXISTS idx_module_history_run ON module_history(run_id); +CREATE INDEX IF NOT EXISTS idx_module_history_module ON module_history(module); +CREATE INDEX IF NOT EXISTS idx_error_history_run ON error_history(run_id); +CREATE INDEX IF NOT EXISTS idx_warning_history_run ON warning_history(run_id); +""" + + +class ExecutionHistory: + """Writer/reader for the four execution-history tables in the catalog db.""" + + def __init__(self, db_path: Path | str) -> None: + self.db_path = Path(db_path) + self.db_path.parent.mkdir(parents=True, exist_ok=True) + self._lock = threading.Lock() + self._conn = sqlite3.connect(str(self.db_path), check_same_thread=False) + self._conn.row_factory = sqlite3.Row + with self._lock: + self._conn.executescript(_SCHEMA) + self._conn.commit() + + # ── writes ──────────────────────────────────────────────────────────────── + def start_run(self, start: RunStart) -> None: + now = datetime.now(UTC).isoformat() + p = start.provenance + with self._lock: + self._conn.execute( + """ + INSERT INTO run_history ( + run_id, pipeline, pipeline_version, git_commit, hostname, username, + start_time, configuration_hash, configuration_file, + dataset, site, instrument, status, + software_version, python_version, platform, created_at, updated_at + ) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?) + """, + ( + start.run_id, + start.pipeline, + start.pipeline_version, + p.git_commit, + p.hostname, + p.username, + start.start_time.isoformat(), + start.configuration_hash, + start.configuration_file, + start.dataset, + start.site, + start.instrument, + "running", + p.software_version, + p.python_version, + p.platform, + now, + now, + ), + ) + self._conn.commit() + + def record_modules( + self, run_id: str, scan_id: str, spans: list[SpanRecord], recorded_at: datetime + ) -> None: + ts = recorded_at.isoformat() + rows = [ + ( + run_id, + scan_id, + s.name, + s.duration_s, + "error" if s.error else "ok", + s.error or None, + s.trace_id, + s.span_id, + ts, + ) + for s in spans + ] + with self._lock: + self._conn.executemany( + """ + INSERT INTO module_history + (run_id, scan_id, module, duration_seconds, status, error, + trace_id, span_id, recorded_at) + VALUES (?,?,?,?,?,?,?,?,?) + """, + rows, + ) + self._conn.commit() + + def record_warnings(self, run_id: str, events: list[WarningEvent]) -> None: + rows = [ + (run_id, e.scan_id, e.module, e.category, e.message, e.logger, e.timestamp.isoformat()) + for e in events + ] + with self._lock: + self._conn.executemany( + "INSERT INTO warning_history " + "(run_id, scan_id, module, category, message, logger, timestamp) " + "VALUES (?,?,?,?,?,?,?)", + rows, + ) + self._conn.commit() + + def record_errors(self, run_id: str, events: list[ErrorEvent]) -> None: + rows = [ + ( + run_id, + e.scan_id, + e.module, + e.exception_type, + e.message, + e.traceback, + e.logger, + e.timestamp.isoformat(), + ) + for e in events + ] + with self._lock: + self._conn.executemany( + "INSERT INTO error_history " + "(run_id, scan_id, module, exception_type, message, traceback, logger, timestamp) " + "VALUES (?,?,?,?,?,?,?,?)", + rows, + ) + self._conn.commit() + + def finalize_run(self, summary: RunSummary) -> None: + slowest, slowest_dur = summary.slowest_stages[0] if summary.slowest_stages else (None, None) + with self._lock: + self._conn.execute( + """ + UPDATE run_history SET + status=?, end_time=?, duration_seconds=?, + files_processed=?, scans_processed=?, objects_detected=?, + warnings=?, errors=?, + average_scan_time=?, maximum_scan_time=?, + slowest_stage=?, slowest_stage_duration=?, updated_at=? + WHERE run_id=? + """, + ( + summary.status, + summary.end_time.isoformat(), + summary.duration_seconds, + summary.files_processed, + summary.scans_processed, + summary.objects_detected, + summary.warnings, + summary.errors, + summary.average_scan_time, + summary.maximum_scan_time, + slowest, + slowest_dur, + datetime.now(UTC).isoformat(), + summary.run_id, + ), + ) + self._conn.commit() + + def export_run_report(self, run_id: str, path: Path | str) -> None: + runs = self.query_runs() + run = next((r for r in runs if r["run_id"] == run_id), None) + report = { + "run": run, + "modules": self.query_modules(run_id=run_id), + "warnings": self.query_warnings(run_id=run_id), + "errors": self.query_errors(run_id=run_id), + } + Path(path).write_text(json.dumps(report, indent=2, default=str)) + + # ── reads ───────────────────────────────────────────────────────────────── + def query_runs(self, *, site: str | None = None, status: str | None = None) -> list[dict]: + sql = "SELECT * FROM run_history" + clauses, params = [], [] + if site is not None: + clauses.append("site = ?") + params.append(site) + if status is not None: + clauses.append("status = ?") + params.append(status) + if clauses: + sql += " WHERE " + " AND ".join(clauses) + sql += " ORDER BY start_time DESC" + return self._fetch(sql, params) + + def query_modules(self, *, run_id: str | None = None, module: str | None = None) -> list[dict]: + sql = "SELECT * FROM module_history" + clauses, params = [], [] + if run_id is not None: + clauses.append("run_id = ?") + params.append(run_id) + if module is not None: + clauses.append("module = ?") + params.append(module) + if clauses: + sql += " WHERE " + " AND ".join(clauses) + return self._fetch(sql, params) + + def query_warnings(self, *, run_id: str | None = None) -> list[dict]: + return self._fetch_by_run("warning_history", run_id) + + def query_errors(self, *, run_id: str | None = None) -> list[dict]: + return self._fetch_by_run("error_history", run_id) + + def failure_rate_by_module(self) -> dict[str, float]: + rows = self._fetch( + "SELECT module, " + "AVG(CASE WHEN status = 'error' THEN 1.0 ELSE 0.0 END) AS rate " + "FROM module_history GROUP BY module", + [], + ) + return {r["module"]: r["rate"] for r in rows} + + def close(self) -> None: + self._conn.close() + + # ── helpers ─────────────────────────────────────────────────────────────── + def _fetch(self, sql: str, params: list) -> list[dict]: + with self._lock: + return [dict(row) for row in self._conn.execute(sql, params).fetchall()] + + def _fetch_by_run(self, table: str, run_id: str | None) -> list[dict]: + if run_id is None: + return self._fetch(f"SELECT * FROM {table}", []) + return self._fetch(f"SELECT * FROM {table} WHERE run_id = ?", [run_id]) diff --git a/src/adapt/persistence/repository.py b/src/adapt/persistence/repository.py index 93b8250..fbf8113 100644 --- a/src/adapt/persistence/repository.py +++ b/src/adapt/persistence/repository.py @@ -29,6 +29,7 @@ import xarray as xr from adapt.persistence.catalog import RadarCatalog +from adapt.persistence.execution_history import ExecutionHistory from adapt.persistence.registry import RepositoryRegistry if TYPE_CHECKING: @@ -130,6 +131,8 @@ def __init__( # Catalog system self.registry = RepositoryRegistry.get_instance(self.base_dir) self.catalog = RadarCatalog(self.base_dir / radar) + # Execution history (run/module/warning/error records) shares the catalog db. + self.history = ExecutionHistory(self.catalog.db_path) # Thread safety self._lock = threading.RLock() @@ -337,7 +340,10 @@ def open_dataset(self, artifact_id: str) -> xr.Dataset: if not file_path.exists(): raise FileNotFoundError(f"File not found: {file_path}") - return xr.open_dataset(file_path) + # Pin the engine: Adapt writes with netcdf4, so name it on read too. Letting + # xarray sniff backends makes the h5netcdf probe emit HDF5-DIAG noise to stderr + # ("Not an HDF5 file") for every classic-format file. + return xr.open_dataset(file_path, engine="netcdf4") def open_table(self, artifact_id: str, table_name: str | None = None) -> pd.DataFrame: """Open SQLite or Parquet artifact as DataFrame. diff --git a/src/adapt/persistence/sqlite_store.py b/src/adapt/persistence/sqlite_store.py new file mode 100644 index 0000000..18dccbc --- /dev/null +++ b/src/adapt/persistence/sqlite_store.py @@ -0,0 +1,70 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Shared SQLite store base for Adapt persistence. + +Owns the thread-safe WAL connection and schema bootstrap from the canonical SQL +files in ``adapt/configuration/schemas/``. Subclasses (``RadarCatalog``, +``RepositoryRegistry``) supply their database path and schema filename and add +their table-specific methods. + +A missing schema file raises immediately — there is no inline fallback. +""" + +import sqlite3 +import threading +from pathlib import Path + +__all__ = ["SqliteStore"] + +_SCHEMA_DIR = Path(__file__).resolve().parents[1] / "configuration" / "schemas" + + +class SqliteStore: + """Thread-safe SQLite connection and schema initialisation. + + Provides the shared WAL connection, locking, and schema bootstrap. The + database path and schema file are chosen by the subclass. + """ + + def __init__( + self, db_path: str | Path, schema_filename: str, *, checkpoint: bool = False + ) -> None: + self.db_path = Path(db_path) + self._lock = threading.RLock() + self._conn: sqlite3.Connection | None = None + self._init_schema(schema_filename, checkpoint=checkpoint) + + def _get_connection(self) -> sqlite3.Connection: + """Return the shared thread-safe connection, opening it on first use.""" + if self._conn is None: + self._conn = sqlite3.connect( + str(self.db_path), check_same_thread=False, isolation_level="DEFERRED" + ) + self._conn.row_factory = sqlite3.Row + # Enable WAL mode for concurrent access + self._conn.execute("PRAGMA journal_mode=WAL") + self._conn.execute("PRAGMA foreign_keys=ON") + return self._conn + + def _init_schema(self, schema_filename: str, *, checkpoint: bool) -> None: + """Create tables from the canonical SQL file. Raise if it is missing.""" + schema_path = _SCHEMA_DIR / schema_filename + if not schema_path.exists(): + raise FileNotFoundError(f"SQL schema not found: {schema_path}") + + schema_sql = schema_path.read_text(encoding="utf-8") + conn = self._get_connection() + with self._lock: + conn.executescript(schema_sql) + conn.commit() + if checkpoint: + # Checkpoint WAL so readonly readers (immutable=1) see the schema. + conn.execute("PRAGMA wal_checkpoint(PASSIVE)") + + def close(self) -> None: + """Close the connection if open.""" + if self._conn: + with self._lock: + self._conn.close() + self._conn = None diff --git a/src/adapt/runtime/diagnostics.py b/src/adapt/runtime/diagnostics.py new file mode 100644 index 0000000..d91d249 --- /dev/null +++ b/src/adapt/runtime/diagnostics.py @@ -0,0 +1,29 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Quiet redundant compiled-library diagnostics through their supported controls. + +These libraries write directly to the process stderr/stdout, bypassing Python +logging entirely, so they cannot be routed by the logging configuration. We use +each library's own supported mechanism — never stdout/stderr redirection — and +apply it on the thread that does the I/O (HDF5 error stacks are thread-local). +""" + +from __future__ import annotations + +import h5py + +__all__ = ["silence_hdf5_errors"] + + +def silence_hdf5_errors() -> None: + """Turn off libhdf5's automatic stderr error dump on the calling thread. + + libhdf5 prints multi-line ``HDF5-DIAG`` blocks to stderr when an operation + fails — e.g. while probing a file's format — even though the real failure is + already raised to Python as an exception. The dump is never actionable, only + clutter. ``h5py._errors.silence_errors`` is h5py's wrapper over the official + HDF5 ``H5Eset_auto`` control; because HDF5 error stacks are thread-local, this + must run on each worker thread that touches HDF5, not once at import. + """ + h5py._errors.silence_errors() diff --git a/src/adapt/runtime/history_handler.py b/src/adapt/runtime/history_handler.py new file mode 100644 index 0000000..1c7083c --- /dev/null +++ b/src/adapt/runtime/history_handler.py @@ -0,0 +1,75 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Capture WARNING/ERROR log records into structured execution-history events. + +Installed on the root logger by the orchestrator. ``emit`` only buffers in +memory (fast, no DB I/O, no reentrancy risk on the hot path); the processor and +orchestrator ``drain`` the buffers and write them via ``ExecutionHistory``. This +turns the existing ~50 ``logger.warning``/``logger.exception`` call sites into +searchable warning/error events with zero new instrumentation in those modules. +""" + +from __future__ import annotations + +import logging +import threading +from datetime import UTC, datetime + +from adapt.contracts.execution_history import ErrorEvent, WarningEvent +from adapt.runtime.observability import current_context + +__all__ = ["HistoryLogHandler"] + + +class HistoryLogHandler(logging.Handler): + """Buffers WARNING -> WarningEvent and ERROR/CRITICAL -> ErrorEvent.""" + + def __init__(self) -> None: + super().__init__() + self._lock = threading.Lock() + self._warnings: list[WarningEvent] = [] + self._errors: list[ErrorEvent] = [] + + def emit(self, record: logging.LogRecord) -> None: + ctx = current_context() + now = datetime.now(UTC) + if record.levelno >= logging.ERROR: + exc_type = "" + if record.exc_info and record.exc_info[0] is not None: + exc_type = record.exc_info[0].__name__ + event = ErrorEvent( + scan_id=ctx.scan_id, + module=ctx.stage, + exception_type=exc_type, + message=record.getMessage(), + traceback=self._format_traceback(record), + logger=record.name, + timestamp=now, + ) + with self._lock: + self._errors.append(event) + elif record.levelno >= logging.WARNING: + event = WarningEvent( + scan_id=ctx.scan_id, + module=ctx.stage, + category=getattr(record, "category", "general"), + message=record.getMessage(), + logger=record.name, + timestamp=now, + ) + with self._lock: + self._warnings.append(event) + + def drain(self) -> tuple[list[WarningEvent], list[ErrorEvent]]: + """Return buffered warnings + errors, then clear (called per scan / at stop).""" + with self._lock: + warnings, errors = self._warnings, self._errors + self._warnings, self._errors = [], [] + return warnings, errors + + @staticmethod + def _format_traceback(record: logging.LogRecord) -> str: + if record.exc_info: + return logging.Formatter().formatException(record.exc_info) + return "" diff --git a/src/adapt/runtime/logging_setup.py b/src/adapt/runtime/logging_setup.py new file mode 100644 index 0000000..9d6f6b8 --- /dev/null +++ b/src/adapt/runtime/logging_setup.py @@ -0,0 +1,116 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""The single logging-configuration site for Adapt. + +Replaces the per-orchestrator handler setup and the CLI ``basicConfig``. Installs +a ``ContextFilter`` that stamps the current correlation ids onto every record, a +``JsonFormatter`` for the full file log, and a ``ConsoleFilter`` so the console +stays quiet (warnings/errors + explicitly console-tagged lines only) while the +file/JSON log keeps everything. Scientific modules keep ``getLogger(__name__)`` +and never import this — they gain context automatically. +""" + +from __future__ import annotations + +import json +import logging +from pathlib import Path + +from adapt.runtime.observability import ObsSettings, current_context + +__all__ = ["ContextFilter", "JsonFormatter", "ConsoleFilter", "configure_logging"] + +_CONSOLE_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" +# Standard LogRecord attributes to exclude when surfacing structured extras. +_STD_ATTRS = set(logging.makeLogRecord({}).__dict__) | {"message", "asctime"} +_CONTEXT_FIELDS = ( + "pipeline_id", + "trace_id", + "span_id", + "scan_id", + "dataset_id", + "experiment_id", + "worker_id", + "stage", +) + + +class ContextFilter(logging.Filter): + """Stamp the current ObsContext onto every record so modules never repeat ids.""" + + def filter(self, record: logging.LogRecord) -> bool: + ctx = current_context() + for field in _CONTEXT_FIELDS: + setattr(record, field, getattr(ctx, field)) + return True + + +class JsonFormatter(logging.Formatter): + """One JSON object per record: standard fields + context ids + any extra=.""" + + def format(self, record: logging.LogRecord) -> str: + # NB: the trailing-Z scan-time literal is fitness-pinned to utils/time; keep this no-Z form. + payload: dict[str, object] = { + "ts": self.formatTime(record, "%Y-%m-%dT%H:%M:%S"), + "level": record.levelname, + "logger": record.name, + "msg": record.getMessage(), + } + for key, value in record.__dict__.items(): + if key not in _STD_ATTRS and not key.startswith("_"): + payload[key] = value + if record.exc_info: + payload["exc"] = self.formatException(record.exc_info) + return json.dumps(payload, default=str) + + +class ConsoleFilter(logging.Filter): + """Pass only what a human needs to watch: records at/above the console threshold, + plus explicitly console-tagged lines (the run header, progress, and summary). + + The threshold is carried here, not on the handler's level: a handler's level is + checked *before* its filters, so an INFO threshold-bypassing console-tagged line + would be dropped by a WARNING handler level before this filter ever ran. + """ + + def __init__(self, level: int = logging.WARNING) -> None: + super().__init__() + self._level = level + + def filter(self, record: logging.LogRecord) -> bool: + return record.levelno >= self._level or bool(getattr(record, "console", False)) + + +def configure_logging(settings: ObsSettings, log_path: Path | None) -> None: + """Configure the root logger. The one place handlers are constructed. + + Fails loudly: ``json_logs`` with no ``log_path`` raises (no silent default). + Idempotent — clears existing handlers before re-adding, so repeated calls (or + a prior ``basicConfig``) never accumulate handlers. + """ + root = logging.getLogger() + root.setLevel(getattr(logging, settings.level.upper(), logging.INFO)) + for handler in root.handlers[:]: + root.removeHandler(handler) + + context_filter = ContextFilter() + + if settings.json_logs: + if log_path is None: + raise ValueError("json_logs=True requires a log_path") + log_path.parent.mkdir(parents=True, exist_ok=True) + file_handler = logging.FileHandler(log_path) + file_handler.setFormatter(JsonFormatter()) + file_handler.addFilter(context_filter) + root.addHandler(file_handler) + + if settings.console_logs: + # No handler-level gate: ConsoleFilter does all gating (see its docstring), so + # console-tagged INFO lines survive a WARNING console threshold. + console_threshold = getattr(logging, settings.console_level.upper(), logging.WARNING) + console = logging.StreamHandler() + console.setFormatter(logging.Formatter(_CONSOLE_FORMAT)) + console.addFilter(context_filter) + console.addFilter(ConsoleFilter(console_threshold)) + root.addHandler(console) diff --git a/src/adapt/runtime/observability.py b/src/adapt/runtime/observability.py new file mode 100644 index 0000000..3322d05 --- /dev/null +++ b/src/adapt/runtime/observability.py @@ -0,0 +1,304 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Observability provider: structured-logging context, tracing, and metrics. + +Lives in ``runtime`` because the orchestrator (composition root) is its only +importer; ``execution``/``runtime`` receive it injected as the +``adapt.contracts.observability.Observability`` Protocol, so it can be relocated +later without touching call sites. All wall-clock/RNG reads are injected and +required — never defaulted — so durations and ids stay deterministic under test +and no hidden clock read can slip in. Scientific modules never import this. +""" + +from __future__ import annotations + +import contextvars +import threading +from collections.abc import Callable +from contextlib import contextmanager +from dataclasses import dataclass, replace +from datetime import UTC, datetime +from random import Random + +from adapt.contracts.observability import ObsContext, SpanRecord + +__all__ = [ + "ObsSettings", + "Observability", + "build_observability", + "current_context", + "disabled_observability", +] + +_EPOCH = datetime(1970, 1, 1, tzinfo=UTC) + +# The single correlation context. Per-thread isolated (contextvars do not cross +# threads); each worker re-binds at the top of its run(). Default is blank so a +# fresh/unbound context is harmless. +_BLANK_CONTEXT = ObsContext(pipeline_id="", trace_id="") +_CTX: contextvars.ContextVar[ObsContext] = contextvars.ContextVar( + "adapt_obs", default=_BLANK_CONTEXT +) + + +def current_context() -> ObsContext: + """The current correlation context (blank if nothing is bound on this thread).""" + return _CTX.get() + + +def _restore_context(token: contextvars.Token, parent: ObsContext) -> None: + """Restore the context on scope exit, surviving a cross-Context teardown. + + ``Token.reset`` raises ValueError when the token was created in a different + contextvars Context — the Ctrl+C shutdown path, where a long-lived span is + entered and exited across a copied Context. The span record is already + emitted by then, so the only job left is to put the parent value back; do that + directly (valid in any Context) instead of letting cleanup crash the shutdown. + """ + try: + _CTX.reset(token) + except ValueError: + _CTX.set(parent) + + +@dataclass(frozen=True, slots=True) +class ObsSettings: + """Resolved observability toggles. Built by the orchestrator from InternalConfig. + + Field defaults are the dev profile; profiles/CLI override explicitly upstream. + """ + + enabled: bool = True + traces: bool = True + metrics: bool = True + json_logs: bool = False + console_logs: bool = True + level: str = "INFO" + console_level: str = "WARNING" + progress_every: float = 30.0 + + +class Observability: + """Injectable provider for context, spans, and metrics.""" + + def __init__( + self, + settings: ObsSettings, + *, + clock: Callable[[], float], + wall_clock: Callable[[], datetime], + rng: Random, + ) -> None: + self._settings = settings + self._clock = clock + self._wall = wall_clock + self._rng = rng + self._enabled = settings.enabled + self._spans: list[SpanRecord] = [] + self.metrics = _Meter(enabled=settings.enabled and settings.metrics) + + # ── context propagation ─────────────────────────────────────────────────── + def current(self) -> ObsContext: + return current_context() + + @contextmanager + def bind(self, **fields: str): + """Push correlation fields onto the current context for the block's scope.""" + if not self._enabled: + yield + return + parent = _CTX.get() + token = _CTX.set(replace(parent, **fields)) + try: + yield + finally: + _restore_context(token, parent) + + # ── tracing ─────────────────────────────────────────────────────────────── + def span(self, name: str, **seed: object): + """Open a span; use as a context manager. Nesting is automatic via contextvars.""" + if not self._enabled: + return _NULL_SPAN + return _Span(self, name, {k: str(v) for k, v in seed.items()}) + + def drain_spans(self) -> list[SpanRecord]: + """Return finished spans recorded since the last drain, then clear the buffer.""" + spans, self._spans = self._spans, [] + return spans + + def _trace_id(self) -> str: + return f"{self._rng.getrandbits(128):032x}" + + def _span_id(self) -> str: + return f"{self._rng.getrandbits(64):016x}" + + def _record(self, record: SpanRecord) -> None: + if self._settings.traces: + self._spans.append(record) + + +class _Span: + """A single span. Reads the current context as its parent; restores on exit.""" + + def __init__(self, obs: Observability, name: str, seed: dict[str, str]) -> None: + self._obs = obs + self._name = name + self._seed = seed + self._error = "" + self._meta: dict[str, str] = {} + self.duration_s = 0.0 + + def set(self, **metadata: object) -> None: + self._meta.update({k: str(v) for k, v in metadata.items()}) + + def record_error(self, exc: BaseException) -> None: + self._error = f"{type(exc).__name__}: {exc}" + + def __enter__(self) -> _Span: + parent = _CTX.get() + self._parent = parent + self._parent_id = parent.span_id + self._trace_id = parent.trace_id or self._obs._trace_id() + self._span_id = self._obs._span_id() + self._start = self._obs._clock() + self._token = _CTX.set( + replace( + parent, + trace_id=self._trace_id, + span_id=self._span_id, + stage=self._name, + **self._seed, + ) + ) + return self + + def __exit__(self, exc_type, exc, tb) -> bool: + finish = self._obs._clock() + if exc is not None and not self._error: + self._error = f"{exc_type.__name__}: {exc}" + self.duration_s = finish - self._start + self._obs.metrics.observe("module_duration_seconds", self.duration_s, stage=self._name) + if self._error: + self._obs.metrics.incr("errors_total", stage=self._name) + self._obs._record( + SpanRecord( + name=self._name, + trace_id=self._trace_id, + span_id=self._span_id, + parent_span_id=self._parent_id, + start=self._start, + finish=finish, + duration_s=self.duration_s, + error=self._error, + metadata=dict(self._meta), + ) + ) + _restore_context(self._token, self._parent) + return False # never swallow — fail loudly + + +class _NullSpan: + """Shared zero-cost span used when observability is disabled.""" + + duration_s = 0.0 + + def set(self, **metadata: object) -> None: ... + + def record_error(self, exc: BaseException) -> None: ... + + def __enter__(self) -> _NullSpan: + return self + + def __exit__(self, exc_type, exc, tb) -> bool: + return False + + +_NULL_SPAN = _NullSpan() + + +def _key(labels: dict[str, str]) -> tuple: + return tuple(sorted(labels.items())) + + +class _Meter: + """Lock-guarded, accumulate-only in-memory metrics. Counters/histograms are + order-insensitive; gauges are last-write-wins (observational, not asserted for + determinism). Never read back into the science path. + """ + + def __init__(self, *, enabled: bool) -> None: + self._enabled = enabled + self._lock = threading.Lock() + self._counters: dict[tuple, float] = {} + self._gauges: dict[tuple, float] = {} + self._hist: dict[tuple, list[float]] = {} + + def incr(self, name: str, value: float = 1.0, **labels: str) -> None: + if not self._enabled: + return + k = (name, _key(labels)) + with self._lock: + self._counters[k] = self._counters.get(k, 0.0) + value + + def gauge(self, name: str, value: float, **labels: str) -> None: + if not self._enabled: + return + with self._lock: + self._gauges[(name, _key(labels))] = value + + def observe(self, name: str, value: float, **labels: str) -> None: + if not self._enabled: + return + with self._lock: + self._hist.setdefault((name, _key(labels)), []).append(value) + + def counter_total(self, name: str) -> float: + with self._lock: + return sum(v for (n, _), v in self._counters.items() if n == name) + + def gauge_value(self, name: str, **labels: str) -> float | None: + with self._lock: + return self._gauges.get((name, _key(labels))) + + def histogram_values(self, name: str) -> list[float]: + with self._lock: + out: list[float] = [] + for (n, _), values in self._hist.items(): + if n == name: + out.extend(values) + return out + + def histogram_totals_by_label(self, name: str, label: str) -> dict[str, float]: + with self._lock: + out: dict[str, float] = {} + for (n, lk), values in self._hist.items(): + if n != name: + continue + key = dict(lk).get(label) + if key is None: + continue + out[key] = out.get(key, 0.0) + sum(values) + return out + + +def build_observability( + settings: ObsSettings, + *, + clock: Callable[[], float], + wall_clock: Callable[[], datetime], + rng: Random, +) -> Observability: + """The one constructor (DI entry). ``clock``/``wall_clock``/``rng`` are required.""" + return Observability(settings, clock=clock, wall_clock=wall_clock, rng=rng) + + +def disabled_observability() -> Observability: + """A real provider with telemetry OFF — a clean explicit default for runtime + components when none is injected (not a silent fallback: telemetry is optional).""" + return build_observability( + ObsSettings(enabled=False), + clock=lambda: 0.0, + wall_clock=lambda: _EPOCH, + rng=Random(0), + ) diff --git a/src/adapt/runtime/orchestrator.py b/src/adapt/runtime/orchestrator.py index c82c1a9..f5a5a7e 100644 --- a/src/adapt/runtime/orchestrator.py +++ b/src/adapt/runtime/orchestrator.py @@ -13,17 +13,26 @@ import logging import queue +import random import time +from datetime import UTC, datetime from pathlib import Path from typing import TYPE_CHECKING +from adapt.contracts.execution_history import RunStart, RunSummary from adapt.persistence import DataRepository from adapt.runtime.file_tracker import FileProcessingTracker +from adapt.runtime.history_handler import HistoryLogHandler +from adapt.runtime.logging_setup import configure_logging +from adapt.runtime.observability import ObsSettings, build_observability from adapt.runtime.processor import RadarProcessor +from adapt.runtime.provenance import capture_provenance, config_hash +from adapt.runtime.run_reporter import RunReporter from adapt.runtime.sources import source_registry if TYPE_CHECKING: from adapt.configuration.schemas.internal import InternalConfig + from adapt.contracts.observability import Observability from adapt.contracts.source import ScanSource __all__ = ["PipelineOrchestrator"] @@ -141,6 +150,108 @@ def __init__( self._max_duration: float | None = None self._close_repository_on_stop = close_repository_on_stop + # Telemetry (built in start()); kept here so stop() is safe before start(). + self._obs: Observability | None = None + self._root_span = None + self._root_trace_id = "" + self._history_handler: HistoryLogHandler | None = None + self._reporter: RunReporter | None = None + + def _obs_settings(self) -> ObsSettings: + """Translate the resolved logging/observability config into ObsSettings.""" + lc = self.config.logging + return ObsSettings( + enabled=lc.enabled, + traces=lc.traces, + metrics=lc.metrics, + json_logs=lc.json_logs, + console_logs=lc.console_logs, + level=lc.level, + console_level=lc.console_level, + progress_every=lc.progress_every, + ) + + def _build_observability(self) -> "Observability": + """Build the telemetry provider from the resolved config. + + Injects production clocks/rng here (the only place wall-clock reads are + legal for telemetry). Trace ids are random per run. + """ + return build_observability( + self._obs_settings(), + clock=time.perf_counter, + wall_clock=lambda: datetime.now(UTC), + rng=random.Random(), + ) + + def _record_run_start(self, radar: str) -> None: + """Open the execution-history run record and print the console run header.""" + prov = capture_provenance() + modules = tuple(m.name for m in self.processor._pipeline_modules) if self.processor else () + start = RunStart( + run_id=self.run_id or "", + pipeline=self.config.source, + pipeline_version=prov.software_version, + site=radar, + dataset=radar, + instrument="NEXRAD", + mode=self.config.mode, + start_time=datetime.now(UTC), + configuration_hash=config_hash(self.config.model_dump_json()), + configuration_file=str( + self.repository.catalog.radar_dir / f"config_run_{self.run_id}.json" + ) + if self.repository + else "", + provenance=prov, + enabled_modules=modules, + ) + assert self.repository is not None + self.repository.history.start_run(start) + self._reporter.header(start) + + def _finalize_history(self) -> None: + """Flush captured warnings/errors, finalize the run record, print the summary.""" + if self._obs is None or self.repository is None: + return + if self._history_handler is not None: + warnings, errors = self._history_handler.drain() + if warnings: + self.repository.history.record_warnings(self.run_id, warnings) + if errors: + self.repository.history.record_errors(self.run_id, errors) + summary = self._build_run_summary("cancelled" if self._interrupted else "success") + self.repository.history.finalize_run(summary) + if self._reporter is not None: + self._reporter.summary(summary) + + def _build_run_summary(self, status: str) -> RunSummary: + """Aggregate the end-of-run summary from telemetry metrics + history counts.""" + m = self._obs.metrics + scan_times = m.histogram_values("scan_processing_time") + slowest = sorted( + m.histogram_totals_by_label("module_duration_seconds", "stage").items(), + key=lambda kv: kv[1], + reverse=True, + )[:3] + warnings = len(self.repository.history.query_warnings(run_id=self.run_id)) + errors = len(self.repository.history.query_errors(run_id=self.run_id)) + duration = (time.time() - self._start_time) if self._start_time else 0.0 + return RunSummary( + run_id=self.run_id or "", + status=status, + end_time=datetime.now(UTC), + duration_seconds=duration, + files_processed=int(m.counter_total("files_processed_total")), + scans_processed=len(scan_times), + objects_detected=int(m.counter_total("cells_detected_total")), + warnings=warnings, + errors=errors, + average_scan_time=(sum(scan_times) / len(scan_times)) if scan_times else 0.0, + maximum_scan_time=max(scan_times) if scan_times else 0.0, + slowest_stages=tuple(slowest), + ) + def _setup_logging(self): """Configure logging and file tracking systems. @@ -258,6 +369,23 @@ def start(self, max_runtime: int | None = None): config=self.config, ) + # Build telemetry and open the root pipeline span. The root trace id is handed + # to the processor thread (contextvars do not cross threads) so the whole run + # shares one trace. + self._obs = self._build_observability() + self._root_span = self._obs.span( + "pipeline", pipeline_id=self.run_id or "", dataset_id=radar + ) + self._root_span.__enter__() + self._root_trace_id = self._obs.current().trace_id + + # Structured logging (JSON file + quiet console) + capture warnings/errors. + log_path = Path(self.output_dirs["logs"]) / f"pipeline_{radar}.log" + configure_logging(self._obs_settings(), log_path) + self._history_handler = HistoryLogHandler() + logging.getLogger().addHandler(self._history_handler) + self._reporter = RunReporter() + self._start_time = time.time() self._max_duration = max_runtime * 60 if max_runtime else None @@ -279,9 +407,15 @@ def start(self, max_runtime: int | None = None): output_dirs=self.output_dirs, file_tracker=self.tracker, repository=self.repository, + observability=self._obs, + root_trace_id=self._root_trace_id, + reporter=self._reporter, ) self.processor.start() + # Open the execution-history run record and print the one-shot console header. + self._record_run_start(radar) + mode = self.config.mode logger.debug("Pipeline running in %s mode. Press Ctrl+C to stop.", mode.upper()) @@ -427,6 +561,11 @@ def stop(self): self._stop_event = True + # Close the root pipeline span if it was opened in start(). + if self._root_span is not None: + self._root_span.__exit__(None, None, None) + self._root_span = None + # Stop threads for name, thread in [ ("Downloader", self.downloader), @@ -443,6 +582,10 @@ def stop(self): self.processor.save_results() self.processor.close_database() + # Execution history: flush captured warnings/errors, finalize the run record, + # and print the one-shot console summary (before the repository closes). + self._finalize_history() + # Finalize repository if self.repository: final_status = "cancelled" if self._interrupted else "completed" diff --git a/src/adapt/runtime/processor.py b/src/adapt/runtime/processor.py index 1745c87..8908235 100644 --- a/src/adapt/runtime/processor.py +++ b/src/adapt/runtime/processor.py @@ -23,7 +23,7 @@ import queue import threading import time -from datetime import UTC +from datetime import UTC, datetime from pathlib import Path from typing import TYPE_CHECKING @@ -31,12 +31,15 @@ from adapt.configuration.schemas.module_resolver import resolve_module_configs from adapt.contracts import ContractViolation, PersistenceMeta +from adapt.contracts.observability import Observability from adapt.execution.graph.builder import GraphBuilder from adapt.execution.graph.executor import GraphExecutor from adapt.execution.module_registry import registry from adapt.execution.pipeline_builder import _ensure_modules_registered, resolve_enabled_modules from adapt.persistence import DataRepository, ProductType from adapt.persistence.output_router import OutputRouter +from adapt.runtime.diagnostics import silence_hdf5_errors +from adapt.runtime.observability import disabled_observability if TYPE_CHECKING: from adapt.configuration.schemas.internal import InternalConfig @@ -82,6 +85,9 @@ def __init__( file_tracker=None, repository: DataRepository | None = None, name: str = "RadarProcessor", + observability: Observability | None = None, + root_trace_id: str = "", + reporter=None, ): super().__init__(daemon=True, name=name) @@ -90,6 +96,12 @@ def __init__( self.output_dirs = {k: Path(v) for k, v in output_dirs.items()} self.file_tracker = file_tracker self.repository = repository + # Telemetry provider, injected by the orchestrator. Absent -> disabled (off), + # so the rest of this class calls it unconditionally with no `if obs` branches. + self._obs = observability if observability is not None else disabled_observability() + self._root_trace_id = root_trace_id + # Console reporter (injected). Absent -> no per-scan progress line. + self._reporter = reporter self._stop_event = threading.Event() self.output_lock = threading.Lock() @@ -121,13 +133,15 @@ def __init__( history_groups.setdefault(m.required_history, []).append(m) self._executors: dict[int, GraphExecutor] = { - req: GraphExecutor(GraphBuilder(mods).build()) + req: GraphExecutor(GraphBuilder(mods).build(), observability=self._obs) for req, mods in sorted(history_groups.items()) } self._post_modules = post_persist self._post_executor: GraphExecutor | None = ( - GraphExecutor(GraphBuilder(post_persist).build()) if post_persist else None + GraphExecutor(GraphBuilder(post_persist).build(), observability=self._obs) + if post_persist + else None ) self._module_configs = resolve_module_configs(config) @@ -161,10 +175,27 @@ def stopped(self) -> bool: return self._stop_event.is_set() def run(self): - """Main processor loop (runs in thread).""" + """Main processor loop (runs in thread). + + Binds this thread's correlation context once — contextvars do not cross + threads, so the root trace id is handed in from the orchestrator and + re-bound here so every log/span emitted on this thread shares one trace. + """ + # HDF5 error stacks are thread-local: silence libhdf5's stderr dumps on THIS + # worker thread, where all the NetCDF/HDF5 I/O happens. + silence_hdf5_errors() logger.info("Processor started, waiting for files...") - _skip_count = 0 + with self._obs.bind( + trace_id=self._root_trace_id, + pipeline_id=self.repository.run_id, + dataset_id=self.config.downloader.radar, + worker_id="processor", + ): + self._run_loop() + logger.info("Processor stopped") + def _run_loop(self): + _skip_count = 0 while not self.stopped(): try: filepath = self.input_queue.get(timeout=1) @@ -189,7 +220,6 @@ def run(self): if _skip_count: logger.info("Skipped %d already-analyzed files", _skip_count) - logger.info("Processor stopped") # ── Per-file processing ─────────────────────────────────────────────────── @@ -233,6 +263,30 @@ def process_file(self, filepath) -> bool: queue_wait_s = (time.time() - queued_at) if queued_at else None logger.info("Processing: %s", Path(filepath).name) + # Bind scan context and open the scan span; module spans nest under it and + # every log on this path carries scan_id. Disabled provider -> no-ops. + with self._obs.bind(scan_id=file_id): + with self._obs.span("scan") as scan_span: + ok = self._run_scan(filepath, file_id, queue_wait_s, scan_span) + # The scan span has closed; split this scan's drained spans into the module + # spans (one module_history batch) and the scan span itself (carries n_cells). + all_spans = self._obs.drain_spans() + modules = [s for s in all_spans if s.name != "scan"] + if modules: + self.repository.history.record_modules( + self.repository.run_id, file_id, modules, recorded_at=datetime.now(UTC) + ) + # One controlled console line per scan, built from the captured telemetry + # (stage timings + cell count) — never printed from inside a module. + if self._reporter is not None and modules: + scan_rec = next((s for s in all_spans if s.name == "scan"), None) + n_cells = int(scan_rec.metadata.get("n_cells", 0)) if scan_rec else 0 + self._reporter.scan(file_id, modules, n_cells) + return ok + + def _run_scan(self, filepath, file_id, queue_wait_s, scan_span) -> bool: + """Execute the scientific pipeline for one scan (inside the scan span).""" + tracker = self.file_tracker try: t0 = time.perf_counter() @@ -345,6 +399,14 @@ def process_file(self, filepath) -> bool: timings["queue_wait_seconds"] = queue_wait_s tracker.mark_stage_complete(file_id, "analyzed", num_cells=n_cells, timings=timings) + radar = self.config.downloader.radar + self._obs.metrics.incr("files_processed_total", dataset_id=radar) + self._obs.metrics.incr("cells_detected_total", value=float(n_cells)) + self._obs.metrics.observe("scan_processing_time", elapsed_s) + self._obs.metrics.gauge("queue_depth", self.input_queue.qsize()) + if queue_wait_s is not None: + self._obs.metrics.observe("download_latency", queue_wait_s) + scan_span.set(n_cells=n_cells) return True except ContractViolation as e: @@ -355,7 +417,19 @@ def process_file(self, filepath) -> bool: return False except Exception as e: - logger.exception("Error processing %s", filepath) + # One context-rich failure line + exactly one traceback. scan_id rides the + # bound context; the failing stage is on the exception's notes (added by + # GraphExecutor). elapsed/error_type are structured for the JSON log. + elapsed = time.perf_counter() - t0 + logger.error( + "Scan failed | scan=%s elapsed=%.2fs %s: %s", + file_id, + elapsed, + type(e).__name__, + e, + exc_info=True, + extra={"elapsed_s": round(elapsed, 3), "error_type": type(e).__name__}, + ) if tracker: tracker.mark_stage_complete(file_id, "analyzed", error=str(e)) return False diff --git a/src/adapt/runtime/provenance.py b/src/adapt/runtime/provenance.py new file mode 100644 index 0000000..251c013 --- /dev/null +++ b/src/adapt/runtime/provenance.py @@ -0,0 +1,56 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Capture run provenance (git/host/user/python/platform/version) + config hashing. + +Lives in the runtime composition layer: the orchestrator captures this once at +run start and hands it to ``ExecutionHistory`` via a ``RunStart`` DTO. Fields +that genuinely do not exist in the environment (e.g. no git checkout) are +recorded as ``None`` — the true value, never a fabricated default. +""" + +from __future__ import annotations + +import getpass +import hashlib +import platform +import socket +import subprocess + +from adapt import __version__ +from adapt.contracts.execution_history import RunProvenance + +__all__ = ["capture_provenance", "config_hash"] + + +def _git_commit() -> str | None: + """Current commit hash, or None when not in a git checkout / git unavailable.""" + try: + result = subprocess.run( + ["git", "rev-parse", "HEAD"], + capture_output=True, + text=True, + timeout=2, + ) + except (OSError, subprocess.SubprocessError): + return None + if result.returncode != 0: + return None + return result.stdout.strip() or None + + +def capture_provenance() -> RunProvenance: + """Snapshot the execution environment for reproducibility.""" + return RunProvenance( + git_commit=_git_commit(), + hostname=socket.gethostname(), + username=getpass.getuser(), + python_version=platform.python_version(), + platform=platform.platform(), + software_version=__version__, + ) + + +def config_hash(resolved_config_json: str) -> str: + """Stable SHA-256 of the resolved configuration JSON (provenance fingerprint).""" + return hashlib.sha256(resolved_config_json.encode("utf-8")).hexdigest() diff --git a/src/adapt/runtime/run_reporter.py b/src/adapt/runtime/run_reporter.py new file mode 100644 index 0000000..e61af3e --- /dev/null +++ b/src/adapt/runtime/run_reporter.py @@ -0,0 +1,98 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""The quiet console: a one-shot run header, a periodic progress line, and an +end-of-run summary. Plain ASCII, no colours/spinners. Lines are emitted at INFO +with ``extra={"console": True}`` so the ConsoleFilter lets them through while +routine per-scan logs stay in the file/JSON log only. +""" + +from __future__ import annotations + +import logging + +from adapt.contracts.execution_history import RunStart, RunSummary +from adapt.contracts.observability import SpanRecord + +__all__ = ["RunReporter", "format_header", "format_summary", "format_duration", "format_scan"] + +_RULE = "─" * 60 + + +def format_duration(seconds: float) -> str: + """Human-compact duration, e.g. 862 -> '14m22s', 45 -> '45s', 3725 -> '1h2m5s'.""" + total = int(round(seconds)) + hours, rem = divmod(total, 3600) + minutes, secs = divmod(rem, 60) + parts = [] + if hours: + parts.append(f"{hours}h") + if minutes or hours: + parts.append(f"{minutes}m") + parts.append(f"{secs}s") + return "".join(parts) + + +def format_header(start: RunStart) -> str: + """One compact block summarising the run at a glance.""" + p = start.provenance + commit = p.git_commit[:7] if p.git_commit else "—" + return "\n".join( + [ + f"── ADAPT run {start.run_id} {_RULE[: max(0, 48 - len(start.run_id))]}", + f"site {start.site} · {start.instrument} · mode {start.mode}", + f"pipeline {start.pipeline} v{start.pipeline_version} · commit {commit} " + f"· python {p.python_version} · {p.platform}", + f"config {start.configuration_file} (sha {start.configuration_hash[:8]}) " + f"· modules: {', '.join(start.enabled_modules)}", + _RULE, + ] + ) + + +def format_summary(summary: RunSummary) -> str: + """One compact block of end-of-run stats.""" + slowest = " · ".join(f"{name} {format_duration(secs)}" for name, secs in summary.slowest_stages) + return "\n".join( + [ + f"── run {summary.run_id} {summary.status.upper()} " + f"{format_duration(summary.duration_seconds)} {_RULE[:20]}", + f"files {summary.files_processed:,} · scans {summary.scans_processed:,} " + f"· objects {summary.objects_detected:,} · warnings {summary.warnings} " + f"· errors {summary.errors}", + f"scan time avg {summary.average_scan_time:.2f}s " + f"· max {summary.maximum_scan_time:.2f}s", + f"slowest: {slowest}" if slowest else "slowest: —", + _RULE, + ] + ) + + +def format_scan(scan_id: str, spans: list[SpanRecord], n_cells: int) -> str: + """One compact per-scan line, built from the scan's drained module spans. + + Reads ``stage duration`` straight off the captured telemetry — so the console + reports what ran and how long without any module emitting its own progress. + """ + stages = " ".join(f"{s.name} {s.duration_s:.1f}s" for s in spans) + total = sum(s.duration_s for s in spans) + return f"scan {scan_id} │ {stages} │ {n_cells} cells {total:.1f}s" + + +class RunReporter: + """Emits the console-tagged line groups: header, per-scan progress, summary.""" + + def __init__(self, logger: logging.Logger | None = None) -> None: + self._log = logger or logging.getLogger("adapt.run") + + def header(self, start: RunStart) -> None: + self._log.info(format_header(start), extra={"console": True}) + + def progress(self, text: str) -> None: + self._log.info(text, extra={"console": True}) + + def scan(self, scan_id: str, spans: list[SpanRecord], n_cells: int) -> None: + self._log.info(format_scan(scan_id, spans, n_cells), extra={"console": True}) + + def summary(self, summary: RunSummary) -> None: + self._log.info(format_summary(summary), extra={"console": True}) diff --git a/tests/configuration/test_observability_config.py b/tests/configuration/test_observability_config.py new file mode 100644 index 0000000..1173633 --- /dev/null +++ b/tests/configuration/test_observability_config.py @@ -0,0 +1,17 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Observability toggles live on the existing logging config section.""" + +from adapt.configuration.schemas.internal import InternalLoggingConfig + + +def test_logging_config_carries_observability_defaults() -> None: + cfg = InternalLoggingConfig(level="INFO") + assert cfg.enabled is True + assert cfg.traces is True + assert cfg.metrics is True + assert cfg.json_logs is False + assert cfg.console_logs is True + assert cfg.console_level == "WARNING" + assert cfg.progress_every == 30.0 diff --git a/tests/contracts/test_observability.py b/tests/contracts/test_observability.py new file mode 100644 index 0000000..43f2bc7 --- /dev/null +++ b/tests/contracts/test_observability.py @@ -0,0 +1,67 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Contract types for observability: the injectable DI seam. + +ObsContext/SpanRecord are pure frozen data; Observability is the Protocol that +execution/runtime depend on so the concrete provider can be injected (and later +relocated) without touching call sites. +""" + +import dataclasses + +import pytest + +from adapt.contracts.observability import ObsContext, Observability, SpanRecord + + +def test_obs_context_is_frozen_and_defaults_blank_ids() -> None: + ctx = ObsContext(pipeline_id="run-1", trace_id="abc") + assert ctx.pipeline_id == "run-1" + assert ctx.trace_id == "abc" + # everything else defaults to empty so a fresh context is harmless + assert ctx.span_id == "" + assert ctx.scan_id == "" + assert ctx.dataset_id == "" + assert ctx.stage == "" + with pytest.raises(dataclasses.FrozenInstanceError): + ctx.scan_id = "512" # type: ignore[misc] + + +def test_span_record_carries_timing_and_error() -> None: + rec = SpanRecord( + name="detection", + trace_id="t", + span_id="s", + parent_span_id="p", + start=10.0, + finish=12.5, + duration_s=2.5, + error="", + metadata={"n_cells": "42"}, + ) + assert rec.duration_s == 2.5 + assert rec.metadata["n_cells"] == "42" + with pytest.raises(dataclasses.FrozenInstanceError): + rec.error = "boom" # type: ignore[misc] + + +def test_observability_protocol_is_runtime_checkable() -> None: + class _Impl: + metrics = object() + + def span(self, name, **ctx): ... + + def bind(self, **ctx): ... + + def current(self): ... + + def drain_spans(self): ... + + def install_logging(self, log_path): ... + + class _NotImpl: + def span(self, name, **ctx): ... + + assert isinstance(_Impl(), Observability) + assert not isinstance(_NotImpl(), Observability) diff --git a/tests/graph/test_executor_observability.py b/tests/graph/test_executor_observability.py new file mode 100644 index 0000000..3b7dc9d --- /dev/null +++ b/tests/graph/test_executor_observability.py @@ -0,0 +1,121 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""GraphExecutor auto-instruments every node with a span — zero module code. + +Behaviour under test: one span per node, real durations from the injected clock, +the module runs inside its own stage context, and a node failure is recorded on +the span and re-raised (never swallowed). A graph with no provider still runs. +""" + +import random +from datetime import UTC, datetime + +import pytest + +from adapt.execution.graph.builder import GraphBuilder +from adapt.execution.graph.executor import GraphExecutor +from adapt.modules.base import BaseModule +from adapt.runtime.observability import ObsSettings, build_observability + + +class _Stub(BaseModule): + def __init__(self, name, inputs, outputs, fn=None): + self._name, self._inputs, self._outputs, self._fn = name, inputs, outputs, fn + + @property + def name(self): + return self._name + + @property + def inputs(self): + return self._inputs + + @property + def outputs(self): + return self._outputs + + def run(self, context): + return self._fn(context) if self._fn else {k: 1 for k in self._outputs} + + +def _obs(**settings): + return build_observability( + ObsSettings(**settings), + clock=lambda: 0.0, + wall_clock=lambda: datetime(2026, 1, 1, tzinfo=UTC), + rng=random.Random(0), + ) + + +def _obs_seq_clock(values): + it = iter(values) + return build_observability( + ObsSettings(), + clock=lambda: next(it), + wall_clock=lambda: datetime(2026, 1, 1, tzinfo=UTC), + rng=random.Random(0), + ) + + +def test_executor_opens_one_span_per_node_with_real_durations() -> None: + obs = _obs_seq_clock([0.0, 1.0, 1.0, 3.0]) # ingest enter/exit, detection enter/exit + ingest = _Stub("ingest", [], ["g"]) + detection = _Stub("detection", ["g"], ["c"]) + GraphExecutor(GraphBuilder([detection, ingest]).build(), observability=obs).run({}) + spans = {s.name: s for s in obs.drain_spans()} + assert set(spans) == {"ingest", "detection"} + assert spans["ingest"].duration_s == 1.0 + assert spans["detection"].duration_s == 2.0 + + +def test_node_runs_inside_its_own_stage_context() -> None: + obs = _obs() + seen: dict[str, str] = {} + + def fn(ctx): + seen["stage"] = obs.current().stage + return {"g": 1} + + GraphExecutor(GraphBuilder([_Stub("ingest", [], ["g"], fn)]).build(), observability=obs).run({}) + assert seen["stage"] == "ingest" + + +def test_node_failure_records_error_on_span_and_propagates() -> None: + obs = _obs() + + def boom(ctx): + raise ValueError("kaboom") + + nodes = GraphBuilder([_Stub("detection", [], ["c"], boom)]).build() + with pytest.raises(ValueError, match="kaboom"): + GraphExecutor(nodes, observability=obs).run({}) + span = obs.drain_spans()[0] + assert span.name == "detection" + assert "ValueError: kaboom" in span.error + assert obs.metrics.counter_total("errors_total") == 1.0 + + +def test_node_failure_annotates_failing_stage() -> None: + """A propagating module exception must name the failing stage. + + The scan-level handler logs one failure line; without this annotation it cannot + say *which* stage broke. Works with or without a provider (it identifies the + stage, independent of telemetry), so assert it on the no-provider path too. + """ + + def boom(ctx): + raise ValueError("kaboom") + + nodes = GraphBuilder([_Stub("detection", [], ["c"], boom)]).build() + with pytest.raises(ValueError) as excinfo: + GraphExecutor(nodes).run({}) + + notes = getattr(excinfo.value, "__notes__", []) + assert any("detection" in note for note in notes) + + +def test_graph_runs_without_a_provider() -> None: + nodes = GraphBuilder([_Stub("ingest", [], ["g"])]).build() + result = GraphExecutor(nodes).run({}) # no observability -> no telemetry, still runs + assert result["g"] == 1 diff --git a/tests/integration/test_integration_execution_history.py b/tests/integration/test_integration_execution_history.py new file mode 100644 index 0000000..15e0af7 --- /dev/null +++ b/tests/integration/test_integration_execution_history.py @@ -0,0 +1,112 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""End-to-end: modules run through the real executor populate execution history. + +Acceptance for the observability + execution-history chain — span seam -> drain +-> ExecutionHistory -> SQL query — with ZERO instrumentation code in the modules. +""" + +import random +from datetime import UTC, datetime + +import pytest + +from adapt.contracts.execution_history import RunProvenance, RunStart +from adapt.execution.graph.builder import GraphBuilder +from adapt.execution.graph.executor import GraphExecutor +from adapt.modules.base import BaseModule +from adapt.persistence.execution_history import ExecutionHistory +from adapt.runtime.observability import ObsSettings, build_observability + +pytestmark = pytest.mark.integration + + +class _Stub(BaseModule): + def __init__(self, name, inputs, outputs, fn=None): + self._name, self._inputs, self._outputs, self._fn = name, inputs, outputs, fn + + @property + def name(self): + return self._name + + @property + def inputs(self): + return self._inputs + + @property + def outputs(self): + return self._outputs + + def run(self, context): + return self._fn(context) if self._fn else {k: 1 for k in self._outputs} + + +def _obs(): + return build_observability( + ObsSettings(), + clock=lambda: 0.0, + wall_clock=lambda: datetime(2026, 1, 1, tzinfo=UTC), + rng=random.Random(0), + ) + + +def _started_history(tmp_path): + history = ExecutionHistory(tmp_path / "catalog.db") + history.start_run( + RunStart( + run_id="R1", + pipeline="nexrad", + pipeline_version="0.4.1", + site="KDIX", + dataset="KDIX", + instrument="NEXRAD", + mode="historical", + start_time=datetime(2026, 6, 28, tzinfo=UTC), + configuration_hash="abc", + configuration_file="cfg.yaml", + provenance=RunProvenance(None, "h", "u", "3.11", "linux", "0.4.1"), + enabled_modules=("ingest", "detection"), + ) + ) + return history + + +def test_modules_run_populate_module_history_and_share_one_trace(tmp_path): + obs = _obs() + history = _started_history(tmp_path) + ingest = _Stub("ingest", [], ["g"]) + detection = _Stub("detection", ["g"], ["c"]) + + with obs.span("scan"): + GraphExecutor(GraphBuilder([detection, ingest]).build(), observability=obs).run({}) + + module_spans = [s for s in obs.drain_spans() if s.name != "scan"] + history.record_modules( + "R1", "scan1", module_spans, recorded_at=datetime(2026, 6, 28, tzinfo=UTC) + ) + + rows = {m["module"]: m for m in history.query_modules(run_id="R1")} + assert set(rows) == {"ingest", "detection"} + assert all(r["status"] == "ok" for r in rows.values()) + assert len({s.trace_id for s in module_spans}) == 1 # one trace spans the scan + + +def test_failing_module_records_error_row_and_failure_rate(tmp_path): + obs = _obs() + history = _started_history(tmp_path) + + def boom(_ctx): + raise ValueError("kaboom") + + nodes = GraphBuilder([_Stub("detection", [], ["c"], boom)]).build() + with pytest.raises(ValueError, match="kaboom"): + GraphExecutor(nodes, observability=obs).run({}) + + module_spans = [s for s in obs.drain_spans() if s.name != "scan"] + history.record_modules( + "R1", "scan1", module_spans, recorded_at=datetime(2026, 6, 28, tzinfo=UTC) + ) + + assert history.query_modules(run_id="R1")[0]["status"] == "error" + assert history.failure_rate_by_module()["detection"] == 1.0 diff --git a/tests/modules/acquisition/test_downloader_quiet.py b/tests/modules/acquisition/test_downloader_quiet.py new file mode 100644 index 0000000..7554c4d --- /dev/null +++ b/tests/modules/acquisition/test_downloader_quiet.py @@ -0,0 +1,39 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""The nexradaws download call must not leak its own print() chatter to the console. + +nexradaws prints "Downloaded " and " out of files downloaded..." straight +to stdout with no quiet option. The acquisition module already logs a controlled +"Downloaded: " line, so the library's duplicate prints are pure clutter and must +be contained at the one call site (no supported quiet flag exists). +""" + +from datetime import UTC, datetime + +import pytest + +from adapt.modules.acquisition.module import AwsNexradDownloader + +pytestmark = pytest.mark.unit + + +def test_download_scan_suppresses_nexradaws_stdout(tmp_path, fake_scan, make_config, capsys): + class PrintingConn: + def download(self, files, basepath, keep_aws_folders=False): + print("Downloaded KOHX_TEST") # nexradaws chatter + print("1 out of 1 files downloaded...0 errors") + + class _Results: + def iter_success(self): + return [] + + return _Results() + + d = AwsNexradDownloader(make_config(), output_dir=tmp_path, conn=PrintingConn()) + + d._download_scan(fake_scan("KOHX_TEST", datetime.now(UTC)), tmp_path / "out.nc") + + out = capsys.readouterr().out + assert "out of 1 files downloaded" not in out + assert "Downloaded KOHX_TEST" not in out diff --git a/tests/modules/ingest/test_ingest_quiet.py b/tests/modules/ingest/test_ingest_quiet.py new file mode 100644 index 0000000..362468a --- /dev/null +++ b/tests/modules/ingest/test_ingest_quiet.py @@ -0,0 +1,33 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Importing the ingest module must not splash the Py-ART citation banner to stdout. + +Py-ART prints a multi-line citation block on import unless ``PYART_QUIET`` is set. +That block is pure console clutter for an Adapt run. The ingest module owns the only +``import pyart`` in the package, so it must set the supported suppression env var +*before* importing pyart. Verified in a clean subprocess so the assertion reflects a +fresh interpreter, not one where pyart was already imported by the test session. +""" + +import os +import subprocess +import sys + +import pytest + +pytestmark = pytest.mark.unit + + +def test_importing_ingest_does_not_print_pyart_citation(): + env = {k: v for k, v in os.environ.items() if k != "PYART_QUIET"} + result = subprocess.run( + [sys.executable, "-c", "import adapt.modules.ingest.module"], + capture_output=True, + text=True, + env=env, + ) + + assert result.returncode == 0, result.stderr + assert "Py-ART" not in result.stdout + assert "jors.119" not in result.stdout diff --git a/tests/modules/projection/test_projector_internal_utils.py b/tests/modules/projection/test_projector_internal_utils.py index 894db7b..f46d991 100644 --- a/tests/modules/projection/test_projector_internal_utils.py +++ b/tests/modules/projection/test_projector_internal_utils.py @@ -1,3 +1,5 @@ +import logging + import numpy as np import pytest @@ -31,3 +33,24 @@ def test_fill_concave_hull_small_object_falls_back(make_projection_config): filled = proj._fill_concave_hull(mask) assert filled.any() + + +def test_fill_concave_hull_collinear_points_no_qhull_warning(make_projection_config, caplog): + """Rank-deficient (collinear) points must not log a QhullError warning. + + A column of >=4 set pixels is < 2-D, so a Delaunay triangulation raises QH6013 + ("input is less than 3-dimensional"). The old code caught it and logged the full + multi-line qhull dump as a WARNING for every such cell — pure clutter. The + projector must detect degeneracy up front and fall back to dilation silently. + """ + config = make_projection_config() + proj = RadarCellProjector(config) + + mask = np.zeros((8, 8), dtype=bool) + mask[1:6, 3] = True # 5 collinear points in a single column -> rank 1 + + with caplog.at_level(logging.WARNING, logger="adapt.modules.projection.module"): + filled = proj._fill_concave_hull(mask) + + assert filled.any() # dilation fallback still fills the line + assert not [r for r in caplog.records if "oncave hull" in r.getMessage()] diff --git a/tests/persistence/test_data_repository.py b/tests/persistence/test_data_repository.py index bbd235b..a2654a0 100644 --- a/tests/persistence/test_data_repository.py +++ b/tests/persistence/test_data_repository.py @@ -296,6 +296,40 @@ def test_write_netcdf_with_string_variable(self, repository): finally: reopened.close() + def test_open_dataset_uses_explicit_netcdf4_engine( + self, repository, sample_dataset, monkeypatch + ): + """open_dataset must pin engine='netcdf4' instead of letting xarray sniff. + + Adapt writes its NetCDF with the netcdf4 engine, so reads should name it too. + Without an explicit engine xarray probes every backend, and the h5netcdf probe + makes libhdf5 splatter ``HDF5-DIAG`` blocks to stderr ("Not an HDF5 file") for + every classic-format file. Pinning the engine skips the probe entirely. + """ + from adapt.persistence import repository as repo_module + + artifact_id = repository.write_netcdf( + ds=sample_dataset, + product_type=ProductType.GRIDDED_NC, + scan_time=datetime(2026, 2, 11, 12, 0, 0, tzinfo=UTC), + producer="loader", + ) + + captured: dict = {} + real_open = repo_module.xr.open_dataset + + def spy(path, *args, **kwargs): + captured["engine"] = kwargs.get("engine") + return real_open(path, *args, **kwargs) + + monkeypatch.setattr(repo_module.xr, "open_dataset", spy) + + ds = repository.open_dataset(artifact_id) + try: + assert captured["engine"] == "netcdf4" + finally: + ds.close() + def test_write_parquet(self, repository, sample_dataframe): """Should write Parquet and register artifact.""" artifact_id = repository.write_parquet( diff --git a/tests/persistence/test_execution_history.py b/tests/persistence/test_execution_history.py new file mode 100644 index 0000000..e00a71a --- /dev/null +++ b/tests/persistence/test_execution_history.py @@ -0,0 +1,150 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Execution History: durable, queryable run/module/error/warning records. + +Behaviour under test (round-trips that can break): a run is recorded as +'running' then finalized with real aggregates; one module_history row per span +with the right status; errors/warnings are stored and queryable; failure rate by +module is computed from the stored rows. +""" + +from datetime import UTC, datetime + +from adapt.contracts.execution_history import ( + ErrorEvent, + RunProvenance, + RunStart, + RunSummary, + WarningEvent, +) +from adapt.contracts.observability import SpanRecord +from adapt.persistence.execution_history import ExecutionHistory + + +def _prov(): + return RunProvenance( + git_commit="abc123", + hostname="host", + username="user", + python_version="3.11", + platform="linux", + software_version="0.4.1", + ) + + +def _start(run_id="R1"): + return RunStart( + run_id=run_id, + pipeline="nexrad", + pipeline_version="0.4.1", + site="KDIX", + dataset="KDIX", + instrument="NEXRAD", + mode="historical", + start_time=datetime(2026, 6, 28, 0, 0, tzinfo=UTC), + configuration_hash="deadbeef", + configuration_file="cfg.yaml", + provenance=_prov(), + enabled_modules=("ingest", "detection"), + ) + + +def _history(tmp_path): + return ExecutionHistory(tmp_path / "catalog.db") + + +def test_start_run_inserts_running_row_with_provenance(tmp_path): + h = _history(tmp_path) + h.start_run(_start()) + runs = h.query_runs() + assert len(runs) == 1 + assert runs[0]["run_id"] == "R1" + assert runs[0]["status"] == "running" + assert runs[0]["git_commit"] == "abc123" + assert runs[0]["site"] == "KDIX" + + +def test_record_modules_one_row_per_span_with_status(tmp_path): + h = _history(tmp_path) + h.start_run(_start()) + spans = [ + SpanRecord("detection", "t", "s1", "p", 0.0, 2.0, 2.0, "", {}), + SpanRecord("tracking", "t", "s2", "p", 2.0, 2.5, 0.5, "ValueError: x", {}), + ] + h.record_modules("R1", "scan1", spans, recorded_at=datetime(2026, 6, 28, 0, 1, tzinfo=UTC)) + by = {m["module"]: m for m in h.query_modules(run_id="R1")} + assert by["detection"]["duration_seconds"] == 2.0 + assert by["detection"]["status"] == "ok" + assert by["tracking"]["status"] == "error" + assert by["tracking"]["scan_id"] == "scan1" + + +def test_finalize_run_updates_status_and_aggregates(tmp_path): + h = _history(tmp_path) + h.start_run(_start()) + h.finalize_run( + RunSummary( + run_id="R1", + status="success", + end_time=datetime(2026, 6, 28, 0, 14, tzinfo=UTC), + duration_seconds=862.0, + files_processed=842, + scans_processed=842, + objects_detected=18431, + warnings=7, + errors=0, + average_scan_time=1.03, + maximum_scan_time=3.92, + slowest_stages=(("detection", 371.0), ("tracking", 182.0)), + ) + ) + run = h.query_runs()[0] + assert run["status"] == "success" + assert run["duration_seconds"] == 862.0 + assert run["scans_processed"] == 842 + assert run["objects_detected"] == 18431 + assert run["slowest_stage"] == "detection" + assert run["slowest_stage_duration"] == 371.0 + + +def test_errors_and_warnings_stored_and_failure_rate(tmp_path): + h = _history(tmp_path) + h.start_run(_start()) + ok = [SpanRecord("detection", "t", "s", "p", 0, 1, 1.0, "", {})] + bad = [SpanRecord("detection", "t", "s", "p", 0, 1, 1.0, "E: x", {})] + h.record_modules("R1", "scan1", ok, recorded_at=datetime(2026, 6, 28, tzinfo=UTC)) + h.record_modules("R1", "scan2", bad, recorded_at=datetime(2026, 6, 28, tzinfo=UTC)) + h.record_errors( + "R1", + [ + ErrorEvent( + scan_id="scan2", + module="detection", + exception_type="ValueError", + message="x", + traceback="tb...", + logger="adapt.detection", + timestamp=datetime(2026, 6, 28, tzinfo=UTC), + ) + ], + ) + h.record_warnings( + "R1", + [ + WarningEvent( + scan_id="scan1", + module="detection", + category="slow_execution", + message="slow", + logger="adapt.detection", + timestamp=datetime(2026, 6, 28, tzinfo=UTC), + ) + ], + ) + assert h.failure_rate_by_module()["detection"] == 0.5 # 1 of 2 module runs errored + errs = h.query_errors(run_id="R1") + assert errs[0]["exception_type"] == "ValueError" + assert errs[0]["traceback"] == "tb..." + warns = h.query_warnings(run_id="R1") + assert warns[0]["category"] == "slow_execution" diff --git a/tests/runtime/test_diagnostics.py b/tests/runtime/test_diagnostics.py new file mode 100644 index 0000000..2a2bd74 --- /dev/null +++ b/tests/runtime/test_diagnostics.py @@ -0,0 +1,62 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""HDF5 diagnostics are silenced through h5py's supported control, on the I/O thread. + +We cannot reliably force libhdf5 to print a diagnostic in every environment, so the +behaviour under test is the contract: the helper invokes h5py's official silence, +is idempotent, and the processor applies it at worker-thread entry (HDF5 error +stacks are thread-local, so a single import-time call would miss the worker). +""" + +import queue +import random +from datetime import UTC, datetime + +import pytest + +from adapt.runtime import diagnostics +from adapt.runtime.observability import ObsSettings, build_observability +from adapt.runtime.processor import RadarProcessor + +pytestmark = [pytest.mark.unit, pytest.mark.pipeline] + + +def test_silence_hdf5_errors_calls_h5py_supported_control(monkeypatch): + calls = [] + monkeypatch.setattr(diagnostics.h5py._errors, "silence_errors", lambda: calls.append(True)) + + diagnostics.silence_hdf5_errors() + diagnostics.silence_hdf5_errors() # idempotent — safe to call per thread entry + + assert calls == [True, True] + + +def _obs(): + return build_observability( + ObsSettings(), + clock=lambda: 0.0, + wall_clock=lambda: datetime(2026, 1, 1, tzinfo=UTC), + rng=random.Random(0), + ) + + +def test_processor_silences_hdf5_at_worker_thread_entry( + monkeypatch, pipeline_config, pipeline_output_dirs, test_repository +): + import adapt.runtime.processor as processor_module + + called = [] + monkeypatch.setattr(processor_module, "silence_hdf5_errors", lambda: called.append(True)) + + proc = RadarProcessor( + queue.Queue(), + pipeline_config, + pipeline_output_dirs, + repository=test_repository, + observability=_obs(), + ) + proc.stop() # so _run_loop exits immediately; run() still hits the thread-entry call + proc.run() + + assert called == [True] # silenced once, at the top of the worker thread diff --git a/tests/runtime/test_history_handler.py b/tests/runtime/test_history_handler.py new file mode 100644 index 0000000..1e112a5 --- /dev/null +++ b/tests/runtime/test_history_handler.py @@ -0,0 +1,64 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""HistoryLogHandler turns existing warning/error logs into searchable events. + +Behaviour under test: a WARNING becomes a WarningEvent and an exception log +becomes an ErrorEvent carrying the bound scan_id/module + the real traceback; +emit only buffers (no DB), and drain returns then clears. +""" + +import logging +import random +from datetime import UTC, datetime + +from adapt.runtime.history_handler import HistoryLogHandler +from adapt.runtime.observability import ObsSettings, build_observability + + +def _obs(): + return build_observability( + ObsSettings(), + clock=lambda: 0.0, + wall_clock=lambda: datetime(2026, 1, 1, tzinfo=UTC), + rng=random.Random(0), + ) + + +def test_handler_captures_warning_and_error_with_context_and_traceback(): + obs = _obs() + handler = HistoryLogHandler() + log = logging.getLogger("adapt.test.history_handler") + log.handlers[:] = [handler] + log.setLevel(logging.DEBUG) + log.propagate = False + + with obs.bind(scan_id="scan9", stage="detection"): + log.warning("slow scan", extra={"category": "slow_execution"}) + try: + raise ValueError("boom") + except ValueError: + log.exception("processing failed") + + warnings, errors = handler.drain() + assert len(warnings) == 1 + assert warnings[0].scan_id == "scan9" + assert warnings[0].module == "detection" + assert warnings[0].category == "slow_execution" + + assert len(errors) == 1 + assert errors[0].scan_id == "scan9" + assert errors[0].exception_type == "ValueError" + assert "ValueError: boom" in errors[0].traceback + + assert handler.drain() == ([], []) # drain clears the buffers + + +def test_emit_only_buffers_without_db_and_defaults_category(): + handler = HistoryLogHandler() + record = logging.LogRecord("adapt.x", logging.WARNING, __file__, 1, "plain", None, None) + handler.emit(record) # no DB handle exists; pure in-memory append + warnings, errors = handler.drain() + assert len(warnings) == 1 + assert warnings[0].category == "general" # default when none provided + assert errors == [] diff --git a/tests/runtime/test_logging_setup.py b/tests/runtime/test_logging_setup.py new file mode 100644 index 0000000..73929d6 --- /dev/null +++ b/tests/runtime/test_logging_setup.py @@ -0,0 +1,143 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Structured logging setup: context injection, JSON output, quiet console.""" + +import json +import logging +import random +from datetime import UTC, datetime + +import pytest + +from adapt.runtime.logging_setup import ( + ConsoleFilter, + ContextFilter, + JsonFormatter, + configure_logging, +) +from adapt.runtime.observability import ObsSettings, build_observability + + +@pytest.fixture(autouse=True) +def _restore_root_logger(): + root = logging.getLogger() + saved = root.handlers[:] + saved_level = root.level + yield + root.handlers[:] = saved + root.setLevel(saved_level) + + +def _obs(**settings): + return build_observability( + ObsSettings(**settings), + clock=lambda: 0.0, + wall_clock=lambda: datetime(2026, 1, 1, tzinfo=UTC), + rng=random.Random(0), + ) + + +def _record(level=logging.INFO, msg="hello", **extra): + rec = logging.LogRecord("adapt.x", level, __file__, 1, msg, None, None) + for k, v in extra.items(): + setattr(rec, k, v) + return rec + + +def test_context_filter_attaches_current_ids() -> None: + obs = _obs() + rec = _record() + with obs.bind(scan_id="512", dataset_id="KDIX"): + ContextFilter().filter(rec) + assert rec.scan_id == "512" + assert rec.dataset_id == "KDIX" + + +def test_json_formatter_emits_structured_fields_without_trailing_z() -> None: + rec = _record(scan_id="512", n_cells=42) + rec.scan_id = "512" # as ContextFilter would set + out = json.loads(JsonFormatter().format(rec)) + assert out["level"] == "INFO" + assert out["logger"] == "adapt.x" + assert out["msg"] == "hello" + assert out["scan_id"] == "512" + assert out["n_cells"] == 42 # arbitrary extra field surfaces + assert "Z" not in out["ts"] # the "...Z" literal is fitness-pinned to utils/time + + +def test_json_formatter_includes_exception() -> None: + try: + raise ValueError("boom") + except ValueError: + import sys + + rec = logging.LogRecord("adapt.x", logging.ERROR, __file__, 1, "fail", None, sys.exc_info()) + out = json.loads(JsonFormatter().format(rec)) + assert "ValueError: boom" in out["exc"] + + +def test_console_filter_passes_warning_and_console_tagged_only() -> None: + cf = ConsoleFilter() + assert cf.filter(_record(level=logging.WARNING)) is True + assert cf.filter(_record(level=logging.INFO, console=True)) is True + assert cf.filter(_record(level=logging.INFO)) is False + assert cf.filter(_record(level=logging.DEBUG)) is False + + +def test_configure_logging_requires_log_path_for_json() -> None: + with pytest.raises(ValueError, match="log_path"): + configure_logging(ObsSettings(json_logs=True), log_path=None) + + +def test_configure_logging_creates_dir_and_is_idempotent(tmp_path) -> None: + log_path = tmp_path / "logs" / "pipeline_KDIX.log" + configure_logging(ObsSettings(json_logs=True, console_logs=True), log_path=log_path) + first = len(logging.getLogger().handlers) + configure_logging(ObsSettings(json_logs=True, console_logs=True), log_path=log_path) + second = len(logging.getLogger().handlers) + assert log_path.parent.is_dir() + assert first == second # handlers cleared+re-added, no accumulation + + +def _console_stream_output(settings, records): + """Run records through a freshly configured root logger; return console text.""" + import io + + configure_logging(settings, log_path=None) + root = logging.getLogger() + stream = io.StringIO() + console = next(h for h in root.handlers if isinstance(h, logging.StreamHandler)) + console.stream = stream + log = logging.getLogger("adapt.run") + for level, msg, console_tag in records: + log.log(level, msg, extra={"console": True} if console_tag else {}) + return stream.getvalue() + + +def test_console_shows_info_console_tagged_lines_despite_warning_threshold() -> None: + """Console-tagged INFO (run header/progress/summary) must reach the console even + when console_level is WARNING. The handler level must not pre-empt ConsoleFilter: + that gating bug made the run header and end-of-run summary vanish entirely. + """ + out = _console_stream_output( + ObsSettings(console_logs=True, console_level="WARNING"), + [ + (logging.INFO, "RUN HEADER", True), + (logging.INFO, "routine chatter", False), + (logging.WARNING, "a real warning", False), + ], + ) + assert "RUN HEADER" in out # console-tagged INFO survives + assert "a real warning" in out # warnings always survive + assert "routine chatter" not in out # plain INFO stays off the console + + +def test_verbose_console_level_lets_plain_info_through() -> None: + """console_level=INFO (the -v firehose) shows plain INFO too.""" + out = _console_stream_output( + ObsSettings(console_logs=True, console_level="INFO"), + [(logging.INFO, "routine chatter", False), (logging.DEBUG, "debug noise", False)], + ) + assert "routine chatter" in out + assert "debug noise" not in out diff --git a/tests/runtime/test_observability.py b/tests/runtime/test_observability.py new file mode 100644 index 0000000..6c242e0 --- /dev/null +++ b/tests/runtime/test_observability.py @@ -0,0 +1,260 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Observability provider: context propagation, spans, metrics, disabled path. + +The provider reads an injected clock/wall-clock/rng (never the wall clock by +default) so durations and ids are deterministic in tests. +""" + +import random +import threading +from datetime import UTC, datetime + +import pytest + +from adapt.runtime.observability import ( + ObsSettings, + build_observability, + disabled_observability, +) + + +def _obs(**settings): + return build_observability( + ObsSettings(**settings), + clock=lambda: 0.0, + wall_clock=lambda: datetime(2026, 1, 1, tzinfo=UTC), + rng=random.Random(0), + ) + + +def _obs_seq_clock(values, **settings): + it = iter(values) + return build_observability( + ObsSettings(**settings), + clock=lambda: next(it), + wall_clock=lambda: datetime(2026, 1, 1, tzinfo=UTC), + rng=random.Random(0), + ) + + +def test_fresh_provider_has_blank_context() -> None: + obs = _obs() + assert obs.current().pipeline_id == "" + assert obs.current().scan_id == "" + + +def test_bind_sets_then_restores() -> None: + obs = _obs() + with obs.bind(scan_id="512", dataset_id="KDIX"): + assert obs.current().scan_id == "512" + assert obs.current().dataset_id == "KDIX" + assert obs.current().scan_id == "" + assert obs.current().dataset_id == "" + + +def test_nested_bind_overrides_then_restores() -> None: + obs = _obs() + with obs.bind(scan_id="1"): + with obs.bind(scan_id="2"): + assert obs.current().scan_id == "2" + assert obs.current().scan_id == "1" + + +def test_fresh_thread_sees_blank_context_until_it_binds() -> None: + """contextvars do NOT cross threads — each worker must bind its own context.""" + obs = _obs() + seen: dict[str, str] = {} + with obs.bind(scan_id="outer"): + + def worker() -> None: + seen["before"] = obs.current().scan_id + with obs.bind(scan_id="inner"): + seen["during"] = obs.current().scan_id + + t = threading.Thread(target=worker) + t.start() + t.join() + assert seen["before"] == "" + assert seen["during"] == "inner" + + +def test_span_records_duration_from_injected_clock() -> None: + obs = _obs_seq_clock([10.0, 12.5]) + with obs.span("detection"): + pass + spans = obs.drain_spans() + assert len(spans) == 1 + assert spans[0].name == "detection" + assert spans[0].duration_s == 2.5 + + +def test_span_binds_stage_and_mints_trace_in_context() -> None: + obs = _obs_seq_clock([0.0, 1.0]) + with obs.span("ingest"): + assert obs.current().stage == "ingest" + assert obs.current().trace_id != "" + assert obs.current().span_id != "" + assert obs.current().stage == "" # restored on exit + + +def test_nested_spans_share_trace_and_link_parent() -> None: + obs = _obs_seq_clock([0.0, 1.0, 2.0, 3.0]) + with obs.span("parent"): + parent_ctx = obs.current() + with obs.span("child"): + child_ctx = obs.current() + assert child_ctx.trace_id == parent_ctx.trace_id + assert child_ctx.span_id != parent_ctx.span_id + spans = {s.name: s for s in obs.drain_spans()} + assert spans["child"].parent_span_id == spans["parent"].span_id + assert spans["child"].trace_id == spans["parent"].trace_id + + +def test_span_records_error_and_reraises() -> None: + obs = _obs_seq_clock([0.0, 1.0]) + with pytest.raises(ValueError), obs.span("boom"): + raise ValueError("nope") + spans = obs.drain_spans() + assert spans[0].error.startswith("ValueError") + + +def test_span_set_adds_metadata() -> None: + obs = _obs_seq_clock([0.0, 1.0]) + with obs.span("detection") as s: + s.set(n_cells=42) + assert obs.drain_spans()[0].metadata["n_cells"] == "42" + + +def test_drain_spans_returns_then_clears() -> None: + obs = _obs_seq_clock([0.0, 1.0]) + with obs.span("a"): + pass + assert len(obs.drain_spans()) == 1 + assert obs.drain_spans() == [] + + +def test_counter_accumulates_across_calls() -> None: + obs = _obs() + obs.metrics.incr("files_processed_total") + obs.metrics.incr("files_processed_total", 2.0) + assert obs.metrics.counter_total("files_processed_total") == 3.0 + + +def test_counter_total_sums_over_labels() -> None: + obs = _obs() + obs.metrics.incr("errors_total", stage="detection") + obs.metrics.incr("errors_total", stage="tracking") + assert obs.metrics.counter_total("errors_total") == 2.0 + + +def test_gauge_is_last_write_wins() -> None: + obs = _obs() + obs.metrics.gauge("queue_depth", 5) + obs.metrics.gauge("queue_depth", 2) + assert obs.metrics.gauge_value("queue_depth") == 2 + + +def test_histogram_collects_values() -> None: + obs = _obs() + obs.metrics.observe("scan_processing_time", 1.0) + obs.metrics.observe("scan_processing_time", 3.0) + assert obs.metrics.histogram_values("scan_processing_time") == [1.0, 3.0] + + +def test_histogram_totals_by_stage_label() -> None: + obs = _obs() + obs.metrics.observe("module_duration_seconds", 2.0, stage="detection") + obs.metrics.observe("module_duration_seconds", 1.0, stage="detection") + obs.metrics.observe("module_duration_seconds", 4.0, stage="tracking") + totals = obs.metrics.histogram_totals_by_label("module_duration_seconds", "stage") + assert totals == {"detection": 3.0, "tracking": 4.0} + + +def test_concurrent_counter_increments_sum_correctly() -> None: + obs = _obs() + + def bump() -> None: + for _ in range(1000): + obs.metrics.incr("c") + + threads = [threading.Thread(target=bump) for _ in range(4)] + for t in threads: + t.start() + for t in threads: + t.join() + assert obs.metrics.counter_total("c") == 4000.0 + + +def test_span_emits_duration_histogram_and_error_counter() -> None: + obs = _obs_seq_clock([0.0, 2.0]) + with pytest.raises(ValueError), obs.span("detection"): + raise ValueError("x") + assert obs.metrics.histogram_values("module_duration_seconds") == [2.0] + assert obs.metrics.counter_total("errors_total") == 1.0 + + +def test_disabled_provider_span_is_shared_noop() -> None: + obs = _obs(enabled=False) + with obs.span("a") as h: + h.set(x=1) + assert obs.drain_spans() == [] + assert obs.span("b") is obs.span("c") # shared no-op, no per-call allocation + + +def test_disabled_provider_span_does_not_touch_context() -> None: + obs = _obs(enabled=False) + with obs.span("a"): + assert obs.current().stage == "" + + +def test_disabled_provider_bind_and_metrics_are_noop() -> None: + obs = _obs(enabled=False) + with obs.bind(scan_id="z"): + assert obs.current().scan_id == "" + obs.metrics.incr("x") + assert obs.metrics.counter_total("x") == 0.0 + + +def test_disabled_observability_helper_records_nothing() -> None: + obs = disabled_observability() + with obs.span("a"): + obs.metrics.incr("c") + assert obs.drain_spans() == [] + assert obs.metrics.counter_total("c") == 0.0 + + +def test_span_exit_survives_cross_context_teardown(): + """Entering a span in one contextvars Context and exiting it in another must + restore the parent context without raising. + + This is the Ctrl+C shutdown path: the orchestrator opens the root "pipeline" + span in start() and closes it in stop(); if teardown straddles a copied + Context the raw Token.reset() raises ValueError. Teardown must stay + deterministic and crash-free, leaving a consistent (blank) context. + """ + import contextvars + + obs = _obs() + span = obs.span("pipeline", pipeline_id="R1") + + # Enter inside an isolated copied Context so the reset Token is foreign to the + # outer Context where we exit — exactly the straddle that crashes on shutdown. + contextvars.copy_context().run(span.__enter__) + span.__exit__(None, None, None) # must not raise + + assert obs.current().pipeline_id == "" # parent context restored, not corrupted + + +def test_bind_exit_survives_cross_context_teardown(): + """bind() teardown must also tolerate a cross-Context exit.""" + import contextvars + + obs = _obs() + cm = obs.bind(scan_id="s1") + + contextvars.copy_context().run(cm.__enter__) + cm.__exit__(None, None, None) # must not raise + + assert obs.current().scan_id == "" diff --git a/tests/runtime/test_orchestrator.py b/tests/runtime/test_orchestrator.py index ec1f17e..dda2d56 100644 --- a/tests/runtime/test_orchestrator.py +++ b/tests/runtime/test_orchestrator.py @@ -1,3 +1,5 @@ +import time + import pytest from adapt.runtime.orchestrator import PipelineOrchestrator @@ -5,6 +7,35 @@ pytestmark = [pytest.mark.unit, pytest.mark.pipeline] +def test_orchestrator_build_run_summary_aggregates_metrics_and_history( + pipeline_config, test_repository +): + orch = PipelineOrchestrator(pipeline_config) + orch._obs = orch._build_observability() + orch.repository = test_repository + orch.run_id = test_repository.run_id + orch._start_time = time.time() - 10 + + m = orch._obs.metrics + m.incr("files_processed_total") + m.incr("files_processed_total") + m.incr("cells_detected_total", value=18431) + m.observe("scan_processing_time", 1.0) + m.observe("scan_processing_time", 3.0) + m.observe("module_duration_seconds", 5.0, stage="detection") + m.observe("module_duration_seconds", 2.0, stage="tracking") + + summary = orch._build_run_summary("success") + assert summary.status == "success" + assert summary.files_processed == 2 + assert summary.objects_detected == 18431 + assert summary.scans_processed == 2 + assert summary.average_scan_time == 2.0 + assert summary.maximum_scan_time == 3.0 + assert summary.slowest_stages[0] == ("detection", 5.0) + assert summary.duration_seconds >= 10.0 + + def test_orchestrator_initialization(pipeline_config): """Orchestrator initializes with config.""" orch = PipelineOrchestrator(pipeline_config) @@ -19,6 +50,16 @@ def test_orchestrator_logging_and_tracker(pipeline_config): assert orch.tracker is not None +def test_orchestrator_builds_working_observability_from_config(pipeline_config): + """The provider built from config actually records spans and metrics.""" + orch = PipelineOrchestrator(pipeline_config) + obs = orch._build_observability() + with obs.span("ingest"): + obs.metrics.incr("files_processed_total") + assert [s.name for s in obs.drain_spans()] == ["ingest"] + assert obs.metrics.counter_total("files_processed_total") == 1.0 + + def test_orchestrator_queue_wiring(pipeline_config): """Orchestrator creates queues with correct size limits.""" orch = PipelineOrchestrator(pipeline_config) diff --git a/tests/runtime/test_processor_observability.py b/tests/runtime/test_processor_observability.py new file mode 100644 index 0000000..e14c032 --- /dev/null +++ b/tests/runtime/test_processor_observability.py @@ -0,0 +1,180 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""The processor injects the provider, binds scan context, and emits scan metrics. + +Behaviour under test: processing a file opens a "scan" span, increments +files_processed_total, records a scan_processing_time observation, and the +scan_id is visible in the bound context while the executors run. +""" + +import logging +import queue +import random +from datetime import UTC, datetime + +import numpy as np +import pandas as pd +import pytest +import xarray as xr + +from adapt.runtime.observability import ObsSettings, build_observability +from adapt.runtime.processor import RadarProcessor + +pytestmark = [pytest.mark.unit, pytest.mark.pipeline] + + +def _fake_ds(): + return xr.Dataset( + { + "reflectivity": (("y", "x"), np.ones((4, 4))), + "cell_labels": (("y", "x"), np.zeros((4, 4), dtype=int)), + }, + coords={"x": np.arange(4), "y": np.arange(4)}, + attrs={"z_level_m": 2000}, + ) + + +def _obs(): + return build_observability( + ObsSettings(), + clock=lambda: 0.0, + wall_clock=lambda: datetime(2026, 1, 1, tzinfo=UTC), + rng=random.Random(0), + ) + + +def test_processor_emits_scan_metrics_and_binds_scan_id( + monkeypatch, pipeline_config, pipeline_output_dirs, test_repository +): + obs = _obs() + proc = RadarProcessor( + queue.Queue(), + pipeline_config, + pipeline_output_dirs, + repository=test_repository, + observability=obs, + ) + + seen_scan_ids: list[str] = [] + scan_times = [ + datetime(2024, 5, 18, 12, 0, 0, tzinfo=UTC), + datetime(2024, 5, 18, 12, 5, 0, tzinfo=UTC), + ] + + def _fake_single(context): + seen_scan_ids.append(obs.current().scan_id) # context bound by process_file + return { + "grid_ds": _fake_ds(), + "grid_ds_2d": _fake_ds(), + "segmented_ds": _fake_ds(), + "scan_time": scan_times.pop(0), + "num_cells": 0, + } + + fake_multi = { + "projected_ds": _fake_ds(), + "cell_stats": pd.DataFrame(), + "cell_adjacency": pd.DataFrame(), + } + monkeypatch.setattr(proc._executors[1], "run", _fake_single) + monkeypatch.setattr(proc._executors[2], "run", lambda ctx: fake_multi) + monkeypatch.setattr(proc._router, "persist", lambda modules, result, meta: None) + + assert proc.process_file("/fake/file_1") is True + assert proc.process_file("/fake/file_2") is True + + assert obs.metrics.counter_total("files_processed_total") == 2.0 + assert len(obs.metrics.histogram_values("scan_processing_time")) == 2 + assert seen_scan_ids == ["file_1", "file_2"] # scan_id bound while executors ran + assert obs.drain_spans() == [] # processor drained each scan's spans for history + + +def test_processor_logs_single_enriched_traceback_on_scan_failure( + monkeypatch, caplog, pipeline_config, pipeline_output_dirs, test_repository +): + """A failing scan must log exactly one stack trace, carrying scan/elapsed/type. + + The bare "Error processing " line gave no stage, scan, elapsed, or + exception type and risked a second trace if the error re-bubbled. The handler + must report the failure once, with enough context to act on without log-diving. + """ + obs = _obs() + proc = RadarProcessor( + queue.Queue(), + pipeline_config, + pipeline_output_dirs, + repository=test_repository, + observability=obs, + ) + + def _boom(context): + raise ValueError("kaboom") + + monkeypatch.setattr(proc._executors[1], "run", _boom) + + with caplog.at_level(logging.ERROR): + result = proc.process_file("/fake/file_9") + + assert result is False + traced = [r for r in caplog.records if r.exc_info] + assert len(traced) == 1 # one trace per failure, never duplicated + rec = traced[0] + assert "file_9" in rec.getMessage() + assert "ValueError" in rec.getMessage() + assert rec.error_type == "ValueError" + assert isinstance(rec.elapsed_s, float) + + +def test_processor_emits_per_scan_progress_from_spans( + monkeypatch, pipeline_config, pipeline_output_dirs, test_repository +): + """After each scan the processor hands the captured module spans to the reporter, + so the console progress line is driven by telemetry, not module-level prints. + """ + obs = _obs() + + calls: list[tuple] = [] + + class _Reporter: + def scan(self, scan_id, spans, n_cells): + calls.append((scan_id, [s.name for s in spans], n_cells)) + + proc = RadarProcessor( + queue.Queue(), + pipeline_config, + pipeline_output_dirs, + repository=test_repository, + observability=obs, + reporter=_Reporter(), + ) + + def _fake_single(context): + # Mimic GraphExecutor opening one span per module so the drained telemetry + # carries real stage records for the progress line. + for stage in ("ingest", "detection"): + with obs.span(stage): + pass + return { + "grid_ds": _fake_ds(), + "grid_ds_2d": _fake_ds(), + "segmented_ds": _fake_ds(), + "scan_time": datetime(2024, 5, 18, 12, 0, 0, tzinfo=UTC), + "num_cells": 0, + } + + monkeypatch.setattr(proc._executors[1], "run", _fake_single) + fake_multi = { + "projected_ds": _fake_ds(), + "cell_stats": pd.DataFrame(), + "cell_adjacency": pd.DataFrame(), + } + monkeypatch.setattr(proc._executors[2], "run", lambda ctx: fake_multi) + monkeypatch.setattr(proc._router, "persist", lambda modules, result, meta: None) + + assert proc.process_file("/fake/file_7") is True + + assert len(calls) == 1 + scan_id, stage_names, _ = calls[0] + assert scan_id == "file_7" + assert stage_names # the per-scan line carries the executed stages diff --git a/tests/runtime/test_provenance.py b/tests/runtime/test_provenance.py new file mode 100644 index 0000000..8204d9c --- /dev/null +++ b/tests/runtime/test_provenance.py @@ -0,0 +1,25 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Run provenance capture + config hashing for reproducibility.""" + +from adapt.runtime.provenance import capture_provenance, config_hash + + +def test_capture_provenance_populates_environment_fields(): + p = capture_provenance() + assert p.hostname + assert p.username + assert "." in p.python_version + assert p.platform + assert p.software_version + # git_commit is a real hex string OR None (faithful — never fabricated) + assert p.git_commit is None or ( + len(p.git_commit) >= 7 and all(c in "0123456789abcdef" for c in p.git_commit) + ) + + +def test_config_hash_is_stable_and_sensitive_to_change(): + assert config_hash('{"a": 1}') == config_hash('{"a": 1}') + assert config_hash('{"a": 1}') != config_hash('{"a": 2}') + assert len(config_hash("anything")) == 64 # sha256 hex digest diff --git a/tests/runtime/test_run_reporter.py b/tests/runtime/test_run_reporter.py new file mode 100644 index 0000000..5df18e1 --- /dev/null +++ b/tests/runtime/test_run_reporter.py @@ -0,0 +1,147 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""RunReporter: the quiet console's run header, progress line, and summary. + +Behaviour under test: the header/summary contain the run's identity and real +stats (durations formatted, thousands-separated counts, slowest stage), and the +reporter emits console-tagged records so the ConsoleFilter lets them through. +""" + +import logging +from datetime import UTC, datetime + +from adapt.contracts.execution_history import RunProvenance, RunStart, RunSummary +from adapt.runtime.run_reporter import RunReporter, format_duration, format_header, format_summary + + +def _start(): + return RunStart( + run_id="2026JUN28-0206-KDIX", + pipeline="nexrad", + pipeline_version="0.4.1", + site="KDIX", + dataset="KDIX", + instrument="NEXRAD", + mode="historical", + start_time=datetime(2026, 6, 28, 0, 0, tzinfo=UTC), + configuration_hash="9f2cdeadbeef", + configuration_file="cfg.yaml", + provenance=RunProvenance("abc1234def", "host", "user", "3.11.0", "linux", "0.4.1"), + enabled_modules=("ingest", "detection", "tracking"), + ) + + +def _summary(): + return RunSummary( + run_id="2026JUN28-0206-KDIX", + status="success", + end_time=datetime(2026, 6, 28, 0, 14, 22, tzinfo=UTC), + duration_seconds=862.0, + files_processed=842, + scans_processed=842, + objects_detected=18431, + warnings=7, + errors=0, + average_scan_time=1.03, + maximum_scan_time=3.92, + slowest_stages=(("detection", 371.0), ("tracking", 182.0)), + ) + + +def test_format_duration_is_human_compact(): + assert format_duration(862.0) == "14m22s" + assert format_duration(45.0) == "45s" + assert format_duration(3725.0) == "1h2m5s" + + +def test_header_shows_identity_mode_version_commit_modules(): + out = format_header(_start()) + assert "2026JUN28-0206-KDIX" in out + assert "KDIX" in out + assert "historical" in out + assert "nexrad" in out and "0.4.1" in out + assert "abc1234" in out # short commit + assert "ingest" in out and "detection" in out and "tracking" in out + + +def test_summary_shows_status_counts_and_slowest_stage(): + out = format_summary(_summary()) + assert "SUCCESS" in out + assert "14m22s" in out + assert "842" in out + assert "18,431" in out # thousands separator + assert "1.03" in out and "3.92" in out + assert "detection" in out + + +def test_reporter_emits_console_tagged_records(): + captured: list[logging.LogRecord] = [] + + class _Cap(logging.Handler): + def emit(self, record): + captured.append(record) + + log = logging.getLogger("adapt.test.run_reporter") + log.handlers[:] = [_Cap()] + log.setLevel(logging.DEBUG) + log.propagate = False + + reporter = RunReporter(logger=log) + reporter.progress("412/842 scans") + assert captured[0].console is True + assert "412/842 scans" in captured[0].getMessage() + + +def _span(name, duration_s): + from adapt.contracts.observability import SpanRecord + + return SpanRecord( + name=name, + trace_id="t", + span_id="s", + parent_span_id="p", + start=0.0, + finish=duration_s, + duration_s=duration_s, + error="", + metadata={}, + ) + + +def test_format_scan_shows_each_stage_with_timing_cells_and_total(): + """The per-scan progress line is built from the drained module spans — stage + names with real per-stage durations, the cell count, and the total — so the + console tells you what ran and how long without any module-level print(). + """ + spans = [_span("ingest", 1.2), _span("segmentation", 0.8), _span("detection", 0.4)] + from adapt.runtime.run_reporter import format_scan + + line = format_scan("KOHX_180851", spans, n_cells=5) + + assert "KOHX_180851" in line + assert "ingest 1.2s" in line + assert "segmentation 0.8s" in line + assert "detection 0.4s" in line + assert "5 cells" in line + assert "2.4s" in line # total of the stage durations + + +def test_reporter_scan_emits_one_console_tagged_line(): + from adapt.runtime.run_reporter import RunReporter + + captured: list[logging.LogRecord] = [] + + class _Cap(logging.Handler): + def emit(self, record): + captured.append(record) + + logger = logging.getLogger("adapt.test.scan") + logger.addHandler(_Cap()) + logger.setLevel(logging.INFO) + + RunReporter(logger).scan("KOHX_180851", [_span("ingest", 1.0)], n_cells=3) + + assert len(captured) == 1 + assert captured[0].console is True + assert "KOHX_180851" in captured[0].getMessage() diff --git a/tests/runtime/test_thirdparty_quiet.py b/tests/runtime/test_thirdparty_quiet.py new file mode 100644 index 0000000..810006a --- /dev/null +++ b/tests/runtime/test_thirdparty_quiet.py @@ -0,0 +1,43 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Third-party libraries must not splatter their own banners to the Adapt console. + +Py-ART prints a citation block on import; it is pulled in transitively by +``nexradaws`` (acquisition -> sources -> runtime), which imports pyart at its own +import time — *before* any Adapt module could set PYART_QUIET. So the suppression +must live in the package root (adapt/__init__), which runs before every submodule. +Verified in clean subprocesses importing the runtime path that triggers nexradaws. +""" + +import os +import subprocess +import sys + +import pytest + +pytestmark = pytest.mark.unit + + +def _import_clean(statement: str) -> subprocess.CompletedProcess: + env = {k: v for k, v in os.environ.items() if k != "PYART_QUIET"} + return subprocess.run( + [sys.executable, "-c", statement], + capture_output=True, + text=True, + env=env, + ) + + +def test_importing_runtime_does_not_print_pyart_citation(): + # adapt.runtime pulls nexradaws -> pyart; the banner must still be suppressed. + result = _import_clean("import adapt.runtime.processor") + assert result.returncode == 0, result.stderr + assert "Py-ART" not in result.stdout + assert "jors.119" not in result.stdout + + +def test_importing_top_level_adapt_sets_pyart_quiet(): + result = _import_clean("import adapt, os; print(os.environ.get('PYART_QUIET'))") + assert result.returncode == 0, result.stderr + assert result.stdout.strip().splitlines()[-1] != "None" diff --git a/tests/test_architecture.py b/tests/test_architecture.py index 14b7d14..8ac8a8c 100644 --- a/tests/test_architecture.py +++ b/tests/test_architecture.py @@ -351,3 +351,26 @@ def test_module_outputs_carry_contracts() -> None: f"\n_UNCONTRACTED_OUTPUTS contains entries that now have contracts: " f"{sorted(stale)}. Remove them so the ratchet only tightens." ) + + +# ── Telemetry ids stay out of the science context dict ──────────────────────── +# Observability correlation ids (trace/span/scan/pipeline/...) travel out-of-band +# in contextvars. If one ever appeared as a module input/output key it would couple +# telemetry to the science data contract and could perturb determinism. Pin it. + + +def test_obs_context_fields_never_in_module_io() -> None: + """No ObsContext field name may be a module input/output key.""" + from adapt.contracts.observability import ObsContext + + obs_fields = set(ObsContext.__dataclass_fields__) + leaks: list[str] = [] + for module in _registered_default_modules(): + for key in list(module.inputs) + list(module.outputs): + if key in obs_fields: + leaks.append(f" {module.name}: '{key}'") + + assert not leaks, ( + "\nTelemetry correlation ids leaked into the science context dict — keep " + "ObsContext fields out of module inputs/outputs:\n" + "\n".join(leaks) + ) From 41199898f1fa3c721e67ddda32c1019974ea507d Mon Sep 17 00:00:00 2001 From: Bhupendra Raut Date: Sun, 28 Jun 2026 17:05:51 -0500 Subject: [PATCH 6/8] ENH: clean, live status updates --- .pre-commit-config.yaml | 9 ++ docs/api/client.rst | 2 +- src/adapt/api/client.py | 6 +- src/adapt/api/domain.py | 2 +- src/adapt/cli.py | 25 ++-- src/adapt/consumers/live/__init__.py | 2 +- src/adapt/consumers/live/dashboard.py | 2 +- src/adapt/contracts/execution_history.py | 3 + src/adapt/contracts/observability.py | 25 +++- src/adapt/modules/cell_volume_stats/module.py | 2 +- src/adapt/modules/xlma_stat/geo.py | 4 +- src/adapt/runtime/console_status.py | 76 ++++++++++ src/adapt/runtime/history_handler.py | 8 +- src/adapt/runtime/logging_setup.py | 46 +++++- src/adapt/runtime/observability.py | 23 ++- src/adapt/runtime/orchestrator.py | 128 +++++++++++++---- src/adapt/runtime/processor.py | 58 +++++--- src/adapt/runtime/run_reporter.py | 54 +++++-- tests/contracts/test_observability.py | 2 - tests/modules/xlma_stat/test_geo.py | 2 +- tests/runtime/test_console_status.py | 136 ++++++++++++++++++ tests/runtime/test_observability.py | 10 ++ tests/runtime/test_orchestrator.py | 79 +++++++++- tests/runtime/test_processor_observability.py | 44 ++++++ tests/runtime/test_run_reporter.py | 45 +++++- tests/test_architecture.py | 15 ++ 26 files changed, 714 insertions(+), 94 deletions(-) create mode 100644 src/adapt/runtime/console_status.py create mode 100644 tests/runtime/test_console_status.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 497fdde..a1fab2b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,6 +18,15 @@ repos: - repo: local hooks: + - id: forbid-adapt-allcaps + name: forbid all-caps product name (use "Adapt") + language: pygrep + entry: '\bADAPT\b' + types: [text] + # Source + docs only; the enforcement tests and this config must name the + # literal to forbid it, so they are intentionally out of scope. + files: ^(src/adapt/|docs/) + - id: mypy name: mypy entry: mypy diff --git a/docs/api/client.rst b/docs/api/client.rst index 4b7ef58..b6ff16a 100644 --- a/docs/api/client.rst +++ b/docs/api/client.rst @@ -1,7 +1,7 @@ Repository Client API ===================== -Read-only interface for querying ADAPT pipeline output from a repository. +Read-only interface for querying Adapt pipeline output from a repository. Initialise :class:`~adapt.api.RepositoryClient` with the repository root path; it auto-discovers runs, radars, scans, and data items through the two-tier database system (root-level registry + per-radar catalog). diff --git a/src/adapt/api/client.py b/src/adapt/api/client.py index bb0a31c..c8520d3 100644 --- a/src/adapt/api/client.py +++ b/src/adapt/api/client.py @@ -1,7 +1,7 @@ # Copyright © 2026, UChicago Argonne, LLC # See LICENSE for terms and disclaimer. -"""RepositoryClient — read-only access to an ADAPT repository. +"""RepositoryClient — read-only access to an Adapt repository. Discovers data through the two-tier database system: - Root-level registry (adapt_registry.db): runs and radars. @@ -53,7 +53,7 @@ class RepositoryClient: - """Read-only interface for an ADAPT repository. + """Read-only interface for an Adapt repository. Thread-safe for notebook usage. Discovers all data through catalog databases — no filesystem inspection. @@ -61,7 +61,7 @@ class RepositoryClient: Parameters ---------- repository_root : str or Path - Root directory of the ADAPT repository. + Root directory of the Adapt repository. """ def __init__(self, repository_root: str | Path) -> None: diff --git a/src/adapt/api/domain.py b/src/adapt/api/domain.py index 42cfaa8..c35537a 100644 --- a/src/adapt/api/domain.py +++ b/src/adapt/api/domain.py @@ -1,7 +1,7 @@ # Copyright © 2026, UChicago Argonne, LLC # See LICENSE for terms and disclaimer. -"""First-class domain objects for the ADAPT repository API.""" +"""First-class domain objects for the Adapt repository API.""" from __future__ import annotations diff --git a/src/adapt/cli.py b/src/adapt/cli.py index e718783..9c7a3c6 100644 --- a/src/adapt/cli.py +++ b/src/adapt/cli.py @@ -148,6 +148,14 @@ def _build_run_nexrad_parser(sub: argparse.ArgumentParser) -> None: def _run_nexrad(args: argparse.Namespace) -> None: """Execute the NEXRAD processing pipeline.""" + from adapt import __version__ + from adapt.runtime.run_reporter import format_banner + + # Plain banner, first line of output — printed before logging is configured so it + # carries no log prefix and precedes any library/catalog chatter. Trailing blank + # line separates it cleanly from the run logs that follow. + print(format_banner(__version__) + "\n") + if getattr(args, "only_modules", None) and getattr(args, "exclude_modules", None): raise SystemExit("error: --only and --not are mutually exclusive") _check_single_instance() @@ -164,12 +172,6 @@ def _run_nexrad(args: argparse.Namespace) -> None: stop_event = threading.Event() - def _safe_stop(orch: PipelineOrchestrator) -> None: - try: - orch.stop() - except Exception as exc: - print(f"[adapt] Stop cleanup error (ignored): {exc}") - def _run_orchestrator( orch: PipelineOrchestrator, max_runtime: int, done: threading.Event ) -> None: @@ -182,7 +184,9 @@ def _handle_sigterm(signum, frame) -> None: print("\n[adapt] SIGTERM received — stopping pipeline...") orchestrator._interrupted = True stop_event.set() - threading.Thread(target=_safe_stop, args=(orchestrator,), daemon=True).start() + # Ask the orchestrator's own (joined) thread to stop; it runs finalize + the + # run summary in its start() finally, so the summary completes before exit. + orchestrator.request_stop() signal.signal(signal.SIGTERM, _handle_sigterm) @@ -227,9 +231,10 @@ def _handle_sigterm(signum, frame) -> None: # Mark interrupted so the run is finalised as "cancelled" not "completed". orchestrator._interrupted = True stop_event.set() - # The orchestrator runs in a worker thread and never receives - # KeyboardInterrupt; set its stop flag explicitly. - threading.Thread(target=_safe_stop, args=(orchestrator,), daemon=True).start() + # The orchestrator runs in a worker thread and never receives KeyboardInterrupt. + # Ask it to break its loop; its own start() finally then runs stop() + the run + # summary on this (non-daemon, joined) thread — so the summary always prints. + orchestrator.request_stop() try: orchestrator_thread.join(timeout=20) except KeyboardInterrupt: diff --git a/src/adapt/consumers/live/__init__.py b/src/adapt/consumers/live/__init__.py index b841b99..b73676f 100644 --- a/src/adapt/consumers/live/__init__.py +++ b/src/adapt/consumers/live/__init__.py @@ -1,7 +1,7 @@ # Copyright © 2026, UChicago Argonne, LLC # See LICENSE for terms and disclaimer. -"""ADAPT Live — operational scan viewer (Tkinter dashboard). +"""Adapt Live — operational scan viewer (Tkinter dashboard). Usage:: diff --git a/src/adapt/consumers/live/dashboard.py b/src/adapt/consumers/live/dashboard.py index 639203a..c1ffa3c 100644 --- a/src/adapt/consumers/live/dashboard.py +++ b/src/adapt/consumers/live/dashboard.py @@ -681,7 +681,7 @@ def _close_plot_settings(self, win) -> None: self._plot_settings_win = None win.destroy() - # ── Run ADAPT wizard ────────────────────────────────────────────────────── + # ── Run Adapt wizard ────────────────────────────────────────────────────── def _open_run_wizard(self) -> None: import webbrowser diff --git a/src/adapt/contracts/execution_history.py b/src/adapt/contracts/execution_history.py index 856c1cc..c586c14 100644 --- a/src/adapt/contracts/execution_history.py +++ b/src/adapt/contracts/execution_history.py @@ -60,6 +60,9 @@ class RunSummary: average_scan_time: float maximum_scan_time: float slowest_stages: tuple[tuple[str, float], ...] # (module, total_seconds) desc + # Per-module aggregates for the console summary: (module, calls, total_seconds) desc. + module_stats: tuple[tuple[str, int, float], ...] = () + failures: int = 0 # module execution failures (errors_total counter) @dataclass(frozen=True, slots=True) diff --git a/src/adapt/contracts/observability.py b/src/adapt/contracts/observability.py index c3111ff..7c26e23 100644 --- a/src/adapt/contracts/observability.py +++ b/src/adapt/contracts/observability.py @@ -51,11 +51,32 @@ class SpanRecord: metadata: dict[str, str] = field(default_factory=dict) +@runtime_checkable +class Metrics(Protocol): + """In-process metrics sink + read-back accessors used to build the run summary.""" + + def incr(self, name: str, value: float = ..., **labels: str) -> None: ... + + def gauge(self, name: str, value: float, **labels: str) -> None: ... + + def observe(self, name: str, value: float, **labels: str) -> None: ... + + def counter_total(self, name: str) -> float: ... + + def gauge_value(self, name: str, **labels: str) -> float | None: ... + + def histogram_values(self, name: str) -> list[float]: ... + + def histogram_totals_by_label(self, name: str, label: str) -> dict[str, float]: ... + + def histogram_counts_by_label(self, name: str, label: str) -> dict[str, int]: ... + + @runtime_checkable class Observability(Protocol): """The injected observability provider. Real or disabled, same interface.""" - metrics: object + metrics: Metrics def span(self, name: str, **ctx: object) -> AbstractContextManager: ... @@ -64,5 +85,3 @@ def bind(self, **ctx: str) -> AbstractContextManager: ... def current(self) -> ObsContext: ... def drain_spans(self) -> list[SpanRecord]: ... - - def install_logging(self, log_path: object) -> None: ... diff --git a/src/adapt/modules/cell_volume_stats/module.py b/src/adapt/modules/cell_volume_stats/module.py index 29e7076..0d86676 100644 --- a/src/adapt/modules/cell_volume_stats/module.py +++ b/src/adapt/modules/cell_volume_stats/module.py @@ -3,7 +3,7 @@ """3D cell statistics — pure scientific functions + CellVolumeStatsAlgorithm. -No I/O, no ADAPT engine imports. Operates on numpy arrays extracted from the 3D +No I/O, no Adapt engine imports. Operates on numpy arrays extracted from the 3D gridded volume. A cell's 3D volume is its 2D detection footprint extruded through all altitude levels; every per-pixel column ("profile") is analysed for echo structure, and aggregates are reduced over the footprint. diff --git a/src/adapt/modules/xlma_stat/geo.py b/src/adapt/modules/xlma_stat/geo.py index 304bdf2..8224919 100644 --- a/src/adapt/modules/xlma_stat/geo.py +++ b/src/adapt/modules/xlma_stat/geo.py @@ -1,9 +1,9 @@ # Copyright © 2026, UChicago Argonne, LLC # See LICENSE for terms and disclaimer. -"""Project lightning lon/lat onto ADAPT's projected grid. +"""Project lightning lon/lat onto Adapt's projected grid. -ADAPT grids are azimuthal-equidistant in metres centred on the radar +Adapt grids are azimuthal-equidistant in metres centred on the radar (``+proj=aeqd +lat_0= +lon_0= +units=m``), so flash coordinates must use the same projection to align with the cell-mask grid. """ diff --git a/src/adapt/runtime/console_status.py b/src/adapt/runtime/console_status.py new file mode 100644 index 0000000..14501ca --- /dev/null +++ b/src/adapt/runtime/console_status.py @@ -0,0 +1,76 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""A single transient console status line that overwrites itself in place. + +Shows what the pipeline is doing between the permanent per-scan log lines — +``⠹ processing · 12s`` / ``⠋ waiting for next scan · 1m4s`` — animated by +the orchestrator's monitor loop and erased before each permanent console line prints +(via StatusAwareStreamHandler in logging_setup). TTY only: on a non-tty stream, or +when disabled, every method is a no-op so redirected output and file logs stay clean. + +Stdlib only — no tqdm/rich. The elapsed clock is injected (UI only; not science). +""" + +from __future__ import annotations + +import sys +import threading +from collections.abc import Callable +from typing import TextIO + +from adapt.runtime.run_reporter import format_seconds + +__all__ = ["ConsoleStatus"] + +_SPINNER = "⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏" +_ERASE = "\r\x1b[K" # carriage return + ANSI erase-to-end-of-line + + +class ConsoleStatus: + """Thread-safe manager for one self-overwriting status line on a TTY stream.""" + + def __init__( + self, stream: TextIO | None = None, *, enabled: bool, clock: Callable[[], float] + ) -> None: + self.stream: TextIO = stream if stream is not None else sys.stderr + self._clock = clock + # Active only on a real terminal — otherwise the carriage returns would + # corrupt piped output and log files. + self._active = enabled and bool(getattr(self.stream, "isatty", lambda: False)()) + self.lock = threading.RLock() # re-entrant: the log handler clears while holding it + self._text = "" + self._t0 = clock() + self._frame = 0 + self._drawn = False + + def set(self, text: str) -> None: + """Set the current activity text; reset the elapsed timer only if it changed.""" + if not self._active: + return + with self.lock: + if text != self._text: + self._text = text + self._t0 = self._clock() + + def tick(self) -> None: + """Redraw the line in place, advancing the spinner and the elapsed counter.""" + if not self._active: + return + with self.lock: + frame = _SPINNER[self._frame % len(_SPINNER)] + self._frame += 1 + elapsed = format_seconds(self._clock() - self._t0) + self.stream.write(f"\r{frame} {self._text} · {elapsed}\x1b[K") + self.stream.flush() + self._drawn = True + + def clear(self) -> None: + """Erase the transient line if one is drawn (before a permanent line, or on stop).""" + if not self._active: + return + with self.lock: + if self._drawn: + self.stream.write(_ERASE) + self.stream.flush() + self._drawn = False diff --git a/src/adapt/runtime/history_handler.py b/src/adapt/runtime/history_handler.py index 1c7083c..c56be64 100644 --- a/src/adapt/runtime/history_handler.py +++ b/src/adapt/runtime/history_handler.py @@ -38,7 +38,7 @@ def emit(self, record: logging.LogRecord) -> None: exc_type = "" if record.exc_info and record.exc_info[0] is not None: exc_type = record.exc_info[0].__name__ - event = ErrorEvent( + error = ErrorEvent( scan_id=ctx.scan_id, module=ctx.stage, exception_type=exc_type, @@ -48,9 +48,9 @@ def emit(self, record: logging.LogRecord) -> None: timestamp=now, ) with self._lock: - self._errors.append(event) + self._errors.append(error) elif record.levelno >= logging.WARNING: - event = WarningEvent( + warning = WarningEvent( scan_id=ctx.scan_id, module=ctx.stage, category=getattr(record, "category", "general"), @@ -59,7 +59,7 @@ def emit(self, record: logging.LogRecord) -> None: timestamp=now, ) with self._lock: - self._warnings.append(event) + self._warnings.append(warning) def drain(self) -> tuple[list[WarningEvent], list[ErrorEvent]]: """Return buffered warnings + errors, then clear (called per scan / at stop).""" diff --git a/src/adapt/runtime/logging_setup.py b/src/adapt/runtime/logging_setup.py index 9d6f6b8..7670bdc 100644 --- a/src/adapt/runtime/logging_setup.py +++ b/src/adapt/runtime/logging_setup.py @@ -16,10 +16,20 @@ import json import logging from pathlib import Path +from typing import TYPE_CHECKING from adapt.runtime.observability import ObsSettings, current_context -__all__ = ["ContextFilter", "JsonFormatter", "ConsoleFilter", "configure_logging"] +if TYPE_CHECKING: + from adapt.runtime.console_status import ConsoleStatus + +__all__ = [ + "ContextFilter", + "JsonFormatter", + "ConsoleFilter", + "StatusAwareStreamHandler", + "configure_logging", +] _CONSOLE_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" # Standard LogRecord attributes to exclude when surfacing structured extras. @@ -82,7 +92,29 @@ def filter(self, record: logging.LogRecord) -> bool: return record.levelno >= self._level or bool(getattr(record, "console", False)) -def configure_logging(settings: ObsSettings, log_path: Path | None) -> None: +class StatusAwareStreamHandler(logging.StreamHandler): + """Console handler that erases the transient status line before each record. + + Holds the ConsoleStatus lock across clear + emit so the orchestrator's ticker + can't redraw the status line in between; the ticker repaints on its next tick. + """ + + def __init__(self, status: ConsoleStatus) -> None: + super().__init__(status.stream) + self._status = status + + def emit(self, record: logging.LogRecord) -> None: + with self._status.lock: + self._status.clear() + super().emit(record) + + +def configure_logging( + settings: ObsSettings, + log_path: Path | None, + *, + console_status: ConsoleStatus | None = None, +) -> None: """Configure the root logger. The one place handlers are constructed. Fails loudly: ``json_logs`` with no ``log_path`` raises (no silent default). @@ -107,9 +139,15 @@ def configure_logging(settings: ObsSettings, log_path: Path | None) -> None: if settings.console_logs: # No handler-level gate: ConsoleFilter does all gating (see its docstring), so - # console-tagged INFO lines survive a WARNING console threshold. + # console-tagged INFO lines survive a WARNING console threshold. With a + # console_status, use the status-aware handler so each permanent line erases + # the transient spinner line first. console_threshold = getattr(logging, settings.console_level.upper(), logging.WARNING) - console = logging.StreamHandler() + console: logging.StreamHandler = ( + StatusAwareStreamHandler(console_status) + if console_status is not None + else logging.StreamHandler() + ) console.setFormatter(logging.Formatter(_CONSOLE_FORMAT)) console.addFilter(context_filter) console.addFilter(ConsoleFilter(console_threshold)) diff --git a/src/adapt/runtime/observability.py b/src/adapt/runtime/observability.py index 3322d05..8d13c9b 100644 --- a/src/adapt/runtime/observability.py +++ b/src/adapt/runtime/observability.py @@ -20,8 +20,9 @@ from dataclasses import dataclass, replace from datetime import UTC, datetime from random import Random +from typing import Literal -from adapt.contracts.observability import ObsContext, SpanRecord +from adapt.contracts.observability import Metrics, ObsContext, SpanRecord __all__ = [ "ObsSettings", @@ -96,7 +97,9 @@ def __init__( self._rng = rng self._enabled = settings.enabled self._spans: list[SpanRecord] = [] - self.metrics = _Meter(enabled=settings.enabled and settings.metrics) + # Typed as the Protocol (not _Meter) so the provider satisfies the invariant + # ``metrics: Metrics`` attribute of contracts.observability.Observability. + self.metrics: Metrics = _Meter(enabled=settings.enabled and settings.metrics) # ── context propagation ─────────────────────────────────────────────────── def current(self) -> ObsContext: @@ -173,7 +176,7 @@ def __enter__(self) -> _Span: ) return self - def __exit__(self, exc_type, exc, tb) -> bool: + def __exit__(self, exc_type, exc, tb) -> Literal[False]: finish = self._obs._clock() if exc is not None and not self._error: self._error = f"{exc_type.__name__}: {exc}" @@ -210,7 +213,7 @@ def record_error(self, exc: BaseException) -> None: ... def __enter__(self) -> _NullSpan: return self - def __exit__(self, exc_type, exc, tb) -> bool: + def __exit__(self, exc_type, exc, tb) -> Literal[False]: return False @@ -281,6 +284,18 @@ def histogram_totals_by_label(self, name: str, label: str) -> dict[str, float]: out[key] = out.get(key, 0.0) + sum(values) return out + def histogram_counts_by_label(self, name: str, label: str) -> dict[str, int]: + with self._lock: + out: dict[str, int] = {} + for (n, lk), values in self._hist.items(): + if n != name: + continue + key = dict(lk).get(label) + if key is None: + continue + out[key] = out.get(key, 0) + len(values) + return out + def build_observability( settings: ObsSettings, diff --git a/src/adapt/runtime/orchestrator.py b/src/adapt/runtime/orchestrator.py index f5a5a7e..7a51f18 100644 --- a/src/adapt/runtime/orchestrator.py +++ b/src/adapt/runtime/orchestrator.py @@ -14,13 +14,16 @@ import logging import queue import random +import sys import time +from contextlib import AbstractContextManager from datetime import UTC, datetime from pathlib import Path from typing import TYPE_CHECKING from adapt.contracts.execution_history import RunStart, RunSummary from adapt.persistence import DataRepository +from adapt.runtime.console_status import ConsoleStatus from adapt.runtime.file_tracker import FileProcessingTracker from adapt.runtime.history_handler import HistoryLogHandler from adapt.runtime.logging_setup import configure_logging @@ -39,6 +42,13 @@ logger = logging.getLogger(__name__) +# Span names that wrap modules rather than being modules themselves; excluded from +# the per-module summary breakdown (they share the module_duration_seconds histogram). +_WRAPPER_SPANS = frozenset({"pipeline", "scan"}) + +# Monitor-loop cadence; also the live status-line tick (spinner + elapsed) interval. +_STATUS_TICK_SECONDS = 0.5 + class PipelineOrchestrator: """Manages the multi-threaded radar processing pipeline. @@ -145,6 +155,7 @@ def __init__( # Lifecycle state self._stop_event = False + self._stop_requested = False # external break request (separate from the stop() guard) self._interrupted = False # Track user interrupt (Ctrl+C) vs normal completion self._start_time: float | None = None self._max_duration: float | None = None @@ -152,10 +163,11 @@ def __init__( # Telemetry (built in start()); kept here so stop() is safe before start(). self._obs: Observability | None = None - self._root_span = None + self._root_span: AbstractContextManager | None = None self._root_trace_id = "" self._history_handler: HistoryLogHandler | None = None self._reporter: RunReporter | None = None + self._status: ConsoleStatus | None = None def _obs_settings(self) -> ObsSettings: """Translate the resolved logging/observability config into ObsSettings.""" @@ -184,10 +196,21 @@ def _build_observability(self) -> "Observability": rng=random.Random(), ) + def _enabled_module_names(self) -> tuple[str, ...]: + """All enabled module names — in-pipeline AND post-persistence (phase-3). + + The header is built from this, so phase-3 enrichment modules (e.g. + cell_volume_stats) are surfaced, not just the phase 0-2 pipeline. + """ + if not self.processor: + return () + mods = (*self.processor._pipeline_modules, *self.processor._post_modules) + return tuple(m.name for m in mods) + def _record_run_start(self, radar: str) -> None: """Open the execution-history run record and print the console run header.""" prov = capture_provenance() - modules = tuple(m.name for m in self.processor._pipeline_modules) if self.processor else () + modules = self._enabled_module_names() start = RunStart( run_id=self.run_id or "", pipeline=self.config.source, @@ -207,35 +230,63 @@ def _record_run_start(self, radar: str) -> None: enabled_modules=modules, ) assert self.repository is not None + assert self._reporter is not None self.repository.history.start_run(start) self._reporter.header(start) def _finalize_history(self) -> None: - """Flush captured warnings/errors, finalize the run record, print the summary.""" + """Print the run summary, then persist history. + + The console summary is built from in-memory telemetry + the drained handler + counts and printed FIRST, so it never depends on a database write succeeding + (a DB failure during shutdown is logged loudly but cannot hide the summary). + """ if self._obs is None or self.repository is None: return + warnings: list = [] + errors: list = [] if self._history_handler is not None: warnings, errors = self._history_handler.drain() - if warnings: - self.repository.history.record_warnings(self.run_id, warnings) - if errors: - self.repository.history.record_errors(self.run_id, errors) - summary = self._build_run_summary("cancelled" if self._interrupted else "success") - self.repository.history.finalize_run(summary) + + summary = self._build_run_summary( + "cancelled" if self._interrupted else "success", + warnings=len(warnings), + errors=len(errors), + ) if self._reporter is not None: self._reporter.summary(summary) - def _build_run_summary(self, status: str) -> RunSummary: - """Aggregate the end-of-run summary from telemetry metrics + history counts.""" + # Persist after the user-facing summary is out. Failures here are reported + # loudly but must not abort the orderly shutdown. + run_id = self.run_id or "" + try: + if warnings: + self.repository.history.record_warnings(run_id, warnings) + if errors: + self.repository.history.record_errors(run_id, errors) + self.repository.history.finalize_run(summary) + except Exception: + logger.exception("Failed to persist execution history on shutdown") + + def _build_run_summary(self, status: str, *, warnings: int, errors: int) -> RunSummary: + """Aggregate the end-of-run summary from in-memory telemetry metrics. + + warnings/errors are passed in (drained from the history handler) so the + summary never reads the database — keeping it robust during shutdown. + """ + assert self._obs is not None m = self._obs.metrics scan_times = m.histogram_values("scan_processing_time") - slowest = sorted( - m.histogram_totals_by_label("module_duration_seconds", "stage").items(), - key=lambda kv: kv[1], - reverse=True, - )[:3] - warnings = len(self.repository.history.query_warnings(run_id=self.run_id)) - errors = len(self.repository.history.query_errors(run_id=self.run_id)) + totals = m.histogram_totals_by_label("module_duration_seconds", "stage") + counts = m.histogram_counts_by_label("module_duration_seconds", "stage") + # "pipeline" (root) and "scan" are wrapper spans, not modules — exclude them + # from the per-module breakdown so the table shows only real stages. + module_stats = tuple( + (name, counts.get(name, 0), total) + for name, total in sorted(totals.items(), key=lambda kv: kv[1], reverse=True) + if name not in _WRAPPER_SPANS + ) + slowest = tuple((name, total) for name, _calls, total in module_stats[:3]) duration = (time.time() - self._start_time) if self._start_time else 0.0 return RunSummary( run_id=self.run_id or "", @@ -249,7 +300,9 @@ def _build_run_summary(self, status: str) -> RunSummary: errors=errors, average_scan_time=(sum(scan_times) / len(scan_times)) if scan_times else 0.0, maximum_scan_time=max(scan_times) if scan_times else 0.0, - slowest_stages=tuple(slowest), + slowest_stages=slowest, + module_stats=module_stats, + failures=int(m.counter_total("errors_total")), ) def _setup_logging(self): @@ -379,9 +432,17 @@ def start(self, max_runtime: int | None = None): self._root_span.__enter__() self._root_trace_id = self._obs.current().trace_id + # Live status line (spinner) — TTY only; self-disables when piped or quiet. + settings = self._obs_settings() + self._status = ConsoleStatus( + sys.stderr, + enabled=settings.console_logs and sys.stderr.isatty(), + clock=time.monotonic, + ) + # Structured logging (JSON file + quiet console) + capture warnings/errors. log_path = Path(self.output_dirs["logs"]) / f"pipeline_{radar}.log" - configure_logging(self._obs_settings(), log_path) + configure_logging(settings, log_path, console_status=self._status) self._history_handler = HistoryLogHandler() logging.getLogger().addHandler(self._history_handler) self._reporter = RunReporter() @@ -432,8 +493,10 @@ def _main_loop(self, mode: str): last_status_time = time.time() while True: - # 0. Honour external stop() call (e.g. from CLI after SIGTERM/SIGINT) - if self._stop_event: + # 0. Honour an external stop request or a prior stop() (e.g. from the CLI + # after SIGTERM/SIGINT). request_stop() lets the orchestrator's OWN thread + # run finalize+summary, instead of a daemon that the process kills on exit. + if self._stop_requested or self._stop_event: break # 1. Historical completion check (must run before downloader death check) @@ -465,12 +528,19 @@ def _main_loop(self, mode: str): logger.info("Max duration reached") break - # 4. Status logging (every 30s) + # 4. Status logging (every 30s, file only) + live status line (every tick) if time.time() - last_status_time > 30: self._log_status() last_status_time = time.time() - time.sleep(1) + if self._status is not None: + self._status.set(self.processor.current_activity() or "waiting for next scan") + self._status.tick() + + time.sleep(_STATUS_TICK_SECONDS) + + if self._status is not None: + self._status.clear() # leave the terminal clean before stop()/summary def _check_historical_complete(self) -> bool: """Check if historical mode is complete. Returns True to exit.""" @@ -531,6 +601,16 @@ def _drain_queue(self, q: queue.Queue, name: str, timeout: int = 300): ) break + def request_stop(self) -> None: + """Ask the monitoring loop to exit so the orchestrator's own thread finalizes. + + Called from the CLI on SIGINT/SIGTERM. Distinct from ``stop()`` so signalling + the loop never trips ``stop()``'s "already finalized" guard — finalize + + summary then run on the joined (non-daemon) orchestrator thread, not a daemon + the process kills on exit. + """ + self._stop_requested = True + def stop(self): """Stop the pipeline gracefully and finalize all results. diff --git a/src/adapt/runtime/processor.py b/src/adapt/runtime/processor.py index 8908235..17e8657 100644 --- a/src/adapt/runtime/processor.py +++ b/src/adapt/runtime/processor.py @@ -98,10 +98,14 @@ def __init__( self.repository = repository # Telemetry provider, injected by the orchestrator. Absent -> disabled (off), # so the rest of this class calls it unconditionally with no `if obs` branches. - self._obs = observability if observability is not None else disabled_observability() + self._obs: Observability = ( + observability if observability is not None else disabled_observability() + ) self._root_trace_id = root_trace_id # Console reporter (injected). Absent -> no per-scan progress line. self._reporter = reporter + # Observational "what am I doing now" for the live status line (display only). + self._current_scan_id: str | None = None self._stop_event = threading.Event() self.output_lock = threading.Lock() @@ -174,6 +178,11 @@ def stopped(self) -> bool: """True if stop() has been called or a ContractViolation forced stop.""" return self._stop_event.is_set() + def current_activity(self) -> str | None: + """Short 'what am I doing now' for the live status line; None when idle.""" + scan_id = self._current_scan_id + return f"processing {scan_id}" if scan_id else None + def run(self): """Main processor loop (runs in thread). @@ -184,6 +193,7 @@ def run(self): # HDF5 error stacks are thread-local: silence libhdf5's stderr dumps on THIS # worker thread, where all the NetCDF/HDF5 I/O happens. silence_hdf5_errors() + assert self.repository is not None # guaranteed by __init__ logger.info("Processor started, waiting for files...") with self._obs.bind( trace_id=self._root_trace_id, @@ -247,6 +257,7 @@ def process_file(self, filepath) -> bool: bool True if processed or deferred (waiting for pair), False on error. """ + assert self.repository is not None # guaranteed by __init__ queued_at = None if isinstance(filepath, dict): queued_at = filepath.get("queued_at") @@ -263,26 +274,31 @@ def process_file(self, filepath) -> bool: queue_wait_s = (time.time() - queued_at) if queued_at else None logger.info("Processing: %s", Path(filepath).name) - # Bind scan context and open the scan span; module spans nest under it and - # every log on this path carries scan_id. Disabled provider -> no-ops. - with self._obs.bind(scan_id=file_id): - with self._obs.span("scan") as scan_span: - ok = self._run_scan(filepath, file_id, queue_wait_s, scan_span) - # The scan span has closed; split this scan's drained spans into the module - # spans (one module_history batch) and the scan span itself (carries n_cells). - all_spans = self._obs.drain_spans() - modules = [s for s in all_spans if s.name != "scan"] - if modules: - self.repository.history.record_modules( - self.repository.run_id, file_id, modules, recorded_at=datetime.now(UTC) - ) - # One controlled console line per scan, built from the captured telemetry - # (stage timings + cell count) — never printed from inside a module. - if self._reporter is not None and modules: - scan_rec = next((s for s in all_spans if s.name == "scan"), None) - n_cells = int(scan_rec.metadata.get("n_cells", 0)) if scan_rec else 0 - self._reporter.scan(file_id, modules, n_cells) - return ok + # Publish the current activity for the live status line; cleared on every exit. + self._current_scan_id = file_id + try: + # Bind scan context and open the scan span; module spans nest under it and + # every log on this path carries scan_id. Disabled provider -> no-ops. + with self._obs.bind(scan_id=file_id): + with self._obs.span("scan") as scan_span: + ok = self._run_scan(filepath, file_id, queue_wait_s, scan_span) + # The scan span has closed; split this scan's drained spans into the module + # spans (one module_history batch) and the scan span itself (carries n_cells). + all_spans = self._obs.drain_spans() + modules = [s for s in all_spans if s.name != "scan"] + if modules: + self.repository.history.record_modules( + self.repository.run_id, file_id, modules, recorded_at=datetime.now(UTC) + ) + # One controlled console line per scan, built from the captured telemetry + # (stage timings + cell count) — never printed from inside a module. + if self._reporter is not None and modules: + scan_rec = next((s for s in all_spans if s.name == "scan"), None) + n_cells = int(scan_rec.metadata.get("n_cells", 0)) if scan_rec else 0 + self._reporter.scan(file_id, modules, n_cells) + return ok + finally: + self._current_scan_id = None def _run_scan(self, filepath, file_id, queue_wait_s, scan_span) -> bool: """Execute the scientific pipeline for one scan (inside the scan span).""" diff --git a/src/adapt/runtime/run_reporter.py b/src/adapt/runtime/run_reporter.py index e61af3e..4154385 100644 --- a/src/adapt/runtime/run_reporter.py +++ b/src/adapt/runtime/run_reporter.py @@ -14,9 +14,29 @@ from adapt.contracts.execution_history import RunStart, RunSummary from adapt.contracts.observability import SpanRecord -__all__ = ["RunReporter", "format_header", "format_summary", "format_duration", "format_scan"] +__all__ = [ + "RunReporter", + "format_header", + "format_summary", + "format_duration", + "format_seconds", + "format_scan", + "format_banner", +] _RULE = "─" * 60 +_LICENSE_URL = "https://arm-doe.github.io/Adapt/license.html" + + +def format_banner(version: str) -> str: + """The plain name/version/copyright block the CLI prints first, before anything else.""" + return "\n".join( + [ + f"ARM Adapt v{version}", + "Copyright © 2026, UChicago Argonne, LLC.", + f"See the LICENSE ({_LICENSE_URL}) for terms and disclaimer.", + ] + ) def format_duration(seconds: float) -> str: @@ -33,13 +53,24 @@ def format_duration(seconds: float) -> str: return "".join(parts) +def format_seconds(seconds: float) -> str: + """Compact duration that keeps sub-second resolution: ms under 1s (so fast stages + don't collapse to '0s'), tenths under a minute, then h/m/s. + """ + if seconds < 1.0: + return f"{seconds * 1000:.0f}ms" + if seconds < 60.0: + return f"{seconds:.1f}s" + return format_duration(seconds) + + def format_header(start: RunStart) -> str: """One compact block summarising the run at a glance.""" p = start.provenance commit = p.git_commit[:7] if p.git_commit else "—" return "\n".join( [ - f"── ADAPT run {start.run_id} {_RULE[: max(0, 48 - len(start.run_id))]}", + f"── Adapt run {start.run_id} {_RULE[: max(0, 48 - len(start.run_id))]}", f"site {start.site} · {start.instrument} · mode {start.mode}", f"pipeline {start.pipeline} v{start.pipeline_version} · commit {commit} " f"· python {p.python_version} · {p.platform}", @@ -51,18 +82,25 @@ def format_header(start: RunStart) -> str: def format_summary(summary: RunSummary) -> str: - """One compact block of end-of-run stats.""" - slowest = " · ".join(f"{name} {format_duration(secs)}" for name, secs in summary.slowest_stages) + """One compact block of end-of-run stats with a per-module timing table.""" + module_lines = [ + f" {name:<18} {calls:>4} · {format_seconds(total):>7} · " + f"{format_seconds(total / calls if calls else 0.0)}" + for name, calls, total in summary.module_stats + ] + if not module_lines: + module_lines = [" —"] return "\n".join( [ f"── run {summary.run_id} {summary.status.upper()} " f"{format_duration(summary.duration_seconds)} {_RULE[:20]}", f"files {summary.files_processed:,} · scans {summary.scans_processed:,} " f"· objects {summary.objects_detected:,} · warnings {summary.warnings} " - f"· errors {summary.errors}", + f"· errors {summary.errors} · failures {summary.failures}", f"scan time avg {summary.average_scan_time:.2f}s " f"· max {summary.maximum_scan_time:.2f}s", - f"slowest: {slowest}" if slowest else "slowest: —", + "per module (calls · total · avg):", + *module_lines, _RULE, ] ) @@ -74,9 +112,9 @@ def format_scan(scan_id: str, spans: list[SpanRecord], n_cells: int) -> str: Reads ``stage duration`` straight off the captured telemetry — so the console reports what ran and how long without any module emitting its own progress. """ - stages = " ".join(f"{s.name} {s.duration_s:.1f}s" for s in spans) + stages = " ".join(f"{s.name} {format_seconds(s.duration_s)}" for s in spans) total = sum(s.duration_s for s in spans) - return f"scan {scan_id} │ {stages} │ {n_cells} cells {total:.1f}s" + return f"scan {scan_id} │ {stages} │ {n_cells} cells {format_seconds(total)}" class RunReporter: diff --git a/tests/contracts/test_observability.py b/tests/contracts/test_observability.py index 43f2bc7..0b80206 100644 --- a/tests/contracts/test_observability.py +++ b/tests/contracts/test_observability.py @@ -58,8 +58,6 @@ def current(self): ... def drain_spans(self): ... - def install_logging(self, log_path): ... - class _NotImpl: def span(self, name, **ctx): ... diff --git a/tests/modules/xlma_stat/test_geo.py b/tests/modules/xlma_stat/test_geo.py index 4b19c3f..343df46 100644 --- a/tests/modules/xlma_stat/test_geo.py +++ b/tests/modules/xlma_stat/test_geo.py @@ -1,7 +1,7 @@ # Copyright © 2026, UChicago Argonne, LLC # See LICENSE for terms and disclaimer. -"""Flash lon/lat -> projected x/y on ADAPT's azimuthal-equidistant (radar) grid.""" +"""Flash lon/lat -> projected x/y on Adapt's azimuthal-equidistant (radar) grid.""" import numpy as np import pytest diff --git a/tests/runtime/test_console_status.py b/tests/runtime/test_console_status.py new file mode 100644 index 0000000..6f2fc06 --- /dev/null +++ b/tests/runtime/test_console_status.py @@ -0,0 +1,136 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""ConsoleStatus: a single transient TTY line that overwrites itself in place. + +Behaviour under test: on a TTY it draws ``\\r · \\x1b[K``, +advances the spinner each tick, resets the elapsed timer only when the text changes, +and clears with ``\\r\\x1b[K``. On a non-TTY (or when disabled) it is completely inert, +so redirected output and file logs never see control characters. +""" + +import io +import logging + +import pytest + +from adapt.runtime.console_status import ConsoleStatus +from adapt.runtime.logging_setup import StatusAwareStreamHandler + +pytestmark = pytest.mark.unit + + +class _FakeTTY(io.StringIO): + def isatty(self) -> bool: + return True + + +def _status(stream, now): + return ConsoleStatus(stream, enabled=True, clock=lambda: now[0]) + + +def test_tick_draws_overwriting_line_with_spinner_text_and_elapsed(): + out = _FakeTTY() + now = [10.0] + cs = _status(out, now) + + cs.set("processing KOHX_1") + now[0] = 13.0 + cs.tick() + + drawn = out.getvalue() + assert drawn.startswith("\r") # overwrites from line start + assert "processing KOHX_1" in drawn + assert "3.0s" in drawn # elapsed since the text was set + assert "\x1b[K" in drawn # erase-to-end so shorter text doesn't leave residue + + +def test_tick_advances_spinner_frame(): + out = _FakeTTY() + cs = _status(out, [0.0]) + cs.set("waiting for next scan") + cs.tick() + cs.tick() + # Two ticks -> two different leading spinner glyphs. + lines = [seg for seg in out.getvalue().split("\r") if seg] + assert lines[0][0] != lines[1][0] + + +def test_elapsed_resets_only_when_text_changes(): + out = _FakeTTY() + now = [10.0] + cs = _status(out, now) + + cs.set("a") + now[0] = 22.0 + cs.set("a") # same text -> timer NOT reset + cs.tick() + assert "12.0s" in out.getvalue().rsplit("\r", 1)[-1] + + now[0] = 30.0 + cs.set("b") # new text -> timer resets to now + now[0] = 31.0 + cs.tick() + assert "1.0s" in out.getvalue().rsplit("\r", 1)[-1] + + +def test_clear_erases_the_line_and_is_idempotent(): + out = _FakeTTY() + cs = _status(out, [0.0]) + cs.set("x") + cs.tick() + out.truncate(0) + out.seek(0) + + cs.clear() + assert out.getvalue() == "\r\x1b[K" + + out.truncate(0) + out.seek(0) + cs.clear() # nothing drawn now -> no-op + assert out.getvalue() == "" + + +def test_non_tty_stream_is_completely_inert(): + out = io.StringIO() # isatty() is False + cs = ConsoleStatus(out, enabled=True, clock=lambda: 0.0) + cs.set("x") + cs.tick() + cs.clear() + assert out.getvalue() == "" + + +def test_disabled_is_inert_even_on_a_tty(): + out = _FakeTTY() + cs = ConsoleStatus(out, enabled=False, clock=lambda: 0.0) + cs.set("x") + cs.tick() + assert out.getvalue() == "" + + +def test_status_aware_handler_erases_transient_line_before_the_record(): + out = _FakeTTY() + cs = _status(out, [0.0]) + cs.set("processing KOHX_1") + cs.tick() # a transient line is now drawn + out.truncate(0) + out.seek(0) + + handler = StatusAwareStreamHandler(cs) + handler.setFormatter(logging.Formatter("%(message)s")) + handler.emit(logging.LogRecord("adapt.run", logging.INFO, __file__, 1, "scan done", None, None)) + + written = out.getvalue() + assert written.startswith("\r\x1b[K") # erase first + assert "scan done" in written # then the permanent line + + +def test_status_aware_handler_on_non_tty_just_writes_the_record(): + out = io.StringIO() + cs = ConsoleStatus(out, enabled=True, clock=lambda: 0.0) + handler = StatusAwareStreamHandler(cs) + handler.setFormatter(logging.Formatter("%(message)s")) + handler.emit(logging.LogRecord("adapt.run", logging.INFO, __file__, 1, "hello", None, None)) + + assert "\x1b" not in out.getvalue() + assert "hello" in out.getvalue() diff --git a/tests/runtime/test_observability.py b/tests/runtime/test_observability.py index 6c242e0..df34a81 100644 --- a/tests/runtime/test_observability.py +++ b/tests/runtime/test_observability.py @@ -172,6 +172,16 @@ def test_histogram_totals_by_stage_label() -> None: assert totals == {"detection": 3.0, "tracking": 4.0} +def test_histogram_counts_by_stage_label() -> None: + """Per-label call counts (for averaging per-module timing in the run summary).""" + obs = _obs() + obs.metrics.observe("module_duration_seconds", 2.0, stage="detection") + obs.metrics.observe("module_duration_seconds", 1.0, stage="detection") + obs.metrics.observe("module_duration_seconds", 4.0, stage="tracking") + counts = obs.metrics.histogram_counts_by_label("module_duration_seconds", "stage") + assert counts == {"detection": 2, "tracking": 1} + + def test_concurrent_counter_increments_sum_correctly() -> None: obs = _obs() diff --git a/tests/runtime/test_orchestrator.py b/tests/runtime/test_orchestrator.py index dda2d56..39d52ff 100644 --- a/tests/runtime/test_orchestrator.py +++ b/tests/runtime/test_orchestrator.py @@ -24,8 +24,11 @@ def test_orchestrator_build_run_summary_aggregates_metrics_and_history( m.observe("scan_processing_time", 3.0) m.observe("module_duration_seconds", 5.0, stage="detection") m.observe("module_duration_seconds", 2.0, stage="tracking") + # Wrapper spans share the same histogram but are NOT modules — must be excluded. + m.observe("module_duration_seconds", 99.0, stage="pipeline") + m.observe("module_duration_seconds", 25.0, stage="scan") - summary = orch._build_run_summary("success") + summary = orch._build_run_summary("success", warnings=0, errors=0) assert summary.status == "success" assert summary.files_processed == 2 assert summary.objects_detected == 18431 @@ -34,6 +37,44 @@ def test_orchestrator_build_run_summary_aggregates_metrics_and_history( assert summary.maximum_scan_time == 3.0 assert summary.slowest_stages[0] == ("detection", 5.0) assert summary.duration_seconds >= 10.0 + # per-module aggregates: calls + total, sorted by total desc + assert summary.module_stats[0] == ("detection", 1, 5.0) + assert ("tracking", 1, 2.0) in summary.module_stats + # the pipeline/scan wrapper spans are excluded from the module table + names = {name for name, _calls, _total in summary.module_stats} + assert "pipeline" not in names and "scan" not in names + assert summary.slowest_stages[0] == ("detection", 5.0) # not the 99s pipeline span + + +def test_finalize_history_prints_summary_even_when_db_write_fails( + pipeline_config, test_repository, monkeypatch +): + """The console summary must be emitted on shutdown even if the history DB write + raises — it is printed before persistence and must not depend on it. + """ + orch = PipelineOrchestrator(pipeline_config) + orch._obs = orch._build_observability() + orch.repository = test_repository + orch.run_id = test_repository.run_id + orch._start_time = time.time() - 5 + orch._history_handler = None # no buffered warnings/errors + + seen = [] + + class _Rep: + def summary(self, s): + seen.append(s) + + orch._reporter = _Rep() + + def _boom(_summary): + raise RuntimeError("db locked") + + monkeypatch.setattr(test_repository.history, "finalize_run", _boom) + + orch._finalize_history() # must not raise + + assert len(seen) == 1 # summary still printed despite the DB failure def test_orchestrator_initialization(pipeline_config): @@ -149,3 +190,39 @@ def test_orchestrator_processor_config_accessible(pipeline_config): assert hasattr(orch.config, "processor") assert orch.config.processor.max_history >= 0 assert orch.config.processor.min_file_size > 0 + + +def test_enabled_module_names_includes_post_persistence_modules(pipeline_config): + """The run header must list phase-3 (post-persistence) modules like cell_volume_stats, + not only the in-pipeline modules — otherwise the user can't see they're configured. + """ + orch = PipelineOrchestrator(pipeline_config) + + class _M: + def __init__(self, name): + self.name = name + + class _Proc: + _pipeline_modules = [_M("ingest"), _M("detection")] + _post_modules = [_M("cell_volume_stats")] + + orch.processor = _Proc() + + names = orch._enabled_module_names() + + assert names[:2] == ("ingest", "detection") + assert "cell_volume_stats" in names # phase-3 module surfaced + + +def test_request_stop_breaks_main_loop(pipeline_config): + """request_stop() must make the monitoring loop exit so the orchestrator's own + (joined, non-daemon) thread runs stop()/finalize — instead of a daemon that the + process kills mid-summary on shutdown. + """ + orch = PipelineOrchestrator(pipeline_config) + orch.downloader = object() + orch.processor = object() + orch._start_time = time.time() + + orch.request_stop() + orch._main_loop("realtime") # returns immediately; would hang if the break is missing diff --git a/tests/runtime/test_processor_observability.py b/tests/runtime/test_processor_observability.py index e14c032..f271fc8 100644 --- a/tests/runtime/test_processor_observability.py +++ b/tests/runtime/test_processor_observability.py @@ -178,3 +178,47 @@ def _fake_single(context): scan_id, stage_names, _ = calls[0] assert scan_id == "file_7" assert stage_names # the per-scan line carries the executed stages + + +def test_processor_publishes_current_activity( + monkeypatch, pipeline_config, pipeline_output_dirs, test_repository +): + """The processor exposes a short 'what am I doing' string for the status line: + 'processing ' while a scan runs, None when idle. + """ + obs = _obs() + proc = RadarProcessor( + queue.Queue(), + pipeline_config, + pipeline_output_dirs, + repository=test_repository, + observability=obs, + ) + + assert proc.current_activity() is None # idle before any scan + + seen_during = [] + + def _fake_single(context): + seen_during.append(proc.current_activity()) # captured mid-scan + return { + "grid_ds": _fake_ds(), + "grid_ds_2d": _fake_ds(), + "segmented_ds": _fake_ds(), + "scan_time": datetime(2024, 5, 18, 12, 0, 0, tzinfo=UTC), + "num_cells": 0, + } + + fake_multi = { + "projected_ds": _fake_ds(), + "cell_stats": pd.DataFrame(), + "cell_adjacency": pd.DataFrame(), + } + monkeypatch.setattr(proc._executors[1], "run", _fake_single) + monkeypatch.setattr(proc._executors[2], "run", lambda ctx: fake_multi) + monkeypatch.setattr(proc._router, "persist", lambda modules, result, meta: None) + + proc.process_file("/fake/file_3") + + assert seen_during == ["processing file_3"] # set while the scan ran + assert proc.current_activity() is None # cleared afterwards diff --git a/tests/runtime/test_run_reporter.py b/tests/runtime/test_run_reporter.py index 5df18e1..e4b8188 100644 --- a/tests/runtime/test_run_reporter.py +++ b/tests/runtime/test_run_reporter.py @@ -46,6 +46,8 @@ def _summary(): average_scan_time=1.03, maximum_scan_time=3.92, slowest_stages=(("detection", 371.0), ("tracking", 182.0)), + module_stats=(("detection", 842, 371.0), ("tracking", 842, 182.0)), + failures=2, ) @@ -55,8 +57,22 @@ def test_format_duration_is_human_compact(): assert format_duration(3725.0) == "1h2m5s" +def test_format_seconds_keeps_subsecond_resolution(): + """Fast stages (ms) must not collapse to '0s' — keep ms under 1s.""" + from adapt.runtime.run_reporter import format_seconds + + assert format_seconds(0.012) == "12ms" + assert format_seconds(0.5) == "500ms" + assert format_seconds(0.0) == "0ms" + assert format_seconds(1.2) == "1.2s" + assert format_seconds(25.2) == "25.2s" + assert format_seconds(371.0) == "6m11s" + + def test_header_shows_identity_mode_version_commit_modules(): out = format_header(_start()) + assert "Adapt run" in out # product name is "Adapt", never "ADAPT" + assert "ADAPT" not in out assert "2026JUN28-0206-KDIX" in out assert "KDIX" in out assert "historical" in out @@ -75,6 +91,17 @@ def test_summary_shows_status_counts_and_slowest_stage(): assert "detection" in out +def test_summary_shows_per_module_table_and_failures(): + """The summary lists every module with calls, total, and avg, plus a failures count.""" + out = format_summary(_summary()) + assert "failures 2" in out + # per-module block: each module appears with its call count and a total duration + assert "detection" in out and "842" in out + assert "tracking" in out + # avg is total/calls: 371.0/842 ≈ 0.441s -> shown in ms, not collapsed to 0 + assert "441ms" in out + + def test_reporter_emits_console_tagged_records(): captured: list[logging.LogRecord] = [] @@ -121,8 +148,8 @@ def test_format_scan_shows_each_stage_with_timing_cells_and_total(): assert "KOHX_180851" in line assert "ingest 1.2s" in line - assert "segmentation 0.8s" in line - assert "detection 0.4s" in line + assert "segmentation 800ms" in line # sub-second shown in ms + assert "detection 400ms" in line assert "5 cells" in line assert "2.4s" in line # total of the stage durations @@ -145,3 +172,17 @@ def emit(self, record): assert len(captured) == 1 assert captured[0].console is True assert "KOHX_180851" in captured[0].getMessage() + + +def test_format_banner_is_plain_name_version_and_copyright(): + """The startup banner is plain text (printed first by the CLI, before logging is + configured) — no log prefix, no decorative rules. + """ + from adapt.runtime.run_reporter import format_banner + + out = format_banner("0.1.3") + + assert out.splitlines()[0] == "ARM Adapt v0.1.3" # first line, exact + assert "Copyright © 2026, UChicago Argonne, LLC" in out + assert "https://arm-doe.github.io/Adapt/license.html" in out + assert "─" not in out # no decorative rule lines diff --git a/tests/test_architecture.py b/tests/test_architecture.py index 8ac8a8c..52a9e3d 100644 --- a/tests/test_architecture.py +++ b/tests/test_architecture.py @@ -13,6 +13,7 @@ import ast import importlib import pkgutil +import re from pathlib import Path import pytest @@ -144,6 +145,20 @@ def test_scan_time_format_is_defined_in_exactly_one_place() -> None: ) +def test_adapt_name_is_never_all_caps() -> None: + """The product is 'Adapt', never 'ADAPT' — in output, prints, comments, or docstrings. + + Matches the standalone token only, so 'arm_adaptive', 'ADAPTIVE', 'ADAPTER' are fine. + """ + offenders = [ + f"{py_file.relative_to(_SRC_ADAPT)}:{i}" + for py_file in _SRC_ADAPT.rglob("*.py") + for i, line in enumerate(py_file.read_text().splitlines(), 1) + if re.search(r"\bADAPT\b", line) + ] + assert not offenders, "Use 'Adapt', not 'ADAPT':\n" + "\n".join(offenders) + + # ── Determinism: no wall clock or global RNG in scientific modules ───────────── # Identical inputs + config must produce identical outputs. A module that reads # the wall clock or numpy's global RNG breaks that silently. The acquisition From c220e8e765ef6bcf14893b4a3505d569d8c3b3f7 Mon Sep 17 00:00:00 2001 From: Bhupendra Raut Date: Sun, 28 Jun 2026 22:28:36 -0500 Subject: [PATCH 7/8] FIX: nexradaws is not dependancy --- src/adapt/downloaders/__init__.py | 14 ++ src/adapt/downloaders/models.py | 54 +++++++ src/adapt/downloaders/nexrad.py | 124 ++++++++++++++++ src/adapt/downloaders/s3.py | 74 ++++++++++ tests/downloaders/test_models.py | 45 ++++++ tests/downloaders/test_nexrad.py | 236 ++++++++++++++++++++++++++++++ tests/downloaders/test_s3.py | 97 ++++++++++++ 7 files changed, 644 insertions(+) create mode 100644 src/adapt/downloaders/__init__.py create mode 100644 src/adapt/downloaders/models.py create mode 100644 src/adapt/downloaders/nexrad.py create mode 100644 src/adapt/downloaders/s3.py create mode 100644 tests/downloaders/test_models.py create mode 100644 tests/downloaders/test_nexrad.py create mode 100644 tests/downloaders/test_s3.py diff --git a/src/adapt/downloaders/__init__.py b/src/adapt/downloaders/__init__.py new file mode 100644 index 0000000..ac67f0d --- /dev/null +++ b/src/adapt/downloaders/__init__.py @@ -0,0 +1,14 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Native downloaders for public, unsigned S3 datasets. + +Currently provides NEXRAD Level-II via :class:`NexradS3`. The unsigned-S3 +mechanics in :mod:`adapt.downloaders.s3` are dataset-agnostic and form the +reuse seam for future datasets. +""" + +from .models import ArchiveScan, DownloadError, DownloadResult, LocalScan +from .nexrad import NexradS3 + +__all__ = ["ArchiveScan", "DownloadError", "DownloadResult", "LocalScan", "NexradS3"] diff --git a/src/adapt/downloaders/models.py b/src/adapt/downloaders/models.py new file mode 100644 index 0000000..2e1a04b --- /dev/null +++ b/src/adapt/downloaders/models.py @@ -0,0 +1,54 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Value objects for archived-scan discovery and download. + +Pure data — no S3, no I/O. ``ArchiveScan`` describes a remote object, +``LocalScan`` a downloaded one, and ``DownloadResult`` aggregates a batch. +""" + +from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path + +__all__ = ["ArchiveScan", "DownloadError", "DownloadResult", "LocalScan"] + + +@dataclass(frozen=True) +class ArchiveScan: + """A remote scan object in the S3 bucket. + + ``size`` is the S3 ``Size`` captured at listing time so a download can be + verified without an extra ``head_object`` round-trip. + """ + + key: str + scan_time: datetime + size: int + + +@dataclass(frozen=True) +class LocalScan: + """A scan that has been downloaded to ``filepath``.""" + + filepath: Path + + +@dataclass +class DownloadResult: + """The outcome of a download batch: successes and per-file failures.""" + + success: list[LocalScan] = field(default_factory=list) + failed: list[ArchiveScan] = field(default_factory=list) + + def iter_success(self): + """Iterate the successfully downloaded scans.""" + return iter(self.success) + + +class DownloadError(Exception): + """A single object failed to download (carries the offending scan).""" + + def __init__(self, message: str, scan: ArchiveScan | None = None): + super().__init__(message) + self.scan = scan diff --git a/src/adapt/downloaders/nexrad.py b/src/adapt/downloaders/nexrad.py new file mode 100644 index 0000000..a6d7746 --- /dev/null +++ b/src/adapt/downloaders/nexrad.py @@ -0,0 +1,124 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Native NEXRAD Level-II archive access over unsigned S3. + +Replaces the third-party ``nexradaws`` library. ``NexradS3`` exposes the exact +surface the acquisition module consumes (``get_avail_scans_in_range`` / +``get_avail_radars`` / ``download``) so it is a drop-in ``conn`` backend. + +Bucket layout: ``unidata-nexrad-level2`` with keys ``YYYY/MM/DD/RADAR/``. +Filenames carry the scan time, e.g. ``KTLX20130520_180000_V06[.gz]``. +""" + +import re +from datetime import UTC, date, datetime, timedelta +from pathlib import Path + +from . import s3 +from .models import ArchiveScan, DownloadResult, LocalScan + +__all__ = ["NexradS3", "build_prefix", "parse_scan_time"] + +BUCKET = "unidata-nexrad-level2" + +# YYYYMMDD_HHMMSS embedded in every NEXRAD volume filename. +_TIMESTAMP_RE = re.compile(r"(\d{8})_(\d{6})") + + +def build_prefix(radar: str, day: date) -> str: + """Return the S3 key prefix for one radar on one UTC day.""" + return f"{day:%Y/%m/%d}/{radar}/" + + +def parse_scan_time(filename: str) -> datetime: + """Parse the UTC scan time from a NEXRAD volume filename.""" + match = _TIMESTAMP_RE.search(filename) + if match is None: + raise ValueError(f"no scan timestamp in filename: {filename!r}") + stamp = f"{match.group(1)}_{match.group(2)}" + return datetime.strptime(stamp, "%Y%m%d_%H%M%S").replace(tzinfo=UTC) + + +def _is_volume_file(key: str) -> bool: + """True for real volume scans; excludes ``_MDM`` and non-volume sidecars.""" + name = key.rsplit("/", 1)[-1] + if "_MDM" in name: + return False + if _TIMESTAMP_RE.search(name) is None: + return False + return name.endswith(".gz") or "_V0" in name + + +def _to_utc(value: datetime) -> datetime: + """Normalize a datetime to UTC, treating a naive value as already UTC.""" + if value.tzinfo is None: + return value.replace(tzinfo=UTC) + return value.astimezone(UTC) + + +class NexradS3: + """NEXRAD Level-II archive client over an anonymous S3 connection.""" + + def __init__(self, client=None): + """Use ``client`` if given (for testing), else one unsigned S3 client.""" + self._client = client or s3.client() + + def get_avail_scans_in_range( + self, start: datetime, end: datetime, radar: str + ) -> list[ArchiveScan]: + """Return volume scans for ``radar`` with scan time in ``[start, end]``. + + A future ``start`` raises ``ValueError``; a future ``end`` is clamped to + now (so the day loop never lists non-existent future days). Missing data + yields an empty list. Results are sorted by scan time. + """ + start = _to_utc(start) + end = _to_utc(end) + now = datetime.now(UTC) + if start > now: + raise ValueError(f"start time {start.isoformat()} is in the future") + end = min(end, now) + scans: list[ArchiveScan] = [] + day = start.date() + while day <= end.date(): + for entry in s3.list_objects(self._client, BUCKET, build_prefix(radar, day)): + key = entry["Key"] + if not _is_volume_file(key): + continue + scan_time = parse_scan_time(key.rsplit("/", 1)[-1]) + if start <= scan_time <= end: + scans.append(ArchiveScan(key=key, scan_time=scan_time, size=entry["Size"])) + day += timedelta(days=1) + scans.sort(key=lambda scan: scan.scan_time) + return scans + + def get_avail_radars(self, y: str, m: str, d: str) -> list[str]: + """Return the radar IDs with data for the given UTC year/month/day.""" + prefix = f"{y}/{m}/{d}/" + return [ + cp.rstrip("/").rsplit("/", 1)[-1] + for cp in s3.list_common_prefixes(self._client, BUCKET, prefix) + ] + + def download( + self, scans: list[ArchiveScan], target_dir: Path, keep_aws_folders: bool = False + ) -> DownloadResult: + """Download each scan into ``target_dir``, verifying byte size. + + Failures are collected in ``DownloadResult.failed`` so the caller can + re-submit them; the batch is not aborted. + """ + if keep_aws_folders: + raise NotImplementedError("keep_aws_folders=True is not supported") + target_dir = Path(target_dir) + target_dir.mkdir(parents=True, exist_ok=True) + result = DownloadResult() + for scan in scans: + dest = target_dir / Path(scan.key).name + try: + s3.download_object(self._client, BUCKET, scan.key, dest, scan.size) + result.success.append(LocalScan(filepath=dest)) + except Exception: + result.failed.append(scan) + return result diff --git a/src/adapt/downloaders/s3.py b/src/adapt/downloaders/s3.py new file mode 100644 index 0000000..fa3c9ae --- /dev/null +++ b/src/adapt/downloaders/s3.py @@ -0,0 +1,74 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Unsigned (anonymous) S3 I/O — the only module that imports boto3. + +These primitives know nothing about NEXRAD: they operate on plain bucket/key +strings. They are the reuse seam for any other public, unsigned S3 dataset +(e.g. GOES `noaa-goes16..19`, MRMS `noaa-mrms-pds`): a future dataset module +supplies its own bucket, key-prefix builder, filename parser and key filter, +and reuses ``client`` / ``list_objects`` / ``list_common_prefixes`` / +``download_object`` unchanged. No such dataset exists yet, so none is built. +""" + +import os +from collections.abc import Iterator +from pathlib import Path + +import boto3 +from botocore import UNSIGNED +from botocore.client import Config + +from .models import DownloadError + +__all__ = ["client", "download_object", "list_common_prefixes", "list_objects"] + + +def client(): + """Return one anonymous (unsigned) S3 client, safe to share across threads.""" + return boto3.client("s3", config=Config(signature_version=UNSIGNED)) + + +def _pages(s3_client, bucket: str, prefix: str, delimiter: str | None): + """Yield every page of a paginated ``list_objects_v2`` call.""" + kwargs = {"Bucket": bucket, "Prefix": prefix} + if delimiter is not None: + kwargs["Delimiter"] = delimiter + yield from s3_client.get_paginator("list_objects_v2").paginate(**kwargs) + + +def list_objects(s3_client, bucket: str, prefix: str) -> Iterator[dict]: + """Yield each object entry (``Key``, ``Size``, ``LastModified``) under ``prefix``.""" + for page in _pages(s3_client, bucket, prefix, delimiter=None): + yield from page.get("Contents", []) + + +def list_common_prefixes(s3_client, bucket: str, prefix: str) -> Iterator[str]: + """Yield each immediate sub-prefix (folder) under ``prefix``.""" + for page in _pages(s3_client, bucket, prefix, delimiter="/"): + for entry in page.get("CommonPrefixes", []): + yield entry["Prefix"] + + +def download_object(s3_client, bucket: str, key: str, dest_path: Path, expected_size: int) -> None: + """Download ``key`` to ``dest_path``, verifying byte size, atomically. + + Writes to ``dest_path.part``, asserts the on-disk size equals + ``expected_size``, then ``os.replace`` to the final path (atomic on POSIX). + Any failure removes the partial file and raises, so a partial download never + occupies the real path. + """ + part = dest_path.with_name(dest_path.name + ".part") + ok = False + try: + s3_client.download_file(bucket, key, str(part)) + actual = part.stat().st_size + if actual != expected_size: + raise DownloadError( + f"size mismatch for {key}: expected {expected_size} bytes, got {actual}" + ) + os.replace(part, dest_path) + ok = True + finally: + if not ok: + part.unlink(missing_ok=True) diff --git a/tests/downloaders/test_models.py b/tests/downloaders/test_models.py new file mode 100644 index 0000000..bd1f8b0 --- /dev/null +++ b/tests/downloaders/test_models.py @@ -0,0 +1,45 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Value-object behavior for the downloaders package.""" + +import dataclasses +from datetime import UTC, datetime +from pathlib import Path + +import pytest + +from adapt.downloaders.models import ArchiveScan, DownloadError, DownloadResult, LocalScan + +pytestmark = pytest.mark.unit + + +def test_archive_scan_is_frozen(): + scan = ArchiveScan(key="k", scan_time=datetime(2013, 5, 20, tzinfo=UTC), size=10) + assert scan.key == "k" + assert scan.size == 10 + with pytest.raises(dataclasses.FrozenInstanceError): + scan.size = 11 + + +def test_download_result_iter_success_yields_only_successes(): + good = LocalScan(filepath=Path("/tmp/good")) + bad = ArchiveScan(key="bad", scan_time=datetime(2013, 5, 20, tzinfo=UTC), size=1) + result = DownloadResult(success=[good], failed=[bad]) + + assert list(result.iter_success()) == [good] + assert result.failed == [bad] + + +def test_download_result_defaults_are_independent(): + a = DownloadResult() + b = DownloadResult() + a.success.append(LocalScan(filepath=Path("/tmp/x"))) + assert b.success == [] + + +def test_download_error_carries_scan(): + scan = ArchiveScan(key="k", scan_time=datetime(2013, 5, 20, tzinfo=UTC), size=1) + err = DownloadError("boom", scan) + assert err.scan is scan + assert str(err) == "boom" diff --git a/tests/downloaders/test_nexrad.py b/tests/downloaders/test_nexrad.py new file mode 100644 index 0000000..7c88677 --- /dev/null +++ b/tests/downloaders/test_nexrad.py @@ -0,0 +1,236 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""NEXRAD archive logic, mirroring the behaviors the nexradaws suite pinned. + +Source spec: nexradaws-master/tests/{test_nexradAwsInterface,test_nexradAwsFile}.py. +The drill-down (years/months/days) and pyart hooks are dropped; everything kept is +re-pinned here against a stubbed client — no network, no fixtures. The strongest +live pins (KTLX 2013-05-20 18-22 -> 53 scans, real download) are in test_integration. +""" + +from datetime import UTC, date, datetime, timedelta, timezone +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + +from adapt.downloaders.models import ArchiveScan +from adapt.downloaders.nexrad import BUCKET, NexradS3, _to_utc, build_prefix, parse_scan_time + +pytestmark = pytest.mark.unit + + +# ── prefix + filename parsing (mirrors test_prefix_build, test_scan_time) ────── + + +def test_build_prefix(): + # nexradaws test_prefix_build pins exactly this output. + assert build_prefix("KTLX", date(2016, 5, 30)) == "2016/05/30/KTLX/" + + +def test_parse_scan_time_matches_nexradaws_example(): + # nexradaws test_scan_time: KTLX20130531_000358_V06.gz -> 2013-05-31 00:03:58 UTC. + assert parse_scan_time("KTLX20130531_000358_V06.gz") == datetime( + 2013, 5, 31, 0, 3, 58, tzinfo=UTC + ) + + +def test_parse_scan_time_2017_era_without_gz(): + # Filename-era drift: 2017+ scans are bare _V06 (no .gz). + assert parse_scan_time("KTLX20170531_000412_V06") == datetime(2017, 5, 31, 0, 4, 12, tzinfo=UTC) + + +def test_parse_scan_time_rejects_unparseable(): + with pytest.raises(ValueError): + parse_scan_time("not-a-scan.txt") + + +# ── UTC normalization (mirrors test_formattimerange_localtime / _utc) ────────── + + +def test_to_utc_treats_naive_as_utc(): + naive = datetime(2013, 5, 20, 18, 0) + assert _to_utc(naive) == datetime(2013, 5, 20, 18, 0, tzinfo=UTC) + + +def test_to_utc_converts_aware_offset_to_utc(): + central = timezone(timedelta(hours=-5)) + aware = datetime(2013, 5, 20, 18, 0, tzinfo=central) + assert _to_utc(aware) == datetime(2013, 5, 20, 23, 0, tzinfo=UTC) + + +# ── range search (mirrors test_get_available_scans_in_range, _missing) ───────── + + +def _listing_client(contents): + """A client whose list_objects_v2 paginator returns one page of ``contents``.""" + client = MagicMock() + client.get_paginator.return_value.paginate.side_effect = lambda **kwargs: iter( + [{"Contents": contents}] + ) + return client + + +def test_get_avail_scans_in_range_filters_sorts_and_excludes_mdm(): + p = "2013/05/20/KTLX/" + contents = [ + {"Key": f"{p}KTLX20130520_190000_V06.gz", "Size": 12}, # in range + {"Key": f"{p}KTLX20130520_180000_V06.gz", "Size": 10}, # in range (== start) + {"Key": f"{p}KTLX20130520_170000_V06.gz", "Size": 11}, # before start + {"Key": f"{p}KTLX20130520_181500_V06_MDM", "Size": 1}, # MDM sidecar + {"Key": f"{p}KTLX20130520_index.html", "Size": 2}, # not a volume file + ] + store = NexradS3(client=_listing_client(contents)) + + scans = store.get_avail_scans_in_range( + datetime(2013, 5, 20, 18, 0, tzinfo=UTC), + datetime(2013, 5, 20, 19, 30, tzinfo=UTC), + "KTLX", + ) + + assert all(isinstance(s, ArchiveScan) for s in scans) + assert [s.key for s in scans] == [ + f"{p}KTLX20130520_180000_V06.gz", + f"{p}KTLX20130520_190000_V06.gz", + ] + assert [s.size for s in scans] == [10, 12] + assert scans[0].scan_time == datetime(2013, 5, 20, 18, 0, tzinfo=UTC) + + +def test_get_avail_scans_in_range_empty_when_no_data(): + # nexradaws test_get_available_scan_missing: missing data -> [] (never raises). + store = NexradS3(client=_listing_client([])) + scans = store.get_avail_scans_in_range( + datetime(2013, 5, 20, 18, 0, tzinfo=UTC), + datetime(2013, 5, 20, 19, 0, tzinfo=UTC), + "KTLX", + ) + assert scans == [] + + +# ── future-date guard (mirrors nexradawsinterface.py:192-197) ────────────────── + + +def test_future_start_raises(): + store = NexradS3(client=MagicMock()) + start = datetime.now(UTC) + timedelta(days=1) + with pytest.raises(ValueError): + store.get_avail_scans_in_range(start, start + timedelta(hours=1), "KTLX") + + +def test_future_end_is_clamped_so_future_days_are_not_listed(): + # A far-future end must not fire one S3 listing per non-existent future day. + client = _listing_client([]) + store = NexradS3(client=client) + store.get_avail_scans_in_range( + datetime.now(UTC) - timedelta(hours=1), + datetime.now(UTC) + timedelta(days=30), + "KTLX", + ) + assert client.get_paginator.return_value.paginate.call_count <= 2 + + +# ── radar listing (mirrors test_get_available_radars, _missing) ──────────────── + + +def _prefix_client(prefixes): + client = MagicMock() + client.get_paginator.return_value.paginate.side_effect = lambda **kwargs: iter( + [{"CommonPrefixes": [{"Prefix": p} for p in prefixes]}] + ) + return client + + +def test_get_avail_radars_parses_common_prefixes(): + store = NexradS3(client=_prefix_client(["2013/05/20/KTLX/", "2013/05/20/KOUN/"])) + assert store.get_avail_radars("2013", "05", "20") == ["KTLX", "KOUN"] + + +def test_get_avail_radars_empty_when_missing(): + # nexradaws test_get_available_radars_missing: absent day -> []. + store = NexradS3(client=_prefix_client([])) + assert store.get_avail_radars("1900", "05", "31") == [] + + +# ── download (mirrors test_download_single/_multiple, test_failed/_count) ─────── + + +def _download_client(sizes): + """A client whose download_file writes ``sizes[key]`` bytes to the path.""" + client = MagicMock() + client.download_file.side_effect = lambda bucket, key, path: Path(path).write_bytes( + b"x" * sizes[key] + ) + return client + + +def test_download_single_writes_file_and_reports_success(tmp_path): + key = "2013/05/31/KTLX/KTLX20130531_000358_V06.gz" + scan = ArchiveScan(key=key, scan_time=datetime(2013, 5, 31, 0, 3, 58, tzinfo=UTC), size=5) + store = NexradS3(client=_download_client({key: 5})) + + result = store.download([scan], tmp_path / "out") + + success = list(result.iter_success()) + assert len(success) == 1 + assert success[0].filepath == tmp_path / "out" / "KTLX20130531_000358_V06.gz" + assert success[0].filepath.read_bytes() == b"x" * 5 + assert result.failed == [] + + +def test_download_multiple_all_succeed(tmp_path): + p = "2013/05/31/KTLX/" + sizes = { + f"{p}KTLX20130531_000358_V06.gz": 3, + f"{p}KTLX20130531_000834_V06.gz": 4, + f"{p}KTLX20130531_001311_V06.gz": 5, + } + scans = [ + ArchiveScan(key=k, scan_time=datetime(2013, 5, 31, tzinfo=UTC), size=v) + for k, v in sizes.items() + ] + store = NexradS3(client=_download_client(sizes)) + + result = store.download(scans, tmp_path) + + assert len(list(result.iter_success())) == 3 + assert result.failed == [] + for key in sizes: + assert (tmp_path / Path(key).name).exists() + + +def test_download_bad_key_is_recorded_as_failure(tmp_path): + # nexradaws test_failed_count forces a failure by corrupting a key to 'blah/blah'. + good = "2013/05/31/KTLX/KTLX20130531_000358_V06.gz" + bad = "blah/blah" + client = MagicMock() + + def _dl(bucket, key, path): + if key == bad: + raise RuntimeError("NoSuchKey") + Path(path).write_bytes(b"x" * 4) + + client.download_file.side_effect = _dl + scans = [ + ArchiveScan(key=good, scan_time=datetime(2013, 5, 31, tzinfo=UTC), size=4), + ArchiveScan(key=bad, scan_time=datetime(2013, 5, 31, tzinfo=UTC), size=4), + ] + store = NexradS3(client=client) + + result = store.download(scans, tmp_path) + + assert len(list(result.iter_success())) == 1 + assert [s.key for s in result.failed] == [bad] + + +def test_download_rejects_keep_aws_folders(tmp_path): + # We deliberately do not implement the AWS folder structure (acquisition uses flat). + scan = ArchiveScan(key="k/scan.gz", scan_time=datetime(2013, 5, 20, tzinfo=UTC), size=1) + store = NexradS3(client=MagicMock()) + with pytest.raises(NotImplementedError): + store.download([scan], tmp_path, keep_aws_folders=True) + + +def test_bucket_is_unidata_mirror(): + assert BUCKET == "unidata-nexrad-level2" diff --git a/tests/downloaders/test_s3.py b/tests/downloaders/test_s3.py new file mode 100644 index 0000000..bfcc78e --- /dev/null +++ b/tests/downloaders/test_s3.py @@ -0,0 +1,97 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Unsigned-S3 primitives: pagination and verified atomic download. + +All tests use a stubbed client (unittest.mock) — no network, no fixtures. +""" + +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + +from adapt.downloaders import s3 +from adapt.downloaders.models import DownloadError + +pytestmark = pytest.mark.unit + + +def _paginating_client(pages): + """A client whose list_objects_v2 paginator yields ``pages`` (fresh each call).""" + client = MagicMock() + client.get_paginator.return_value.paginate.side_effect = lambda **kwargs: iter(pages) + return client + + +def test_list_objects_spans_multiple_pages(): + pages = [ + {"Contents": [{"Key": "a", "Size": 1}, {"Key": "b", "Size": 2}]}, + {"Contents": [{"Key": "c", "Size": 3}]}, + ] + client = _paginating_client(pages) + + entries = list(s3.list_objects(client, "bucket", "prefix/")) + + assert [e["Key"] for e in entries] == ["a", "b", "c"] + client.get_paginator.assert_called_once_with("list_objects_v2") + client.get_paginator.return_value.paginate.assert_called_once_with( + Bucket="bucket", Prefix="prefix/" + ) + + +def test_list_objects_handles_empty_page(): + client = _paginating_client([{}]) + assert list(s3.list_objects(client, "bucket", "prefix/")) == [] + + +def test_list_common_prefixes_uses_delimiter(): + pages = [{"CommonPrefixes": [{"Prefix": "2013/05/20/KTLX/"}, {"Prefix": "2013/05/20/KOUN/"}]}] + client = _paginating_client(pages) + + prefixes = list(s3.list_common_prefixes(client, "bucket", "2013/05/20/")) + + assert prefixes == ["2013/05/20/KTLX/", "2013/05/20/KOUN/"] + client.get_paginator.return_value.paginate.assert_called_once_with( + Bucket="bucket", Prefix="2013/05/20/", Delimiter="/" + ) + + +def _writing_client(content: bytes): + """A client whose download_file writes ``content`` to the requested path.""" + client = MagicMock() + client.download_file.side_effect = lambda bucket, key, path: Path(path).write_bytes(content) + return client + + +def test_download_object_verifies_size_and_renames_atomically(tmp_path): + dest = tmp_path / "KTLX20130520_180000_V06.gz" + client = _writing_client(b"x" * 100) + + s3.download_object(client, "bucket", "key", dest, 100) + + assert dest.read_bytes() == b"x" * 100 + assert not dest.with_name(dest.name + ".part").exists() + + +def test_download_object_size_mismatch_raises_and_cleans_up(tmp_path): + dest = tmp_path / "scan.gz" + client = _writing_client(b"x" * 50) + + with pytest.raises(DownloadError): + s3.download_object(client, "bucket", "key", dest, 100) + + assert not dest.exists() + assert not dest.with_name(dest.name + ".part").exists() + + +def test_download_object_transient_failure_cleans_up_part(tmp_path): + dest = tmp_path / "scan.gz" + client = MagicMock() + client.download_file.side_effect = RuntimeError("network drop") + + with pytest.raises(RuntimeError): + s3.download_object(client, "bucket", "key", dest, 100) + + assert not dest.exists() + assert not dest.with_name(dest.name + ".part").exists() From 94842582793f138242f5009112185c6587e34995 Mon Sep 17 00:00:00 2001 From: Bhupendra Raut Date: Sun, 28 Jun 2026 22:30:22 -0500 Subject: [PATCH 8/8] FIX: Added native downloader --- .importlinter | 28 +++++++++++++++++++ environment.yml | 2 +- pyproject.toml | 2 +- requirements.txt | 2 +- src/adapt/__init__.py | 6 ++-- src/adapt/modules/acquisition/module.py | 20 ++++++------- src/adapt/modules/ingest/module.py | 2 +- .../acquisition/test_downloader_quiet.py | 14 +++++----- tests/runtime/test_thirdparty_quiet.py | 13 +++++---- tests/test_architecture.py | 3 +- 10 files changed, 60 insertions(+), 32 deletions(-) diff --git a/.importlinter b/.importlinter index b54ed12..5b47b63 100644 --- a/.importlinter +++ b/.importlinter @@ -173,3 +173,31 @@ forbidden_modules = adapt.visualization adapt.cli adapt.extensions + +# ========================================================== +# 8. Downloaders are a self-contained third-party I/O leaf +# ========================================================== +# +# adapt.downloaders isolates all boto3/S3 calls. It is a leaf: +# it depends only on third-party libraries and stdlib, never on +# any adapt internal, so it can be reused without dragging in +# the rest of the package and never creates an import cycle. + +[importlinter:contract:downloaders_are_leaf] +name = Downloaders package imports no adapt internals +type = forbidden +source_modules = + adapt.downloaders +forbidden_modules = + adapt.modules + adapt.runtime + adapt.persistence + adapt.configuration + adapt.execution + adapt.api + adapt.gui + adapt.visualization + adapt.cli + adapt.extensions + adapt.contracts + adapt.utils diff --git a/environment.yml b/environment.yml index f7c5cd1..aa3bd29 100644 --- a/environment.yml +++ b/environment.yml @@ -28,9 +28,9 @@ dependencies: - contextily - pyarrow - duckdb + - boto3 - pydata-sphinx-theme - pydantic - pip: - - nexradaws - sphinx-autodoc-typehints - myst-parser>=2.0 diff --git a/pyproject.toml b/pyproject.toml index 6584b5c..4a08e18 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ dependencies = [ "pydantic>=2.0", "arm_pyart", "opencv-python", - "nexradaws", + "boto3", "pyarrow", "duckdb", ] diff --git a/requirements.txt b/requirements.txt index 9726363..7e3ffad 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ scipy # NEXRAD & Radar arm_pyart -nexradaws +boto3 # anonymous S3 access for the native NEXRAD downloader # Image processing opencv-python diff --git a/src/adapt/__init__.py b/src/adapt/__init__.py index 8e19fe7..c7fe3f0 100644 --- a/src/adapt/__init__.py +++ b/src/adapt/__init__.py @@ -15,9 +15,9 @@ # Quiet third-party import-time chatter before any submodule (and its transitive # deps) load. Py-ART prints a citation banner on import unless PYART_QUIET is set, -# and it is pulled in transitively by nexradaws via the acquisition source — earlier -# than any Adapt module could set this. The package root is the one place guaranteed -# to run first. setdefault preserves a user-provided override. +# and the ingest module imports pyart at its own import time — earlier than any +# Adapt module could set this. The package root is the one place guaranteed to run +# first. setdefault preserves a user-provided override. _os.environ.setdefault("PYART_QUIET", "1") import importlib.metadata as _importlib_metadata diff --git a/src/adapt/modules/acquisition/module.py b/src/adapt/modules/acquisition/module.py index 093bfbd..56ee3c3 100644 --- a/src/adapt/modules/acquisition/module.py +++ b/src/adapt/modules/acquisition/module.py @@ -15,7 +15,7 @@ from datetime import UTC, datetime, timedelta from pathlib import Path -from nexradaws import NexradAwsInterface +from adapt.downloaders import NexradS3 __all__ = ["AwsNexradDownloader"] @@ -34,8 +34,8 @@ class AwsNexradDownloader(threading.Thread): research studies. **AWS S3 Bucket:** Files stored at - `s3://noaa-nexrad-level2/{YYYY}/{MM}/{DD}/{radar}/` - Example: `s3://noaa-nexrad-level2/2025/03/05/KDIX/KDIX20250305_000310_V06` + `s3://unidata-nexrad-level2/{YYYY}/{MM}/{DD}/{radar}/` + Example: `s3://unidata-nexrad-level2/2025/03/05/KDIX/KDIX20250305_000310_V06` **Deduplication:** Maintains set of known files to avoid re-downloading. Safe to restart mid-execution. @@ -100,7 +100,7 @@ def __init__( file_tracker : FileProcessingTracker, optional Optional file processing tracker to record download completion. - conn : nexradaws.NexradAwsInterface, optional + conn : adapt.downloaders.NexradS3, optional AWS S3 connection object. If None, creates new connection. Allows injection for testing. @@ -114,9 +114,8 @@ def __init__( Notes ----- - Requires AWS credentials configured via environment variables or - ~/.aws/credentials. The S3 bucket is public and requires no auth, - but credentials can speed up downloads (higher rate limits). + The S3 bucket is public; downloads use anonymous (unsigned) requests + and require no AWS credentials. """ super().__init__(daemon=True) @@ -135,7 +134,7 @@ def __init__( self.file_tracker = file_tracker self.result_queue = result_queue - self.conn = conn or NexradAwsInterface() + self.conn = conn or NexradS3() # injectable time helpers for testing self._clock = clock or (lambda: datetime.now(UTC)) self._sleep = sleeper or time.sleep @@ -551,9 +550,8 @@ def _download_scan(self, scan, local_path: Path) -> bool: temp_dir = base_dir / "_temp" temp_dir.mkdir(exist_ok=True) - # nexradaws prints "Downloaded ..." / "n out of m files downloaded..." - # straight to stdout with no quiet option. Contain it to this call (we log - # a controlled "Downloaded: " below); logging uses stderr, so this + # Contain any stdout the conn's download() may emit (we log our own + # controlled "Downloaded: " below). Logging uses stderr, so this # narrow stdout redirect never swallows our own output. with contextlib.redirect_stdout(io.StringIO()): results = self.conn.download([scan], temp_dir, keep_aws_folders=False) diff --git a/src/adapt/modules/ingest/module.py b/src/adapt/modules/ingest/module.py index b1aa976..d706306 100644 --- a/src/adapt/modules/ingest/module.py +++ b/src/adapt/modules/ingest/module.py @@ -25,7 +25,7 @@ import xarray as xr # NB: the Py-ART citation banner is suppressed at the package root (adapt/__init__ -# sets PYART_QUIET) because nexradaws imports pyart before this module ever loads. +# sets PYART_QUIET) because pyart is imported at this module's import time. __all__ = ["RadarDataLoader"] diff --git a/tests/modules/acquisition/test_downloader_quiet.py b/tests/modules/acquisition/test_downloader_quiet.py index 7554c4d..af6385f 100644 --- a/tests/modules/acquisition/test_downloader_quiet.py +++ b/tests/modules/acquisition/test_downloader_quiet.py @@ -1,12 +1,12 @@ # Copyright © 2026, UChicago Argonne, LLC # See LICENSE for terms and disclaimer. -"""The nexradaws download call must not leak its own print() chatter to the console. +"""The conn's download call must not leak any print() chatter to the console. -nexradaws prints "Downloaded " and " out of files downloaded..." straight -to stdout with no quiet option. The acquisition module already logs a controlled -"Downloaded: " line, so the library's duplicate prints are pure clutter and must -be contained at the one call site (no supported quiet flag exists). +The acquisition module already logs a controlled "Downloaded: " line, so any +stdout a download backend emits is pure clutter and must be contained at the one call +site. (The retired ``nexradaws`` backend printed such chatter unconditionally; the +guard remains so no backend can leak to the console.) """ from datetime import UTC, datetime @@ -18,10 +18,10 @@ pytestmark = pytest.mark.unit -def test_download_scan_suppresses_nexradaws_stdout(tmp_path, fake_scan, make_config, capsys): +def test_download_scan_suppresses_conn_stdout(tmp_path, fake_scan, make_config, capsys): class PrintingConn: def download(self, files, basepath, keep_aws_folders=False): - print("Downloaded KOHX_TEST") # nexradaws chatter + print("Downloaded KOHX_TEST") # backend chatter print("1 out of 1 files downloaded...0 errors") class _Results: diff --git a/tests/runtime/test_thirdparty_quiet.py b/tests/runtime/test_thirdparty_quiet.py index 810006a..6c0bf86 100644 --- a/tests/runtime/test_thirdparty_quiet.py +++ b/tests/runtime/test_thirdparty_quiet.py @@ -3,11 +3,11 @@ """Third-party libraries must not splatter their own banners to the Adapt console. -Py-ART prints a citation block on import; it is pulled in transitively by -``nexradaws`` (acquisition -> sources -> runtime), which imports pyart at its own -import time — *before* any Adapt module could set PYART_QUIET. So the suppression -must live in the package root (adapt/__init__), which runs before every submodule. -Verified in clean subprocesses importing the runtime path that triggers nexradaws. +Py-ART prints a citation block on import; the ingest module imports pyart at its own +import time (pulled in via the runtime import chain) — *before* any Adapt module could +set PYART_QUIET. So the suppression must live in the package root (adapt/__init__), +which runs before every submodule. Verified in clean subprocesses importing the +runtime path that triggers the pyart import. """ import os @@ -30,7 +30,8 @@ def _import_clean(statement: str) -> subprocess.CompletedProcess: def test_importing_runtime_does_not_print_pyart_citation(): - # adapt.runtime pulls nexradaws -> pyart; the banner must still be suppressed. + # adapt.runtime pulls in pyart (via the ingest import chain); the banner must + # still be suppressed. result = _import_clean("import adapt.runtime.processor") assert result.returncode == 0, result.stderr assert "Py-ART" not in result.stdout diff --git a/tests/test_architecture.py b/tests/test_architecture.py index 52a9e3d..f5f5074 100644 --- a/tests/test_architecture.py +++ b/tests/test_architecture.py @@ -221,7 +221,8 @@ def test_module_does_not_read_wall_clock_or_global_rng(pkg: str) -> None: "tkinter": ("consumers/",), "cv2": ("modules/projection/",), "pyart": ("modules/ingest/",), - "nexradaws": ("modules/acquisition/",), + "boto3": ("downloaders/",), + "botocore": ("downloaders/",), "networkx": ("modules/tracking/",), "duckdb": ("api/",), "pyxlma": ("modules/xlma_stat/",),