diff --git a/src/microplex/targets/reweighting.py b/src/microplex/targets/reweighting.py index f265277..0257cdf 100644 --- a/src/microplex/targets/reweighting.py +++ b/src/microplex/targets/reweighting.py @@ -2,6 +2,7 @@ from __future__ import annotations +from collections.abc import Mapping from dataclasses import dataclass, field from typing import Any @@ -12,6 +13,12 @@ from microplex.targets.benchmarking import relative_error_ratio from microplex.targets.bundles import EntityTableBundle from microplex.targets.spec import FilterOperator, TargetAggregation, TargetSpec +from microplex.telemetry import ( + CalibrationEpochEvent, + CalibrationTargetEvent, + TelemetryWriter, + effective_sample_size, +) @dataclass(frozen=True) @@ -24,17 +31,84 @@ class TargetReweightingConstraint: coefficients: np.ndarray target: float metadata: dict[str, Any] = field(default_factory=dict) + estimate_weight_indexes: np.ndarray | None = None + estimate_coefficients: np.ndarray | None = None + denominator_weight_indexes: np.ndarray | None = None + denominator_coefficients: np.ndarray | None = None + display_target: float | None = None def __post_init__(self) -> None: indexes = np.asarray(self.weight_indexes, dtype=int) coefficients = np.asarray(self.coefficients, dtype=float) - if indexes.ndim != 1 or coefficients.ndim != 1: - raise ValueError("TargetReweightingConstraint arrays must be one-dimensional") - if len(indexes) != len(coefficients): - raise ValueError("weight_indexes and coefficients must have the same length") + _validate_parallel_arrays( + indexes, + coefficients, + name="TargetReweightingConstraint", + ) object.__setattr__(self, "weight_indexes", indexes) object.__setattr__(self, "coefficients", coefficients) object.__setattr__(self, "target", float(self.target)) + if ( + self.estimate_weight_indexes is not None + or self.estimate_coefficients is not None + ): + if ( + self.estimate_weight_indexes is None + or self.estimate_coefficients is None + ): + raise ValueError( + "TargetReweightingConstraint estimate arrays must be provided together" + ) + estimate_indexes = np.asarray(self.estimate_weight_indexes, dtype=int) + estimate_coefficients = np.asarray(self.estimate_coefficients, dtype=float) + _validate_parallel_arrays( + estimate_indexes, + estimate_coefficients, + name="TargetReweightingConstraint estimate", + ) + object.__setattr__(self, "estimate_weight_indexes", estimate_indexes) + object.__setattr__(self, "estimate_coefficients", estimate_coefficients) + if ( + self.denominator_weight_indexes is not None + or self.denominator_coefficients is not None + ): + if ( + self.denominator_weight_indexes is None + or self.denominator_coefficients is None + ): + raise ValueError( + "TargetReweightingConstraint denominator arrays must be provided together" + ) + denominator_indexes = np.asarray(self.denominator_weight_indexes, dtype=int) + denominator_coefficients = np.asarray( + self.denominator_coefficients, + dtype=float, + ) + _validate_parallel_arrays( + denominator_indexes, + denominator_coefficients, + name="TargetReweightingConstraint denominator", + ) + object.__setattr__(self, "denominator_weight_indexes", denominator_indexes) + object.__setattr__( + self, + "denominator_coefficients", + denominator_coefficients, + ) + if self.display_target is not None: + object.__setattr__(self, "display_target", float(self.display_target)) + + +def _validate_parallel_arrays( + indexes: np.ndarray, + coefficients: np.ndarray, + *, + name: str, +) -> None: + if indexes.ndim != 1 or coefficients.ndim != 1: + raise ValueError(f"{name} arrays must be one-dimensional") + if len(indexes) != len(coefficients): + raise ValueError(f"{name} indexes and coefficients must have the same length") @dataclass(frozen=True) @@ -88,7 +162,9 @@ def compile_target_reweighting_constraints( continue aligned_weight_indexes = _coerce_weight_indexes(weight_indexes, len(frame)) missing_features = [ - feature for feature in target.required_features if feature not in frame.columns + feature + for feature in target.required_features + if feature not in frame.columns ] if missing_features: skipped.append( @@ -105,16 +181,34 @@ def compile_target_reweighting_constraints( if not active.any(): skipped.append((target.name, "zero_support")) continue - grouped = ( - pd.DataFrame( - { - "weight_index": aligned_weight_indexes[active], - "coefficient": coefficients.loc[active], - } - ) - .groupby("weight_index", dropna=False)["coefficient"] - .sum() + grouped = _group_weight_coefficients( + aligned_weight_indexes[active], + coefficients.loc[active], ) + metadata = dict(target.metadata) + if target.source is not None: + metadata.setdefault("source", target.source) + metadata.setdefault("period", str(target.period)) + metadata.setdefault("aggregation", target.aggregation.value) + estimate_indexes: np.ndarray | None = None + estimate_coefficients: np.ndarray | None = None + denominator_indexes: np.ndarray | None = None + denominator_coefficients: np.ndarray | None = None + if target.aggregation is TargetAggregation.MEAN and target.measure is not None: + support = mask.astype(bool) + measure_values = _numeric_series(frame[target.measure]).fillna(0.0) + numerator = _group_weight_coefficients( + aligned_weight_indexes[support], + measure_values.loc[support], + ) + denominator = _group_weight_coefficients( + aligned_weight_indexes[support], + np.ones(int(support.sum()), dtype=float), + ) + estimate_indexes = numerator.index.to_numpy(dtype=int) + estimate_coefficients = numerator.to_numpy(dtype=float) + denominator_indexes = denominator.index.to_numpy(dtype=int) + denominator_coefficients = denominator.to_numpy(dtype=float) constraints.append( TargetReweightingConstraint( name=target.name, @@ -122,7 +216,12 @@ def compile_target_reweighting_constraints( weight_indexes=grouped.index.to_numpy(dtype=int), coefficients=grouped.to_numpy(dtype=float), target=_constraint_target_value(target), - metadata=dict(target.metadata), + metadata=metadata, + estimate_weight_indexes=estimate_indexes, + estimate_coefficients=estimate_coefficients, + denominator_weight_indexes=denominator_indexes, + denominator_coefficients=denominator_coefficients, + display_target=float(target.value), ) ) @@ -148,12 +247,23 @@ def compile_entity_table_bundle_target_constraints( def reweight_to_target_constraints( initial_weights: pd.Series | np.ndarray, *, - constraints: list[TargetReweightingConstraint] | tuple[TargetReweightingConstraint, ...], + constraints: list[TargetReweightingConstraint] + | tuple[TargetReweightingConstraint, ...], max_iter: int = 8, tol: float = 1e-4, factor_bounds: tuple[float, float] = (0.5, 2.0), + telemetry_writer: TelemetryWriter | None = None, + run_id: str | None = None, + calibration_id: str = "target_reweighting", ) -> tuple[np.ndarray, TargetReweightingDiagnostics]: """Apply multiplicative updates to match compiled linear target constraints.""" + if ( + telemetry_writer is not None + and getattr(telemetry_writer, "enabled", True) + and not run_id + ): + raise ValueError("run_id is required when telemetry_writer is provided") + weights = np.asarray(initial_weights, dtype=float).copy() lower_factor, upper_factor = factor_bounds converged = False @@ -176,7 +286,9 @@ def reweight_to_target_constraints( max_change = 0.0 skipped_nonpositive_positive_target = False for constraint in compiled: - current = float(np.dot(weights[constraint.weight_indexes], constraint.coefficients)) + current = float( + np.dot(weights[constraint.weight_indexes], constraint.coefficients) + ) target_value = float(constraint.target) if target_value == 0.0: current_abs = abs(current) @@ -193,10 +305,20 @@ def reweight_to_target_constraints( if current <= 0.0: skipped_nonpositive_positive_target = True continue - factor = float(np.clip(target_value / current, lower_factor, upper_factor)) + factor = float( + np.clip(target_value / current, lower_factor, upper_factor) + ) weights[constraint.weight_indexes] *= factor max_change = max(max_change, abs(factor - 1.0)) iterations = iteration + 1 + _emit_reweighting_epoch_telemetry( + telemetry_writer, + run_id=run_id, + calibration_id=calibration_id, + epoch=iterations, + constraints=compiled, + weights=weights, + ) if max_change < tol: converged = True break @@ -204,7 +326,9 @@ def reweight_to_target_constraints( if skipped_nonpositive_positive_target: converged = False - errors = [constraint_abs_relative_error(constraint, weights) for constraint in compiled] + errors = [ + constraint_abs_relative_error(constraint, weights) for constraint in compiled + ] diagnostics = TargetReweightingDiagnostics( target_count=len(compiled), constraint_count=len(compiled), @@ -213,6 +337,14 @@ def reweight_to_target_constraints( mean_abs_relative_error=float(np.mean(errors)) if errors else 0.0, max_abs_relative_error=float(np.max(errors)) if errors else 0.0, ) + _emit_reweighting_target_telemetry( + telemetry_writer, + run_id=run_id, + calibration_id=calibration_id, + epoch_or_final="final", + constraints=compiled, + weights=weights, + ) return weights, diagnostics @@ -223,6 +355,9 @@ def reweight_entity_table_bundle_targets( max_iter: int = 8, tol: float = 1e-4, factor_bounds: tuple[float, float] = (0.5, 2.0), + telemetry_writer: TelemetryWriter | None = None, + run_id: str | None = None, + calibration_id: str = "target_reweighting", ) -> EntityTableBundleReweightingResult: """Compile and reweight a shared entity-table bundle in one step.""" compilation = compile_entity_table_bundle_target_constraints( @@ -235,6 +370,9 @@ def reweight_entity_table_bundle_targets( max_iter=max_iter, tol=tol, factor_bounds=factor_bounds, + telemetry_writer=telemetry_writer, + run_id=run_id, + calibration_id=calibration_id, ) return EntityTableBundleReweightingResult( bundle=bundle.with_updated_weights(weights), @@ -272,11 +410,15 @@ def compile_sparse_target_constraints( def calibrate_sparse_target_weights( initial_weights: pd.Series | np.ndarray, *, - constraints: list[TargetReweightingConstraint] | tuple[TargetReweightingConstraint, ...], + constraints: list[TargetReweightingConstraint] + | tuple[TargetReweightingConstraint, ...], target_count: int | None = None, max_iter: int = 8, tol: float = 1e-4, factor_bounds: tuple[float, float] = (0.5, 2.0), + telemetry_writer: TelemetryWriter | None = None, + run_id: str | None = None, + calibration_id: str = "sparse_target_calibration", ) -> tuple[np.ndarray, TargetReweightingDiagnostics]: """Compatibility wrapper around target reweighting.""" weights, diagnostics = reweight_to_target_constraints( @@ -285,6 +427,9 @@ def calibrate_sparse_target_weights( max_iter=max_iter, tol=tol, factor_bounds=factor_bounds, + telemetry_writer=telemetry_writer, + run_id=run_id, + calibration_id=calibration_id, ) if target_count is None: return weights, diagnostics @@ -303,7 +448,9 @@ def constraint_abs_relative_error( weights: np.ndarray, ) -> float: """Compute absolute relative error for one compiled constraint.""" - estimate = float(np.dot(weights[constraint.weight_indexes], constraint.coefficients)) + estimate = float( + np.dot(weights[constraint.weight_indexes], constraint.coefficients) + ) return abs(relative_error_ratio(estimate, constraint.target)) @@ -315,6 +462,148 @@ def sparse_constraint_abs_rel_error( return constraint_abs_relative_error(constraint, weights) +def _emit_reweighting_epoch_telemetry( + telemetry_writer: TelemetryWriter | None, + *, + run_id: str | None, + calibration_id: str, + epoch: int, + constraints: tuple[TargetReweightingConstraint, ...], + weights: np.ndarray, +) -> None: + if telemetry_writer is None or run_id is None: + return + errors = [ + constraint_abs_relative_error(constraint, weights) for constraint in constraints + ] + data_loss = float(np.mean(errors)) if errors else 0.0 + telemetry_writer.emit( + CalibrationEpochEvent( + run_id=run_id, + calibration_id=calibration_id, + epoch=epoch, + objective=data_loss, + data_loss=data_loss, + l0_penalty=0.0, + l2_penalty=0.0, + nonzero_weights=int(np.count_nonzero(weights > 0.0)), + ess=effective_sample_size(weights), + ) + ) + + +def _emit_reweighting_target_telemetry( + telemetry_writer: TelemetryWriter | None, + *, + run_id: str | None, + calibration_id: str, + epoch_or_final: int | str, + constraints: tuple[TargetReweightingConstraint, ...], + weights: np.ndarray, +) -> None: + if telemetry_writer is None or run_id is None: + return + events = [] + for constraint in constraints: + target_value = _constraint_telemetry_target_value(constraint) + estimate = _constraint_telemetry_estimate(constraint, weights) + relative_error = relative_error_ratio(estimate, target_value) + events.append( + CalibrationTargetEvent( + run_id=run_id, + calibration_id=calibration_id, + epoch_or_final=epoch_or_final, + target_name=constraint.name, + family=_metadata_scalar(constraint.metadata, "family"), + split=_metadata_scalar(constraint.metadata, "split"), + source=_metadata_scalar(constraint.metadata, "source"), + geography=_metadata_scalar(constraint.metadata, "geography"), + target_value=target_value, + estimate=estimate, + relative_error=float(relative_error), + weighted_term=float(abs(relative_error)), + in_loss_function=True, + support_status=_metadata_scalar( + constraint.metadata, + "support_status", + default="included", + ), + ) + ) + telemetry_writer.emit_many(events) + + +def _constraint_telemetry_target_value( + constraint: TargetReweightingConstraint, +) -> float: + if constraint.display_target is not None: + return float(constraint.display_target) + return float(constraint.target) + + +def _constraint_telemetry_estimate( + constraint: TargetReweightingConstraint, + weights: np.ndarray, +) -> float: + if ( + constraint.estimate_weight_indexes is not None + and constraint.estimate_coefficients is not None + ): + numerator = float( + np.dot( + weights[constraint.estimate_weight_indexes], + constraint.estimate_coefficients, + ) + ) + if ( + constraint.denominator_weight_indexes is not None + and constraint.denominator_coefficients is not None + ): + denominator = float( + np.dot( + weights[constraint.denominator_weight_indexes], + constraint.denominator_coefficients, + ) + ) + if denominator == 0.0: + return 0.0 + return numerator / denominator + return numerator + return float(np.dot(weights[constraint.weight_indexes], constraint.coefficients)) + + +def _metadata_scalar( + metadata: Mapping[str, Any], + key: str, + *, + default: str | None = None, +) -> str | None: + value = metadata.get(key, default) + if value is None: + return None + if isinstance(value, str): + return value + if isinstance(value, bool | int | float): + return str(value) + return default + + +def _group_weight_coefficients( + weight_indexes: np.ndarray, + coefficients: pd.Series | np.ndarray, +) -> pd.Series: + return ( + pd.DataFrame( + { + "weight_index": weight_indexes, + "coefficient": np.asarray(coefficients, dtype=float), + } + ) + .groupby("weight_index", dropna=False)["coefficient"] + .sum() + ) + + def _coerce_weight_indexes( values: pd.Series | np.ndarray, expected_length: int, diff --git a/src/microplex/telemetry/__init__.py b/src/microplex/telemetry/__init__.py new file mode 100644 index 0000000..61ee20d --- /dev/null +++ b/src/microplex/telemetry/__init__.py @@ -0,0 +1,606 @@ +"""Run and calibration telemetry primitives for Microplex. + +The module is intentionally country-agnostic. Country packages can enrich +target metadata before creating these events, while core owns the event shapes, +privacy guardrails, and local / remote writer plumbing. +""" + +from __future__ import annotations + +import json +import math +import os +from collections.abc import Iterable, Mapping +from dataclasses import asdict, dataclass, is_dataclass +from datetime import UTC, datetime +from enum import Enum +from pathlib import Path +from typing import Any, ClassVar, Protocol + +import httpx +import numpy as np +import pandas as pd + +TELEMETRY_SCHEMA_VERSION = "microplex.telemetry.v1" + + +JsonRecord = dict[str, Any] + +_SUPABASE_TABLES = { + "run": "runs", + "stage": "run_stages", + "calibration_epoch": "calibration_epochs", + "calibration_target": "calibration_targets", + "artifact": "artifacts", +} + +_SUPABASE_TABLE_COLUMNS = { + "run": ( + "run_id", + "build_id", + "engine", + "period", + "created_at", + "code_ref", + "config_hash", + "incognito", + "status", + ), + "stage": ( + "run_id", + "stage", + "status", + "started_at", + "completed_at", + "elapsed_seconds", + "rss_mb", + "notes", + ), + "calibration_epoch": ( + "run_id", + "calibration_id", + "epoch", + "objective", + "data_loss", + "l0_penalty", + "l2_penalty", + "nonzero_weights", + "ess", + "timestamp", + ), + "calibration_target": ( + "run_id", + "calibration_id", + "epoch_or_final", + "target_name", + "family", + "split", + "source", + "geography", + "target_value", + "estimate", + "relative_error", + "weighted_term", + "in_loss_function", + "support_status", + ), + "artifact": ( + "run_id", + "artifact_kind", + "path_or_uri", + "sha256", + "size_bytes", + "created_at", + ), +} + + +class TelemetryEvent(Protocol): + """Serializable Microplex telemetry event.""" + + event_type: ClassVar[str] + run_id: str + + def to_record(self) -> JsonRecord: + """Return a JSON-safe telemetry record.""" + + +@dataclass(frozen=True) +class RunEvent: + """Run lifecycle metadata.""" + + run_id: str + status: str + build_id: str | None = None + engine: str | None = None + period: int | str | None = None + created_at: str | None = None + code_ref: str | None = None + config_hash: str | None = None + incognito: bool = False + emitted_at: str | None = None + + event_type: ClassVar[str] = "run" + + def to_record(self) -> JsonRecord: + return _event_record(self) + + +@dataclass(frozen=True) +class StageEvent: + """Build-stage lifecycle metadata.""" + + run_id: str + stage: str + status: str + started_at: str | None = None + completed_at: str | None = None + elapsed_seconds: float | None = None + rss_mb: float | None = None + notes: str | None = None + emitted_at: str | None = None + + event_type: ClassVar[str] = "stage" + + def to_record(self) -> JsonRecord: + return _event_record(self) + + +@dataclass(frozen=True) +class CalibrationEpochEvent: + """Aggregate metrics from one calibration epoch or iteration.""" + + run_id: str + calibration_id: str + epoch: int + objective: float | None = None + data_loss: float | None = None + l0_penalty: float | None = None + l2_penalty: float | None = None + nonzero_weights: int | None = None + ess: float | None = None + timestamp: str | None = None + emitted_at: str | None = None + + event_type: ClassVar[str] = "calibration_epoch" + + def to_record(self) -> JsonRecord: + return _event_record(self) + + +@dataclass(frozen=True) +class CalibrationTargetEvent: + """Per-target final or epoch-level calibration diagnostic.""" + + run_id: str + calibration_id: str + epoch_or_final: int | str + target_name: str + target_value: float + estimate: float + relative_error: float + family: str | None = None + split: str | None = None + source: str | None = None + geography: str | None = None + weighted_term: float | None = None + in_loss_function: bool = True + support_status: str | None = None + emitted_at: str | None = None + + event_type: ClassVar[str] = "calibration_target" + + def to_record(self) -> JsonRecord: + return _event_record(self) + + +@dataclass(frozen=True) +class ArtifactEvent: + """Artifact reference emitted by a Microplex run.""" + + run_id: str + artifact_kind: str + path_or_uri: str + sha256: str | None = None + size_bytes: int | None = None + created_at: str | None = None + emitted_at: str | None = None + + event_type: ClassVar[str] = "artifact" + + def to_record(self) -> JsonRecord: + return _event_record(self) + + +class TelemetryWriter(Protocol): + """Telemetry writer protocol shared by local and remote sinks.""" + + enabled: bool + + def emit(self, event: TelemetryEvent | Mapping[str, Any]) -> None: + """Write one telemetry event.""" + + def emit_many(self, events: Iterable[TelemetryEvent | Mapping[str, Any]]) -> None: + """Write multiple telemetry events.""" + + +class NullTelemetryWriter: + """Telemetry writer that intentionally drops all events.""" + + enabled = False + + def emit(self, event: TelemetryEvent | Mapping[str, Any]) -> None: + return None + + def emit_many(self, events: Iterable[TelemetryEvent | Mapping[str, Any]]) -> None: + for event in events: + self.emit(event) + + +class LocalTelemetryWriter: + """Append-only JSONL telemetry writer for local runs and CI artifacts.""" + + enabled = True + + def __init__( + self, + output_dir: str | Path, + *, + incognito: bool = False, + remote_upload_enabled: bool = False, + ) -> None: + self.output_dir = Path(output_dir) + self.incognito = bool(incognito) + self.remote_upload_enabled = bool(remote_upload_enabled and not incognito) + self.output_dir.mkdir(parents=True, exist_ok=True) + self._write_manifest() + + def emit(self, event: TelemetryEvent | Mapping[str, Any]) -> None: + record = normalize_telemetry_event(event) + if self.incognito and record.get("event_type") == "run": + record["incognito"] = True + _append_jsonl(self.output_dir / "events.jsonl", record) + typed_path = self.output_dir / _event_file_name(record["event_type"]) + _append_jsonl(typed_path, record) + + def emit_many(self, events: Iterable[TelemetryEvent | Mapping[str, Any]]) -> None: + for event in events: + self.emit(event) + + def _write_manifest(self) -> None: + manifest = { + "schema_version": TELEMETRY_SCHEMA_VERSION, + "created_at": utc_now(), + "incognito": self.incognito, + "remote_upload_enabled": self.remote_upload_enabled, + } + (self.output_dir / "manifest.json").write_text( + json.dumps(manifest, indent=2, sort_keys=True) + "\n" + ) + + +class SupabaseTelemetryWriter: + """Supabase REST writer for append-only Microplex telemetry events. + + By default events are written to typed tables named by their event shape + (`runs`, `run_stages`, `calibration_epochs`, `calibration_targets`, and + `artifacts`). Passing `table=` switches to a single generic event table with + `event_type`, `run_id`, `emitted_at`, and `payload` columns. + """ + + enabled = True + + def __init__( + self, + supabase_url: str, + supabase_key: str, + *, + table: str | None = None, + table_prefix: str = "", + table_map: Mapping[str, str] | None = None, + client: httpx.Client | None = None, + timeout: float = 30.0, + ) -> None: + if not supabase_url: + raise ValueError("supabase_url is required") + if not supabase_key: + raise ValueError("supabase_key is required") + self.supabase_url = supabase_url.rstrip("/") + self.table = table + self.table_prefix = table_prefix + self.table_map = dict(_SUPABASE_TABLES | dict(table_map or {})) + self._owns_client = client is None + self._client = client or httpx.Client(timeout=timeout) + self._headers = { + "apikey": supabase_key, + "Authorization": f"Bearer {supabase_key}", + "Content-Type": "application/json", + "Prefer": "return=minimal", + } + + @classmethod + def from_env( + cls, + *, + table: str | None = None, + table_prefix: str | None = None, + client: httpx.Client | None = None, + ) -> SupabaseTelemetryWriter: + """Build a Supabase writer from Microplex telemetry environment vars.""" + url = os.environ.get("MICROPLEX_SUPABASE_URL") + key = os.environ.get("MICROPLEX_SUPABASE_SERVICE_ROLE_KEY") or os.environ.get( + "MICROPLEX_SUPABASE_ANON_KEY" + ) + if not url or not key: + raise ValueError( + "MICROPLEX_SUPABASE_URL and a Microplex Supabase key are required" + ) + return cls( + url, + key, + table=table + or os.environ.get("MICROPLEX_TELEMETRY_EVENT_TABLE") + or os.environ.get("MICROPLEX_TELEMETRY_TABLE"), + table_prefix=table_prefix + if table_prefix is not None + else os.environ.get("MICROPLEX_TELEMETRY_TABLE_PREFIX", ""), + client=client, + ) + + def emit(self, event: TelemetryEvent | Mapping[str, Any]) -> None: + record = normalize_telemetry_event(event) + self._post_rows(self._table_for(record), self._row_for(record)) + + def emit_many(self, events: Iterable[TelemetryEvent | Mapping[str, Any]]) -> None: + records = [normalize_telemetry_event(event) for event in events] + if not records: + return + if self.table is not None: + self._post_rows( + self.table, + [ + { + "event_type": record["event_type"], + "run_id": record.get("run_id"), + "emitted_at": record["emitted_at"], + "payload": record, + } + for record in records + ], + ) + return + + rows_by_table: dict[str, list[JsonRecord]] = {} + for record in records: + rows_by_table.setdefault(self._table_for(record), []).append( + _typed_supabase_row(record) + ) + for table, rows in rows_by_table.items(): + self._post_rows(table, rows) + + def _row_for(self, record: JsonRecord) -> JsonRecord: + if self.table is None: + return _typed_supabase_row(record) + return { + "event_type": record["event_type"], + "run_id": record.get("run_id"), + "emitted_at": record["emitted_at"], + "payload": record, + } + + def _table_for(self, record: JsonRecord) -> str: + if self.table is not None: + return self.table + event_type = str(record["event_type"]) + table = self.table_map.get(event_type, "events") + return f"{self.table_prefix}{table}" + + def _post_rows(self, table: str, rows: JsonRecord | list[JsonRecord]) -> None: + response = self._client.post( + f"{self.supabase_url}/rest/v1/{table}", + headers=self._headers, + json=rows, + ) + if response.status_code >= 400: + raise RuntimeError( + f"Supabase telemetry write failed with HTTP {response.status_code}" + ) + + def close(self) -> None: + if self._owns_client: + self._client.close() + + +class CompositeTelemetryWriter: + """Fan-out writer for local + remote telemetry sinks.""" + + def __init__(self, writers: Iterable[TelemetryWriter]) -> None: + self.writers = tuple(writers) + self.enabled = any(getattr(writer, "enabled", True) for writer in self.writers) + + def emit(self, event: TelemetryEvent | Mapping[str, Any]) -> None: + for writer in self.writers: + writer.emit(event) + + def emit_many(self, events: Iterable[TelemetryEvent | Mapping[str, Any]]) -> None: + buffered = tuple(events) + for writer in self.writers: + writer.emit_many(buffered) + + +def build_telemetry_writer( + output_dir: str | Path | None = None, + *, + upload: bool = False, + incognito: bool = False, + supabase_url: str | None = None, + supabase_key: str | None = None, + table: str | None = None, + table_prefix: str | None = None, +) -> TelemetryWriter: + """Create a local, remote, composite, or null telemetry writer. + + `incognito=True` is a hard remote-write off switch. Local artifacts still + mark the run as incognito so dashboards can distinguish private runs. + """ + writers: list[TelemetryWriter] = [] + if output_dir is not None: + writers.append( + LocalTelemetryWriter( + output_dir, + incognito=incognito, + remote_upload_enabled=upload, + ) + ) + + if upload and not incognito: + resolved_url = supabase_url or os.environ.get("MICROPLEX_SUPABASE_URL") + resolved_key = ( + supabase_key + or os.environ.get("MICROPLEX_SUPABASE_SERVICE_ROLE_KEY") + or os.environ.get("MICROPLEX_SUPABASE_ANON_KEY") + ) + if not resolved_url or not resolved_key: + raise ValueError( + "Remote telemetry upload requested without Supabase credentials" + ) + writers.append( + SupabaseTelemetryWriter( + resolved_url, + resolved_key, + table=table + or os.environ.get("MICROPLEX_TELEMETRY_EVENT_TABLE") + or os.environ.get("MICROPLEX_TELEMETRY_TABLE"), + table_prefix=table_prefix + if table_prefix is not None + else os.environ.get("MICROPLEX_TELEMETRY_TABLE_PREFIX", ""), + ) + ) + + if not writers: + return NullTelemetryWriter() + if len(writers) == 1: + return writers[0] + return CompositeTelemetryWriter(writers) + + +def normalize_telemetry_event(event: TelemetryEvent | Mapping[str, Any]) -> JsonRecord: + """Normalize an event object or mapping into a JSON-safe record.""" + if isinstance(event, Mapping): + record = dict(event) + record.setdefault("event_type", "event") + record.setdefault("emitted_at", utc_now()) + return _json_safe_record(record) + return _json_safe_record(event.to_record()) + + +def effective_sample_size(weights: np.ndarray | pd.Series | list[float]) -> float: + """Kish effective sample size for a vector of non-negative weights.""" + values = np.asarray(weights, dtype=float) + denominator = float(np.sum(values**2)) + if denominator <= 0.0: + return 0.0 + numerator = float(np.sum(values)) ** 2 + return numerator / denominator + + +def utc_now() -> str: + """Current UTC timestamp as an ISO-8601 string.""" + return datetime.now(UTC).isoformat() + + +def _event_record(event: Any) -> JsonRecord: + payload = asdict(event) + payload["event_type"] = event.event_type + payload.setdefault("emitted_at", None) + if payload["emitted_at"] is None: + payload["emitted_at"] = utc_now() + if "timestamp" in payload and payload["timestamp"] is None: + payload["timestamp"] = payload["emitted_at"] + if isinstance(event, RunEvent) and payload.get("created_at") is None: + payload["created_at"] = payload["emitted_at"] + return _json_safe_record(payload) + + +def _json_safe_record(record: Mapping[str, Any]) -> JsonRecord: + return {key: _json_safe_value(value, key) for key, value in record.items()} + + +def _json_safe_value(value: Any, path: str) -> Any: + if isinstance(value, pd.DataFrame | pd.Series | pd.Index): + raise TypeError(f"Telemetry payload {path!r} contains row-level pandas data") + if isinstance(value, np.ndarray): + if value.ndim == 0: + return _json_safe_value(value.item(), path) + raise TypeError(f"Telemetry payload {path!r} contains row-level array data") + if isinstance(value, np.generic): + return _json_safe_value(value.item(), path) + if isinstance(value, Path): + return str(value) + if isinstance(value, datetime): + return value.isoformat() + if isinstance(value, Enum): + return value.value + if is_dataclass(value) and not isinstance(value, type): + raise TypeError(f"Telemetry payload {path!r} contains dataclass record data") + if isinstance(value, Mapping): + return { + str(key): _json_safe_value(nested_value, f"{path}.{key}") + for key, nested_value in value.items() + } + if isinstance(value, list | tuple): + raise TypeError(f"Telemetry payload {path!r} contains sequence data") + if isinstance(value, float): + return value if math.isfinite(value) else None + if isinstance(value, str | int | bool) or value is None: + return value + raise TypeError( + f"Telemetry payload {path!r} has unsupported value type {type(value).__name__}" + ) + + +def _append_jsonl(path: Path, record: JsonRecord) -> None: + with path.open("a", encoding="utf-8") as output: + output.write(json.dumps(record, sort_keys=True) + "\n") + + +def _event_file_name(event_type: str) -> str: + return { + "run": "runs.jsonl", + "stage": "run_stages.jsonl", + "calibration_epoch": "calibration_epochs.jsonl", + "calibration_target": "calibration_targets.jsonl", + "artifact": "artifacts.jsonl", + }.get(event_type, "custom_events.jsonl") + + +def _typed_supabase_row(record: JsonRecord) -> JsonRecord: + columns = _SUPABASE_TABLE_COLUMNS.get(str(record["event_type"])) + if columns is None: + return { + "event_type": record["event_type"], + "run_id": record.get("run_id"), + "emitted_at": record["emitted_at"], + "payload": record, + } + return {column: record[column] for column in columns if column in record} + + +__all__ = [ + "ArtifactEvent", + "CalibrationEpochEvent", + "CalibrationTargetEvent", + "CompositeTelemetryWriter", + "LocalTelemetryWriter", + "NullTelemetryWriter", + "RunEvent", + "StageEvent", + "SupabaseTelemetryWriter", + "TELEMETRY_SCHEMA_VERSION", + "TelemetryEvent", + "TelemetryWriter", + "build_telemetry_writer", + "effective_sample_size", + "normalize_telemetry_event", + "utc_now", +] diff --git a/tests/targets/test_reweighting.py b/tests/targets/test_reweighting.py index 09a7f49..3e71c8c 100644 --- a/tests/targets/test_reweighting.py +++ b/tests/targets/test_reweighting.py @@ -1,5 +1,7 @@ from __future__ import annotations +import json + import numpy as np import pandas as pd @@ -16,6 +18,7 @@ reweight_entity_table_bundle_targets, reweight_to_target_constraints, ) +from microplex.telemetry import LocalTelemetryWriter, build_telemetry_writer def test_compile_target_reweighting_constraints_groups_to_shared_weight_vector(): @@ -28,7 +31,9 @@ def test_compile_target_reweighting_constraints_groups_to_shared_weight_vector() } ) household = pd.DataFrame({"household_id": [10, 20]}) - household_index = pd.Series(np.arange(len(household)), index=household["household_id"]) + household_index = pd.Series( + np.arange(len(household)), index=household["household_id"] + ) targets = [ TargetSpec( @@ -62,7 +67,9 @@ def test_compile_target_reweighting_constraints_groups_to_shared_weight_vector() }, entity_weight_indexes={ EntityType.PERSON: person["person_household_id"].map(household_index), - EntityType.HOUSEHOLD: household_index.reindex(household["household_id"]).to_numpy(), + EntityType.HOUSEHOLD: household_index.reindex( + household["household_id"] + ).to_numpy(), }, ) @@ -84,7 +91,9 @@ def test_reweight_to_target_constraints_hits_simple_targets(): } ) household = pd.DataFrame({"household_id": [10, 20]}) - household_index = pd.Series(np.arange(len(household)), index=household["household_id"]) + household_index = pd.Series( + np.arange(len(household)), index=household["household_id"] + ) targets = [ TargetSpec( name="age_band_count", @@ -119,6 +128,132 @@ def test_reweight_to_target_constraints_hits_simple_targets(): assert diagnostics.mean_abs_relative_error == 0.0 +def test_reweight_to_target_constraints_emits_calibration_telemetry(tmp_path): + person = pd.DataFrame( + { + "person_household_id": [10, 10, 20], + "age": [5, 8, 30], + "local_authority_code": ["A", "A", "B"], + } + ) + household_index = pd.Series([0, 1], index=[10, 20]) + targets = [ + TargetSpec( + name="age_band_count", + entity=EntityType.PERSON, + value=4.0, + period=2024, + aggregation="count", + source="unit-test", + metadata={"family": "demographics", "geography": "A"}, + filters=( + TargetFilter("local_authority_code", FilterOperator.EQ, "A"), + TargetFilter("age", FilterOperator.GTE, 0), + TargetFilter("age", FilterOperator.LT, 10), + ), + ) + ] + compilation = compile_target_reweighting_constraints( + targets=targets, + entity_frames={EntityType.PERSON: person}, + entity_weight_indexes={ + EntityType.PERSON: person["person_household_id"].map(household_index), + }, + ) + writer = LocalTelemetryWriter(tmp_path / "telemetry") + + reweight_to_target_constraints( + np.array([1.0, 1.0]), + constraints=compilation.constraints, + max_iter=1, + tol=1e-6, + telemetry_writer=writer, + run_id="run-1", + calibration_id="cal-1", + ) + + epoch_event = json.loads( + (tmp_path / "telemetry" / "calibration_epochs.jsonl").read_text() + ) + target_event = json.loads( + (tmp_path / "telemetry" / "calibration_targets.jsonl").read_text() + ) + + assert epoch_event["run_id"] == "run-1" + assert epoch_event["calibration_id"] == "cal-1" + assert epoch_event["epoch"] == 1 + assert epoch_event["data_loss"] == 0.0 + assert epoch_event["nonzero_weights"] == 2 + assert target_event["target_name"] == "age_band_count" + assert target_event["family"] == "demographics" + assert target_event["source"] == "unit-test" + assert target_event["geography"] == "A" + assert target_event["estimate"] == 4.0 + + +def test_reweight_to_target_constraints_allows_disabled_telemetry_without_run_id(): + constraint = compile_target_reweighting_constraints( + targets=[ + TargetSpec( + name="income_sum", + entity=EntityType.PERSON, + value=4.0, + period=2024, + measure="income", + aggregation="sum", + ) + ], + entity_frames={EntityType.PERSON: pd.DataFrame({"income": [2.0]})}, + entity_weight_indexes={EntityType.PERSON: np.array([0])}, + ).constraints[0] + + weights, diagnostics = reweight_to_target_constraints( + np.array([1.0]), + constraints=(constraint,), + max_iter=1, + telemetry_writer=build_telemetry_writer(output_dir=None, upload=False), + ) + + assert weights.tolist() == [2.0] + assert diagnostics.mean_abs_relative_error == 0.0 + + +def test_reweight_to_target_constraints_reports_mean_target_telemetry(tmp_path): + person = pd.DataFrame({"income": [0.0, 1.0]}) + compilation = compile_target_reweighting_constraints( + targets=[ + TargetSpec( + name="mean_income", + entity=EntityType.PERSON, + value=0.5, + period=2024, + measure="income", + aggregation="mean", + ) + ], + entity_frames={EntityType.PERSON: person}, + entity_weight_indexes={EntityType.PERSON: np.array([0, 1])}, + ) + writer = LocalTelemetryWriter(tmp_path / "telemetry") + + reweight_to_target_constraints( + np.array([1.0, 1.0]), + constraints=compilation.constraints, + max_iter=1, + telemetry_writer=writer, + run_id="run-1", + calibration_id="cal-1", + ) + + target_event = json.loads( + (tmp_path / "telemetry" / "calibration_targets.jsonl").read_text() + ) + assert target_event["target_name"] == "mean_income" + assert target_event["target_value"] == 0.5 + assert target_event["estimate"] == 0.5 + assert target_event["relative_error"] == 0.0 + + def test_reweight_to_target_constraints_shrinks_mean_residual_toward_zero(): person = pd.DataFrame({"income": [0.0, 1.2]}) compilation = compile_target_reweighting_constraints( @@ -244,7 +379,10 @@ def test_entity_table_bundle_maps_weight_indexes_and_syncs_dependent_weights(): updated = bundle.with_updated_weights(np.array([2.0, 1.0])) - assert updated.table_for(EntityType.HOUSEHOLD)["household_weight"].tolist() == [2.0, 1.0] + assert updated.table_for(EntityType.HOUSEHOLD)["household_weight"].tolist() == [ + 2.0, + 1.0, + ] assert updated.table_for(EntityType.PERSON)["weight"].tolist() == [2.0, 2.0, 1.0] @@ -297,6 +435,12 @@ def test_reweight_entity_table_bundle_targets_updates_bundle_in_one_step(): ], ) - assert result.bundle.table_for(EntityType.HOUSEHOLD)["household_weight"].tolist() == [2.0, 1.0] - assert result.bundle.table_for(EntityType.PERSON)["weight"].tolist() == [2.0, 2.0, 1.0] + assert result.bundle.table_for(EntityType.HOUSEHOLD)[ + "household_weight" + ].tolist() == [2.0, 1.0] + assert result.bundle.table_for(EntityType.PERSON)["weight"].tolist() == [ + 2.0, + 2.0, + 1.0, + ] assert result.compilation.skipped_targets == () diff --git a/tests/test_telemetry.py b/tests/test_telemetry.py new file mode 100644 index 0000000..7292e6a --- /dev/null +++ b/tests/test_telemetry.py @@ -0,0 +1,224 @@ +from __future__ import annotations + +import json +from dataclasses import dataclass + +import httpx +import pandas as pd +import pytest + +from microplex.telemetry import ( + CalibrationEpochEvent, + LocalTelemetryWriter, + RunEvent, + StageEvent, + SupabaseTelemetryWriter, + build_telemetry_writer, + normalize_telemetry_event, +) + + +@dataclass(frozen=True) +class RowPayload: + person_id: int + income: float + + +def _read_jsonl(path): + return [json.loads(line) for line in path.read_text().splitlines()] + + +def test_local_telemetry_writer_appends_events_and_manifest(tmp_path): + writer = LocalTelemetryWriter(tmp_path / "telemetry", incognito=True) + + writer.emit(RunEvent(run_id="run-1", status="started", engine="microplex")) + writer.emit( + StageEvent( + run_id="run-1", + stage="calibration", + status="completed", + elapsed_seconds=1.5, + rss_mb=128.0, + ) + ) + + manifest = json.loads((tmp_path / "telemetry" / "manifest.json").read_text()) + events = _read_jsonl(tmp_path / "telemetry" / "events.jsonl") + run_events = _read_jsonl(tmp_path / "telemetry" / "runs.jsonl") + stage_events = _read_jsonl(tmp_path / "telemetry" / "run_stages.jsonl") + + assert manifest["incognito"] is True + assert manifest["remote_upload_enabled"] is False + assert [event["event_type"] for event in events] == ["run", "stage"] + assert run_events[0]["run_id"] == "run-1" + assert stage_events[0]["stage"] == "calibration" + + +def test_local_telemetry_writer_routes_unknown_event_types_to_safe_file(tmp_path): + writer = LocalTelemetryWriter(tmp_path / "telemetry") + + writer.emit({"event_type": "../escaped", "run_id": "run-1"}) + + assert not (tmp_path / "escapeds.jsonl").exists() + assert not (tmp_path / "telemetry" / ".." / "escapeds.jsonl").exists() + custom_events = _read_jsonl(tmp_path / "telemetry" / "custom_events.jsonl") + assert custom_events[0]["event_type"] == "../escaped" + + +def test_build_telemetry_writer_incognito_disables_remote_upload(tmp_path): + writer = build_telemetry_writer( + tmp_path / "telemetry", + upload=True, + incognito=True, + ) + + writer.emit(RunEvent(run_id="run-1", status="started")) + + manifest = json.loads((tmp_path / "telemetry" / "manifest.json").read_text()) + run_event = _read_jsonl(tmp_path / "telemetry" / "runs.jsonl")[0] + assert manifest["incognito"] is True + assert manifest["remote_upload_enabled"] is False + assert run_event["incognito"] is True + assert (tmp_path / "telemetry" / "events.jsonl").exists() + + +def test_supabase_telemetry_writer_posts_append_only_event(): + requests = [] + + def handler(request): + requests.append(request) + return httpx.Response(201) + + client = httpx.Client(transport=httpx.MockTransport(handler)) + writer = SupabaseTelemetryWriter( + "https://example.supabase.co", + "secret-key", + table="telemetry_events", + client=client, + ) + + writer.emit( + CalibrationEpochEvent( + run_id="run-1", + calibration_id="cal-1", + epoch=7, + objective=0.12, + data_loss=0.12, + nonzero_weights=42, + ess=35.5, + ) + ) + + assert len(requests) == 1 + request = requests[0] + assert str(request.url) == "https://example.supabase.co/rest/v1/telemetry_events" + assert request.headers["apikey"] == "secret-key" + body = json.loads(request.content) + assert body["event_type"] == "calibration_epoch" + assert body["run_id"] == "run-1" + assert body["payload"]["epoch"] == 7 + + +def test_supabase_telemetry_writer_posts_typed_table_by_default(): + requests = [] + + def handler(request): + requests.append(request) + return httpx.Response(201) + + client = httpx.Client(transport=httpx.MockTransport(handler)) + writer = SupabaseTelemetryWriter( + "https://example.supabase.co", + "secret-key", + client=client, + ) + + writer.emit( + CalibrationEpochEvent( + run_id="run-1", + calibration_id="cal-1", + epoch=8, + objective=0.08, + ) + ) + + assert len(requests) == 1 + assert str(requests[0].url) == ( + "https://example.supabase.co/rest/v1/calibration_epochs" + ) + body = json.loads(requests[0].content) + assert body["epoch"] == 8 + assert body["timestamp"] is not None + assert body["run_id"] == "run-1" + assert "payload" not in body + assert "event_type" not in body + + +def test_supabase_telemetry_writer_routes_unknown_event_types_to_events_table(): + requests = [] + + def handler(request): + requests.append(request) + return httpx.Response(201) + + client = httpx.Client(transport=httpx.MockTransport(handler)) + writer = SupabaseTelemetryWriter( + "https://example.supabase.co", + "secret-key", + client=client, + ) + + writer.emit({"event_type": "../escaped", "run_id": "run-1"}) + + assert len(requests) == 1 + assert str(requests[0].url) == "https://example.supabase.co/rest/v1/events" + body = json.loads(requests[0].content) + assert body["event_type"] == "../escaped" + assert body["payload"]["event_type"] == "../escaped" + + +def test_telemetry_rejects_row_level_payloads(): + with pytest.raises(TypeError, match="row-level pandas data"): + normalize_telemetry_event( + { + "event_type": "bad", + "run_id": "run-1", + "rows": pd.DataFrame({"person_id": [1, 2]}), + } + ) + + with pytest.raises(TypeError, match="sequence data"): + normalize_telemetry_event( + { + "event_type": "bad", + "run_id": "run-1", + "rows": [{"person_id": 1}, {"person_id": 2}], + } + ) + + with pytest.raises(TypeError, match="sequence data"): + normalize_telemetry_event( + { + "event_type": "bad", + "run_id": "run-1", + "rows": [RowPayload(person_id=1, income=2.0)], + } + ) + + with pytest.raises(TypeError, match="sequence data"): + normalize_telemetry_event( + { + "event_type": "bad", + "run_id": "run-1", + "rows": [1, 2, 3], + } + ) + + with pytest.raises(TypeError, match="dataclass record data"): + normalize_telemetry_event( + { + "event_type": "bad", + "run_id": "run-1", + "row": RowPayload(person_id=1, income=2.0), + } + )