Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 66 additions & 7 deletions api/models/finding.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,9 @@ def create_tables(self) -> None:
completed_at TIMESTAMPTZ,
total_findings INTEGER DEFAULT 0,
score INTEGER DEFAULT NULL,
cve_enrichment_status TEXT DEFAULT 'PENDING'
cve_enrichment_status TEXT DEFAULT 'PENDING',
status TEXT DEFAULT 'pending',
error_message TEXT
);
""")
cur.execute("""
Expand Down Expand Up @@ -206,7 +208,9 @@ def run_migrations(self) -> None:
""")
cur.execute("""
ALTER TABLE scans
ADD COLUMN IF NOT EXISTS cve_enrichment_status TEXT DEFAULT 'PENDING'
ADD COLUMN IF NOT EXISTS cve_enrichment_status TEXT DEFAULT 'PENDING',
ADD COLUMN IF NOT EXISTS status TEXT DEFAULT 'pending',
ADD COLUMN IF NOT EXISTS error_message TEXT
""")
conn.commit()
logger.info("CVE migrations applied successfully")
Expand All @@ -224,18 +228,25 @@ def save_scan(self, scan_result: Dict[str, Any]) -> None:
with conn.cursor() as cur:
cur.execute(
"""
INSERT INTO scans (scan_id, subscription_id, started_at, completed_at, total_findings, score, cve_enrichment_status)
VALUES (%s, %s, %s, %s, %s, %s, %s)
ON CONFLICT (scan_id) DO NOTHING
INSERT INTO scans (scan_id, subscription_id, started_at, completed_at, total_findings, score, cve_enrichment_status, status, error_message)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
ON CONFLICT (scan_id) DO UPDATE SET
completed_at = EXCLUDED.completed_at,
total_findings = EXCLUDED.total_findings,
score = EXCLUDED.score,
status = EXCLUDED.status,
error_message = EXCLUDED.error_message
""",
(
scan_result["scan_id"],
scan_result["subscription_id"],
scan_result["started_at"],
scan_result["completed_at"],
scan_result["total_findings"],
scan_result.get("completed_at"),
scan_result.get("total_findings", 0),
scan_result.get("score"),
scan_result.get("cve_enrichment_status", "PENDING"),
scan_result.get("status", "completed"),
scan_result.get("error_message"),
),
)
for f in scan_result.get("findings", []):
Expand Down Expand Up @@ -362,6 +373,54 @@ def update_scan_enrichment_status(self, scan_id: str, status: str) -> None:
conn.commit()
logger.info("Updated scan %s enrichment status to %s", scan_id, status)

def create_pending_scan(self, scan_id: str, subscription_id: str) -> None:
"""Create a scan record in the 'pending' state."""
conn = self._get_conn()
from datetime import datetime, timezone
started_at = datetime.now(timezone.utc).isoformat()
with conn.cursor() as cur:
cur.execute(
"""
INSERT INTO scans (scan_id, subscription_id, started_at, status)
VALUES (%s, %s, %s, 'pending')
""",
(scan_id, subscription_id, started_at),
)
conn.commit()
logger.info("Created pending scan %s for %s", scan_id, subscription_id)

def update_scan_status(self, scan_id: str, status: str, error_message: Optional[str] = None) -> None:
"""Update the status of a scan (running, completed, failed)."""
conn = self._get_conn()
with conn.cursor() as cur:
if status == "completed":
cur.execute(
"UPDATE scans SET status = %s, completed_at = CURRENT_TIMESTAMP WHERE scan_id = %s",
(status, scan_id),
)
else:
cur.execute(
"UPDATE scans SET status = %s, error_message = %s WHERE scan_id = %s",
(status, error_message, scan_id),
)
conn.commit()
logger.info("Updated scan %s status to %s", scan_id, status)

def get_pending_scans(self) -> List[Dict[str, Any]]:
"""Return all scans in the 'pending' state."""
conn = self._get_conn()
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute("SELECT * FROM scans WHERE status = 'pending' ORDER BY started_at ASC")
return [dict(row) for row in cur.fetchall()]

def get_scan(self, scan_id: str) -> Optional[Dict[str, Any]]:
"""Return a single scan record by its UUID."""
conn = self._get_conn()
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute("SELECT * FROM scans WHERE scan_id = %s", (scan_id,))
row = cur.fetchone()
return dict(row) if row else None

def get_scans(self) -> List[Dict[str, Any]]:
"""Return all scan records ordered by most recent first."""
conn = self._get_conn()
Expand Down
50 changes: 27 additions & 23 deletions api/routes/scans.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
import os
import uuid
from flask import Blueprint, g, jsonify, request

from api.models.finding import DatabaseManager
Expand Down Expand Up @@ -33,21 +34,29 @@ def list_scans():
return jsonify({"error": "Failed to retrieve scans", "detail": str(exc)}), 500


@scans_bp.get("/api/scans/<scan_id>")
def get_scan_status(scan_id):
"""Return the details and status of a specific scan."""
try:
db = _get_db()
scan = db.get_scan(scan_id)
if not scan:
return jsonify({"error": "Scan not found"}), 404
return jsonify(scan)
except Exception as exc:
logger.error("Failed to get scan status: %s", exc)
return jsonify({"error": "Database error", "detail": str(exc)}), 500


@scans_bp.post("/api/scans/trigger")
def trigger_scan():
"""Trigger a synchronous scan against the configured subscription.
"""Trigger an asynchronous scan against the configured subscription.

Accepts an optional JSON body with ``subscription_id``. Falls back to the
``AZURE_SUBSCRIPTION_ID`` environment variable if not provided.

Note: For production use, replace this with an async task queue (e.g.
Celery or Azure Functions) to avoid request timeouts on large subscriptions.
Returns 202 Accepted with the scan_id immediately.
"""
try:
from scanner.engine import ScanEngine
except ImportError:
return jsonify({"error": "Scanner module is not available"}), 500

