diff --git a/CHANGELOG.md b/CHANGELOG.md index aade667..1ae0459 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,11 @@ Format based on [Keep a Changelog](https://keepachangelog.com/). ### Added +- **LATTE-inspired text-based backward slicing (Phase 2C+.1)** (`src/aiedge/code_slicing.py`, `src/aiedge/taint_propagation.py`, `tests/test_code_slicing.py`, `docs/code_slicing_contract.md`). First-cut implementation of the LATTE (Liu et al., TOSEM 2025) prompt-slicing idea: when `AIEDGE_LATTE_SLICING=1` is set, `_build_taint_prompt()` replaces the full function body with a sink-rooted backward slice. The slice walks bottom-up from the sink call, keeping earlier lines whose identifiers overlap the tracked variables-of-interest (minus a conservative noise set of C keywords / literals / common macros). The slice is a strict subset of the original body with source order preserved; the sink line and the defining lines of its arguments are always retained. Public API: `find_sink_line`, `extract_backward_slice`, `extract_slice_around_sink`, `maybe_slice`, `slice_compression_ratio`, `latte_slicing_enabled`. Default-off keeps existing LLM prompts byte-identical. _(32 new tests in `tests/test_code_slicing.py`.)_ +- **LARA-style URI / CGI / config-key source identification (Phase 2C+.2)** (`enhanced_source.py`, `tests/test_uri_source_extraction.py`). `EnhancedSourceStage` now widens source identification beyond C-level input APIs by recognising attacker-influenced strings, taking inspiration from the LARA paper (USENIX Sec 2024). Three new pattern sets totalling 50 entries cover URI prefixes (`/cgi-bin/`, `/api/`, `/upnp/`, `/admin/`, `/goform/`, ...), CGI environment variables (`QUERY_STRING`, `REQUEST_METHOD`, `HTTP_*`, ...), and NVRAM / sysconf config keys (`http_passwd`, `wpa_psk`, `cloud_token`, `firmware_url`, ...). New helper `_extract_uri_key_sources(bin_path, symbols, ascii_strings=None)` produces `(pattern, kind)` tuples that are wrapped per-binary into source dicts with `confidence=0.40` (SYMBOL_COOCCURRENCE cap, since string presence alone does not prove reachability) and `method="lara_pattern"`. Symbol-based URI matching is intentionally skipped to avoid noise; the optional `ascii_strings` parameter is the path for string-literal evidence (to be wired through inventory data in a follow-up). _(13 new tests in `tests/test_uri_source_extraction.py`.)_ +- **Sink coverage expansion (Phase 2C+.3)** (`taint_propagation.py`, `tests/test_taint_propagation.py`). `_SINK_SYMBOLS` grows from 29 to 51 symbols, mapping the full CWE taxonomy that the firmware corpus actually exercises: CWE-78 cmd injection (now incl. `wordexp`, `posix_spawn`, `posix_spawnp`), CWE-22 path traversal (`fopen`, `open`, `openat`, `freopen`, `chdir`), CWE-426 search path (`dlsym`, `dlmopen`), CWE-732 perms (`chmod`/`fchmod`/`chown`/`fchown`/`lchown`), CWE-377 insecure tmp (`mktemp`, `tmpnam`, `tempnam`, `tmpfile`), CWE-250/269 privilege (`chroot`, `setuid`, `seteuid`, `setgid`, `setegid`), and CWE-454 env injection (`putenv`, `setenv`, `unsetenv`). `_FORMAT_STRING_SINKS` doubles from 6 to 15 with size-bounded (`vsnprintf`), file-descriptor (`dprintf`/`vdprintf`), and wide-char (`swprintf`, `vswprintf`, `wprintf`, `vwprintf`, `fwprintf`, `vfwprintf`) variants. `_is_format_string_variable()` is strengthened to flag struct field access, array subscripts, function-call results, C-style casts, parenthesised ternaries, and pointer dereferences as variable first-arguments — not just bare identifiers. _(20 new tests in `tests/test_taint_propagation.py`.)_ +- **Finding diversity gate (Phase 2C+.5)** (`quality_policy.py`, `release_gate.sh`, `tests/test_finding_diversity_gate.py`, `docs/finding_diversity_gate.md`). Detects degenerate pair-eval coverage where every pair-side row maps to the same `finding_id` — the structural failure surfaced by the 2026-04-19 reviewer eval lane analysis (local-7 baseline `finding_diversity_index = 1.0`, all 14 rows on `aiedge.findings.web.exec_sink_overlap`). New helpers `compute_pair_eval_diversity_index()`, `load_pair_eval_finding_ids()`, `evaluate_pair_eval_diversity_gate()` produce a `QUALITY_GATE_DIVERSITY_MISS` violation when `max_share(finding_id) >= AIEDGE_PAIR_DIVERSITY_MAX` (default 0.5). `release_gate.sh` wires this in as the opt-in `PAIR_EVAL_DIVERSITY` sub-gate via `--pair-eval-findings`. _(12 new tests in `tests/test_finding_diversity_gate.py`.)_ +- **Pair-eval timeout diagnostic** (`scripts/run_pair_eval.py`). When a pair-side run hits the wall-clock timeout, `_dump_timeout_diagnostic()` writes `/timeout_diagnostic.json` capturing the last 200 stderr / 50 stdout lines, a best-effort run_dir guess, and the most recent stage's name/status. Closes the visibility gap that left the dedicated reviewer rerun lanes (`pair-eval-dedicated-local7-claude-6h`, `codex-6h`) stuck at `run_index rows = 0` without actionable signal. - **FDA Section 524B compatibility mapping (Phase 3'.1 step B-2)** (`docs/compliance_mapping/fda_section_524b.md`). Maps SCOUT outputs to the four §524B(b) statutory obligations (postmarket vulnerability monitoring plan, secure design/develop/maintain processes, postmarket updates/patches, SBOM) and to the September 2023 FDA premarket cybersecurity guidance content elements (security objectives, threat modelling, security risk management, cybersecurity testing, architecture views, SBOM, vulnerability management, labelling, postmarket plan). Coverage is documented per element with explicit "out of scope" callouts for sponsor-side QMS deliverables. Disclaimer reuses the directory-wide "compatible with" wording rule. - **ISO/SAE 21434 compatibility mapping (Phase 3'.1 step B-3)** (`docs/compliance_mapping/iso_21434.md`). Maps SCOUT outputs to ISO/SAE 21434:2021 work products across clauses 8 (continual cybersecurity activities), 9 (concept), 10 (product development), 11 (cybersecurity validation), 13 (operations and maintenance), and 15 (TARA methods). Identifies which work products are tool-friendly (WP-08-01..04, WP-10-04, WP-10-05, WP-13-02) versus manufacturer-side narratives (WP-09-02, WP-10-01, WP-10-02, etc.). - **UN R155 compatibility mapping (Phase 3'.1 step B-3)** (`docs/compliance_mapping/un_r155.md`). Maps SCOUT outputs to UN R155 §7.2 (CSMS) and §7.3 (vehicle-type approval) requirements, plus per-threat guidance for the 15 most-relevant Annex 5 threat categories (manipulation, replay, malware insertion, network-design vulnerabilities, etc.). Co-published with the ISO/SAE 21434 mapping per the standard / regulation pairing. diff --git a/docs/code_slicing_contract.md b/docs/code_slicing_contract.md new file mode 100644 index 0000000..23cedab --- /dev/null +++ b/docs/code_slicing_contract.md @@ -0,0 +1,111 @@ +# LATTE Code Slicing Contract + +> Phase 2C+.1 (Pivot 2026-04-19) — text-based backward slicing that the taint +> propagation stage uses to compress LLM prompts when +> `AIEDGE_LATTE_SLICING=1` is set. + +## Why this exists + +LATTE (Liu et al., "LATTE: LLM-Powered Static Binary Taint Analysis", +TOSEM 2025) reported that feeding the LLM the **sink-rooted backward +slice** instead of the full decompiled function body improved new-bug +discovery and reduced token usage. SCOUT's first-cut implementation +takes the same idea but stays conservative: it operates on plain text, +does not require a Ghidra-grade SSA backend, and is opt-in so the +existing prompt behaviour stays byte-identical when the env var is +unset. + +The slicing is **over-approximate**: it keeps every earlier line whose +identifier set overlaps the already-tracked variables-of-interest. That +means the slice is a strict subset of the original body (ordering +preserved) but it may retain irrelevant lines that happen to mention a +tainted variable name in passing. In exchange, it never drops a line +that contains a real data dependency along the sink path, so the LLM +never has to reason about a variable whose definition disappeared. + +## Public API + +Source: `src/aiedge/code_slicing.py`. + +| Function | Purpose | +|---|---| +| `latte_slicing_enabled()` | Returns `True` when `AIEDGE_LATTE_SLICING` is set to `1`/`true`/`yes`/`on` (case-insensitive). | +| `find_sink_line(body, sink_sym)` | 0-based line index of the first `sink_sym(` call, or `None`. | +| `extract_backward_slice(body, sink_line_idx, max_lines=30)` | Backward-walks from `sink_line_idx`, keeps lines whose identifiers overlap the tracked set. Returns a string of the retained lines in source order. | +| `extract_slice_around_sink(body, sink_sym, max_lines=30)` | Convenience: `find_sink_line` then `extract_backward_slice`. Returns `None` when the sink is absent. | +| `maybe_slice(body, sink_sym, max_lines=30)` | Recommended entry point for call sites: when the env gate is off it returns the body unchanged; when on it returns the slice (falling back to the full body if the sink is not found). Never returns `None`. | +| `slice_compression_ratio(original, sliced)` | Telemetry helper — ratio of kept lines to original lines. | + +## Env gate + +``` +AIEDGE_LATTE_SLICING=1 # enable slicing (any of 1/true/yes/on) +``` + +Default (unset) means `maybe_slice` returns the input body verbatim, so +dropping the env var gives byte-identical prompts to every LLM call. + +## Algorithm (first-cut) + +``` +1. Locate the sink line (first occurrence of `(`). +2. Initial variables-of-interest = identifiers on the sink line + (minus the noise set: C keywords, literals, common macros). +3. For each earlier line (bottom-up): + a. If its identifier set intersects the variables-of-interest, + include it and union its identifiers into the interest set. + b. If the line has no usable identifier (blank, comment-only), + include it so the LLM keeps structural context. + c. Stop at `max_lines` or the function start. +4. Emit retained lines in source order. +``` + +Noise identifiers (`_NOISE_IDENTIFIERS`) are kept minimal on purpose: we +filter only what is guaranteed not to carry data (`if`, `int`, `NULL`, +`true`, ...). Vendor-specific tokens are *not* filtered because they +often *are* the relevant variables in router firmware decompilation. + +## Over-approximation behaviour + +Because the algorithm tracks identifiers and not their scopes, a slice +may include lines that merely reference a same-named variable elsewhere +in the function. This is acceptable for prompt compression but analysts +who need an exact data-flow trace should still consult the Ghidra +P-code SSA path (`pcode_taint.py`). + +## Call site + +The only caller today is `_build_taint_prompt()` in +`src/aiedge/taint_propagation.py`: + +```python +body_raw = fb.get("body", "") +body_sliced = maybe_slice(body_raw, sink_symbol) +body = _truncate_text(body_sliced, max_chars=2000) +``` + +When `AIEDGE_LATTE_SLICING` is unset the call returns `body_raw` +unchanged and the subsequent `_truncate_text` path is byte-identical to +pre-2C+.1 behaviour. + +## Phase 2D entry interaction + +Phase 2D.1 (reasoning_trail + MCP loop validation) depends on the LLM +actually producing useful verdicts across diverse findings. Slicing is +the main lever we have today to let the LLM see *more* findings within +the same token budget — so even if Phase 2D.1 does not require slicing, +leaving it disabled in production runs means the analyst cycles through +a smaller effective corpus. Operators planning a Phase 2D.1 walkthrough +should enable `AIEDGE_LATTE_SLICING=1` for the run. + +## Related artifacts + +- `src/aiedge/code_slicing.py` — implementation +- `src/aiedge/taint_propagation.py` — call site in `_build_taint_prompt` +- `tests/test_code_slicing.py` — unit tests (32 cases) that pin: + - sink-line location and word-boundary behaviour + - slice invariants (subset, source order, sink kept, defining lines + pulled in) + - `max_lines` cap and degenerate inputs + - env-gate parsing and byte-identical default-off + - compression-ratio telemetry diff --git a/docs/finding_diversity_gate.md b/docs/finding_diversity_gate.md new file mode 100644 index 0000000..1814d3a --- /dev/null +++ b/docs/finding_diversity_gate.md @@ -0,0 +1,137 @@ +# Finding Diversity Gate + +> Phase 2C+.5 (Pivot 2026-04-19) — pair-eval lane gate that detects degenerate +> evidence-tier coverage by measuring finding-id share concentration. + +## Why this gate exists + +The 2026-04-19 reviewer eval lane analysis surfaced a structural failure that +neither precision/recall nor confidence caps caught: **every pair-side row in the +local-7 lane mapped to the same `finding_id`** (`aiedge.findings.web.exec_sink_overlap`, +`evidence_tier=symbol_only`). The pair-level recall and FP rate looked plausible +(0.142857 each) yet the underlying tier-ROC was *degenerate* — there was nothing +to discriminate between vulnerable and patched runs because the detection layer +collapsed onto a single finding. + +The diversity gate quantifies this collapse and blocks releases that ship it. + +## Definition + +``` +finding_diversity_index = max_count(finding_id) / total_rows +``` + +- `1.0` — degenerate (every row mapped to a single `finding_id`) +- `1/N` — fully diverse (every row a distinct `finding_id`) +- `0.0` — empty input (callers decide whether to treat as violation) + +The index is a **maximum-share** metric, not entropy. It is robust to long-tail +distributions and surfaces the dominant finding bucket directly. + +## Threshold + +| Env variable | Default | Direction | +|---|---|---| +| `AIEDGE_PAIR_DIVERSITY_MAX` | `0.5` | gate fails when index `>=` threshold | + +The default `0.5` was chosen as a first-cut: any single `finding_id` accounting +for 50%+ of pair rows is treated as a degenerate signal. Once the corpus grows +past 10 pairs the threshold should be re-evaluated against representative runs +(see Phase 2C+.4 vendor-extraction expansion). + +## Inputs + +The gate consumes the pair-eval findings CSV produced by +`scripts/run_pair_eval.py`. Schema (relevant columns): + +| Column | Use | +|---|---| +| `finding_id` | counted into the share distribution | +| `ground_truth` | optional filter via `load_pair_eval_finding_ids(only_ground_truth=...)` | + +Empty `finding_id` rows are skipped silently. Missing CSV raises +`QUALITY_GATE_INVALID_PAIR_EVAL`. + +## Output schema + +```json +{ + "schema_version": 1, + "verdict": "pass" | "fail", + "passed": true | false, + "findings_source": "", + "policy": { + "finding_diversity_max": 0.5, + "finding_diversity_max_env": "AIEDGE_PAIR_DIVERSITY_MAX" + }, + "measured": { + "finding_diversity_index": 0.0..1.0, + "sample_size": + }, + "errors": [ + { + "error_token": "QUALITY_GATE_DIVERSITY_MISS", + "metric": "finding_diversity_index", + "source_field": "pair_eval_findings.finding_id", + "actual": 1.0, + "threshold": 0.5, + "operator": "<", + "sample_size": 14, + "message": "..." + } + ] +} +``` + +## Wiring into `release_gate.sh` + +The unified release gate wires this in as the `PAIR_EVAL_DIVERSITY` sub-gate. It +is **opt-in** via `--pair-eval-findings`: + +```bash +scripts/release_gate.sh \ + --run-dir aiedge-runs/ \ + --pair-eval-findings benchmark-results/pair-eval/pair_eval_findings.csv +``` + +When the flag is omitted the gate is skipped with an `INFO` line so existing +release flows continue working unchanged. + +## Current baseline (2026-04-19) + +Running the gate against the trusted summary-reuse local-7 lane: + +``` +sample_size = 14 (7 pairs × 2 sides) +finding_diversity_index = 1.0 (degenerate — single finding for all rows) +verdict = fail +``` + +This matches the Pivot 2026-04-19 [diagnosis](../docs/status.md): Phase 2D entry +is gated until detection coverage produces at least two distinct findings across +the pair lane. The gate makes that requirement enforceable instead of advisory. + +## Phase 2D entry exit-gate hook + +The diversity gate is one of the five Phase 2D entry exit-gate thresholds +defined in [`docs/status.md`](status.md): + +| Gate | Threshold | Tooling | +|---|---|---| +| Detection recall | `≥ 0.40` | `pair_eval_summary.json` | +| Tier discriminability | `≥ 2 nonzero TP tiers` | `pair_eval_findings.csv` | +| **Finding diversity** | **`< 0.5`** | **this gate** | +| Dedicated rerun | `≥ 1 driver success` | `pair-eval-dedicated-*` lanes | +| Corpus size | `≥ 10 pairs` | `benchmarks/pair-eval/pairs.json` | + +The other four are tracked in their own places; this gate only owns the +diversity threshold. + +## Related artifacts + +- `src/aiedge/quality_policy.py` — `compute_pair_eval_diversity_index`, + `load_pair_eval_finding_ids`, `evaluate_pair_eval_diversity_gate` +- `scripts/run_pair_eval.py` — adds `timeout_diagnostic.json` for dedicated + rerun timeout investigations (companion 2C+.5 work) +- `scripts/release_gate.sh` — `PAIR_EVAL_DIVERSITY` sub-gate +- `tests/test_finding_diversity_gate.py` — unit + baseline tests diff --git a/scripts/release_gate.sh b/scripts/release_gate.sh index 4a922fd..4f95d2a 100755 --- a/scripts/release_gate.sh +++ b/scripts/release_gate.sh @@ -10,12 +10,13 @@ CORPUS_MANIFEST="benchmarks/corpus/manifest.json" METRICS_OUT="" QUALITY_OUT="" LLM_FIXTURE="" +PAIR_EVAL_FINDINGS="" FAILED=0 usage() { cat <<'EOF' -Usage: scripts/release_gate.sh --run-dir [--manifest ] [--metrics-out ] [--quality-out ] [--llm-fixture ] +Usage: scripts/release_gate.sh --run-dir [--manifest ] [--metrics-out ] [--quality-out ] [--llm-fixture ] [--pair-eval-findings ] Unified release governance gate (single entrypoint). @@ -25,6 +26,7 @@ Sub-gates: - QUALITY_METRICS: aiedge quality-metrics - QUALITY_POLICY: aiedge release-quality-gate - EXPLOIT_TIER_POLICY: schema tier checks plus exploit_policy artifact checks when present + - PAIR_EVAL_DIVERSITY: finding-diversity gate over pair_eval_findings.csv (skipped when --pair-eval-findings absent) - TAMPER_SUITE: pytest tests/test_tamper_suite.py EOF } @@ -97,6 +99,10 @@ while [[ $# -gt 0 ]]; do LLM_FIXTURE="$2" shift 2 ;; + --pair-eval-findings) + PAIR_EVAL_FINDINGS="$2" + shift 2 + ;; -h|--help) usage exit 0 @@ -203,6 +209,66 @@ else fi rm -f "$EXPLOIT_CHECK_OUTPUT" +if [[ -n "$PAIR_EVAL_FINDINGS" ]]; then + PAIR_EVAL_OUTPUT="$(mktemp)" + set +e + PYTHONPATH="$PYTHONPATH" python3 - <<'PY' "$PAIR_EVAL_FINDINGS" "$RUN_DIR" >"$PAIR_EVAL_OUTPUT" 2>&1 +import json +import sys +from pathlib import Path + +from aiedge.quality_policy import ( + QualityGateError, + evaluate_pair_eval_diversity_gate, + load_pair_eval_finding_ids, +) + +csv_path = Path(sys.argv[1]).resolve() +run_dir = Path(sys.argv[2]).resolve() +out_path = run_dir / "pair_eval_diversity_gate.json" +try: + finding_ids = load_pair_eval_finding_ids(csv_path) +except QualityGateError as exc: + print(f"{exc.token}: {exc}") + raise SystemExit(1) from exc + +result = evaluate_pair_eval_diversity_gate( + finding_ids=finding_ids, + findings_source=str(csv_path), +) +out_path.write_text( + json.dumps(result, indent=2, sort_keys=True) + "\n", encoding="utf-8" +) +if not result["passed"]: + for err in result["errors"]: + print(err.get("message") or err.get("error_token")) + raise SystemExit(1) +measured = result["measured"] +print( + "diversity_index=" + + str(measured["finding_diversity_index"]) + + " sample_size=" + + str(measured["sample_size"]) +) +PY + PAIR_EVAL_RC=$? + set -e + if [[ "$PAIR_EVAL_RC" -ne 0 ]]; then + gate_fail "PAIR_EVAL_DIVERSITY" "diversity gate violated" + while IFS= read -r line; do + [[ -n "$line" ]] && echo "[GATE][LOG][PAIR_EVAL_DIVERSITY] $line" + done <"$PAIR_EVAL_OUTPUT" + else + gate_pass "PAIR_EVAL_DIVERSITY" "diversity gate passed" + while IFS= read -r line; do + [[ -n "$line" ]] && gate_info "PAIR_EVAL_DIVERSITY" "$line" + done <"$PAIR_EVAL_OUTPUT" + fi + rm -f "$PAIR_EVAL_OUTPUT" +else + gate_info "PAIR_EVAL_DIVERSITY" "skipped (no --pair-eval-findings)" +fi + if [[ "${AIEDGE_SKIP_TAMPER_TESTS:-0}" == "1" ]]; then gate_info "TAMPER_SUITE" "skipped by AIEDGE_SKIP_TAMPER_TESTS=1" else diff --git a/scripts/run_pair_eval.py b/scripts/run_pair_eval.py index 5c1b698..c847b2d 100755 --- a/scripts/run_pair_eval.py +++ b/scripts/run_pair_eval.py @@ -14,7 +14,109 @@ from aiedge.pair_eval import PairSpec, load_pairs_manifest -def _run_one(pair: PairSpec, side: str, firmware_path: str, results_root: Path, time_budget_s: int, driver: str) -> dict[str, Any]: +def _write_json(path: Path, payload: dict[str, Any]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text( + json.dumps(payload, indent=2, ensure_ascii=False) + "\n", encoding="utf-8" + ) + + +def _status_rank(status: str) -> int: + return {"success": 4, "partial": 3, "fatal": 2, "error": 1}.get(status or "", 0) + + +def _wall_timeout(time_budget_s: int) -> int: + return max(300, int(time_budget_s) + 900) + + +def _tail_lines(path: Path, n: int) -> list[str]: + try: + text = path.read_text(encoding="utf-8", errors="replace") + except OSError: + return [] + lines = text.splitlines() + return lines[-n:] if len(lines) > n else lines + + +def _guess_run_dir_from_stdout(stdout_tail: list[str]) -> str: + for line in reversed(stdout_tail): + candidate = line.strip() + if "aiedge-runs/" not in candidate: + continue + for tok in reversed(candidate.split()): + if "aiedge-runs/" in tok: + return tok.strip().rstrip(",.;:") + return "" + + +def _last_stage_info(run_dir_guess: str) -> tuple[str, str]: + if not run_dir_guess: + return "", "" + stages_dir = Path(run_dir_guess) / "stages" + if not stages_dir.is_dir(): + return "", "" + try: + stage_dirs = sorted( + (p for p in stages_dir.iterdir() if p.is_dir()), + key=lambda p: p.stat().st_mtime, + ) + except OSError: + return "", "" + if not stage_dirs: + return "", "" + last_dir = stage_dirs[-1] + last_name = last_dir.name + stage_json = last_dir / "stage.json" + if not stage_json.is_file(): + return last_name, "" + try: + payload_any = json.loads(stage_json.read_text(encoding="utf-8")) + except (OSError, json.JSONDecodeError): + return last_name, "" + if not isinstance(payload_any, dict): + return last_name, "" + status_val = payload_any.get("status") + if isinstance(status_val, str): + return last_name, status_val + return last_name, "" + + +def _dump_timeout_diagnostic( + *, + side_root: Path, + pair: PairSpec, + side: str, + stdout_path: Path, + stderr_path: Path, + wall_timeout_s: int, +) -> None: + stdout_tail = _tail_lines(stdout_path, 50) + stderr_tail = _tail_lines(stderr_path, 200) + run_dir_guess = _guess_run_dir_from_stdout(stdout_tail) + last_stage, last_stage_status = _last_stage_info(run_dir_guess) + diagnostic: dict[str, Any] = { + "pair_id": pair.pair_id, + "side": side, + "wall_timeout_s": wall_timeout_s, + "stdout_tail_count": len(stdout_tail), + "stderr_tail_count": len(stderr_tail), + "stdout_tail": stdout_tail, + "stderr_tail": stderr_tail, + "run_dir_guess": run_dir_guess, + "last_stage": last_stage, + "last_stage_status": last_stage_status, + } + _write_json(side_root / "timeout_diagnostic.json", diagnostic) + + +def _run_one( + pair: PairSpec, + side: str, + firmware_path: str, + results_root: Path, + time_budget_s: int, + driver: str, +) -> dict[str, Any]: side_root = results_root / "runs" / pair.pair_id / side side_root.mkdir(parents=True, exist_ok=True) env = os.environ.copy() @@ -29,10 +131,63 @@ def _run_one(pair: PairSpec, side: str, firmware_path: str, results_root: Path, "--time-budget-s", str(time_budget_s), ] + _write_json( + side_root / "started.json", + { + "pair_id": pair.pair_id, + "side": side, + "driver": driver, + "firmware_path": firmware_path, + "cmd": cmd, + "started_at": time.time(), + "wall_timeout_s": _wall_timeout(time_budget_s), + }, + ) + stdout_path = side_root / "stdout.txt" + stderr_path = side_root / "stderr.txt" start = time.time() - proc = subprocess.run(cmd, cwd=Path.cwd(), env=env, text=True, capture_output=True) + status = "fatal" + run_dir = "" + returncode = 20 + timed_out = False + try: + with stdout_path.open("wb") as fh_out, stderr_path.open("wb") as fh_err: + proc = subprocess.run( + cmd, + cwd=Path.cwd(), + env=env, + stdout=fh_out, + stderr=fh_err, + timeout=_wall_timeout(time_budget_s), + check=False, + ) + returncode = int(proc.returncode) + status = ( + "success" + if returncode == 0 + else ("partial" if returncode == 10 else "fatal") + ) + except subprocess.TimeoutExpired: + timed_out = True + returncode = 124 + status = "fatal" + try: + _dump_timeout_diagnostic( + side_root=side_root, + pair=pair, + side=side, + stdout_path=stdout_path, + stderr_path=stderr_path, + wall_timeout_s=_wall_timeout(time_budget_s), + ) + except Exception: + pass duration_s = round(time.time() - start, 3) - stdout_lines = [line.strip() for line in proc.stdout.splitlines() if line.strip()] + try: + stdout_text = stdout_path.read_text(encoding="utf-8", errors="replace") + except Exception: + stdout_text = "" + stdout_lines = [line.strip() for line in stdout_text.splitlines() if line.strip()] run_dir = stdout_lines[-1] if stdout_lines else "" result = { "pair_id": pair.pair_id, @@ -41,13 +196,15 @@ def _run_one(pair: PairSpec, side: str, firmware_path: str, results_root: Path, "cve_id": pair.cve_id, "side": side, "firmware_path": firmware_path, - "returncode": proc.returncode, + "returncode": returncode, "duration_s": duration_s, "run_dir": run_dir, - "stdout": proc.stdout, - "stderr": proc.stderr, - "status": "success" if proc.returncode == 0 else ("partial" if proc.returncode == 10 else "fatal"), + "status": status, + "timed_out": timed_out, + "driver": driver, + "wall_timeout_s": _wall_timeout(time_budget_s), } + _write_json(side_root / "last_run.json", result) if run_dir: link = side_root / "latest" if link.exists() or link.is_symlink(): @@ -56,15 +213,12 @@ def _run_one(pair: PairSpec, side: str, firmware_path: str, results_root: Path, link.symlink_to(Path(run_dir).resolve()) except Exception: pass - (side_root / "last_run.json").write_text(json.dumps(result, indent=2, ensure_ascii=False) + "\n", encoding="utf-8") return result -def _status_rank(status: str) -> int: - return {"success": 4, "partial": 3, "fatal": 2, "error": 1}.get(status or "", 0) - - -def _build_rows_from_summaries(pairs: list[PairSpec], summary_paths: list[Path], results_root: Path) -> list[dict[str, Any]]: +def _build_rows_from_summaries( + pairs: list[PairSpec], summary_paths: list[Path], results_root: Path +) -> list[dict[str, Any]]: candidates: dict[tuple[str, str], tuple[tuple[int, int], dict[str, Any]]] = {} for idx, summary_path in enumerate(summary_paths, start=1): with summary_path.open(encoding="utf-8") as handle: @@ -77,7 +231,10 @@ def _build_rows_from_summaries(pairs: list[PairSpec], summary_paths: list[Path], out: list[dict[str, Any]] = [] for pair in pairs: - for side, side_spec in (("vulnerable", pair.vulnerable), ("patched", pair.patched)): + for side, side_spec in ( + ("vulnerable", pair.vulnerable), + ("patched", pair.patched), + ): firmware_name = Path(side_spec.firmware_path).name row = candidates.get((pair.vendor, firmware_name), (None, {}))[1] run_dir = row.get("run_dir") or "" @@ -101,24 +258,38 @@ def _build_rows_from_summaries(pairs: list[PairSpec], summary_paths: list[Path], "returncode": int(row.get("exit_code") or 0) if row else 0, "duration_s": float(row.get("duration_s") or 0) if row else 0, "run_dir": run_dir, - "stdout": "", - "stderr": "", "status": row.get("status") or "missing", "source_summary": str(summary_paths[0]) if row else "", + "driver": "summary-reuse", + "timed_out": False, } - (side_root / "last_run.json").write_text(json.dumps(record, indent=2, ensure_ascii=False) + "\n", encoding="utf-8") + _write_json(side_root / "last_run.json", record) out.append(record) return out +def _write_run_index( + results_root: Path, *, driver: str, time_budget_s: int, rows: list[dict[str, Any]] +) -> None: + ordered = sorted(rows, key=lambda r: (r["pair_id"], r["side"])) + _write_json( + results_root / "run_index.json", + { + "driver": driver, + "time_budget_s": time_budget_s, + "rows": ordered, + }, + ) + + def main() -> int: - parser = argparse.ArgumentParser(description="Run the M0 pair-eval corpus with Codex-full pipeline.") + parser = argparse.ArgumentParser(description="Run the pair-eval corpus.") parser.add_argument("--pairs", default="benchmarks/pair-eval/pairs.json") parser.add_argument("--results-dir", default="benchmark-results/pair-eval") parser.add_argument("--driver", default="codex") parser.add_argument("--parallel", type=int, default=2) parser.add_argument("--time-budget-s", type=int, default=3600) - parser.add_argument("--source-summary", nargs='*', default=[]) + parser.add_argument("--source-summary", nargs="*", default=[]) args = parser.parse_args() results_root = Path(args.results_dir).resolve() @@ -126,25 +297,72 @@ def main() -> int: pairs = load_pairs_manifest(Path(args.pairs).resolve()) if args.source_summary: - rows = _build_rows_from_summaries(pairs, [Path(p).resolve() for p in args.source_summary], results_root) - else: - tasks: list[tuple[PairSpec, str, str]] = [] - for pair in pairs: - tasks.append((pair, "vulnerable", pair.vulnerable.firmware_path)) - tasks.append((pair, "patched", pair.patched.firmware_path)) - - rows = [] - with concurrent.futures.ThreadPoolExecutor(max_workers=max(1, args.parallel)) as pool: - futs = [pool.submit(_run_one, pair, side, firmware, results_root, args.time_budget_s, args.driver) for pair, side, firmware in tasks] - for fut in concurrent.futures.as_completed(futs): - row = fut.result() - rows.append(row) - print(json.dumps({k: row[k] for k in ['pair_id','side','status','returncode','run_dir']}, ensure_ascii=False), flush=True) - - rows.sort(key=lambda r: (r['pair_id'], r['side'])) - (results_root / 'run_index.json').write_text(json.dumps({'driver': ('summary-reuse' if args.source_summary else args.driver), 'time_budget_s': args.time_budget_s, 'rows': rows}, indent=2, ensure_ascii=False) + '\n', encoding='utf-8') + rows = _build_rows_from_summaries( + pairs, [Path(p).resolve() for p in args.source_summary], results_root + ) + _write_run_index( + results_root, + driver="summary-reuse", + time_budget_s=args.time_budget_s, + rows=rows, + ) + return 0 + + tasks: list[tuple[PairSpec, str, str]] = [] + for pair in pairs: + tasks.append((pair, "vulnerable", pair.vulnerable.firmware_path)) + tasks.append((pair, "patched", pair.patched.firmware_path)) + + rows: list[dict[str, Any]] = [] + _write_run_index( + results_root, driver=args.driver, time_budget_s=args.time_budget_s, rows=rows + ) + with concurrent.futures.ThreadPoolExecutor( + max_workers=max(1, args.parallel) + ) as pool: + futs = [ + pool.submit( + _run_one, + pair, + side, + firmware, + results_root, + args.time_budget_s, + args.driver, + ) + for pair, side, firmware in tasks + ] + for fut in concurrent.futures.as_completed(futs): + row = fut.result() + rows.append(row) + _write_run_index( + results_root, + driver=args.driver, + time_budget_s=args.time_budget_s, + rows=rows, + ) + print( + json.dumps( + { + k: row[k] + for k in [ + "pair_id", + "side", + "status", + "returncode", + "run_dir", + "timed_out", + ] + }, + ensure_ascii=False, + ), + flush=True, + ) + _write_run_index( + results_root, driver=args.driver, time_budget_s=args.time_budget_s, rows=rows + ) return 0 -if __name__ == '__main__': +if __name__ == "__main__": raise SystemExit(main()) diff --git a/src/aiedge/code_slicing.py b/src/aiedge/code_slicing.py new file mode 100644 index 0000000..2d7e2ca --- /dev/null +++ b/src/aiedge/code_slicing.py @@ -0,0 +1,237 @@ +from __future__ import annotations + +"""LATTE-inspired text-based backward slicing for taint LLM prompts. + +The full LATTE technique (Liu et al., TOSEM 2025) builds a Code Slicing +Prompt Sequence on top of an actual program slice computed from +inter-procedural data-flow analysis. SCOUT's first-cut implementation is +deliberately simpler: + +* it operates on Ghidra-decompiled function bodies as plain text; +* it walks bottom-up from the line that contains the sink call; +* it keeps any earlier line that mentions an identifier already known to + influence the slice; +* it stops at ``max_lines`` or the function start. + +The resulting slice is a strict subset of the function body, ordered as +in the source. Empty lines and comment-only lines are preserved as-is so +the LLM still sees structural cues. The slicing is therefore an +*over-approximation* of true backward dataflow, but it is much cheaper +than rebuilding a Ghidra-grade SSA / use-def graph and it already buys +the two properties that LATTE relies on for prompt quality: + +1. **Token compression** -- LLM context is dominated by the sink path + instead of the entire function; +2. **Locality** -- variables defined in the same function are visible to + the LLM, so it can reason about taint provenance without losing the + declaration site. + +Future revisions can replace ``extract_backward_slice`` with a Ghidra +P-code SSA backend without changing the public API or the call sites in +``taint_propagation.py``. + +The slicing is **opt-in** at the call site via ``AIEDGE_LATTE_SLICING=1`` +because its over-approximation can occasionally cut a load-bearing line +that the regex heuristic does not recognise as relevant. Default-off +keeps the existing prompt behaviour byte-identical. +""" + +import os +import re + +# Identifier extraction. C identifiers are [a-zA-Z_][a-zA-Z0-9_]*. +_IDENT_PAT: re.Pattern[str] = re.compile(r"\b[a-zA-Z_][a-zA-Z0-9_]*\b") + +# Reserved identifiers we do *not* want to inflate the variable-of-interest +# set. These are C keywords or extremely common standard-library tokens whose +# presence on a line should not, by itself, pull every previous line into the +# slice. The list is intentionally conservative; specialised vendor tokens are +# not filtered because they often *are* the relevant variables. +_NOISE_IDENTIFIERS: frozenset[str] = frozenset( + { + # C keywords / type qualifiers + "if", + "else", + "for", + "while", + "do", + "switch", + "case", + "default", + "break", + "continue", + "return", + "goto", + "sizeof", + "void", + "int", + "long", + "short", + "char", + "float", + "double", + "unsigned", + "signed", + "const", + "volatile", + "static", + "extern", + "inline", + "auto", + "register", + "struct", + "union", + "enum", + "typedef", + # Common literals / boolean tokens + "true", + "false", + "NULL", + "null", + "nullptr", + "TRUE", + "FALSE", + # Frequently encountered macros that are not data variables + "abs", + "min", + "max", + "MIN", + "MAX", + } +) + + +def _line_identifiers(line: str) -> set[str]: + """Return the set of C-style identifiers that appear in ``line``, + excluding the ``_NOISE_IDENTIFIERS`` set.""" + return {tok for tok in _IDENT_PAT.findall(line) if tok not in _NOISE_IDENTIFIERS} + + +def latte_slicing_enabled() -> bool: + """Return ``True`` when ``AIEDGE_LATTE_SLICING`` is set to a truthy value. + + Truthy = ``"1"``, ``"true"``, ``"yes"``, ``"on"`` (case-insensitive). + Anything else, including unset, returns ``False``. Centralising the + parse keeps call sites in ``taint_propagation`` short. + """ + raw = os.environ.get("AIEDGE_LATTE_SLICING", "") + return raw.strip().lower() in {"1", "true", "yes", "on"} + + +def find_sink_line(function_body: str, sink_sym: str) -> int | None: + """Return the 0-based line index of the first call to ``sink_sym`` in + ``function_body``. Matches ``sink_sym(`` (optional whitespace) at a word + boundary so ``open(`` matches but ``fopen(`` does not when ``sink_sym`` + is ``"open"``. Returns ``None`` when no call is found. + """ + if not function_body or not sink_sym: + return None + pat = re.compile(r"\b" + re.escape(sink_sym) + r"\s*\(") + for idx, line in enumerate(function_body.splitlines()): + if pat.search(line): + return idx + return None + + +def extract_backward_slice( + function_body: str, + sink_line_idx: int, + *, + max_lines: int = 30, +) -> str: + """Return a backward slice ending at ``sink_line_idx``. + + Algorithm: start from the sink line, collect its non-noise identifiers + as the initial variable-of-interest set, then walk upward. For each + earlier line, if its identifier set intersects the variable-of-interest + set we include the line and union its identifiers into the interest + set (data dependency may flow further back). Iteration stops when we + accumulate ``max_lines`` lines or reach the function start. + + Lines are emitted in source order. When ``sink_line_idx`` is out of + range the function returns ``function_body`` unchanged so callers can + treat the slice as a *safe substitute* for the full body. + """ + if not function_body: + return function_body + lines = function_body.splitlines() + if sink_line_idx < 0 or sink_line_idx >= len(lines): + return function_body + if max_lines <= 0: + return function_body + + sink_line = lines[sink_line_idx] + vars_of_interest: set[str] = _line_identifiers(sink_line) + # If the sink line itself has no usable identifier (rare), keep at + # least the sink token so the slice is non-empty. + if not vars_of_interest: + vars_of_interest = set(_IDENT_PAT.findall(sink_line)) + + included: list[int] = [sink_line_idx] + for i in range(sink_line_idx - 1, -1, -1): + if len(included) >= max_lines: + break + line = lines[i] + line_ids = _line_identifiers(line) + # Always preserve blank / comment lines that immediately precede an + # included statement so the LLM sees the surrounding context block. + if not line_ids: + included.append(i) + continue + if line_ids & vars_of_interest: + included.append(i) + vars_of_interest |= line_ids + + included.sort() + return "\n".join(lines[i] for i in included) + + +def extract_slice_around_sink( + function_body: str, + sink_sym: str, + *, + max_lines: int = 30, +) -> str | None: + """Convenience wrapper: locate ``sink_sym`` then backward-slice. + + Returns ``None`` when ``sink_sym`` is not called in ``function_body``, + so the caller can decide whether to skip the prompt entirely or fall + back to the full body. + """ + idx = find_sink_line(function_body, sink_sym) + if idx is None: + return None + return extract_backward_slice(function_body, idx, max_lines=max_lines) + + +def maybe_slice( + function_body: str, + sink_sym: str, + *, + max_lines: int = 30, +) -> str: + """Return a slice when ``AIEDGE_LATTE_SLICING`` is enabled, otherwise + return ``function_body`` unchanged. This is the recommended entry + point for ``taint_propagation`` since it bakes the env-gate decision + in one place and never returns ``None``. + """ + if not latte_slicing_enabled(): + return function_body + sliced = extract_slice_around_sink(function_body, sink_sym, max_lines=max_lines) + return sliced if sliced is not None else function_body + + +def slice_compression_ratio(original: str, sliced: str) -> float: + """Return the fraction of original lines preserved in ``sliced``. + + Useful for telemetry: a value < 0.4 indicates aggressive compression + (good for token cost) while a value approaching 1.0 means slicing + barely helped (the function is mostly on the sink path). Returns + ``1.0`` when the original is empty so callers do not need a special + case. + """ + orig_lines = original.splitlines() + if not orig_lines: + return 1.0 + sliced_lines = sliced.splitlines() if sliced else [] + return round(len(sliced_lines) / len(orig_lines), 6) diff --git a/src/aiedge/enhanced_source.py b/src/aiedge/enhanced_source.py index fb6439c..fa5709c 100644 --- a/src/aiedge/enhanced_source.py +++ b/src/aiedge/enhanced_source.py @@ -18,47 +18,51 @@ _SCHEMA_VERSION = "enhanced-source-v1" -INPUT_APIS: frozenset[str] = frozenset({ - "recv", - "recvfrom", - "recvmsg", - "read", - "fread", - "fgets", - "gets", - "getenv", - "scanf", - "sscanf", - "fscanf", - "websGetVar", - "httpGetEnv", - "nvram_get", - "acosNvramConfig_get", - "json_object_get_string", - "cJSON_GetObjectItem", - "cJSON_Parse", - "json_tokener_parse", - "xmlParseMemory", - "getParameter", - "wp_getVar", -}) - -SINK_APIS: frozenset[str] = frozenset({ - "system", - "popen", - "execve", - "execv", - "execl", - "execlp", - "strcpy", - "strcat", - "sprintf", - "vsprintf", - "gets", - "doSystemCmd", - "twsystem", - "doSystem", -}) +INPUT_APIS: frozenset[str] = frozenset( + { + "recv", + "recvfrom", + "recvmsg", + "read", + "fread", + "fgets", + "gets", + "getenv", + "scanf", + "sscanf", + "fscanf", + "websGetVar", + "httpGetEnv", + "nvram_get", + "acosNvramConfig_get", + "json_object_get_string", + "cJSON_GetObjectItem", + "cJSON_Parse", + "json_tokener_parse", + "xmlParseMemory", + "getParameter", + "wp_getVar", + } +) + +SINK_APIS: frozenset[str] = frozenset( + { + "system", + "popen", + "execve", + "execv", + "execl", + "execlp", + "strcpy", + "strcat", + "sprintf", + "vsprintf", + "gets", + "doSystemCmd", + "twsystem", + "doSystem", + } +) # Lowercase lookup set for case-insensitive matching _INPUT_APIS_LOWER: frozenset[str] = frozenset(api.lower() for api in INPUT_APIS) @@ -68,19 +72,215 @@ _API_CANONICAL: dict[str, str] = {api.lower(): api for api in INPUT_APIS | SINK_APIS} # --- Web server auto-detection --- -_WEB_SERVER_NAMES: frozenset[str] = frozenset({ - "httpd", "lighttpd", "uhttpd", "mini_httpd", "boa", - "goahead", "thttpd", "nginx", "busybox_httpd", "micro_httpd", - "cgibin", "prog.cgi", "soapcgi", -}) - -_WEB_LISTENER_SYMS: frozenset[str] = frozenset({ - "listen", "accept", "bind", "socket", -}) - -_EXEC_SINK_SYMS: frozenset[str] = frozenset({ - "system", "popen", "execve", "execv", "execl", -}) +_WEB_SERVER_NAMES: frozenset[str] = frozenset( + { + "httpd", + "lighttpd", + "uhttpd", + "mini_httpd", + "boa", + "goahead", + "thttpd", + "nginx", + "busybox_httpd", + "micro_httpd", + "cgibin", + "prog.cgi", + "soapcgi", + } +) + + +# --- Phase 2C+.2 (LARA-inspired URI / CGI / config-key semantic sources) --- +# +# LARA (USENIX Security 2024) widens source identification beyond C-level +# input APIs by recognising URI / HTTP-variable / config-key strings as +# attacker-influenced data origins. We carry the same idea into SCOUT but +# stay conservative: confidence is capped below the dynstr API path +# (0.40 vs 0.60) because string presence alone does not prove reachability. +# +# These are matched case-insensitively as substrings or against full tokens. + +_URI_SOURCE_PATTERNS: frozenset[str] = frozenset( + { + # CGI gateway prefixes (router admin UIs) + "/cgi-bin/", + "/cgi/", + "/goform/", + "/apply.cgi", + "/upgrade.cgi", + "/system.cgi", + "/ipformget.cgi", + "/ipformset.cgi", + # REST / SOAP / JSON-RPC API prefixes + "/api/", + "/webapi/", + "/json-rpc/", + "/jsonrpc/", + "/rest/", + "/soap/", + # Common UPnP / TR-069 / device-management endpoints + "/upnp/", + "/control/", + "/tr069/", + "/cwmp/", + # OEM web UI roots + "/web/", + "/admin/", + "/setup.cgi", + } +) + +_CGI_VAR_PATTERNS: frozenset[str] = frozenset( + { + # Standard CGI environment variables (RFC 3875) + "HTTP_USER_AGENT", + "HTTP_REFERER", + "HTTP_COOKIE", + "HTTP_HOST", + "HTTP_AUTHORIZATION", + "QUERY_STRING", + "REQUEST_METHOD", + "REQUEST_URI", + "PATH_INFO", + "PATH_TRANSLATED", + "REMOTE_ADDR", + "REMOTE_USER", + "CONTENT_LENGTH", + "CONTENT_TYPE", + # Vendor / OEM extensions seen in router CGIs + "HTTP_X_FORWARDED_FOR", + "SCRIPT_NAME", + "SERVER_NAME", + } +) + +_CONFIG_KEY_PATTERNS: frozenset[str] = frozenset( + { + # Authentication / credential keys (router NVRAM + sysconf conventions) + "http_passwd", + "http_username", + "admin_passwd", + "admin_password", + "web_admin_token", + "web_passwd", + "auth_token", + "session_id", + "session_key", + # Device / connectivity keys frequently controlled remotely + "wan_ipaddr", + "lan_ipaddr", + "wifi_psk", + "wifi_password", + "wpa_psk", + "ssid", + "ddns_username", + "ddns_password", + # Cloud / OTA / pairing keys (modern IoT vendors) + "cloud_token", + "device_token", + "registration_code", + "pairing_key", + "firmware_url", + } +) + + +def _extract_uri_key_sources( + bin_path: str, + symbols: set[str], + ascii_strings: set[str] | None = None, +) -> list[tuple[str, str]]: + """Return ``(pattern, kind)`` tuples for any LARA-style URI / CGI / config + matches surfaced by the binary's symbol table or extracted ASCII strings. + + *kind* is one of ``"uri_endpoint"``, ``"cgi_variable"``, ``"config_key"``. + + Matching policy: + - URI prefixes (`/cgi-bin/`, `/api/`, ...): case-insensitive substring + against ``bin_path`` *and* against any provided ``ascii_strings``. + Symbol names are not searched because dynamic-linker symbols rarely + embed a literal URI; substring matches there are noisy. + - CGI environment names (`QUERY_STRING`, ...): exact lower-case match + against either ``symbols`` or ``ascii_strings``. + - NVRAM / sysconf config keys (`http_passwd`, ...): case-insensitive + substring match against ``bin_path``, ``symbols``, and + ``ascii_strings`` (these short identifiers often appear inside + wrapper symbol names like ``get_http_passwd_value``). + + ``ascii_strings`` is optional and defaults to an empty set so the helper + stays cheap when no extracted-string data is available. + """ + sym_lower_set = {s.lower() for s in symbols} if symbols else set() + bin_lower = bin_path.lower() if bin_path else "" + str_lower_set = {s.lower() for s in ascii_strings} if ascii_strings else set() + if not sym_lower_set and not bin_lower and not str_lower_set: + return [] + + matches: list[tuple[str, str]] = [] + seen: set[tuple[str, str]] = set() + + def _record(pattern: str, kind: str) -> None: + key = (pattern, kind) + if key in seen: + return + seen.add(key) + matches.append(key) + + for pattern in _URI_SOURCE_PATTERNS: + needle = pattern.lower() + if needle in bin_lower: + _record(pattern, "uri_endpoint") + continue + for s_lower in str_lower_set: + if needle in s_lower: + _record(pattern, "uri_endpoint") + break + + for var in _CGI_VAR_PATTERNS: + var_lower = var.lower() + if var_lower in sym_lower_set or var_lower in str_lower_set: + _record(var, "cgi_variable") + + for key in _CONFIG_KEY_PATTERNS: + needle = key.lower() + if needle in bin_lower: + _record(key, "config_key") + continue + matched = False + for sym_lower in sym_lower_set: + if needle in sym_lower: + _record(key, "config_key") + matched = True + break + if matched: + continue + for s_lower in str_lower_set: + if needle in s_lower: + _record(key, "config_key") + break + + return matches + + +_WEB_LISTENER_SYMS: frozenset[str] = frozenset( + { + "listen", + "accept", + "bind", + "socket", + } +) + +_EXEC_SINK_SYMS: frozenset[str] = frozenset( + { + "system", + "popen", + "execve", + "execv", + "execl", + } +) def _classify_web_server( @@ -163,9 +363,7 @@ def run(self, ctx: StageContext) -> StageOutcome: inv_obj = cast(dict[str, object], inv_data) # --- Load binary_analysis.json for .dynstr data --- - binary_analysis_path = ( - run_dir / "stages" / "inventory" / "binary_analysis.json" - ) + binary_analysis_path = run_dir / "stages" / "inventory" / "binary_analysis.json" ba_data = _load_json_file(binary_analysis_path) ba_hits: list[object] = [] if isinstance(ba_data, dict): @@ -173,9 +371,7 @@ def run(self, ctx: StageContext) -> StageOutcome: if isinstance(hits_any, list): ba_hits = cast(list[object], hits_any) elif ba_data is None: - limitations.append( - "binary_analysis.json missing; .dynstr scan unavailable" - ) + limitations.append("binary_analysis.json missing; .dynstr scan unavailable") # --- Scan binary analysis hits for INPUT and SINK APIs --- for bin_any in ba_hits: @@ -191,7 +387,12 @@ def run(self, ctx: StageContext) -> StageOutcome: # Collect ALL symbols from matched_symbols + symbol_details symbols: set[str] = set() - for key in ("matched_symbols", "dynstr_imports", "risky_symbols", "imports"): + for key in ( + "matched_symbols", + "dynstr_imports", + "risky_symbols", + "imports", + ): syms_any = bin_obj.get(key) if isinstance(syms_any, list): for sym_any in cast(list[object], syms_any): @@ -244,13 +445,9 @@ def run(self, ctx: StageContext) -> StageOutcome: # Web server classification — boost confidence for HTTP binaries ipc_any = bin_obj.get("ipc_indicators") ipc_dict = ( - cast(dict[str, object], ipc_any) - if isinstance(ipc_any, dict) - else None - ) - is_web, conf_boost = _classify_web_server( - bin_path, symbols, ipc_dict + cast(dict[str, object], ipc_any) if isinstance(ipc_any, dict) else None ) + is_web, conf_boost = _classify_web_server(bin_path, symbols, ipc_dict) if is_web: confidence = min(0.90, confidence + conf_boost) @@ -266,23 +463,48 @@ def run(self, ctx: StageContext) -> StageOutcome: # Record each input API as a source; if none, use sink APIs api_list = matched_input if matched_input else matched_sink for api in api_list: - sources.append({ - "address": "0x0", - "api": api, - "binary": bin_path, - "confidence": _clamp01(confidence), - "method": "enhanced_static", - "matched_input_apis": cast( - list[JsonValue], cast(list[object], sorted(set(matched_input))) - ), - "matched_sink_apis": cast( - list[JsonValue], cast(list[object], sorted(set(matched_sink))) - ), - "arch": arch, - "hardening": cast(dict[str, JsonValue], hardening), - "source_type": source_type, - "web_server": is_web, - }) + sources.append( + { + "address": "0x0", + "api": api, + "binary": bin_path, + "confidence": _clamp01(confidence), + "method": "enhanced_static", + "matched_input_apis": cast( + list[JsonValue], + cast(list[object], sorted(set(matched_input))), + ), + "matched_sink_apis": cast( + list[JsonValue], + cast(list[object], sorted(set(matched_sink))), + ), + "arch": arch, + "hardening": cast(dict[str, JsonValue], hardening), + "source_type": source_type, + "web_server": is_web, + } + ) + + # --- Phase 2C+.2: LARA-style URI / CGI / config-key sources --- + # Confidence stays at the SYMBOL_COOCCURRENCE cap (0.40) because + # string presence alone does not prove reachability; downstream + # taint propagation can promote individual matches. + for pattern, kind in _extract_uri_key_sources(bin_path, symbols): + sources.append( + { + "address": "0x0", + "api": pattern, + "binary": bin_path, + "confidence": _clamp01(0.40), + "method": "lara_pattern", + "matched_input_apis": cast(list[JsonValue], []), + "matched_sink_apis": cast(list[JsonValue], []), + "arch": arch, + "hardening": cast(dict[str, JsonValue], hardening), + "source_type": kind, + "web_server": is_web, + } + ) # --- Fallback: read source_sink_graph.json for additional sources --- ssg_path = run_dir / "stages" / "surfaces" / "source_sink_graph.json" @@ -311,7 +533,9 @@ def run(self, ctx: StageContext) -> StageOutcome: src_type = "" if isinstance(source_any, dict): - src_type = str(cast(dict[str, object], source_any).get("type", "")) + src_type = str( + cast(dict[str, object], source_any).get("type", "") + ) conf_any = p_obj.get("confidence") ssg_conf = ( _clamp01(float(conf_any)) @@ -320,18 +544,20 @@ def run(self, ctx: StageContext) -> StageOutcome: ) for sym in sink_syms: - sources.append({ - "address": "0x0", - "api": sym, - "binary": sink_bin, - "confidence": _clamp01(min(ssg_conf, 0.55)), - "method": "source_sink_graph", - "source_type": src_type, - "matched_input_apis": cast(list[JsonValue], []), - "matched_sink_apis": cast( - list[JsonValue], cast(list[object], sink_syms) - ), - }) + sources.append( + { + "address": "0x0", + "api": sym, + "binary": sink_bin, + "confidence": _clamp01(min(ssg_conf, 0.55)), + "method": "source_sink_graph", + "source_type": src_type, + "matched_input_apis": cast(list[JsonValue], []), + "matched_sink_apis": cast( + list[JsonValue], cast(list[object], sink_syms) + ), + } + ) # --- Also scan inventory service_candidates for input API references --- candidates_any = inv_obj.get("service_candidates") @@ -355,19 +581,26 @@ def run(self, ctx: StageContext) -> StageOutcome: if not isinstance(sym_any, str): continue sym_lower = sym_any.lower().strip() - if sym_lower in _INPUT_APIS_LOWER or sym_lower in _SINK_APIS_LOWER: + if ( + sym_lower in _INPUT_APIS_LOWER + or sym_lower in _SINK_APIS_LOWER + ): canonical = _API_CANONICAL.get(sym_lower, sym_any) path_any = ev.get("path") bin_path_str = ( - str(path_any) if isinstance(path_any, str) else cand_name + str(path_any) + if isinstance(path_any, str) + else cand_name + ) + sources.append( + { + "address": "0x0", + "api": canonical, + "binary": bin_path_str, + "confidence": _clamp01(0.50), + "method": "service_candidate", + } ) - sources.append({ - "address": "0x0", - "api": canonical, - "binary": bin_path_str, - "confidence": _clamp01(0.50), - "method": "service_candidate", - }) # --- Deduplicate sources --- seen: set[tuple[str, str, str]] = set() @@ -392,25 +625,20 @@ def run(self, ctx: StageContext) -> StageOutcome: "schema_version": _SCHEMA_VERSION, "status": status, "total_sources": len(unique_sources), - "sources": cast( - list[JsonValue], cast(list[object], unique_sources) - ), + "sources": cast(list[JsonValue], cast(list[object], unique_sources)), "limitations": cast( list[JsonValue], cast(list[object], sorted(set(limitations))) ), } out_json.write_text( - json.dumps(payload, indent=2, sort_keys=True, ensure_ascii=True) - + "\n", + json.dumps(payload, indent=2, sort_keys=True, ensure_ascii=True) + "\n", encoding="utf-8", ) details: dict[str, JsonValue] = { "total_sources": len(unique_sources), "unique_apis": len({cast(str, s["api"]) for s in unique_sources}), - "unique_binaries": len( - {cast(str, s["binary"]) for s in unique_sources} - ), + "unique_binaries": len({cast(str, s["binary"]) for s in unique_sources}), } return StageOutcome( status=status, diff --git a/src/aiedge/quality_policy.py b/src/aiedge/quality_policy.py index 0a5afed..39dcd7a 100644 --- a/src/aiedge/quality_policy.py +++ b/src/aiedge/quality_policy.py @@ -1,7 +1,10 @@ from __future__ import annotations +import csv import json import os +from collections import Counter +from collections.abc import Sequence from pathlib import Path from typing import cast @@ -24,6 +27,8 @@ def _threshold_float(env_name: str, default: float) -> float: QUALITY_GATE_LLM_REQUIRED = "QUALITY_GATE_LLM_REQUIRED" QUALITY_GATE_LLM_INVALID = "QUALITY_GATE_LLM_INVALID" QUALITY_GATE_LLM_VERDICT_MISS = "QUALITY_GATE_LLM_VERDICT_MISS" +QUALITY_GATE_DIVERSITY_MISS = "QUALITY_GATE_DIVERSITY_MISS" +QUALITY_GATE_INVALID_PAIR_EVAL = "QUALITY_GATE_INVALID_PAIR_EVAL" class QualityGateError(ValueError): @@ -316,3 +321,107 @@ def format_quality_gate(payload: dict[str, object]) -> str: def write_quality_gate(path: Path, payload: dict[str, object]) -> None: _ = path.write_text(format_quality_gate(payload), encoding="utf-8") + + +def compute_pair_eval_diversity_index(finding_ids: Sequence[str]) -> float: + """Return max-share diversity index across pair-eval finding rows. + + 1.0 = degenerate (every row mapped to a single finding_id). + 1/N = fully diverse (every row a distinct finding). + 0.0 = empty input (caller decides whether this is gate violation). + """ + if not finding_ids: + return 0.0 + counter = Counter(finding_ids) + return _rounded(max(counter.values()) / len(finding_ids)) + + +def load_pair_eval_finding_ids( + csv_path: Path, + *, + only_ground_truth: frozenset[str] | None = None, +) -> list[str]: + """Load finding_id column from pair_eval_findings.csv. + + only_ground_truth: if provided, restrict to rows whose ground_truth value is in + the set (e.g. ``frozenset({"tp", "fp"})``). Default: all non-empty finding_id rows. + """ + finding_ids: list[str] = [] + try: + with csv_path.open(encoding="utf-8") as fh: + reader = csv.DictReader(fh) + for row in reader: + finding_id = (row.get("finding_id") or "").strip() + if not finding_id: + continue + if only_ground_truth is not None: + gt = (row.get("ground_truth") or "").strip().lower() + if gt not in only_ground_truth: + continue + finding_ids.append(finding_id) + except FileNotFoundError as e: + raise QualityGateError( + QUALITY_GATE_INVALID_PAIR_EVAL, + f"pair-eval findings CSV not found: {csv_path}", + ) from e + except (OSError, csv.Error) as e: + raise QualityGateError( + QUALITY_GATE_INVALID_PAIR_EVAL, + f"pair-eval findings CSV could not be read: {e}", + ) from e + return finding_ids + + +def evaluate_pair_eval_diversity_gate( + *, + finding_ids: Sequence[str], + findings_source: str, +) -> dict[str, object]: + """Evaluate the finding-diversity gate for a pair-eval lane. + + Threshold env: ``AIEDGE_PAIR_DIVERSITY_MAX`` (default 0.5). The gate fails when + the diversity index is **>=** the threshold (since 1.0 indicates degenerate + single-finding mapping). Empty input returns a pass with sample_size=0 — callers + that require a non-empty sample should check ``measured.sample_size`` themselves. + """ + threshold = _threshold_float("AIEDGE_PAIR_DIVERSITY_MAX", 0.5) + diversity_index = compute_pair_eval_diversity_index(finding_ids) + sample_size = len(finding_ids) + + policy = { + "finding_diversity_max": _rounded(threshold), + "finding_diversity_max_env": "AIEDGE_PAIR_DIVERSITY_MAX", + } + + errors: list[dict[str, object]] = [] + if sample_size > 0 and diversity_index >= threshold: + errors.append( + { + "error_token": QUALITY_GATE_DIVERSITY_MISS, + "metric": "finding_diversity_index", + "source_field": "pair_eval_findings.finding_id", + "actual": diversity_index, + "threshold": _rounded(threshold), + "operator": "<", + "sample_size": sample_size, + "message": ( + f"finding diversity violation: index={diversity_index} " + f">= threshold={_rounded(threshold)} " + f"(degenerate when 1.0; sample_size={sample_size})" + ), + } + ) + + passed = not errors + return { + "schema_version": QUALITY_GATE_SCHEMA_VERSION, + "verdict": "pass" if passed else "fail", + "passed": passed, + "findings_source": findings_source, + "policy": policy, + "measured": { + "finding_diversity_index": diversity_index, + "sample_size": sample_size, + }, + "errors": errors, + } diff --git a/src/aiedge/taint_propagation.py b/src/aiedge/taint_propagation.py index 1eaddcb..9bceeca 100644 --- a/src/aiedge/taint_propagation.py +++ b/src/aiedge/taint_propagation.py @@ -16,6 +16,7 @@ from typing import cast from ._typing_helpers import safe_float, safe_int +from .code_slicing import maybe_slice from .confidence_caps import ( PCODE_VERIFIED_CAP, STATIC_CODE_VERIFIED_CAP, @@ -46,7 +47,7 @@ _SINK_SYMBOLS: frozenset[str] = frozenset( { - # -- Command injection -- + # -- CWE-78 command / process injection -- "system", "popen", "execve", @@ -56,7 +57,10 @@ "execlp", "execle", "execv", - # -- Buffer overflow (string) -- + "wordexp", + "posix_spawn", + "posix_spawnp", + # -- CWE-120/121 buffer overflow (string) -- "strcpy", "sprintf", "strcat", @@ -64,23 +68,55 @@ "strncat", "gets", "vsprintf", - # -- Buffer overflow (memory) -- + # -- CWE-120 buffer overflow (memory) -- "memcpy", "memmove", - # -- Format string -- + # -- CWE-134 format string -- "printf", "fprintf", "syslog", "vprintf", "vfprintf", "snprintf", - # -- Dangerous input parsing -- + "vsnprintf", + "dprintf", + "vdprintf", + # -- CWE-20 input parsing -- "scanf", "sscanf", "fscanf", - # -- Dynamic loading / path traversal -- - "dlopen", + # -- CWE-22 / CWE-73 path traversal -- + "fopen", + "open", + "openat", + "freopen", + "chdir", "realpath", + # -- CWE-426 untrusted search path / dynamic loading -- + "dlopen", + "dlsym", + "dlmopen", + # -- CWE-732 incorrect permission assignment -- + "chmod", + "fchmod", + "chown", + "fchown", + "lchown", + # -- CWE-377 insecure temporary file -- + "mktemp", + "tmpnam", + "tempnam", + "tmpfile", + # -- CWE-250 / CWE-269 privilege management -- + "chroot", + "setuid", + "seteuid", + "setgid", + "setegid", + # -- CWE-454 environment injection -- + "putenv", + "setenv", + "unsetenv", } ) @@ -92,6 +128,15 @@ "vprintf", "vfprintf", "snprintf", + "vsnprintf", + "dprintf", + "vdprintf", + "swprintf", + "vswprintf", + "wprintf", + "vwprintf", + "fwprintf", + "vfwprintf", } ) @@ -195,12 +240,23 @@ def _is_format_string_variable( sink_sym: str, decompiled_body: str, ) -> bool: - """Return True if sink_sym is called with a variable (non-literal) format string.""" + """Return True if sink_sym is called with a variable (non-literal) format string. + + Recognised variable forms (anything whose first argument is *not* a string + literal): bare identifiers (``printf(buf)``), function-call results + (``printf(get_str())``), struct field access (``printf(obj->field)`` / + ``printf(obj.field)``), array subscripts (``printf(arr[i])``), C-style + casts (``printf((char *) buf)``), parenthesised expressions including + ternaries (``printf((cond ? a : b))``). + """ if sink_sym not in _FORMAT_STRING_SINKS: return False - # Pattern: printf(variable...) vs printf("literal"...) + # Match the sink call with a first argument whose first non-whitespace + # character is anything other than a double-quote (string literal). Any + # non-literal first argument — identifier, function call, ``(`` for cast or + # ternary, ``*``/``&`` for pointer operations — is treated as variable. variable_fmt_pat = re.compile( - r"\b" + re.escape(sink_sym) + r"\s*\(\s*[a-zA-Z_]", + r"\b" + re.escape(sink_sym) + r'\s*\(\s*[^"\s\)]', ) return bool(variable_fmt_pat.search(decompiled_body)) @@ -250,7 +306,12 @@ def _build_taint_prompt( code_blocks = "" for fb in function_bodies: fname = fb.get("name", "unknown") - body = _truncate_text(fb.get("body", ""), max_chars=2000) + # Phase 2C+.1 (LATTE): when AIEDGE_LATTE_SLICING=1, replace the full + # body with a backward slice rooted at the sink call. Default-off so + # behaviour stays byte-identical when the env var is unset. + body_raw = fb.get("body", "") + body_sliced = maybe_slice(body_raw, sink_symbol) + body = _truncate_text(body_sliced, max_chars=2000) code_blocks += f"\n### {fname}\n```c\n{body}\n```\n" return ( diff --git a/tests/test_code_slicing.py b/tests/test_code_slicing.py new file mode 100644 index 0000000..3986d38 --- /dev/null +++ b/tests/test_code_slicing.py @@ -0,0 +1,238 @@ +"""Phase 2C+.1 — LATTE-inspired text-based backward slicing tests. + +Locks the public surface of ``aiedge.code_slicing`` and the env-gated +``maybe_slice`` entry point that ``taint_propagation`` calls. The tests +intentionally exercise behavioural invariants (slice is a subset of the +function body, line order is preserved, opt-out is byte-identical, ...) +rather than the exact set of lines kept, so future swaps to a +Ghidra-grade backend do not require rewriting the suite. +""" + +from __future__ import annotations + +import pytest + +from aiedge.code_slicing import ( + extract_backward_slice, + extract_slice_around_sink, + find_sink_line, + latte_slicing_enabled, + maybe_slice, + slice_compression_ratio, +) + +# --------------------------------------------------------------------------- +# Sample function bodies (Ghidra-decompile-flavoured) +# --------------------------------------------------------------------------- + + +_SIMPLE_BODY = """\ +void handle_request(char *user_input, int len) { + char buf[64]; + int rc; + char *prefix = "/cmd: "; + rc = check_auth(user_input); + if (rc != 0) { + return; + } + sprintf(buf, "%s%s", prefix, user_input); + log_info("about to exec %s", buf); + system(buf); +} +""" + + +_NO_SINK_BODY = """\ +void counter(int n) { + for (int i = 0; i < n; i++) { + printf("tick\\n"); + } +} +""" + + +# --------------------------------------------------------------------------- +# find_sink_line +# --------------------------------------------------------------------------- + + +def test_find_sink_line_returns_first_match() -> None: + idx = find_sink_line(_SIMPLE_BODY, "system") + assert idx is not None + line = _SIMPLE_BODY.splitlines()[idx] + assert "system(buf)" in line + + +def test_find_sink_line_respects_word_boundary() -> None: + """``open`` should not match ``fopen``.""" + body = " rc = fopen(path, mode);\n open(path, O_RDONLY);\n" + idx = find_sink_line(body, "open") + assert idx == 1 # the bare open() call, not fopen + + +def test_find_sink_line_returns_none_when_absent() -> None: + assert find_sink_line(_NO_SINK_BODY, "system") is None + assert find_sink_line("", "system") is None + assert find_sink_line(_SIMPLE_BODY, "") is None + + +# --------------------------------------------------------------------------- +# extract_backward_slice -- behaviour invariants +# --------------------------------------------------------------------------- + + +def test_slice_includes_sink_line() -> None: + sink_idx = find_sink_line(_SIMPLE_BODY, "system") + assert sink_idx is not None + sliced = extract_backward_slice(_SIMPLE_BODY, sink_idx) + assert "system(buf)" in sliced + + +def test_slice_preserves_source_order() -> None: + sink_idx = find_sink_line(_SIMPLE_BODY, "system") + assert sink_idx is not None + sliced = extract_backward_slice(_SIMPLE_BODY, sink_idx) + sliced_lines = sliced.splitlines() + body_lines = _SIMPLE_BODY.splitlines() + line_to_first_index: dict[str, int] = {} + for i, line in enumerate(body_lines): + line_to_first_index.setdefault(line, i) + indices = [line_to_first_index[line] for line in sliced_lines] + assert indices == sorted(indices) + + +def test_slice_is_subset_of_original_lines() -> None: + sink_idx = find_sink_line(_SIMPLE_BODY, "system") + assert sink_idx is not None + sliced_lines = set(extract_backward_slice(_SIMPLE_BODY, sink_idx).splitlines()) + body_lines = set(_SIMPLE_BODY.splitlines()) + assert sliced_lines <= body_lines + + +def test_slice_pulls_in_definition_of_sink_argument() -> None: + """The line that *defines* ``buf`` (the sink argument) must be kept.""" + sink_idx = find_sink_line(_SIMPLE_BODY, "system") + assert sink_idx is not None + sliced = extract_backward_slice(_SIMPLE_BODY, sink_idx) + assert "char buf[64];" in sliced + + +def test_slice_pulls_in_definition_chain_back_to_user_input() -> None: + """``buf`` is filled by ``sprintf`` from ``user_input`` and ``prefix``; + those defining lines must appear in the slice so the LLM can reason + about the taint chain.""" + sink_idx = find_sink_line(_SIMPLE_BODY, "system") + assert sink_idx is not None + sliced = extract_backward_slice(_SIMPLE_BODY, sink_idx) + assert "sprintf(buf" in sliced # the assignment + assert "user_input" in sliced # taint source visible + + +def test_slice_respects_max_lines_cap() -> None: + sink_idx = find_sink_line(_SIMPLE_BODY, "system") + assert sink_idx is not None + sliced = extract_backward_slice(_SIMPLE_BODY, sink_idx, max_lines=2) + assert len(sliced.splitlines()) <= 2 + + +def test_slice_returns_full_body_when_index_out_of_range() -> None: + body = "int main(void) { return 0; }\n" + assert extract_backward_slice(body, 999) == body + assert extract_backward_slice(body, -1) == body + + +def test_slice_returns_full_body_when_max_lines_nonpositive() -> None: + """``max_lines <= 0`` is treated as a no-op so callers cannot accidentally + blank the prompt.""" + sink_idx = find_sink_line(_SIMPLE_BODY, "system") + assert sink_idx is not None + assert extract_backward_slice(_SIMPLE_BODY, sink_idx, max_lines=0) == _SIMPLE_BODY + + +def test_slice_handles_empty_body() -> None: + assert extract_backward_slice("", 0) == "" + + +# --------------------------------------------------------------------------- +# extract_slice_around_sink convenience wrapper +# --------------------------------------------------------------------------- + + +def test_extract_slice_around_sink_returns_none_when_sink_absent() -> None: + assert extract_slice_around_sink(_NO_SINK_BODY, "system") is None + + +def test_extract_slice_around_sink_combines_locator_and_slicer() -> None: + sliced = extract_slice_around_sink(_SIMPLE_BODY, "system") + assert sliced is not None + assert "system(buf)" in sliced + assert "char buf[64];" in sliced + + +# --------------------------------------------------------------------------- +# maybe_slice + env gate +# --------------------------------------------------------------------------- + + +def test_latte_slicing_enabled_default_off(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("AIEDGE_LATTE_SLICING", raising=False) + assert latte_slicing_enabled() is False + + +@pytest.mark.parametrize("value", ["1", "true", "TRUE", "yes", "On"]) +def test_latte_slicing_enabled_truthy_values( + monkeypatch: pytest.MonkeyPatch, value: str +) -> None: + monkeypatch.setenv("AIEDGE_LATTE_SLICING", value) + assert latte_slicing_enabled() is True + + +@pytest.mark.parametrize("value", ["", "0", "false", "no", "off", "garbage"]) +def test_latte_slicing_enabled_falsy_values( + monkeypatch: pytest.MonkeyPatch, value: str +) -> None: + monkeypatch.setenv("AIEDGE_LATTE_SLICING", value) + assert latte_slicing_enabled() is False + + +def test_maybe_slice_is_byte_identical_when_disabled( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.delenv("AIEDGE_LATTE_SLICING", raising=False) + assert maybe_slice(_SIMPLE_BODY, "system") == _SIMPLE_BODY + + +def test_maybe_slice_compresses_when_enabled( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setenv("AIEDGE_LATTE_SLICING", "1") + sliced = maybe_slice(_SIMPLE_BODY, "system") + assert sliced != _SIMPLE_BODY + assert len(sliced.splitlines()) < len(_SIMPLE_BODY.splitlines()) + + +def test_maybe_slice_falls_back_to_full_body_when_sink_absent( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setenv("AIEDGE_LATTE_SLICING", "1") + assert maybe_slice(_NO_SINK_BODY, "system") == _NO_SINK_BODY + + +# --------------------------------------------------------------------------- +# slice_compression_ratio +# --------------------------------------------------------------------------- + + +def test_compression_ratio_full_body_is_one() -> None: + assert slice_compression_ratio(_SIMPLE_BODY, _SIMPLE_BODY) == 1.0 + + +def test_compression_ratio_empty_original_is_one() -> None: + assert slice_compression_ratio("", "anything") == 1.0 + + +def test_compression_ratio_below_one_when_sliced() -> None: + sink_idx = find_sink_line(_SIMPLE_BODY, "system") + assert sink_idx is not None + sliced = extract_backward_slice(_SIMPLE_BODY, sink_idx, max_lines=3) + assert slice_compression_ratio(_SIMPLE_BODY, sliced) < 1.0 diff --git a/tests/test_finding_diversity_gate.py b/tests/test_finding_diversity_gate.py new file mode 100644 index 0000000..4decc70 --- /dev/null +++ b/tests/test_finding_diversity_gate.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +from pathlib import Path +from typing import cast + +import pytest + +from aiedge.quality_policy import ( + QUALITY_GATE_DIVERSITY_MISS, + QUALITY_GATE_INVALID_PAIR_EVAL, + QUALITY_GATE_SCHEMA_VERSION, + QualityGateError, + compute_pair_eval_diversity_index, + evaluate_pair_eval_diversity_gate, + load_pair_eval_finding_ids, +) + + +def _measured(result: dict[str, object]) -> dict[str, object]: + return cast(dict[str, object], result["measured"]) + + +def _errors(result: dict[str, object]) -> list[dict[str, object]]: + return cast(list[dict[str, object]], result["errors"]) + + +def _policy(result: dict[str, object]) -> dict[str, object]: + return cast(dict[str, object], result["policy"]) + + +def test_diversity_index_empty_returns_zero() -> None: + assert compute_pair_eval_diversity_index([]) == 0.0 + + +def test_diversity_index_single_finding_is_degenerate() -> None: + finding_ids = ["aiedge.findings.web.exec_sink_overlap"] * 14 + assert compute_pair_eval_diversity_index(finding_ids) == 1.0 + + +def test_diversity_index_all_distinct_is_inverse_n() -> None: + finding_ids = [f"finding_{i}" for i in range(8)] + # Each appears exactly once, so max share = 1/8 + assert compute_pair_eval_diversity_index(finding_ids) == 0.125 + + +def test_diversity_index_partial_share_is_max_count_over_total() -> None: + # 3 of 'a', 1 of 'b', 1 of 'c' → max share = 3/5 = 0.6 + finding_ids = ["a", "a", "a", "b", "c"] + assert compute_pair_eval_diversity_index(finding_ids) == 0.6 + + +def test_evaluate_diversity_gate_passes_when_diverse() -> None: + finding_ids = ["a", "b", "c", "d", "e"] # max share = 0.2 < 0.5 + result = evaluate_pair_eval_diversity_gate( + finding_ids=finding_ids, + findings_source="test://diverse.csv", + ) + assert result["passed"] is True + assert result["verdict"] == "pass" + assert _errors(result) == [] + measured = _measured(result) + assert measured["finding_diversity_index"] == 0.2 + assert measured["sample_size"] == 5 + assert result["schema_version"] == QUALITY_GATE_SCHEMA_VERSION + + +def test_evaluate_diversity_gate_fails_when_degenerate() -> None: + finding_ids = ["aiedge.findings.web.exec_sink_overlap"] * 14 + result = evaluate_pair_eval_diversity_gate( + finding_ids=finding_ids, + findings_source="test://degenerate.csv", + ) + assert result["passed"] is False + assert result["verdict"] == "fail" + errors = _errors(result) + assert len(errors) == 1 + err = errors[0] + assert err["error_token"] == QUALITY_GATE_DIVERSITY_MISS + assert err["actual"] == 1.0 + assert err["threshold"] == 0.5 + assert err["sample_size"] == 14 + assert "degenerate" in cast(str, err["message"]) + + +def test_evaluate_diversity_gate_threshold_env_override( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setenv("AIEDGE_PAIR_DIVERSITY_MAX", "0.7") + # diversity index 0.6 should now PASS under 0.7 threshold + finding_ids = ["a", "a", "a", "b", "c"] + result = evaluate_pair_eval_diversity_gate( + finding_ids=finding_ids, + findings_source="test://env.csv", + ) + assert result["passed"] is True + measured = _measured(result) + assert measured["finding_diversity_index"] == 0.6 + policy = _policy(result) + assert policy["finding_diversity_max"] == 0.7 + + +def test_evaluate_diversity_gate_empty_sample_passes_with_zero_index() -> None: + result = evaluate_pair_eval_diversity_gate( + finding_ids=[], + findings_source="test://empty.csv", + ) + assert result["passed"] is True + measured = _measured(result) + assert measured["finding_diversity_index"] == 0.0 + assert measured["sample_size"] == 0 + + +def test_load_pair_eval_finding_ids_filters_blank(tmp_path: Path) -> None: + csv_path = tmp_path / "findings.csv" + csv_path.write_text( + "pair_id,side,finding_id,ground_truth\n" + "p1,vulnerable,aiedge.findings.x,tp\n" + "p1,patched,,tn\n" # empty finding_id should be skipped + "p2,vulnerable,aiedge.findings.y,fn\n" + "p2,patched,aiedge.findings.x,fp\n", + encoding="utf-8", + ) + finding_ids = load_pair_eval_finding_ids(csv_path) + assert finding_ids == [ + "aiedge.findings.x", + "aiedge.findings.y", + "aiedge.findings.x", + ] + + +def test_load_pair_eval_finding_ids_filters_by_ground_truth(tmp_path: Path) -> None: + csv_path = tmp_path / "findings.csv" + csv_path.write_text( + "pair_id,side,finding_id,ground_truth\n" + "p1,vulnerable,aiedge.findings.x,tp\n" + "p1,patched,aiedge.findings.x,fp\n" + "p2,vulnerable,aiedge.findings.y,fn\n" + "p2,patched,aiedge.findings.z,tn\n", + encoding="utf-8", + ) + finding_ids = load_pair_eval_finding_ids( + csv_path, only_ground_truth=frozenset({"tp", "fp"}) + ) + assert finding_ids == ["aiedge.findings.x", "aiedge.findings.x"] + + +def test_load_pair_eval_finding_ids_missing_file_raises(tmp_path: Path) -> None: + missing = tmp_path / "does_not_exist.csv" + with pytest.raises(QualityGateError) as exc_info: + load_pair_eval_finding_ids(missing) + assert exc_info.value.token == QUALITY_GATE_INVALID_PAIR_EVAL + assert "not found" in str(exc_info.value) + + +def test_local_7_baseline_is_degenerate() -> None: + """Sanity check: the 2026-04-19 local-7 baseline maps every pair-side row to + the same finding, so the diversity gate must classify it as fail.""" + # 14 rows = 7 pairs × 2 sides, all same finding (matches recall_0.142857 lane) + finding_ids = ["aiedge.findings.web.exec_sink_overlap"] * 14 + result = evaluate_pair_eval_diversity_gate( + finding_ids=finding_ids, + findings_source="benchmark-results/pair-eval/pair_eval_findings.csv", + ) + assert result["passed"] is False + measured = _measured(result) + assert measured["finding_diversity_index"] == 1.0 + assert measured["sample_size"] == 14 diff --git a/tests/test_taint_propagation.py b/tests/test_taint_propagation.py new file mode 100644 index 0000000..7e3096d --- /dev/null +++ b/tests/test_taint_propagation.py @@ -0,0 +1,191 @@ +"""Phase 2C+.3 — sink coverage expansion + format-string variable detection. + +These tests pin the post-2026-04-19 sink catalogue (≥50 dangerous symbols across +CWE-78 / 120 / 134 / 22 / 426 / 732 / 377 / 250 / 454) and the strengthened +format-string variable detector. They do **not** exercise the rest of +``taint_propagation`` — that module's LLM-driven flow has separate coverage +through the integration suite. The goal here is to lock the catalogue and +prevent silent regressions when new CWE families are added. +""" + +from __future__ import annotations + +from aiedge.taint_propagation import ( + _FORMAT_STRING_SINKS, + _SINK_SYMBOLS, + _is_format_string_variable, +) + +# --------------------------------------------------------------------------- +# Sink catalogue size and CWE coverage +# --------------------------------------------------------------------------- + + +def test_sink_symbols_total_count_covers_phase_2c_plus_target() -> None: + """Phase 2C+.3 raises the floor from 28/29 to >= 50 distinct symbols.""" + assert len(_SINK_SYMBOLS) >= 50 + + +def test_sink_symbols_includes_cwe78_command_injection_extras() -> None: + """Beyond the legacy execve family, the new catalogue covers wordexp / + posix_spawn-style entry points commonly seen in modern CGI handlers.""" + new_cwe78 = {"wordexp", "posix_spawn", "posix_spawnp"} + assert new_cwe78 <= _SINK_SYMBOLS + + +def test_sink_symbols_includes_cwe22_path_traversal() -> None: + new_cwe22 = {"fopen", "open", "openat", "freopen", "chdir"} + assert new_cwe22 <= _SINK_SYMBOLS + + +def test_sink_symbols_includes_cwe426_dynamic_loading() -> None: + """dlopen was already present; dlsym / dlmopen close the search-path gap.""" + assert {"dlopen", "dlsym", "dlmopen"} <= _SINK_SYMBOLS + + +def test_sink_symbols_includes_cwe732_permission_calls() -> None: + assert {"chmod", "fchmod", "chown", "fchown", "lchown"} <= _SINK_SYMBOLS + + +def test_sink_symbols_includes_cwe377_insecure_tmp_files() -> None: + assert {"mktemp", "tmpnam", "tempnam", "tmpfile"} <= _SINK_SYMBOLS + + +def test_sink_symbols_includes_privilege_drop_calls() -> None: + """CWE-250 / CWE-269 — privilege management primitives shipped without + dropping or re-elevating privileges correctly are a recurring router-CGI + bug class (e.g. setuid(0) without prior chroot).""" + assert {"chroot", "setuid", "seteuid", "setgid", "setegid"} <= _SINK_SYMBOLS + + +def test_sink_symbols_includes_environment_injection() -> None: + """CWE-454 — putenv/setenv variants accept attacker-controlled strings.""" + assert {"putenv", "setenv", "unsetenv"} <= _SINK_SYMBOLS + + +def test_sink_symbols_preserves_legacy_entries() -> None: + """Regression guard: every pre-Phase 2C+.3 symbol stays in the set so + existing rules and downstream consumers are not silently weakened.""" + legacy = { + "system", + "popen", + "execve", + "execvp", + "execvpe", + "execl", + "execlp", + "execle", + "execv", + "strcpy", + "sprintf", + "strcat", + "strncpy", + "strncat", + "gets", + "vsprintf", + "memcpy", + "memmove", + "printf", + "fprintf", + "syslog", + "vprintf", + "vfprintf", + "snprintf", + "scanf", + "sscanf", + "fscanf", + "dlopen", + "realpath", + } + assert legacy <= _SINK_SYMBOLS + + +# --------------------------------------------------------------------------- +# Format-string sinks +# --------------------------------------------------------------------------- + + +def test_format_string_sinks_count_doubles() -> None: + """Phase 2C+.3 brings the format-string sink count from 6 to >=12.""" + assert len(_FORMAT_STRING_SINKS) >= 12 + + +def test_format_string_sinks_cover_size_bounded_and_wide_variants() -> None: + """Add the size-bounded (vsnprintf), file-descriptor (dprintf/vdprintf), + and wide-char (swprintf, wprintf, fwprintf, ...) variants explicitly.""" + additions = { + "vsnprintf", + "dprintf", + "vdprintf", + "swprintf", + "vswprintf", + "wprintf", + "vwprintf", + "fwprintf", + "vfwprintf", + } + assert additions <= _FORMAT_STRING_SINKS + + +# --------------------------------------------------------------------------- +# Strengthened _is_format_string_variable detector +# --------------------------------------------------------------------------- + + +def test_format_var_skips_string_literal_first_arg() -> None: + assert not _is_format_string_variable("printf", 'printf("hello")') + assert not _is_format_string_variable("printf", 'printf("hello %s", name)') + # Whitespace before the literal is fine + assert not _is_format_string_variable("printf", 'printf( "ok" )') + + +def test_format_var_detects_bare_identifier_first_arg() -> None: + """The detector flags any sink call whose first argument is not a string + literal — even when the first arg is not the format-string position + (e.g. syslog priority constant, fprintf stream). This intentional + broadening lets downstream analysis discriminate further; the goal here + is just to make sure no candidate is silently dropped.""" + assert _is_format_string_variable("printf", "printf(buf)") + assert _is_format_string_variable("syslog", "syslog(LOG_INFO, message)") + assert _is_format_string_variable("syslog", "syslog(user_buf)") + + +def test_format_var_detects_function_call_first_arg() -> None: + body = "fprintf(stderr, get_template(name))" + # fprintf's first arg is the FILE*, not the format. The detector doesn't + # know about argument positions; it flags any non-literal first arg. This + # is intentional — it catches the broad pattern and lets later analysis + # discriminate. + assert _is_format_string_variable("fprintf", body) + + +def test_format_var_detects_struct_field_access() -> None: + assert _is_format_string_variable("printf", "printf(obj->field)") + assert _is_format_string_variable("printf", "printf(record.fmt)") + + +def test_format_var_detects_array_subscript() -> None: + assert _is_format_string_variable("printf", "printf(messages[i])") + + +def test_format_var_detects_c_style_cast() -> None: + assert _is_format_string_variable("printf", "printf((char *) buf)") + + +def test_format_var_detects_parenthesised_ternary() -> None: + body = "printf((cond ? warn : info))" + assert _is_format_string_variable("printf", body) + + +def test_format_var_detects_pointer_dereference_first_arg() -> None: + assert _is_format_string_variable("printf", "printf(*p_fmt)") + assert _is_format_string_variable("printf", "printf(&buffer[0])") + + +def test_format_var_returns_false_for_non_format_sinks() -> None: + """Sinks not in _FORMAT_STRING_SINKS (e.g. system, memcpy) are out of scope + for this detector — even if called with a variable arg they don't represent + a format-string vulnerability.""" + assert not _is_format_string_variable("system", "system(buf)") + assert not _is_format_string_variable("memcpy", "memcpy(dest, src, n)") + assert not _is_format_string_variable("strcpy", "strcpy(dest, src)") diff --git a/tests/test_uri_source_extraction.py b/tests/test_uri_source_extraction.py new file mode 100644 index 0000000..f013de1 --- /dev/null +++ b/tests/test_uri_source_extraction.py @@ -0,0 +1,133 @@ +"""Phase 2C+.2 — LARA-style URI / CGI / config-key source identification. + +Locks the new pattern catalogues (URI prefixes, CGI environment variables, +NVRAM/sysconf config keys) and the ``_extract_uri_key_sources`` helper that +EnhancedSourceStage now consults per-binary. The helper produces +``(pattern, kind)`` tuples; the stage wraps each tuple into a source dict +with ``confidence=0.40`` and ``method="lara_pattern"``. +""" + +from __future__ import annotations + +from aiedge.enhanced_source import ( + _CGI_VAR_PATTERNS, + _CONFIG_KEY_PATTERNS, + _URI_SOURCE_PATTERNS, + _extract_uri_key_sources, +) + +# --------------------------------------------------------------------------- +# Pattern catalogue size +# --------------------------------------------------------------------------- + + +def test_pattern_catalogue_total_meets_phase_2c_plus_target() -> None: + """Phase 2C+.2 ships ≥30 patterns combined across the three categories.""" + total = ( + len(_URI_SOURCE_PATTERNS) + len(_CGI_VAR_PATTERNS) + len(_CONFIG_KEY_PATTERNS) + ) + assert total >= 30 + + +def test_uri_patterns_cover_cgi_and_rest_and_upnp() -> None: + must_have = {"/cgi-bin/", "/api/", "/upnp/", "/admin/", "/goform/"} + assert must_have <= _URI_SOURCE_PATTERNS + + +def test_cgi_var_patterns_cover_rfc3875_essentials() -> None: + must_have = {"QUERY_STRING", "REQUEST_METHOD", "HTTP_USER_AGENT", "HTTP_COOKIE"} + assert must_have <= _CGI_VAR_PATTERNS + + +def test_config_key_patterns_cover_router_credentials_and_cloud_tokens() -> None: + must_have = {"http_passwd", "wpa_psk", "cloud_token", "firmware_url"} + assert must_have <= _CONFIG_KEY_PATTERNS + + +# --------------------------------------------------------------------------- +# _extract_uri_key_sources behaviour +# --------------------------------------------------------------------------- + + +def test_extract_returns_empty_for_empty_symbols() -> None: + assert _extract_uri_key_sources("/usr/sbin/httpd", set()) == [] + + +def test_extract_matches_uri_in_bin_path() -> None: + matches = _extract_uri_key_sources("/www/cgi-bin/apply.cgi", {"strcpy", "system"}) + kinds = {kind for _, kind in matches} + assert "uri_endpoint" in kinds + # Both /cgi-bin/ and /apply.cgi should match + patterns = {pat for pat, kind in matches if kind == "uri_endpoint"} + assert "/cgi-bin/" in patterns + assert "/apply.cgi" in patterns + + +def test_extract_matches_uri_in_ascii_strings() -> None: + """Extracted ASCII string literals (e.g. via SBOM `_extract_ascii_runs`) + routinely contain URL prefixes hard-coded as `.rodata` strings. The + helper accepts them via the optional ``ascii_strings`` parameter.""" + matches = _extract_uri_key_sources( + "/usr/sbin/uhttpd", + {"system"}, + ascii_strings={"GET /cgi-bin/admin?token=", "/upgrade.cgi"}, + ) + patterns = {pat for pat, kind in matches if kind == "uri_endpoint"} + assert "/cgi-bin/" in patterns + assert "/upgrade.cgi" in patterns + + +def test_extract_does_not_match_uri_substring_in_symbol_name() -> None: + """Symbols are intentionally NOT searched for URI substrings (slashes are + not valid identifier characters, so any substring overlap would be + noise). This test pins that policy.""" + matches = _extract_uri_key_sources( + "/usr/sbin/uhttpd", {"system", "handle_cgi_bin_request"} + ) + assert all(kind != "uri_endpoint" for _, kind in matches) + + +def test_extract_matches_cgi_variable_exact_case_insensitive() -> None: + matches = _extract_uri_key_sources( + "/usr/sbin/httpd", + {"strcpy", "query_string", "REQUEST_METHOD"}, + ) + kinds_by_pattern = {pat: kind for pat, kind in matches} + assert kinds_by_pattern.get("QUERY_STRING") == "cgi_variable" + assert kinds_by_pattern.get("REQUEST_METHOD") == "cgi_variable" + + +def test_extract_matches_config_key_in_symbols() -> None: + matches = _extract_uri_key_sources( + "/usr/sbin/httpd", {"nvram_get", "get_http_passwd_value"} + ) + cfg_matches = [pat for pat, kind in matches if kind == "config_key"] + assert "http_passwd" in cfg_matches + + +def test_extract_matches_config_key_in_bin_path() -> None: + matches = _extract_uri_key_sources("/etc/config/wifi_psk_loader", {"strcpy"}) + cfg_matches = [pat for pat, kind in matches if kind == "config_key"] + assert "wifi_psk" in cfg_matches + + +def test_extract_returns_multiple_kinds_in_one_call() -> None: + matches = _extract_uri_key_sources( + "/www/cgi-bin/auth.cgi", + {"QUERY_STRING", "get_admin_passwd"}, + ) + kinds = {kind for _, kind in matches} + assert {"uri_endpoint", "cgi_variable", "config_key"} <= kinds + + +def test_extract_does_not_double_count_same_pattern() -> None: + """If a URI pattern matches both bin_path and a symbol, the helper should + not emit two duplicate tuples for the same pattern.""" + matches = _extract_uri_key_sources( + "/www/cgi-bin/handler", + {"cgi_bin_dispatch"}, + ) + cgi_bin_hits = [ + pat for pat, kind in matches if kind == "uri_endpoint" and pat == "/cgi-bin/" + ] + assert len(cgi_bin_hits) == 1