try:
body = request.get_json(silent=True) or {}
subscription_id = body.get("subscription_id") or os.environ.get(
Expand All @@ -57,26 +66,21 @@ def trigger_scan():
if not subscription_id:
return jsonify({"error": "subscription_id is required"}), 400

logger.info("Scan triggered for subscription %s", subscription_id)

try:
engine = ScanEngine(subscription_id)
result = engine.run_scan()
except Exception as exc:
logger.error("Scan engine execution failed: %s", exc, exc_info=True)
return jsonify({"error": "Scan failed", "detail": str(exc)}), 500

if not isinstance(result, dict) or "scan_id" not in result:
return jsonify({"error": "Invalid scan result returned"}), 500
scan_id = str(uuid.uuid4())
logger.info("Async scan triggered for subscription %s (id: %s)", subscription_id, scan_id)

try:
db = _get_db()
db.save_scan(result)
db.create_pending_scan(scan_id, subscription_id)
except Exception as exc:
logger.error("Failed to save scan result: %s", exc, exc_info=True)
return jsonify({"error": "Database save failed", "detail": str(exc)}), 500
logger.error("Failed to create pending scan: %s", exc, exc_info=True)
return jsonify({"error": "Database error", "detail": str(exc)}), 500

return jsonify(result), 201
return jsonify({
"scan_id": scan_id,
"status": "pending",
"message": "Scan has been queued and will start shortly."
}), 202

except Exception as exc:
logger.error("Critical error in trigger_scan route: %s", exc, exc_info=True)
Expand Down
54 changes: 26 additions & 28 deletions docs/api-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,32 @@ Example response:

---

## GET /api/scans/&lt;scan_id&gt;

Returns the details and current status of a specific scan.

Path parameters: `scan_id` — UUID of the scan.

Example response:

```json
{
"scan_id": "6f4a08ac-7d3a-4d9a-a4b4-2a26e5f63c8a",
"subscription_id": "00000000-0000-0000-0000-000000000000",
"status": "completed",
"started_at": "2026-05-09T12:00:00Z",
"completed_at": "2026-05-09T12:02:00Z",
"total_findings": 3,
"score": 85,
"error_message": null
}
```

---

## POST /api/scans/trigger

Runs a synchronous scan and saves the result to PostgreSQL. The request body may include `subscription_id`; otherwise the API uses `AZURE_SUBSCRIPTION_ID`.
Triggers an asynchronous scan against the configured subscription. Returns `202 Accepted` with the `scan_id` immediately. The actual scan execution happens in a background worker process.

Request body:

Expand All @@ -150,32 +173,8 @@ Example response:
```json
{
"scan_id": "6f4a08ac-7d3a-4d9a-a4b4-2a26e5f63c8a",
"subscription_id": "00000000-0000-0000-0000-000000000000",
"started_at": "2026-05-09T12:00:00+00:00",
"completed_at": "2026-05-09T12:02:00+00:00",
"total_findings": 1,
"findings": [
{
"rule_id": "AZ-STOR-001",
"rule_name": "Public Blob Access Enabled on Storage Account",
"severity": "HIGH",
"category": "Storage",
"resource_id": "/subscriptions/example/resourceGroups/rg/providers/Microsoft.Storage/storageAccounts/example",
"resource_name": "example",
"resource_type": "Microsoft.Storage/storageAccounts",
"description": "Storage accounts with public blob access enabled allow unauthenticated read access to blob data over the internet.",
"remediation": "Disable public blob access on the storage account.",
"playbook": "playbooks/cli/fix_az_stor_001.sh",
"frameworks": {
"CIS": "3.5",
"NIST": "PR.AC-3",
"ISO27001": "A.9.4.1"
},
"metadata": {},
"detected_at": "2026-05-09T12:00:00+00:00",
"scan_id": "6f4a08ac-7d3a-4d9a-a4b4-2a26e5f63c8a"
}
]
"status": "pending",
"message": "Scan has been queued and will start shortly."
}
```

Expand Down Expand Up @@ -437,4 +436,3 @@ The following endpoints are called by the frontend but have no backend implement
| Endpoint | Used by | Status |
|---|---|---|
| `GET /api/monitoring` | Monitoring page — score trend chart, category distribution | Deferred. Score and findings data come from `GET /api/score` and `GET /api/findings` instead. |
| `GET /api/scans/<scan_id>` | Header scan poller | Deferred. The frontend falls back to `GET /api/scans` and matches by `scan_id` in the response list. The poller is rarely entered because `POST /api/scans/trigger` now returns `status: completed` immediately. |
47 changes: 47 additions & 0 deletions docs/async-scan-architecture.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Asynchronous Scan Architecture

## Overview

OpenShield uses an asynchronous execution model for Azure posture scans. This architecture ensures the system can handle large subscriptions with thousands of resources without hitting web server timeouts or degrading frontend performance.

## The Problem: Synchronous Bottlenecks

In the legacy synchronous model, POST /api/scans/trigger would block the HTTP request until the scan completed. For large environments, this led to several critical issues. First, Gunicorn or load balancer timeouts would kill the scan mid execution. Second, web workers were tied up for minutes, preventing other users from accessing the dashboard. Third, the UI would hang or show generic Network Error messages while waiting for the response.

## The Solution: DB Backed Background Worker

OpenShield now employs a decoupled, database backed worker architecture. This is the industry standard for long running security tasks where reliability and state persistence are critical.

### 1. The API (Flask)
When a scan is triggered, the API performs minimal work. It validates the subscription_id, creates a record in the scans table with status set to pending, and returns 202 Accepted and the scan_id immediately.

### 2. The Queue (PostgreSQL)
The scans table acts as a persistent task queue. This avoids the need for additional infrastructure like Redis or RabbitMQ while providing ACID compliance, visibility, and auditability. Scan states are never lost during crashes, status polling is a simple SQL query, and every scan has a persistent record of its error state.

### 3. The Worker (Python)
The scanner/worker.py process runs independently of the web server. Its lifecycle involves several steps. It queries the DB for scans where status is pending. It updates the status to running to prevent other workers from picking it up. It invokes ScanEngine.run_scan(scan_id). On success, it saves findings and sets status to completed. On failure, it captures the traceback and sets status to failed with the error_message.

## Technical Rationale

### Why not Celery or Redis
While Celery is powerful, it introduces external dependencies and operational complexity. CSPM scans are macro tasks taking minutes rather than milliseconds. A database backed model is more resilient for these workloads because the state is persisted at the source of truth in PostgreSQL.

### Why not Threading
Python background threads are ephemeral. If the web server process restarts, all in flight scans are killed instantly and marked as running forever in the DB. A separate worker process ensures that the scan lifecycle is independent of the web server lifecycle.

## Testing Suite

The asynchronous transition is verified through a multi layered testing strategy.

### 1. Unit Tests
Located in tests/test_cve_correlator.py, tests/test_nvd_client.py, and tests/test_worker.py. These tests verify the core logic in isolation by mocking all network calls to Azure and NVD.

### 2. Smoke Tests
Located in tests/smoke_test.py. These tests verify the full integration. TC 13 verifies POST /api/scans/trigger returns 202 Accepted. TC 14 verifies the response contains a valid scan_id. TC 40 verifies that GET /api/scans/scan_id returns a valid status object, enabling frontend polling.

### 3. CI Validation
The ci checks job in .github/workflows/ci.yml ensures that worker syntax is valid, new database methods maintain schema integrity, and cross references between compliance mappings and rule files remain intact.

## Integrating with the Frontend

The frontend should follow this pattern for a smooth user experience. Call POST /api/scans/trigger. Extract the scan_id. Show a Scan Queued notification. Poll GET /api/scans/scan_id every 5 to 10 seconds until status is completed or failed. Refresh the dashboard once the status is completed.
11 changes: 7 additions & 4 deletions scanner/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

from scanner.azure_client import AzureClient

Expand Down Expand Up @@ -89,14 +89,17 @@ def load_rules(self) -> None:
# Scan execution #
# ------------------------------------------------------------------ #

def run_scan(self) -> Dict[str, Any]:
def run_scan(self, scan_id: Optional[str] = None) -> Dict[str, Any]:
"""Execute all loaded rules and return a normalised scan result.

Args:
scan_id: Optional existing UUID. If not provided, a new one is generated.

Returns:
dict with keys: scan_id, subscription_id, started_at,
completed_at, total_findings, findings.
"""
scan_id = str(uuid.uuid4())
scan_id = scan_id or str(uuid.uuid4())
started_at = datetime.now(timezone.utc).isoformat()
findings: List[Dict[str, Any]] = []
detected_at = datetime.now(timezone.utc).isoformat()
Expand Down Expand Up @@ -149,4 +152,4 @@ def run_scan(self) -> Dict[str, Any]:
"Scan %s complete — %d total finding(s). Normalising results...", scan_id, len(findings)
)

return make_serializable(result)
return result
Loading
Loading