From feb4f161b9a95d44940ea6e79ece532da23b579f Mon Sep 17 00:00:00 2001 From: Muhammad-usman92 Date: Sat, 13 Jun 2026 13:04:22 +0500 Subject: [PATCH 1/4] classifier error verdict foundation --- .gitignore | 1 + backend/cmd/adrian/main.go | 4 +- backend/internal/api/handlers_events.go | 2 + backend/internal/api/handlers_policy.go | 51 +-- backend/internal/api/handlers_reviews.go | 37 ++- backend/internal/api/handlers_test.go | 54 ++- backend/internal/api/handlers_verdicts.go | 27 ++ backend/internal/db/db.go | 9 +- backend/internal/db/migrate.go | 74 ++++- backend/internal/db/migrate_test.go | 135 ++++++++ backend/internal/engine/client.go | 43 +-- backend/internal/engine/client_test.go | 154 ++++++--- backend/internal/engine/engine.go | 7 +- backend/internal/engine/parse.go | 11 +- backend/internal/proto/event.pb.go | 204 ++++++++---- backend/internal/store/events.go | 6 +- backend/internal/store/hitl.go | 38 ++- backend/internal/store/policies.go | 34 +- backend/internal/store/verdicts.go | 25 +- backend/internal/ws/frames.go | 11 +- backend/internal/ws/handler.go | 46 ++- backend/internal/ws/handler_test.go | 97 ++++++ .../migrations/002_verdict_status_policy.sql | 72 ++++ backend/migrations/embed.go | 8 +- backend/proto/event.proto | 23 +- proto/event.proto | 23 +- scripts/setup.py | 59 +++- sdk/python/adrian/proto/event_pb2.py | 314 +++++------------- sdk/python/adrian/proto/event_pb2.pyi | 88 ++++- 29 files changed, 1150 insertions(+), 507 deletions(-) create mode 100644 backend/internal/db/migrate_test.go create mode 100644 backend/migrations/002_verdict_status_policy.sql diff --git a/.gitignore b/.gitignore index 212152c..555d8e4 100644 --- a/.gitignore +++ b/.gitignore @@ -18,6 +18,7 @@ models/* !models/.gitkeep # Go +.tools/ *.exe *.exe~ *.dll diff --git a/backend/cmd/adrian/main.go b/backend/cmd/adrian/main.go index 1d0dd0a..3931d88 100644 --- a/backend/cmd/adrian/main.go +++ b/backend/cmd/adrian/main.go @@ -3,8 +3,8 @@ // Adrian backend entrypoint. // -// Loads config, opens the SQLite database (running idempotent -// migrations), constructs the API server with the LLM-backed +// Loads config, opens the SQLite database (running pending +// ledger-tracked migrations), constructs the API server with the LLM-backed // classifier, and listens on ADRIAN_BACKEND_PORT until SIGTERM. package main diff --git a/backend/internal/api/handlers_events.go b/backend/internal/api/handlers_events.go index 2c2a522..b5b1729 100644 --- a/backend/internal/api/handlers_events.go +++ b/backend/internal/api/handlers_events.go @@ -43,6 +43,7 @@ type timelineVerdict struct { ID string `json:"id"` MADCode string `json:"mad_code"` Classification string `json:"classification"` + VerdictStatus string `json:"verdict_status"` } type timelineEntry struct { @@ -147,6 +148,7 @@ func (s *Server) handleSessionTimeline(w http.ResponseWriter, r *http.Request) { ID: row.VerdictID, MADCode: row.MADCode, Classification: row.Classification, + VerdictStatus: row.VerdictStatus, } } resp.Entries = append(resp.Entries, entry) diff --git a/backend/internal/api/handlers_policy.go b/backend/internal/api/handlers_policy.go index 139cb08..fcaf04a 100644 --- a/backend/internal/api/handlers_policy.go +++ b/backend/internal/api/handlers_policy.go @@ -10,20 +10,22 @@ import ( ) type policyResponse struct { - Mode string `json:"mode"` - PolicyM0 bool `json:"policy_m0"` - PolicyM2 bool `json:"policy_m2"` - PolicyM3 bool `json:"policy_m3"` - PolicyM4 bool `json:"policy_m4"` - UpdatedAt string `json:"updated_at"` + Mode string `json:"mode"` + PolicyM0 bool `json:"policy_m0"` + PolicyM2 bool `json:"policy_m2"` + PolicyM3 bool `json:"policy_m3"` + PolicyM4 bool `json:"policy_m4"` + FailClosedOnClassifierError bool `json:"fail_closed_on_classifier_error"` + UpdatedAt string `json:"updated_at"` } type policyPatchRequest struct { - Mode *string `json:"mode"` - PolicyM0 *bool `json:"policy_m0"` - PolicyM2 *bool `json:"policy_m2"` - PolicyM3 *bool `json:"policy_m3"` - PolicyM4 *bool `json:"policy_m4"` + Mode *string `json:"mode"` + PolicyM0 *bool `json:"policy_m0"` + PolicyM2 *bool `json:"policy_m2"` + PolicyM3 *bool `json:"policy_m3"` + PolicyM4 *bool `json:"policy_m4"` + FailClosedOnClassifierError *bool `json:"fail_closed_on_classifier_error"` } func (s *Server) handleGetPolicy(w http.ResponseWriter, r *http.Request) { @@ -47,11 +49,12 @@ func (s *Server) handleUpdatePolicy(w http.ResponseWriter, r *http.Request) { } patch := &store.PolicyPatch{ - Mode: req.Mode, - PolicyM0: req.PolicyM0, - PolicyM2: req.PolicyM2, - PolicyM3: req.PolicyM3, - PolicyM4: req.PolicyM4, + Mode: req.Mode, + PolicyM0: req.PolicyM0, + PolicyM2: req.PolicyM2, + PolicyM3: req.PolicyM3, + PolicyM4: req.PolicyM4, + FailClosedOnClassifierError: req.FailClosedOnClassifierError, } if err := s.store.UpdatePolicy(r.Context(), patch); err != nil { writeError(w, http.StatusInternalServerError, "update failed") @@ -80,6 +83,9 @@ func (s *Server) handleUpdatePolicy(w http.ResponseWriter, r *http.Request) { if req.PolicyM4 != nil { details["policy_m4"] = *req.PolicyM4 } + if req.FailClosedOnClassifierError != nil { + details["fail_closed_on_classifier_error"] = *req.FailClosedOnClassifierError + } writeAuditLog(r.Context(), s.store, userID(r), "policy_updated", "policies", details) writeJSON(w, http.StatusOK, policyResponseFromStore(pol)) @@ -87,12 +93,13 @@ func (s *Server) handleUpdatePolicy(w http.ResponseWriter, r *http.Request) { func policyResponseFromStore(p *store.Policy) policyResponse { return policyResponse{ - Mode: p.Mode, - PolicyM0: p.PolicyM0, - PolicyM2: p.PolicyM2, - PolicyM3: p.PolicyM3, - PolicyM4: p.PolicyM4, - UpdatedAt: p.UpdatedAt.UTC().Format("2006-01-02T15:04:05.000Z"), + Mode: p.Mode, + PolicyM0: p.PolicyM0, + PolicyM2: p.PolicyM2, + PolicyM3: p.PolicyM3, + PolicyM4: p.PolicyM4, + FailClosedOnClassifierError: p.FailClosedOnClassifierError, + UpdatedAt: p.UpdatedAt.UTC().Format("2006-01-02T15:04:05.000Z"), } } diff --git a/backend/internal/api/handlers_reviews.go b/backend/internal/api/handlers_reviews.go index f51220a..2ecd782 100644 --- a/backend/internal/api/handlers_reviews.go +++ b/backend/internal/api/handlers_reviews.go @@ -16,15 +16,16 @@ import ( ) type reviewSummary struct { - ID string `json:"id"` - EventID string `json:"event_id"` - VerdictID string `json:"verdict_id"` - SessionID string `json:"session_id"` - MADCode string `json:"mad_code"` - Status string `json:"status"` - CreatedAt string `json:"created_at"` - ReviewedBy string `json:"reviewed_by,omitempty"` - ReviewedAt string `json:"reviewed_at,omitempty"` + ID string `json:"id"` + EventID string `json:"event_id"` + VerdictID string `json:"verdict_id"` + SessionID string `json:"session_id"` + MADCode string `json:"mad_code"` + VerdictStatus string `json:"verdict_status"` + Status string `json:"status"` + CreatedAt string `json:"created_at"` + ReviewedBy string `json:"reviewed_by,omitempty"` + ReviewedAt string `json:"reviewed_at,omitempty"` } type reviewListResponse struct { @@ -127,6 +128,7 @@ func (s *Server) resolveReview(w http.ResponseWriter, r *http.Request, status st EventId: row.EventID, SessionId: row.SessionID, MadCode: row.MADCode, + Status: pb.VerdictStatus_VERDICT_STATUS_OK, Policy: s.policySnapshotProto(pol), Hitl: &pb.HitlResponse{ContinueExecution: continueExec}, }}, @@ -148,14 +150,15 @@ func (s *Server) resolveReview(w http.ResponseWriter, r *http.Request, status st func reviewToSummary(r *store.HitlReview) reviewSummary { out := reviewSummary{ - ID: r.ID, - EventID: r.EventID, - VerdictID: r.VerdictID, - SessionID: r.SessionID, - MADCode: r.MADCode, - Status: r.Status, - CreatedAt: r.CreatedAt.UTC().Format("2006-01-02T15:04:05.000Z"), - ReviewedBy: r.ReviewedBy, + ID: r.ID, + EventID: r.EventID, + VerdictID: r.VerdictID, + SessionID: r.SessionID, + MADCode: r.MADCode, + VerdictStatus: r.VerdictStatus, + Status: r.Status, + CreatedAt: r.CreatedAt.UTC().Format("2006-01-02T15:04:05.000Z"), + ReviewedBy: r.ReviewedBy, } if !r.ReviewedAt.IsZero() { out.ReviewedAt = r.ReviewedAt.UTC().Format("2006-01-02T15:04:05.000Z") diff --git a/backend/internal/api/handlers_test.go b/backend/internal/api/handlers_test.go index a37a9f0..0b2c209 100644 --- a/backend/internal/api/handlers_test.go +++ b/backend/internal/api/handlers_test.go @@ -289,10 +289,15 @@ func TestPolicyGetAndUpdate(t *testing.T) { if body["data"].(map[string]any)["mode"] != "alert" { t.Errorf("default mode = %v, want alert", body["data"].(map[string]any)["mode"]) } + if body["data"].(map[string]any)["fail_closed_on_classifier_error"] != false { + t.Errorf("default fail_closed_on_classifier_error = %v, want false", + body["data"].(map[string]any)["fail_closed_on_classifier_error"]) + } // PUT resp = doJSON(t, srv, cookie, http.MethodPut, "/api/settings/policy", map[string]any{ - "mode": "hitl", + "mode": "hitl", + "fail_closed_on_classifier_error": true, }) if resp.StatusCode != http.StatusOK { t.Fatalf("PUT status = %d, want 200", resp.StatusCode) @@ -301,6 +306,10 @@ func TestPolicyGetAndUpdate(t *testing.T) { if body["data"].(map[string]any)["mode"] != "hitl" { t.Errorf("post-PUT mode = %v, want hitl", body["data"].(map[string]any)["mode"]) } + if body["data"].(map[string]any)["fail_closed_on_classifier_error"] != true { + t.Errorf("post-PUT fail_closed_on_classifier_error = %v, want true", + body["data"].(map[string]any)["fail_closed_on_classifier_error"]) + } } func TestPolicyInvalidMode(t *testing.T) { @@ -478,6 +487,44 @@ func TestStatsActivityEmpty(t *testing.T) { } } +// ----------------------------------------------------------------- +// Verdicts +// ----------------------------------------------------------------- + +func TestListVerdictsIncludesStatusAndFiltersError(t *testing.T) { + srv, db, _, cookie := newTestServerLoggedIn(t) + + eventID := uuid.NewString() + if _, err := db.Exec( + `INSERT INTO events (id, session_id, agent_id, event_type, run_id, payload) + VALUES (?, 'sess-verdicts', 'agent-v', 'tool', 'r1', '{}')`, + eventID, + ); err != nil { + t.Fatalf("seed event: %v", err) + } + if _, err := db.Exec( + `INSERT INTO verdicts (id, event_id, session_id, mad_code, classification, verdict_status, reasoning) + VALUES (?, ?, 'sess-verdicts', '', 'error', 'error', 'classifier failed')`, + uuid.NewString(), eventID, + ); err != nil { + t.Fatalf("seed verdict: %v", err) + } + + resp := getReq(t, srv, cookie, "/api/verdicts?classification=error&verdict_status=error") + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + data := decodeBody(t, resp)["data"].(map[string]any) + if int(data["total"].(float64)) != 1 { + t.Fatalf("total = %v, want 1", data["total"]) + } + verdicts := data["verdicts"].([]any) + row := verdicts[0].(map[string]any) + if row["classification"] != "error" || row["verdict_status"] != "error" { + t.Errorf("verdict row = %v, want classification/status error", row) + } +} + // ----------------------------------------------------------------- // Reviews / HITL // ----------------------------------------------------------------- @@ -915,6 +962,9 @@ func TestSessionTimeline(t *testing.T) { if verdict["mad_code"] != "M3" { t.Errorf("verdict.mad_code = %v, want M3", verdict["mad_code"]) } + if verdict["verdict_status"] != "ok" { + t.Errorf("verdict.verdict_status = %v, want ok", verdict["verdict_status"]) + } } // ----------------------------------------------------------------- @@ -1291,6 +1341,7 @@ CREATE TABLE policies ( policy_m2 INTEGER NOT NULL DEFAULT 0, policy_m3 INTEGER NOT NULL DEFAULT 1, policy_m4 INTEGER NOT NULL DEFAULT 1, + fail_closed_on_classifier_error INTEGER NOT NULL DEFAULT 0, updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ','now')) ); INSERT INTO policies (id) VALUES (1); @@ -1312,6 +1363,7 @@ CREATE TABLE verdicts ( agent_profile_id TEXT, mad_code TEXT NOT NULL, classification TEXT NOT NULL, + verdict_status TEXT NOT NULL DEFAULT 'ok', reasoning TEXT, latency_ms INTEGER, tokens_used INTEGER NOT NULL DEFAULT 0, diff --git a/backend/internal/api/handlers_verdicts.go b/backend/internal/api/handlers_verdicts.go index 30d7c73..3c75e57 100644 --- a/backend/internal/api/handlers_verdicts.go +++ b/backend/internal/api/handlers_verdicts.go @@ -16,6 +16,7 @@ type verdictResponse struct { SessionID string `json:"session_id"` MADCode string `json:"mad_code"` Classification string `json:"classification"` + VerdictStatus string `json:"verdict_status"` LatencyMS *int64 `json:"latency_ms,omitempty"` TokensUsed int32 `json:"tokens_used"` CreatedAt string `json:"created_at"` @@ -38,10 +39,19 @@ func (s *Server) handleListVerdicts(w http.ResponseWriter, r *http.Request) { since = t } } + if c := q.Get("classification"); c != "" && !validVerdictClassification(c) { + writeError(w, http.StatusBadRequest, "invalid classification") + return + } + if status := q.Get("verdict_status"); status != "" && !validVerdictStatus(status) { + writeError(w, http.StatusBadRequest, "invalid verdict_status") + return + } filters := store.VerdictFilters{ Since: since, Classification: q.Get("classification"), MADCode: q.Get("mad_code"), + VerdictStatus: q.Get("verdict_status"), } rows, total, err := s.store.ListVerdicts(r.Context(), filters, pg.PerPage, pg.Offset) if err != nil { @@ -67,8 +77,25 @@ func verdictRowToResponse(r *store.VerdictListRow) verdictResponse { SessionID: r.SessionID, MADCode: r.MADCode, Classification: r.Classification, + VerdictStatus: r.VerdictStatus, LatencyMS: r.LatencyMS, TokensUsed: r.TokensUsed, CreatedAt: r.CreatedAt.UTC().Format("2006-01-02T15:04:05.000Z"), } } + +func validVerdictClassification(c string) bool { + switch c { + case "benign", "notify", "block", "error": + return true + } + return false +} + +func validVerdictStatus(s string) bool { + switch s { + case "ok", "error": + return true + } + return false +} diff --git a/backend/internal/db/db.go b/backend/internal/db/db.go index c62b7c3..2f2b9e4 100644 --- a/backend/internal/db/db.go +++ b/backend/internal/db/db.go @@ -1,8 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 // Copyright (c) 2026 SecureAgentics -// Package db opens the SQLite database, applies idempotent migrations, -// and exposes the *sql.DB handle to the rest of the backend. +// Package db opens the SQLite database, applies pending ledger-tracked +// migrations, and exposes the *sql.DB handle to the rest of the backend. package db import ( @@ -16,8 +16,9 @@ import ( ) // Open opens the SQLite database at path, applies the WAL / FK -// pragmas, and runs every embedded migration in lexical order. -// Migrations are idempotent so re-running on each startup is safe. +// pragmas, and runs each pending embedded migration in lexical order. +// Applied migrations are recorded in schema_migrations so startup can +// safely skip files that already ran. func Open(path string) (*sql.DB, error) { conn, err := sql.Open("sqlite", path) if err != nil { diff --git a/backend/internal/db/migrate.go b/backend/internal/db/migrate.go index f0621e2..0c03576 100644 --- a/backend/internal/db/migrate.go +++ b/backend/internal/db/migrate.go @@ -11,10 +11,19 @@ import ( "strings" ) -// applyMigrations walks fsys for `*.sql` files and execs each one in -// lexical order. Migrations are idempotent (CREATE TABLE IF NOT EXISTS -// + INSERT OR IGNORE), so re-running on a populated database is a -// no-op. Returns the list of files applied. +const migrationLedgerDDL = ` +CREATE TABLE IF NOT EXISTS schema_migrations ( + name TEXT PRIMARY KEY, + applied_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ','now')) +);` + +const noTransactionMarker = "-- adrian: no-transaction" + +// applyMigrations walks fsys for `*.sql` files and applies each +// previously-unseen migration in lexical order. Applied files are +// recorded in schema_migrations by filename, so future startup runs +// skip them instead of requiring every migration to be idempotent. +// Returns the list of migration files applied during this call. func applyMigrations(db *sql.DB, fsys fs.FS) ([]string, error) { var names []string err := fs.WalkDir(fsys, ".", func(path string, d fs.DirEntry, err error) error { @@ -32,14 +41,65 @@ func applyMigrations(db *sql.DB, fsys fs.FS) ([]string, error) { } sort.Strings(names) + if _, err := db.Exec(migrationLedgerDDL); err != nil { + return nil, fmt.Errorf("ensure schema_migrations: %w", err) + } + + applied := make([]string, 0, len(names)) for _, name := range names { + alreadyApplied, err := migrationApplied(db, name) + if err != nil { + return nil, err + } + if alreadyApplied { + continue + } + body, err := fs.ReadFile(fsys, name) if err != nil { return nil, fmt.Errorf("read %s: %w", name, err) } - if _, err := db.Exec(string(body)); err != nil { - return nil, fmt.Errorf("exec %s: %w", name, err) + bodyText := string(body) + + if strings.Contains(bodyText, noTransactionMarker) { + if _, err := db.Exec(bodyText); err != nil { + _, _ = db.Exec("ROLLBACK") + _, _ = db.Exec("PRAGMA foreign_keys=ON") + return nil, fmt.Errorf("exec %s: %w", name, err) + } + if _, err := db.Exec(`INSERT INTO schema_migrations (name) VALUES (?)`, name); err != nil { + return nil, fmt.Errorf("record %s: %w", name, err) + } + } else { + tx, err := db.Begin() + if err != nil { + return nil, fmt.Errorf("begin %s: %w", name, err) + } + if _, err := tx.Exec(bodyText); err != nil { + _ = tx.Rollback() + return nil, fmt.Errorf("exec %s: %w", name, err) + } + if _, err := tx.Exec(`INSERT INTO schema_migrations (name) VALUES (?)`, name); err != nil { + _ = tx.Rollback() + return nil, fmt.Errorf("record %s: %w", name, err) + } + if err := tx.Commit(); err != nil { + return nil, fmt.Errorf("commit %s: %w", name, err) + } } + applied = append(applied, name) + } + return applied, nil +} + +func migrationApplied(db *sql.DB, name string) (bool, error) { + var seen int + err := db.QueryRow(`SELECT 1 FROM schema_migrations WHERE name = ?`, name).Scan(&seen) + if err == nil { + return true, nil + } + if err == sql.ErrNoRows { + return false, nil } - return names, nil + return false, fmt.Errorf("lookup migration %s: %w", name, err) } diff --git a/backend/internal/db/migrate_test.go b/backend/internal/db/migrate_test.go new file mode 100644 index 0000000..b257698 --- /dev/null +++ b/backend/internal/db/migrate_test.go @@ -0,0 +1,135 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2026 SecureAgentics + +package db + +import ( + "database/sql" + "testing" + "testing/fstest" + + _ "modernc.org/sqlite" +) + +func TestApplyMigrationsUsesLedger(t *testing.T) { + conn := openTestDB(t) + defer conn.Close() + + fsys := fstest.MapFS{ + "001_create.sql": { + Data: []byte(`CREATE TABLE widgets (id INTEGER PRIMARY KEY, name TEXT NOT NULL);`), + }, + "002_insert.sql": { + Data: []byte(`INSERT INTO widgets (name) VALUES ('first');`), + }, + } + + applied, err := applyMigrations(conn, fsys) + if err != nil { + t.Fatalf("first applyMigrations: %v", err) + } + if got, want := len(applied), 2; got != want { + t.Fatalf("first applied len = %d, want %d (%v)", got, want, applied) + } + + applied, err = applyMigrations(conn, fsys) + if err != nil { + t.Fatalf("second applyMigrations: %v", err) + } + if got := len(applied); got != 0 { + t.Fatalf("second applied len = %d, want 0 (%v)", got, applied) + } + + var widgets int + if err := conn.QueryRow(`SELECT count(*) FROM widgets`).Scan(&widgets); err != nil { + t.Fatalf("count widgets: %v", err) + } + if widgets != 1 { + t.Fatalf("widgets count = %d, want 1", widgets) + } + + var ledgerRows int + if err := conn.QueryRow(`SELECT count(*) FROM schema_migrations`).Scan(&ledgerRows); err != nil { + t.Fatalf("count schema_migrations: %v", err) + } + if ledgerRows != 2 { + t.Fatalf("schema_migrations count = %d, want 2", ledgerRows) + } +} + +func TestApplyMigrationsDoesNotRecordFailedMigration(t *testing.T) { + conn := openTestDB(t) + defer conn.Close() + + fsys := fstest.MapFS{ + "001_create.sql": { + Data: []byte(`CREATE TABLE widgets (id INTEGER PRIMARY KEY);`), + }, + "002_bad.sql": { + Data: []byte(`INSERT INTO missing_table (id) VALUES (1);`), + }, + } + + applied, err := applyMigrations(conn, fsys) + if err == nil { + t.Fatal("applyMigrations unexpectedly succeeded") + } + if got, want := len(applied), 0; got != want { + t.Fatalf("applied len after failure = %d, want %d (%v)", got, want, applied) + } + + if migrationWasRecorded(t, conn, "002_bad.sql") { + t.Fatal("failed migration was recorded in schema_migrations") + } + if !migrationWasRecorded(t, conn, "001_create.sql") { + t.Fatal("successful prior migration was not recorded") + } +} + +func TestApplyMigrationsSupportsNoTransactionMarker(t *testing.T) { + conn := openTestDB(t) + defer conn.Close() + + fsys := fstest.MapFS{ + "001_no_tx.sql": { + Data: []byte(noTransactionMarker + ` +BEGIN; +CREATE TABLE widgets (id INTEGER PRIMARY KEY, name TEXT NOT NULL); +INSERT INTO widgets (name) VALUES ('marker'); +COMMIT;`), + }, + } + + applied, err := applyMigrations(conn, fsys) + if err != nil { + t.Fatalf("applyMigrations: %v", err) + } + if got, want := len(applied), 1; got != want { + t.Fatalf("applied len = %d, want %d (%v)", got, want, applied) + } + if !migrationWasRecorded(t, conn, "001_no_tx.sql") { + t.Fatal("no-transaction migration was not recorded") + } +} + +func openTestDB(t *testing.T) *sql.DB { + t.Helper() + conn, err := sql.Open("sqlite", "file:migratetest?mode=memory&cache=shared") + if err != nil { + t.Fatalf("open sqlite: %v", err) + } + return conn +} + +func migrationWasRecorded(t *testing.T, conn *sql.DB, name string) bool { + t.Helper() + var seen int + err := conn.QueryRow(`SELECT 1 FROM schema_migrations WHERE name = ?`, name).Scan(&seen) + if err == sql.ErrNoRows { + return false + } + if err != nil { + t.Fatalf("lookup migration %s: %v", name, err) + } + return true +} diff --git a/backend/internal/engine/client.go b/backend/internal/engine/client.go index 97e7d9a..b91d2de 100644 --- a/backend/internal/engine/client.go +++ b/backend/internal/engine/client.go @@ -30,13 +30,10 @@ const ( ) // HTTPClient classifies paired events by POSTing to ADRIAN_LLM_URL. -// Any classifier failure (transport, non-2xx HTTP, malformed body, -// no parseable M-code) falls back to a synthetic M0 / benign verdict -// with the cause stored on the Reasoning column and a WARN logged. -// Adrian's posture is fail-open: a classifier outage on our side -// must not halt the operator's agent. The trade-off is that a -// malicious agent who can DOS the classifier rides this path; -// detection-class outages are treated the same as model parse misses. +// Classifier failures (transport, non-2xx HTTP, malformed body, +// empty choices, or no parseable M-code) are returned as errors. The +// WS ingest layer records those as status=ERROR verdicts and applies +// the active execution policy. // // The classifier owns the SlidingWindow: every call acquires the // per-(session, invocation, agent_id) lock, reads history into the @@ -148,9 +145,8 @@ func (c *HTTPClient) lookupProfile(ctx context.Context, id string) *store.AgentP // classifyOnce renders the trace, builds the message array (with the // optional history prepended), POSTs, and parses. Returns (nil, error) -// on any failure; the WS handler is responsible for the mode-specific -// fail-closed dispatch (halt the SDK in BLOCK, queue for review in -// HITL, audit-only in ALERT). +// on any failure; the WS handler is responsible for persisting the +// status=ERROR verdict and applying the active execution policy. func (c *HTTPClient) classifyOnce(ctx context.Context, ev *pb.PairedEvent, history []HistoryItem, guid string, profile *store.AgentProfile) (*Verdict, error) { start := time.Now() trace := extractTrace(ev, guid) @@ -165,24 +161,22 @@ func (c *HTTPClient) classifyOnce(ctx context.Context, ev *pb.PairedEvent, histo raw, err := c.post(ctx, body) if err != nil { - // Transport / non-2xx. Fail open with M0 / benign so the - // agent isn't halted by a classifier outage on our side. - return c.failOpen(ctx, fmt.Errorf("post: %w", err), start), nil + return nil, fmt.Errorf("post: %w", err) } var parsed responseBody if err := json.Unmarshal(raw, &parsed); err != nil { - return c.failOpen(ctx, fmt.Errorf("unmarshal: %w", err), start), nil + return nil, fmt.Errorf("unmarshal: %w", err) } if len(parsed.Choices) == 0 { - return c.failOpen(ctx, errors.New("no choices in response"), start), nil + return nil, errors.New("no choices in response") } rawContent := parsed.Choices[0].Message.Content stripped := stripReasoning(rawContent) code := parseMADCode(stripped) if code == "" { - return c.failOpen(ctx, fmt.Errorf("no MAD code in response: %q", truncate(stripped, 200)), start), nil + return nil, fmt.Errorf("no MAD code in response: %q", truncate(stripped, 200)) } classification := madCodeToClassification(code) @@ -230,23 +224,6 @@ func (c *HTTPClient) post(ctx context.Context, body requestBody) ([]byte, error) return respBody, nil } -// failOpen returns a synthetic M0 / benign verdict on any classifier -// failure (transport, non-2xx, malformed JSON, empty choices, no -// parseable M-code). WARN-logged with the cause; the cause string -// also lands on the Reasoning column so operators can distinguish a -// classifier outage from a benign-by-classification verdict in the -// dashboard. Adrian's posture is fail-open: a classifier outage on -// our side must not halt the operator's agent. -func (c *HTTPClient) failOpen(ctx context.Context, cause error, start time.Time) *Verdict { - slog.WarnContext(ctx, "engine.classifier_failure_fail_open", "error", cause) - return &Verdict{ - MADCode: "M0", - Classification: "benign", - Reasoning: "classifier failure (fail-open): " + cause.Error(), - LatencyMS: time.Since(start).Milliseconds(), - } -} - // Ping reaches the configured classifier URL with a short timeout to // confirm the upstream answers TCP + TLS + HTTP. Treats any HTTP // status (including 4xx like 405 Method Not Allowed for our POST-only diff --git a/backend/internal/engine/client_test.go b/backend/internal/engine/client_test.go index 781878e..24799a6 100644 --- a/backend/internal/engine/client_test.go +++ b/backend/internal/engine/client_test.go @@ -89,7 +89,7 @@ func TestMADCodeToClassification(t *testing.T) { "M3.b": "block", "M4": "block", "M4.e": "block", - "": "benign", + "": "error", } for code, want := range cases { if got := madCodeToClassification(code); got != want { @@ -232,11 +232,10 @@ func TestHTTPClientClassifyHappy(t *testing.T) { } } -// TestHTTPClientClassifyFailsOpenOn5xx asserts that an upstream HTTP -// error (e.g. 500) returns a synthetic M0 / benign verdict rather than -// halting the agent. Adrian's posture is fail-open across all -// classifier failures. -func TestHTTPClientClassifyFailsOpenOn5xx(t *testing.T) { +// TestHTTPClientClassifyErrorsOn5xx asserts that an upstream HTTP +// error (e.g. 500) is returned to the WS ingest layer so it can +// persist an ERROR verdict and apply policy. +func TestHTTPClientClassifyErrorsOn5xx(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.Error(w, "boom", http.StatusInternalServerError) })) @@ -250,28 +249,21 @@ func TestHTTPClientClassifyFailsOpenOn5xx(t *testing.T) { Tool: &pb.ToolPairData{ToolName: "noop"}, }, }, "") - if err != nil { - t.Fatalf("Classify on 5xx must NOT error (fail-open path); got %v", err) - } - if v == nil { - t.Fatal("verdict should not be nil on the fail-open path") - } - if v.MADCode != "M0" { - t.Errorf("fail-open mad_code = %q, want M0", v.MADCode) + if err == nil { + t.Fatal("Classify on 5xx unexpectedly succeeded") } - if v.Classification != "benign" { - t.Errorf("fail-open classification = %q, want benign", v.Classification) + if v != nil { + t.Fatalf("verdict = %+v, want nil on classifier error", v) } - if !strings.Contains(v.Reasoning, "classifier failure") || !strings.Contains(v.Reasoning, "status 500") { - t.Errorf("Reasoning should reference upstream status; got %q", v.Reasoning) + if !strings.Contains(err.Error(), "post:") || !strings.Contains(err.Error(), "status 500") { + t.Errorf("error should reference upstream status; got %v", err) } } -// TestHTTPClientClassifyFailsOpenOnConnRefused asserts the -// transport-failure path (server unreachable / connection refused) -// also fails open with a synthetic M0 / benign verdict. Same posture -// as 5xx: classifier outages on our side must not halt the agent. -func TestHTTPClientClassifyFailsOpenOnConnRefused(t *testing.T) { +// TestHTTPClientClassifyErrorsOnConnRefused asserts the transport +// failure path (server unreachable / connection refused) returns an +// error rather than a synthetic benign verdict. +func TestHTTPClientClassifyErrorsOnConnRefused(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Will not be hit; we close the server before calling Classify. w.WriteHeader(http.StatusOK) @@ -287,21 +279,21 @@ func TestHTTPClientClassifyFailsOpenOnConnRefused(t *testing.T) { Tool: &pb.ToolPairData{ToolName: "noop"}, }, }, "") - if err != nil { - t.Fatalf("Classify on connection-refused must NOT error (fail-open path); got %v", err) + if err == nil { + t.Fatal("Classify on connection-refused unexpectedly succeeded") } - if v == nil || v.MADCode != "M0" || v.Classification != "benign" { - t.Errorf("fail-open verdict = %+v, want M0/benign", v) + if v != nil { + t.Fatalf("verdict = %+v, want nil on classifier error", v) } - if !strings.Contains(v.Reasoning, "classifier failure") { - t.Errorf("Reasoning should mention classifier failure; got %q", v.Reasoning) + if !strings.Contains(err.Error(), "post:") { + t.Errorf("error should identify post failure; got %v", err) } } -// TestHTTPClientClassifyFailsOpenOnUnparseable asserts the +// TestHTTPClientClassifyErrorsOnUnparseable asserts the // 2xx-with-garbled-body path: upstream answered, body has no -// recognisable M-code, engine returns synthetic M0 / benign. -func TestHTTPClientClassifyFailsOpenOnUnparseable(t *testing.T) { +// recognisable M-code, so engine returns an error. +func TestHTTPClientClassifyErrorsOnUnparseable(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte(`{"choices":[{"message":{"content":"sorry, no idea"}}]}`)) })) @@ -313,27 +305,44 @@ func TestHTTPClientClassifyFailsOpenOnUnparseable(t *testing.T) { PairType: pb.PairType_PAIR_TYPE_TOOL, Data: &pb.PairedEvent_Tool{Tool: &pb.ToolPairData{ToolName: "noop"}}, }, "") - if err != nil { - t.Fatalf("Classify on unparseable body must NOT error (fail-open path); got %v", err) + if err == nil { + t.Fatal("Classify on unparseable body unexpectedly succeeded") } - if v == nil { - t.Fatal("verdict should not be nil on the fail-open path") + if v != nil { + t.Fatalf("verdict = %+v, want nil on classifier error", v) } - if v.MADCode != "M0" { - t.Errorf("fail-open mad_code = %q, want M0", v.MADCode) + if !strings.Contains(err.Error(), "no MAD code") { + t.Errorf("error should explain the parse miss; got %v", err) } - if v.Classification != "benign" { - t.Errorf("fail-open classification = %q, want benign", v.Classification) +} + +func TestHTTPClientClassifyErrorsOnMalformedJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(`not-json`)) + })) + defer srv.Close() + + c := NewHTTPClient(srv.URL, "test-key", "test-model", nil, nil) + v, err := c.Classify(context.Background(), &pb.PairedEvent{ + EventId: "ev-malformed", + PairType: pb.PairType_PAIR_TYPE_TOOL, + Data: &pb.PairedEvent_Tool{Tool: &pb.ToolPairData{ToolName: "noop"}}, + }, "") + if err == nil { + t.Fatal("Classify on malformed JSON unexpectedly succeeded") + } + if v != nil { + t.Fatalf("verdict = %+v, want nil on classifier error", v) } - if !strings.Contains(v.Reasoning, "classifier failure") || !strings.Contains(v.Reasoning, "no MAD code") { - t.Errorf("Reasoning should explain the parse miss; got %q", v.Reasoning) + if !strings.Contains(err.Error(), "unmarshal:") { + t.Errorf("error should explain malformed JSON; got %v", err) } } -// TestHTTPClientClassifyFailsOpenOnEmptyChoices is the second -// branch into failOpenUnparseable: 2xx + valid JSON envelope, but -// the choices array is empty. Same fail-open posture. -func TestHTTPClientClassifyFailsOpenOnEmptyChoices(t *testing.T) { +// TestHTTPClientClassifyErrorsOnEmptyChoices is the second +// unparseable-response branch: 2xx + valid JSON envelope, but the +// choices array is empty. +func TestHTTPClientClassifyErrorsOnEmptyChoices(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte(`{"choices":[]}`)) })) @@ -345,11 +354,14 @@ func TestHTTPClientClassifyFailsOpenOnEmptyChoices(t *testing.T) { PairType: pb.PairType_PAIR_TYPE_TOOL, Data: &pb.PairedEvent_Tool{Tool: &pb.ToolPairData{ToolName: "noop"}}, }, "") - if err != nil { - t.Fatalf("Classify on empty-choices must NOT error; got %v", err) + if err == nil { + t.Fatal("Classify on empty-choices unexpectedly succeeded") } - if v == nil || v.MADCode != "M0" { - t.Errorf("fail-open verdict = %+v, want M0/benign", v) + if v != nil { + t.Fatalf("verdict = %+v, want nil on classifier error", v) + } + if !strings.Contains(err.Error(), "no choices") { + t.Errorf("error should explain empty choices; got %v", err) } } @@ -418,6 +430,48 @@ func TestHTTPClientWindowFeedsHistory(t *testing.T) { } } +func TestHTTPClientWindowSkipsFailedTurns(t *testing.T) { + var captured []requestBody + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + var req requestBody + _ = json.Unmarshal(body, &req) + captured = append(captured, req) + if len(captured) == 1 { + _, _ = w.Write([]byte(`{"choices":[]}`)) + return + } + _, _ = w.Write([]byte(`{"choices":[{"message":{"content":"M0"}}]}`)) + })) + defer srv.Close() + + window := NewSlidingWindow(WindowOpts{Size: 16, TTL: time.Hour}) + c := NewHTTPClient(srv.URL, "test-key", "test-model", window, nil) + event := &pb.PairedEvent{ + EventId: "ev-window-fail", + SessionId: "sess-window-fail", + InvocationId: "inv-window-fail", + PairType: pb.PairType_PAIR_TYPE_TOOL, + Agent: &pb.AgentContext{AgentId: "agent-window-fail"}, + Data: &pb.PairedEvent_Tool{Tool: &pb.ToolPairData{ToolName: "first_tool"}}, + } + + if _, err := c.Classify(context.Background(), event, ""); err == nil { + t.Fatal("first classify unexpectedly succeeded") + } + event.EventId = "ev-window-success" + if _, err := c.Classify(context.Background(), event, ""); err != nil { + t.Fatalf("second classify: %v", err) + } + + if len(captured) != 2 { + t.Fatalf("captured %d requests, want 2", len(captured)) + } + if got := len(captured[1].Messages); got != 4 { + t.Fatalf("second call messages = %d, want 4 (failed turn not pushed to history)", got) + } +} + // TestHTTPClientNoWindowSkipsHistory ensures the existing zero-config // path (window=nil) works exactly as before: every call sees no // history regardless of any prior call. diff --git a/backend/internal/engine/engine.go b/backend/internal/engine/engine.go index 424ad05..2e80db5 100644 --- a/backend/internal/engine/engine.go +++ b/backend/internal/engine/engine.go @@ -23,9 +23,10 @@ type Verdict struct { // Classifier classifies a paired event. Implementations honour ctx // cancellation. A returned error means classification could not be -// completed safely (LLM unreachable, malformed response, no parseable -// M-code) and the caller must fail closed per execution mode. A nil -// verdict with nil error is not a valid response. +// completed (LLM unreachable, malformed response, empty choices, no +// parseable M-code). The caller owns persistence and policy routing +// for those operational failures. A nil verdict with nil error is not +// a valid response. // // agentProfileID is the customer-facing agent identity bound to the // SDK's API key (looked up server-side at WS-login time). Pass "" to diff --git a/backend/internal/engine/parse.go b/backend/internal/engine/parse.go index 4b78297..4bad12c 100644 --- a/backend/internal/engine/parse.go +++ b/backend/internal/engine/parse.go @@ -47,11 +47,12 @@ func stripReasoning(content string) string { } } -// madCodeToClassification maps an M-code to its display classification. -// Unknown codes return "benign" (caller should log a warn). +// madCodeToClassification maps a classifier-produced M-code to its +// display classification. Empty or unknown codes are operational +// classifier errors, not benign results. func madCodeToClassification(code string) string { - if code == "" { - return "benign" + if len(code) < 2 { + return "error" } switch code[:2] { case "M0": @@ -61,6 +62,6 @@ func madCodeToClassification(code string) string { case "M3", "M4": return "block" default: - return "benign" + return "error" } } diff --git a/backend/internal/proto/event.pb.go b/backend/internal/proto/event.pb.go index f958ab8..06caa26 100644 --- a/backend/internal/proto/event.pb.go +++ b/backend/internal/proto/event.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v3.19.6 +// protoc v6.33.5 // source: event.proto package proto @@ -132,6 +132,59 @@ func (Mode) EnumDescriptor() ([]byte, []int) { return file_event_proto_rawDescGZIP(), []int{1} } +// VerdictStatus says whether a Verdict came from a completed classifier +// decision or represents a classifier failure. ERROR verdicts carry no +// classifier-produced MAD code; policy decides whether they fail open +// or fail closed. +type VerdictStatus int32 + +const ( + VerdictStatus_VERDICT_STATUS_UNSPECIFIED VerdictStatus = 0 + VerdictStatus_VERDICT_STATUS_OK VerdictStatus = 1 + VerdictStatus_VERDICT_STATUS_ERROR VerdictStatus = 2 +) + +// Enum value maps for VerdictStatus. +var ( + VerdictStatus_name = map[int32]string{ + 0: "VERDICT_STATUS_UNSPECIFIED", + 1: "VERDICT_STATUS_OK", + 2: "VERDICT_STATUS_ERROR", + } + VerdictStatus_value = map[string]int32{ + "VERDICT_STATUS_UNSPECIFIED": 0, + "VERDICT_STATUS_OK": 1, + "VERDICT_STATUS_ERROR": 2, + } +) + +func (x VerdictStatus) Enum() *VerdictStatus { + p := new(VerdictStatus) + *p = x + return p +} + +func (x VerdictStatus) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (VerdictStatus) Descriptor() protoreflect.EnumDescriptor { + return file_event_proto_enumTypes[2].Descriptor() +} + +func (VerdictStatus) Type() protoreflect.EnumType { + return &file_event_proto_enumTypes[2] +} + +func (x VerdictStatus) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use VerdictStatus.Descriptor instead. +func (VerdictStatus) EnumDescriptor() ([]byte, []int) { + return file_event_proto_rawDescGZIP(), []int{2} +} + // ChatMessage represents a conversation message with a string role. type ChatMessage struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -1107,15 +1160,19 @@ func (*ClientFrame_McpInventory) isClientFrame_Frame() {} // // Per-MAD-code booleans say whether the active mode's behaviour fires on // that code. False means "treat this code as silent regardless of mode". +// fail_closed_on_classifier_error controls ERROR verdicts and BLOCK-mode +// SDK verdict timeouts. The default false value preserves fail-open +// availability when talking to older backends. type PolicySnapshot struct { - state protoimpl.MessageState `protogen:"open.v1"` - Mode Mode `protobuf:"varint,1,opt,name=mode,proto3,enum=adrian.core_api.v1.Mode" json:"mode,omitempty"` - PolicyM0 bool `protobuf:"varint,2,opt,name=policy_m0,json=policyM0,proto3" json:"policy_m0,omitempty"` - PolicyM2 bool `protobuf:"varint,3,opt,name=policy_m2,json=policyM2,proto3" json:"policy_m2,omitempty"` - PolicyM3 bool `protobuf:"varint,4,opt,name=policy_m3,json=policyM3,proto3" json:"policy_m3,omitempty"` - PolicyM4 bool `protobuf:"varint,5,opt,name=policy_m4,json=policyM4,proto3" json:"policy_m4,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + Mode Mode `protobuf:"varint,1,opt,name=mode,proto3,enum=adrian.core_api.v1.Mode" json:"mode,omitempty"` + PolicyM0 bool `protobuf:"varint,2,opt,name=policy_m0,json=policyM0,proto3" json:"policy_m0,omitempty"` + PolicyM2 bool `protobuf:"varint,3,opt,name=policy_m2,json=policyM2,proto3" json:"policy_m2,omitempty"` + PolicyM3 bool `protobuf:"varint,4,opt,name=policy_m3,json=policyM3,proto3" json:"policy_m3,omitempty"` + PolicyM4 bool `protobuf:"varint,5,opt,name=policy_m4,json=policyM4,proto3" json:"policy_m4,omitempty"` + FailClosedOnClassifierError bool `protobuf:"varint,6,opt,name=fail_closed_on_classifier_error,json=failClosedOnClassifierError,proto3" json:"fail_closed_on_classifier_error,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *PolicySnapshot) Reset() { @@ -1183,6 +1240,13 @@ func (x *PolicySnapshot) GetPolicyM4() bool { return false } +func (x *PolicySnapshot) GetFailClosedOnClassifierError() bool { + if x != nil { + return x.FailClosedOnClassifierError + } + return false +} + // HitlResponse rides on a Verdict that has been resolved through the // human-in-the-loop review queue. Absent on regular (non-HITL or // out-of-scope) verdicts. @@ -1377,7 +1441,10 @@ type Verdict struct { EventId string `protobuf:"bytes,1,opt,name=event_id,json=eventId,proto3" json:"event_id,omitempty"` // Session identifier for routing. SessionId string `protobuf:"bytes,2,opt,name=session_id,json=sessionId,proto3" json:"session_id,omitempty"` - // MAD code the classifier returned (e.g. "M0", "M2_C", "M4_a"). Empty string for benign. + // MAD code the classifier returned (e.g. "M0", "M2_C", "M4_a"). + // Empty string means no MAD code was produced, such as for a + // VerdictStatus.ERROR classifier failure. Benign classifier success + // is represented by status OK with mad_code "M0". MadCode string `protobuf:"bytes,4,opt,name=mad_code,json=madCode,proto3" json:"mad_code,omitempty"` // Org's effective execution-mode policy at the time of this verdict. // Always populated by the server; SDK reads this to decide whether to @@ -1386,7 +1453,11 @@ type Verdict struct { // Present only when this verdict represents a human-in-the-loop review // resolution (approve or reject from the dashboard). Absent on auto- // classified verdicts and on out-of-scope verdicts forwarded immediately. - Hitl *HitlResponse `protobuf:"bytes,7,opt,name=hitl,proto3" json:"hitl,omitempty"` + Hitl *HitlResponse `protobuf:"bytes,7,opt,name=hitl,proto3" json:"hitl,omitempty"` + // Status of the classifier result. OK means mad_code carries a normal + // classifier decision. ERROR means classification did not complete and + // mad_code is empty; fail-open/fail-closed behaviour comes from policy. + Status VerdictStatus `protobuf:"varint,8,opt,name=status,proto3,enum=adrian.core_api.v1.VerdictStatus" json:"status,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -1456,6 +1527,13 @@ func (x *Verdict) GetHitl() *HitlResponse { return nil } +func (x *Verdict) GetStatus() VerdictStatus { + if x != nil { + return x.Status + } + return VerdictStatus_VERDICT_STATUS_UNSPECIFIED +} + var File_event_proto protoreflect.FileDescriptor const file_event_proto_rawDesc = "" + @@ -1527,13 +1605,14 @@ const file_event_proto_rawDesc = "" + "\x05login\x18\x01 \x01(\v2 .adrian.core_api.v1.SessionLoginH\x00R\x05login\x12I\n" + "\fpaired_batch\x18\x03 \x01(\v2$.adrian.core_api.v1.PairedEventBatchH\x00R\vpairedBatch\x12G\n" + "\rmcp_inventory\x18\x04 \x01(\v2 .adrian.core_api.v1.McpInventoryH\x00R\fmcpInventoryB\a\n" + - "\x05frameJ\x04\b\x02\x10\x03R\x05batch\"\xb2\x01\n" + + "\x05frameJ\x04\b\x02\x10\x03R\x05batch\"\xf8\x01\n" + "\x0ePolicySnapshot\x12,\n" + "\x04mode\x18\x01 \x01(\x0e2\x18.adrian.core_api.v1.ModeR\x04mode\x12\x1b\n" + "\tpolicy_m0\x18\x02 \x01(\bR\bpolicyM0\x12\x1b\n" + "\tpolicy_m2\x18\x03 \x01(\bR\bpolicyM2\x12\x1b\n" + "\tpolicy_m3\x18\x04 \x01(\bR\bpolicyM3\x12\x1b\n" + - "\tpolicy_m4\x18\x05 \x01(\bR\bpolicyM4\"=\n" + + "\tpolicy_m4\x18\x05 \x01(\bR\bpolicyM4\x12D\n" + + "\x1ffail_closed_on_classifier_error\x18\x06 \x01(\bR\x1bfailClosedOnClassifierError\"=\n" + "\fHitlResponse\x12-\n" + "\x12continue_execution\x18\x01 \x01(\bR\x11continueExecution\"F\n" + "\bLoginAck\x12:\n" + @@ -1541,14 +1620,15 @@ const file_event_proto_rawDesc = "" + "\vServerFrame\x12;\n" + "\tlogin_ack\x18\x01 \x01(\v2\x1c.adrian.core_api.v1.LoginAckH\x00R\bloginAck\x127\n" + "\averdict\x18\x02 \x01(\v2\x1b.adrian.core_api.v1.VerdictH\x00R\averdictB\a\n" + - "\x05frame\"\xf6\x01\n" + + "\x05frame\"\xb1\x02\n" + "\aVerdict\x12\x19\n" + "\bevent_id\x18\x01 \x01(\tR\aeventId\x12\x1d\n" + "\n" + "session_id\x18\x02 \x01(\tR\tsessionId\x12\x19\n" + "\bmad_code\x18\x04 \x01(\tR\amadCode\x12:\n" + "\x06policy\x18\x06 \x01(\v2\".adrian.core_api.v1.PolicySnapshotR\x06policy\x124\n" + - "\x04hitl\x18\a \x01(\v2 .adrian.core_api.v1.HitlResponseR\x04hitlJ\x04\b\x03\x10\x04J\x04\b\x05\x10\x06R\x0eclassificationR\bescalate*L\n" + + "\x04hitl\x18\a \x01(\v2 .adrian.core_api.v1.HitlResponseR\x04hitl\x129\n" + + "\x06status\x18\b \x01(\x0e2!.adrian.core_api.v1.VerdictStatusR\x06statusJ\x04\b\x03\x10\x04J\x04\b\x05\x10\x06R\x0eclassificationR\bescalate*L\n" + "\bPairType\x12\x19\n" + "\x15PAIR_TYPE_UNSPECIFIED\x10\x00\x12\x11\n" + "\rPAIR_TYPE_LLM\x10\x01\x12\x12\n" + @@ -1559,7 +1639,11 @@ const file_event_proto_rawDesc = "" + "MODE_ALERT\x10\x01\x12\r\n" + "\tMODE_HITL\x10\x02\x12\x0e\n" + "\n" + - "MODE_BLOCK\x10\x03B?Z=github.com/secureagentics/Adrian/backend/internal/proto;protob\x06proto3" + "MODE_BLOCK\x10\x03*`\n" + + "\rVerdictStatus\x12\x1e\n" + + "\x1aVERDICT_STATUS_UNSPECIFIED\x10\x00\x12\x15\n" + + "\x11VERDICT_STATUS_OK\x10\x01\x12\x18\n" + + "\x14VERDICT_STATUS_ERROR\x10\x02B?Z=github.com/secureagentics/Adrian/backend/internal/proto;protob\x06proto3" var ( file_event_proto_rawDescOnce sync.Once @@ -1573,56 +1657,58 @@ func file_event_proto_rawDescGZIP() []byte { return file_event_proto_rawDescData } -var file_event_proto_enumTypes = make([]protoimpl.EnumInfo, 2) +var file_event_proto_enumTypes = make([]protoimpl.EnumInfo, 3) var file_event_proto_msgTypes = make([]protoimpl.MessageInfo, 18) var file_event_proto_goTypes = []any{ (PairType)(0), // 0: adrian.core_api.v1.PairType (Mode)(0), // 1: adrian.core_api.v1.Mode - (*ChatMessage)(nil), // 2: adrian.core_api.v1.ChatMessage - (*ToolCall)(nil), // 3: adrian.core_api.v1.ToolCall - (*TokenUsage)(nil), // 4: adrian.core_api.v1.TokenUsage - (*AgentContext)(nil), // 5: adrian.core_api.v1.AgentContext - (*LlmPairData)(nil), // 6: adrian.core_api.v1.LlmPairData - (*ToolPairData)(nil), // 7: adrian.core_api.v1.ToolPairData - (*PairedEvent)(nil), // 8: adrian.core_api.v1.PairedEvent - (*PairedEventBatch)(nil), // 9: adrian.core_api.v1.PairedEventBatch - (*McpServer)(nil), // 10: adrian.core_api.v1.McpServer - (*McpInventory)(nil), // 11: adrian.core_api.v1.McpInventory - (*LLMStack)(nil), // 12: adrian.core_api.v1.LLMStack - (*SessionLogin)(nil), // 13: adrian.core_api.v1.SessionLogin - (*ClientFrame)(nil), // 14: adrian.core_api.v1.ClientFrame - (*PolicySnapshot)(nil), // 15: adrian.core_api.v1.PolicySnapshot - (*HitlResponse)(nil), // 16: adrian.core_api.v1.HitlResponse - (*LoginAck)(nil), // 17: adrian.core_api.v1.LoginAck - (*ServerFrame)(nil), // 18: adrian.core_api.v1.ServerFrame - (*Verdict)(nil), // 19: adrian.core_api.v1.Verdict + (VerdictStatus)(0), // 2: adrian.core_api.v1.VerdictStatus + (*ChatMessage)(nil), // 3: adrian.core_api.v1.ChatMessage + (*ToolCall)(nil), // 4: adrian.core_api.v1.ToolCall + (*TokenUsage)(nil), // 5: adrian.core_api.v1.TokenUsage + (*AgentContext)(nil), // 6: adrian.core_api.v1.AgentContext + (*LlmPairData)(nil), // 7: adrian.core_api.v1.LlmPairData + (*ToolPairData)(nil), // 8: adrian.core_api.v1.ToolPairData + (*PairedEvent)(nil), // 9: adrian.core_api.v1.PairedEvent + (*PairedEventBatch)(nil), // 10: adrian.core_api.v1.PairedEventBatch + (*McpServer)(nil), // 11: adrian.core_api.v1.McpServer + (*McpInventory)(nil), // 12: adrian.core_api.v1.McpInventory + (*LLMStack)(nil), // 13: adrian.core_api.v1.LLMStack + (*SessionLogin)(nil), // 14: adrian.core_api.v1.SessionLogin + (*ClientFrame)(nil), // 15: adrian.core_api.v1.ClientFrame + (*PolicySnapshot)(nil), // 16: adrian.core_api.v1.PolicySnapshot + (*HitlResponse)(nil), // 17: adrian.core_api.v1.HitlResponse + (*LoginAck)(nil), // 18: adrian.core_api.v1.LoginAck + (*ServerFrame)(nil), // 19: adrian.core_api.v1.ServerFrame + (*Verdict)(nil), // 20: adrian.core_api.v1.Verdict } var file_event_proto_depIdxs = []int32{ - 2, // 0: adrian.core_api.v1.LlmPairData.messages:type_name -> adrian.core_api.v1.ChatMessage - 3, // 1: adrian.core_api.v1.LlmPairData.tool_calls:type_name -> adrian.core_api.v1.ToolCall - 4, // 2: adrian.core_api.v1.LlmPairData.usage:type_name -> adrian.core_api.v1.TokenUsage + 3, // 0: adrian.core_api.v1.LlmPairData.messages:type_name -> adrian.core_api.v1.ChatMessage + 4, // 1: adrian.core_api.v1.LlmPairData.tool_calls:type_name -> adrian.core_api.v1.ToolCall + 5, // 2: adrian.core_api.v1.LlmPairData.usage:type_name -> adrian.core_api.v1.TokenUsage 0, // 3: adrian.core_api.v1.PairedEvent.pair_type:type_name -> adrian.core_api.v1.PairType - 5, // 4: adrian.core_api.v1.PairedEvent.agent:type_name -> adrian.core_api.v1.AgentContext - 5, // 5: adrian.core_api.v1.PairedEvent.parent:type_name -> adrian.core_api.v1.AgentContext - 6, // 6: adrian.core_api.v1.PairedEvent.llm:type_name -> adrian.core_api.v1.LlmPairData - 7, // 7: adrian.core_api.v1.PairedEvent.tool:type_name -> adrian.core_api.v1.ToolPairData - 8, // 8: adrian.core_api.v1.PairedEventBatch.events:type_name -> adrian.core_api.v1.PairedEvent - 10, // 9: adrian.core_api.v1.McpInventory.servers:type_name -> adrian.core_api.v1.McpServer - 12, // 10: adrian.core_api.v1.SessionLogin.llm_stack:type_name -> adrian.core_api.v1.LLMStack - 13, // 11: adrian.core_api.v1.ClientFrame.login:type_name -> adrian.core_api.v1.SessionLogin - 9, // 12: adrian.core_api.v1.ClientFrame.paired_batch:type_name -> adrian.core_api.v1.PairedEventBatch - 11, // 13: adrian.core_api.v1.ClientFrame.mcp_inventory:type_name -> adrian.core_api.v1.McpInventory + 6, // 4: adrian.core_api.v1.PairedEvent.agent:type_name -> adrian.core_api.v1.AgentContext + 6, // 5: adrian.core_api.v1.PairedEvent.parent:type_name -> adrian.core_api.v1.AgentContext + 7, // 6: adrian.core_api.v1.PairedEvent.llm:type_name -> adrian.core_api.v1.LlmPairData + 8, // 7: adrian.core_api.v1.PairedEvent.tool:type_name -> adrian.core_api.v1.ToolPairData + 9, // 8: adrian.core_api.v1.PairedEventBatch.events:type_name -> adrian.core_api.v1.PairedEvent + 11, // 9: adrian.core_api.v1.McpInventory.servers:type_name -> adrian.core_api.v1.McpServer + 13, // 10: adrian.core_api.v1.SessionLogin.llm_stack:type_name -> adrian.core_api.v1.LLMStack + 14, // 11: adrian.core_api.v1.ClientFrame.login:type_name -> adrian.core_api.v1.SessionLogin + 10, // 12: adrian.core_api.v1.ClientFrame.paired_batch:type_name -> adrian.core_api.v1.PairedEventBatch + 12, // 13: adrian.core_api.v1.ClientFrame.mcp_inventory:type_name -> adrian.core_api.v1.McpInventory 1, // 14: adrian.core_api.v1.PolicySnapshot.mode:type_name -> adrian.core_api.v1.Mode - 15, // 15: adrian.core_api.v1.LoginAck.policy:type_name -> adrian.core_api.v1.PolicySnapshot - 17, // 16: adrian.core_api.v1.ServerFrame.login_ack:type_name -> adrian.core_api.v1.LoginAck - 19, // 17: adrian.core_api.v1.ServerFrame.verdict:type_name -> adrian.core_api.v1.Verdict - 15, // 18: adrian.core_api.v1.Verdict.policy:type_name -> adrian.core_api.v1.PolicySnapshot - 16, // 19: adrian.core_api.v1.Verdict.hitl:type_name -> adrian.core_api.v1.HitlResponse - 20, // [20:20] is the sub-list for method output_type - 20, // [20:20] is the sub-list for method input_type - 20, // [20:20] is the sub-list for extension type_name - 20, // [20:20] is the sub-list for extension extendee - 0, // [0:20] is the sub-list for field type_name + 16, // 15: adrian.core_api.v1.LoginAck.policy:type_name -> adrian.core_api.v1.PolicySnapshot + 18, // 16: adrian.core_api.v1.ServerFrame.login_ack:type_name -> adrian.core_api.v1.LoginAck + 20, // 17: adrian.core_api.v1.ServerFrame.verdict:type_name -> adrian.core_api.v1.Verdict + 16, // 18: adrian.core_api.v1.Verdict.policy:type_name -> adrian.core_api.v1.PolicySnapshot + 17, // 19: adrian.core_api.v1.Verdict.hitl:type_name -> adrian.core_api.v1.HitlResponse + 2, // 20: adrian.core_api.v1.Verdict.status:type_name -> adrian.core_api.v1.VerdictStatus + 21, // [21:21] is the sub-list for method output_type + 21, // [21:21] is the sub-list for method input_type + 21, // [21:21] is the sub-list for extension type_name + 21, // [21:21] is the sub-list for extension extendee + 0, // [0:21] is the sub-list for field type_name } func init() { file_event_proto_init() } @@ -1648,7 +1734,7 @@ func file_event_proto_init() { File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_event_proto_rawDesc), len(file_event_proto_rawDesc)), - NumEnums: 2, + NumEnums: 3, NumMessages: 18, NumExtensions: 0, NumServices: 0, diff --git a/backend/internal/store/events.go b/backend/internal/store/events.go index a388a01..c7c71d9 100644 --- a/backend/internal/store/events.go +++ b/backend/internal/store/events.go @@ -50,6 +50,7 @@ type TimelineRow struct { VerdictID string MADCode string Classification string + VerdictStatus string } // EventFilters is the query-string surface for ListEvents. @@ -183,7 +184,8 @@ func (s *Store) SessionTimeline(ctx context.Context, sessionID string) ([]*Timel e.created_at, COALESCE(v.id, ''), COALESCE(v.mad_code, ''), - COALESCE(v.classification, '') + COALESCE(v.classification, ''), + COALESCE(v.verdict_status, '') FROM events e LEFT JOIN agent_profiles ap ON ap.id = e.agent_profile_id LEFT JOIN verdicts v ON v.event_id = e.id @@ -205,7 +207,7 @@ func (s *Store) SessionTimeline(ctx context.Context, sessionID string) ([]*Timel if err := rows.Scan( &r.ID, &r.EventType, &r.RunID, &r.AgentID, &r.AgentName, &r.PayloadJSON, &createdAt, - &r.VerdictID, &r.MADCode, &r.Classification, + &r.VerdictID, &r.MADCode, &r.Classification, &r.VerdictStatus, ); err != nil { return nil, err } diff --git a/backend/internal/store/hitl.go b/backend/internal/store/hitl.go index 5ee08cd..7aa3019 100644 --- a/backend/internal/store/hitl.go +++ b/backend/internal/store/hitl.go @@ -15,15 +15,16 @@ import ( // HitlReview is a row from hitl_queue, plus joined fields the dashboard // list view needs. type HitlReview struct { - ID string - EventID string - VerdictID string - SessionID string - MADCode string - Status string - ReviewedBy string - ReviewedAt time.Time - CreatedAt time.Time + ID string + EventID string + VerdictID string + SessionID string + MADCode string + VerdictStatus string + Status string + ReviewedBy string + ReviewedAt time.Time + CreatedAt time.Time } // HitlReviewDetail extends HitlReview with the event payload + verdict @@ -59,12 +60,13 @@ func (s *Store) ListHitlQueue(ctx context.Context, status string, perPage, offse return nil, 0, err } rows, err := s.db.QueryContext(ctx, - `SELECT id, event_id, COALESCE(verdict_id, ''), COALESCE(session_id, ''), - mad_code, status, COALESCE(reviewed_by, ''), - COALESCE(reviewed_at, ''), created_at - FROM hitl_queue - WHERE status = ? - ORDER BY created_at DESC + `SELECT q.id, q.event_id, COALESCE(q.verdict_id, ''), COALESCE(q.session_id, ''), + q.mad_code, COALESCE(v.verdict_status, 'ok'), q.status, COALESCE(q.reviewed_by, ''), + COALESCE(q.reviewed_at, ''), q.created_at + FROM hitl_queue q + LEFT JOIN verdicts v ON v.id = q.verdict_id + WHERE q.status = ? + ORDER BY q.created_at DESC LIMIT ? OFFSET ?`, status, perPage, offset) if err != nil { @@ -76,7 +78,7 @@ func (s *Store) ListHitlQueue(ctx context.Context, status string, perPage, offse r := &HitlReview{} var reviewedAt, createdAt string if err := rows.Scan(&r.ID, &r.EventID, &r.VerdictID, &r.SessionID, - &r.MADCode, &r.Status, &r.ReviewedBy, &reviewedAt, &createdAt); err != nil { + &r.MADCode, &r.VerdictStatus, &r.Status, &r.ReviewedBy, &reviewedAt, &createdAt); err != nil { return nil, 0, err } if reviewedAt != "" { @@ -100,7 +102,7 @@ func (s *Store) GetHitlReview(ctx context.Context, id string) (*HitlReviewDetail q.mad_code, q.status, COALESCE(q.reviewed_by, ''), COALESCE(q.reviewed_at, ''), q.created_at, COALESCE(e.payload, ''), - v.classification, v.reasoning + v.classification, COALESCE(v.verdict_status, 'ok'), v.reasoning FROM hitl_queue q LEFT JOIN events e ON e.id = q.event_id LEFT JOIN verdicts v ON v.id = q.verdict_id @@ -110,7 +112,7 @@ func (s *Store) GetHitlReview(ctx context.Context, id string) (*HitlReviewDetail &r.MADCode, &r.Status, &r.ReviewedBy, &reviewedAt, &createdAt, &r.EventPayloadJSON, - &classification, &reasoning, + &classification, &r.VerdictStatus, &reasoning, ) if err != nil { if errors.Is(err, sql.ErrNoRows) { diff --git a/backend/internal/store/policies.go b/backend/internal/store/policies.go index 0152ce8..b0c6a3d 100644 --- a/backend/internal/store/policies.go +++ b/backend/internal/store/policies.go @@ -10,33 +10,37 @@ import ( // Policy is the singleton row from the policies table. type Policy struct { - Mode string - PolicyM0 bool - PolicyM2 bool - PolicyM3 bool - PolicyM4 bool - UpdatedAt time.Time + Mode string + PolicyM0 bool + PolicyM2 bool + PolicyM3 bool + PolicyM4 bool + FailClosedOnClassifierError bool + UpdatedAt time.Time } // PolicyPatch is the partial-update payload. Nil fields mean // "no change". type PolicyPatch struct { - Mode *string - PolicyM0 *bool - PolicyM2 *bool - PolicyM3 *bool - PolicyM4 *bool + Mode *string + PolicyM0 *bool + PolicyM2 *bool + PolicyM3 *bool + PolicyM4 *bool + FailClosedOnClassifierError *bool } // GetPolicy returns the singleton row. Migration 001 inserts a default // row so this never returns ErrNotFound on a healthy database. func (s *Store) GetPolicy(ctx context.Context) (*Policy, error) { row := s.db.QueryRowContext(ctx, - `SELECT mode, policy_m0, policy_m2, policy_m3, policy_m4, updated_at + `SELECT mode, policy_m0, policy_m2, policy_m3, policy_m4, + fail_closed_on_classifier_error, updated_at FROM policies WHERE id = 1`) var p Policy var updatedAt string - if err := row.Scan(&p.Mode, &p.PolicyM0, &p.PolicyM2, &p.PolicyM3, &p.PolicyM4, &updatedAt); err != nil { + if err := row.Scan(&p.Mode, &p.PolicyM0, &p.PolicyM2, &p.PolicyM3, &p.PolicyM4, + &p.FailClosedOnClassifierError, &updatedAt); err != nil { return nil, err } p.UpdatedAt = parseTime(updatedAt) @@ -53,9 +57,11 @@ func (s *Store) UpdatePolicy(ctx context.Context, patch *PolicyPatch) error { policy_m2 = COALESCE(?, policy_m2), policy_m3 = COALESCE(?, policy_m3), policy_m4 = COALESCE(?, policy_m4), + fail_closed_on_classifier_error = COALESCE(?, fail_closed_on_classifier_error), updated_at = (strftime('%Y-%m-%dT%H:%M:%fZ','now')) WHERE id = 1`, patch.Mode, boolPtrToInt(patch.PolicyM0), boolPtrToInt(patch.PolicyM2), - boolPtrToInt(patch.PolicyM3), boolPtrToInt(patch.PolicyM4)) + boolPtrToInt(patch.PolicyM3), boolPtrToInt(patch.PolicyM4), + boolPtrToInt(patch.FailClosedOnClassifierError)) return err } diff --git a/backend/internal/store/verdicts.go b/backend/internal/store/verdicts.go index 88e3e56..027490a 100644 --- a/backend/internal/store/verdicts.go +++ b/backend/internal/store/verdicts.go @@ -19,6 +19,7 @@ type Verdict struct { AgentProfileID *string MADCode string Classification string + VerdictStatus string Reasoning *string LatencyMS *int64 TokensUsed int32 @@ -31,6 +32,7 @@ type VerdictListRow struct { SessionID string MADCode string Classification string + VerdictStatus string LatencyMS *int64 TokensUsed int32 CreatedAt time.Time @@ -41,16 +43,21 @@ type VerdictFilters struct { Since time.Time Classification string // exact match (empty = no filter) MADCode string // exact match (empty = no filter) + VerdictStatus string // exact match (empty = no filter) } // InsertVerdict persists one classification result. func (s *Store) InsertVerdict(ctx context.Context, v *Verdict) error { + status := v.VerdictStatus + if status == "" { + status = "ok" + } _, err := s.db.ExecContext(ctx, `INSERT INTO verdicts - (id, event_id, session_id, agent_profile_id, mad_code, classification, reasoning, latency_ms, tokens_used) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, + (id, event_id, session_id, agent_profile_id, mad_code, classification, verdict_status, reasoning, latency_ms, tokens_used) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, v.ID, v.EventID, v.SessionID, v.AgentProfileID, - v.MADCode, v.Classification, v.Reasoning, v.LatencyMS, v.TokensUsed) + v.MADCode, v.Classification, status, v.Reasoning, v.LatencyMS, v.TokensUsed) return err } @@ -68,7 +75,7 @@ func (s *Store) ListVerdicts(ctx context.Context, f VerdictFilters, perPage, off args = append(args, perPage, offset) rows, err := s.db.QueryContext(ctx, - `SELECT id, event_id, session_id, mad_code, classification, + `SELECT id, event_id, session_id, mad_code, classification, verdict_status, latency_ms, tokens_used, created_at FROM verdicts WHERE `+where+` @@ -84,7 +91,7 @@ func (s *Store) ListVerdicts(ctx context.Context, f VerdictFilters, perPage, off r := &VerdictListRow{} var latency sql.NullInt64 var createdAt string - if err := rows.Scan(&r.ID, &r.EventID, &r.SessionID, &r.MADCode, &r.Classification, + if err := rows.Scan(&r.ID, &r.EventID, &r.SessionID, &r.MADCode, &r.Classification, &r.VerdictStatus, &latency, &r.TokensUsed, &createdAt); err != nil { return nil, 0, err } @@ -101,14 +108,14 @@ func (s *Store) ListVerdicts(ctx context.Context, f VerdictFilters, perPage, off // or ErrNotFound. func (s *Store) GetVerdictByEventID(ctx context.Context, eventID string) (*VerdictListRow, error) { row := s.db.QueryRowContext(ctx, - `SELECT id, event_id, session_id, mad_code, classification, + `SELECT id, event_id, session_id, mad_code, classification, verdict_status, latency_ms, tokens_used, created_at FROM verdicts WHERE event_id = ? ORDER BY created_at DESC LIMIT 1`, eventID) r := &VerdictListRow{} var latency sql.NullInt64 var createdAt string - if err := row.Scan(&r.ID, &r.EventID, &r.SessionID, &r.MADCode, &r.Classification, + if err := row.Scan(&r.ID, &r.EventID, &r.SessionID, &r.MADCode, &r.Classification, &r.VerdictStatus, &latency, &r.TokensUsed, &createdAt); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, ErrNotFound @@ -133,5 +140,9 @@ func verdictsWhere(f VerdictFilters) (string, []any) { parts = append(parts, "mad_code = ?") args = append(args, f.MADCode) } + if f.VerdictStatus != "" { + parts = append(parts, "verdict_status = ?") + args = append(args, f.VerdictStatus) + } return strings.Join(parts, " AND "), args } diff --git a/backend/internal/ws/frames.go b/backend/internal/ws/frames.go index 23c1e91..f9e9e58 100644 --- a/backend/internal/ws/frames.go +++ b/backend/internal/ws/frames.go @@ -31,11 +31,12 @@ const ( // can build HITL-resolution Verdict frames carrying the same shape. func PolicySnapshot(p *store.Policy) *pb.PolicySnapshot { return &pb.PolicySnapshot{ - Mode: modeFromString(p.Mode), - PolicyM0: p.PolicyM0, - PolicyM2: p.PolicyM2, - PolicyM3: p.PolicyM3, - PolicyM4: p.PolicyM4, + Mode: modeFromString(p.Mode), + PolicyM0: p.PolicyM0, + PolicyM2: p.PolicyM2, + PolicyM3: p.PolicyM3, + PolicyM4: p.PolicyM4, + FailClosedOnClassifierError: p.FailClosedOnClassifierError, } } diff --git a/backend/internal/ws/handler.go b/backend/internal/ws/handler.go index da28f68..a40706b 100644 --- a/backend/internal/ws/handler.go +++ b/backend/internal/ws/handler.go @@ -271,7 +271,7 @@ func persistAndClassify(ctx context.Context, sess *session, st *store.Store, cla for i := 0; i < 3; i++ { existing, err := st.GetVerdictByEventID(ctx, ev.EventId) if err == nil { - return dispatchVerdict(ctx, sess, st, hub, ev, snap, existing.ID, existing.MADCode) + return dispatchVerdict(ctx, sess, st, hub, ev, snap, existing.ID, existing.MADCode, existing.VerdictStatus) } if !errors.Is(err, store.ErrNotFound) { return err @@ -303,14 +303,27 @@ func persistAndClassify(ctx context.Context, sess *session, st *store.Store, cla } verdict, err := classifier.Classify(ctx, ev, agentProfileID) if err != nil { - // The HTTPClient implementation never returns a non-nil error - // (all classifier failures are mapped to a synthetic M0 / - // benign verdict by engine.HTTPClient.failOpen). This branch - // is defensive against future classifier implementations or - // context-cancellation edge cases: log and skip the event. - slog.ErrorContext(ctx, "ws.classify_unexpected_error", + slog.WarnContext(ctx, "ws.classifier_failure", "error", err, "event_id", ev.EventId) - return nil + reasoning := "classifier failure: " + err.Error() + vrow := &store.Verdict{ + ID: uuid.NewString(), + EventID: ev.EventId, + SessionID: sess.sessionID, + AgentProfileID: sess.agentProfileID(), + MADCode: "", + Classification: "error", + VerdictStatus: "error", + Reasoning: &reasoning, + TokensUsed: 0, + } + if err := st.InsertVerdict(ctx, vrow); err != nil { + return err + } + if hook != nil { + hook(ev.EventId, sess.sessionID, ev.GetAgent().GetAgentId(), "", "error") + } + return dispatchVerdict(ctx, sess, st, hub, ev, snap, vrow.ID, "", "error") } vrow := &store.Verdict{ @@ -320,6 +333,7 @@ func persistAndClassify(ctx context.Context, sess *session, st *store.Store, cla AgentProfileID: sess.agentProfileID(), MADCode: verdict.MADCode, Classification: verdict.Classification, + VerdictStatus: "ok", Reasoning: strPtrOrNil(verdict.Reasoning), LatencyMS: int64PtrIfNonZero(verdict.LatencyMS), TokensUsed: 0, @@ -336,10 +350,10 @@ func persistAndClassify(ctx context.Context, sess *session, st *store.Store, cla verdict.MADCode, verdict.Classification) } - return dispatchVerdict(ctx, sess, st, hub, ev, snap, vrow.ID, verdict.MADCode) + return dispatchVerdict(ctx, sess, st, hub, ev, snap, vrow.ID, verdict.MADCode, "ok") } -func dispatchVerdict(ctx context.Context, sess *session, st *store.Store, hub *Hub, ev *pb.PairedEvent, snap *pb.PolicySnapshot, verdictID, madCode string) error { +func dispatchVerdict(ctx context.Context, sess *session, st *store.Store, hub *Hub, ev *pb.PairedEvent, snap *pb.PolicySnapshot, verdictID, madCode, verdictStatus string) error { // Mode-gated dispatch: // alert: persist verdict, do NOT notify the SDK (dashboard-only). // hitl + in-scope + actionable: persist + queue for human review, @@ -383,6 +397,7 @@ func dispatchVerdict(ctx context.Context, sess *session, st *store.Store, hub *H EventId: ev.EventId, SessionId: sess.sessionID, MadCode: madCode, + Status: verdictStatusProto(verdictStatus), Policy: snap, }, }, @@ -394,6 +409,17 @@ func dispatchVerdict(ctx context.Context, sess *session, st *store.Store, hub *H return nil } +func verdictStatusProto(status string) pb.VerdictStatus { + switch status { + case "error": + return pb.VerdictStatus_VERDICT_STATUS_ERROR + case "ok": + return pb.VerdictStatus_VERDICT_STATUS_OK + default: + return pb.VerdictStatus_VERDICT_STATUS_UNSPECIFIED + } +} + func handleMcpInventory(ctx context.Context, sess *session, st *store.Store, inv *pb.McpInventory) error { if inv == nil { return nil diff --git a/backend/internal/ws/handler_test.go b/backend/internal/ws/handler_test.go index 6933458..cf70d9c 100644 --- a/backend/internal/ws/handler_test.go +++ b/backend/internal/ws/handler_test.go @@ -159,6 +159,101 @@ func TestRoundTrip(t *testing.T) { websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) } +// Phase 4 anchor for issue #46: a classifier transport / HTTP failure +// is persisted + pushed as an ERROR verdict with no MAD code. The +// mode-specific fail-closed policy matrix is layered on in Phase 5. +func TestClassifierFailurePersistsAndPublishesErrorVerdict(t *testing.T) { + db := openInMemoryDB(t) + t.Cleanup(func() { _ = db.Close() }) + + st := store.New(db) + plaintextKey := "adr_local_test_key_classifier_failure" + keyHash := sha256Hex(plaintextKey) + insertAPIKey(t, db, keyHash) + if _, err := db.Exec(`UPDATE policies SET mode = 'block' WHERE id = 1`); err != nil { + t.Fatalf("set mode=block: %v", err) + } + + llm := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "classifier exploded", http.StatusInternalServerError) + })) + t.Cleanup(llm.Close) + classifier := engine.NewHTTPClient(llm.URL, "test-key", "test-model", nil, nil) + + mux := http.NewServeMux() + mux.Handle("/ws", ws.AuthMiddleware(st)(ws.NewHandler(st, classifier, ws.NewHub(), nil, nil))) + srv := httptest.NewServer(mux) + t.Cleanup(srv.Close) + + wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + "/ws" + header := http.Header{"Authorization": {"Bearer " + plaintextKey}} + conn, _, err := websocket.DefaultDialer.Dial(wsURL, header) + if err != nil { + t.Fatalf("dial: %v", err) + } + t.Cleanup(func() { _ = conn.Close() }) + + if err := writeProto(conn, &bpb.ClientFrame{ + Frame: &bpb.ClientFrame_Login{Login: &bpb.SessionLogin{ + SessionId: "classifier-failure-sess", SchemaVersion: 2, + }}, + }); err != nil { + t.Fatalf("send login: %v", err) + } + if _, err := readServerFrame(conn); err != nil { + t.Fatalf("read login_ack: %v", err) + } + + eventID := uuid.NewString() + if err := writeProto(conn, &bpb.ClientFrame{ + Frame: &bpb.ClientFrame_PairedBatch{PairedBatch: &bpb.PairedEventBatch{ + Events: []*bpb.PairedEvent{{ + EventId: eventID, SessionId: "classifier-failure-sess", + RunId: "run-classifier-failure", + PairType: bpb.PairType_PAIR_TYPE_TOOL, + Agent: &bpb.AgentContext{AgentId: "failure-agent"}, + Data: &bpb.PairedEvent_Tool{Tool: &bpb.ToolPairData{ + ToolName: "noop", ToolCallId: "tc-classifier-failure", Input: "{}", Output: "ok", + }}, + }}, + }}, + }); err != nil { + t.Fatalf("send paired_batch: %v", err) + } + + frame, err := readServerFrame(conn) + if err != nil { + t.Fatalf("read verdict: %v", err) + } + verdict := frame.GetVerdict() + if verdict == nil { + t.Fatalf("expected Verdict, got %T", frame.Frame) + } + if verdict.MadCode != "" { + t.Fatalf("pushed mad_code = %q, want empty on classifier error", verdict.MadCode) + } + if verdict.Status != bpb.VerdictStatus_VERDICT_STATUS_ERROR { + t.Fatalf("pushed status = %v, want ERROR", verdict.Status) + } + + var madCode, classification, verdictStatus, reasoning string + if err := db.QueryRow( + `SELECT mad_code, classification, verdict_status, reasoning FROM verdicts WHERE event_id = ?`, + eventID, + ).Scan(&madCode, &classification, &verdictStatus, &reasoning); err != nil { + t.Fatalf("query verdict: %v", err) + } + if madCode != "" || classification != "error" || verdictStatus != "error" { + t.Fatalf("stored verdict = (%q, %q, %q), want ('', error, error)", + madCode, classification, verdictStatus) + } + if !strings.Contains(reasoning, "classifier failure") || + !strings.Contains(reasoning, "post:") || + !strings.Contains(reasoning, "status 500") { + t.Fatalf("stored reasoning = %q, want classifier failure with post/status 500", reasoning) + } +} + func TestDuplicateEventRetryKeepsWSOpen(t *testing.T) { db := openInMemoryDB(t) t.Cleanup(func() { _ = db.Close() }) @@ -660,6 +755,7 @@ CREATE TABLE policies ( policy_m2 INTEGER NOT NULL DEFAULT 0, policy_m3 INTEGER NOT NULL DEFAULT 1, policy_m4 INTEGER NOT NULL DEFAULT 1, + fail_closed_on_classifier_error INTEGER NOT NULL DEFAULT 0, updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ','now')) ); INSERT INTO policies (id) VALUES (1); @@ -681,6 +777,7 @@ CREATE TABLE verdicts ( agent_profile_id TEXT, mad_code TEXT NOT NULL, classification TEXT NOT NULL, + verdict_status TEXT NOT NULL DEFAULT 'ok', reasoning TEXT, latency_ms INTEGER, tokens_used INTEGER NOT NULL DEFAULT 0, diff --git a/backend/migrations/002_verdict_status_policy.sql b/backend/migrations/002_verdict_status_policy.sql new file mode 100644 index 0000000..ea7a816 --- /dev/null +++ b/backend/migrations/002_verdict_status_policy.sql @@ -0,0 +1,72 @@ +-- ============================================================ +-- Issue #46: verdict status + classifier-error policy toggle +-- ============================================================ +-- adrian: no-transaction +-- +-- Rebuild verdicts so the classification CHECK can admit the +-- classifier-error state. The Go/Python runners execute this file +-- outside their own transaction wrapper so foreign_keys can be +-- disabled before this migration's explicit transaction begins. +-- ============================================================ + +PRAGMA foreign_keys=OFF; + +BEGIN; + +ALTER TABLE policies + ADD COLUMN fail_closed_on_classifier_error INTEGER NOT NULL DEFAULT 0 + CHECK (fail_closed_on_classifier_error IN (0,1)); + +CREATE TABLE verdicts_new ( + id TEXT PRIMARY KEY, + event_id TEXT NOT NULL REFERENCES events(id) ON DELETE CASCADE, + session_id TEXT NOT NULL, + agent_profile_id TEXT REFERENCES agent_profiles(id) ON DELETE SET NULL, + mad_code TEXT NOT NULL, + classification TEXT NOT NULL CHECK (classification IN ('benign','notify','block','error')), + verdict_status TEXT NOT NULL DEFAULT 'ok' + CHECK (verdict_status IN ('ok','error')), + reasoning TEXT, + latency_ms INTEGER, + tokens_used INTEGER NOT NULL DEFAULT 0, + created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ','now')) +); + +INSERT INTO verdicts_new ( + id, + event_id, + session_id, + agent_profile_id, + mad_code, + classification, + verdict_status, + reasoning, + latency_ms, + tokens_used, + created_at +) +SELECT + id, + event_id, + session_id, + agent_profile_id, + mad_code, + classification, + 'ok', + reasoning, + latency_ms, + tokens_used, + created_at +FROM verdicts; + +DROP TABLE verdicts; +ALTER TABLE verdicts_new RENAME TO verdicts; + +CREATE INDEX IF NOT EXISTS idx_verdicts_event_id ON verdicts(event_id); +CREATE INDEX IF NOT EXISTS idx_verdicts_session_id ON verdicts(session_id); +CREATE INDEX IF NOT EXISTS idx_verdicts_created_at ON verdicts(created_at); + +COMMIT; + +PRAGMA foreign_key_check; +PRAGMA foreign_keys=ON; diff --git a/backend/migrations/embed.go b/backend/migrations/embed.go index e76170f..e08d168 100644 --- a/backend/migrations/embed.go +++ b/backend/migrations/embed.go @@ -4,10 +4,10 @@ // Package migrations embeds the SQL migration files for the Adrian // backend. The same files are also COPYed into the adrian-setup // bootstrap image (deploy/Dockerfile.setup), where setup.py applies -// them on first run. The backend re-applies them at startup so -// upgrades after `git pull` work without a manual step; every -// migration is idempotent (CREATE TABLE IF NOT EXISTS, INSERT OR -// IGNORE). +// pending migrations on bootstrap / apply-migrations. The backend +// also checks pending migrations at startup so upgrades after +// `git pull` work without a manual step. Both runners record applied +// filenames in schema_migrations. package migrations import "embed" diff --git a/backend/proto/event.proto b/backend/proto/event.proto index efac7d4..fe1ad33 100644 --- a/backend/proto/event.proto +++ b/backend/proto/event.proto @@ -201,6 +201,16 @@ enum Mode { MODE_BLOCK = 3; } +// VerdictStatus says whether a Verdict came from a completed classifier +// decision or represents a classifier failure. ERROR verdicts carry no +// classifier-produced MAD code; policy decides whether they fail open +// or fail closed. +enum VerdictStatus { + VERDICT_STATUS_UNSPECIFIED = 0; + VERDICT_STATUS_OK = 1; + VERDICT_STATUS_ERROR = 2; +} + // PolicySnapshot is the org's effective execution-mode policy at the moment // a verdict was decided. Attached by the server to every Verdict it sends // so the SDK can apply user-configured behaviour (halt vs continue, @@ -208,12 +218,16 @@ enum Mode { // // Per-MAD-code booleans say whether the active mode's behaviour fires on // that code. False means "treat this code as silent regardless of mode". +// fail_closed_on_classifier_error controls ERROR verdicts and BLOCK-mode +// SDK verdict timeouts. The default false value preserves fail-open +// availability when talking to older backends. message PolicySnapshot { Mode mode = 1; bool policy_m0 = 2; bool policy_m2 = 3; bool policy_m3 = 4; bool policy_m4 = 5; + bool fail_closed_on_classifier_error = 6; } // HitlResponse rides on a Verdict that has been resolved through the @@ -262,7 +276,10 @@ message Verdict { // off the wire. Reserved so the slot can't be reused. reserved 3; reserved "classification"; - // MAD code the classifier returned (e.g. "M0", "M2_C", "M4_a"). Empty string for benign. + // MAD code the classifier returned (e.g. "M0", "M2_C", "M4_a"). + // Empty string means no MAD code was produced, such as for a + // VerdictStatus.ERROR classifier failure. Benign classifier success + // is represented by status OK with mad_code "M0". string mad_code = 4; // Field 5 previously held `bool escalate`, a verdict-level flag the // engine derived from the classifier reasoning string. Removed because @@ -280,4 +297,8 @@ message Verdict { // resolution (approve or reject from the dashboard). Absent on auto- // classified verdicts and on out-of-scope verdicts forwarded immediately. HitlResponse hitl = 7; + // Status of the classifier result. OK means mad_code carries a normal + // classifier decision. ERROR means classification did not complete and + // mad_code is empty; fail-open/fail-closed behaviour comes from policy. + VerdictStatus status = 8; } diff --git a/proto/event.proto b/proto/event.proto index aaa4b9b..0fcbfef 100644 --- a/proto/event.proto +++ b/proto/event.proto @@ -201,6 +201,16 @@ enum Mode { MODE_BLOCK = 3; } +// VerdictStatus says whether a Verdict came from a completed classifier +// decision or represents a classifier failure. ERROR verdicts carry no +// classifier-produced MAD code; policy decides whether they fail open +// or fail closed. +enum VerdictStatus { + VERDICT_STATUS_UNSPECIFIED = 0; + VERDICT_STATUS_OK = 1; + VERDICT_STATUS_ERROR = 2; +} + // PolicySnapshot is the org's effective execution-mode policy at the moment // a verdict was decided. Attached by the server to every Verdict it sends // so the SDK can apply user-configured behaviour (halt vs continue, @@ -208,12 +218,16 @@ enum Mode { // // Per-MAD-code booleans say whether the active mode's behaviour fires on // that code. False means "treat this code as silent regardless of mode". +// fail_closed_on_classifier_error controls ERROR verdicts and BLOCK-mode +// SDK verdict timeouts. The default false value preserves fail-open +// availability when talking to older backends. message PolicySnapshot { Mode mode = 1; bool policy_m0 = 2; bool policy_m2 = 3; bool policy_m3 = 4; bool policy_m4 = 5; + bool fail_closed_on_classifier_error = 6; } // HitlResponse rides on a Verdict that has been resolved through the @@ -262,7 +276,10 @@ message Verdict { // off the wire. Reserved so the slot can't be reused. reserved 3; reserved "classification"; - // MAD code the classifier returned (e.g. "M0", "M2_C", "M4_a"). Empty string for benign. + // MAD code the classifier returned (e.g. "M0", "M2_C", "M4_a"). + // Empty string means no MAD code was produced, such as for a + // VerdictStatus.ERROR classifier failure. Benign classifier success + // is represented by status OK with mad_code "M0". string mad_code = 4; // Field 5 previously held `bool escalate`, a verdict-level flag the // engine derived from the classifier reasoning string. Removed because @@ -280,4 +297,8 @@ message Verdict { // resolution (approve or reject from the dashboard). Absent on auto- // classified verdicts and on out-of-scope verdicts forwarded immediately. HitlResponse hitl = 7; + // Status of the classifier result. OK means mad_code carries a normal + // classifier decision. ERROR means classification did not complete and + // mad_code is empty; fail-open/fail-closed behaviour comes from policy. + VerdictStatus status = 8; } diff --git a/scripts/setup.py b/scripts/setup.py index 931a616..86bf47e 100644 --- a/scripts/setup.py +++ b/scripts/setup.py @@ -72,6 +72,7 @@ }, } DEFAULT_VARIANT = "E4B" +NO_TRANSACTION_MARKER = "-- adrian: no-transaction" # ---------------------------------------------------------------- @@ -108,21 +109,65 @@ def open_db(db_path: Path) -> sqlite3.Connection: def apply_migrations(conn: sqlite3.Connection, migrations_dir: Path) -> list[str]: - """Apply every `*.sql` file in lexical order. Returns the list of - files applied. The migrations themselves use `IF NOT EXISTS` and - `INSERT OR IGNORE`, so re-applying is a no-op.""" + """Apply previously-unseen `*.sql` files in lexical order. + + Applied filenames are recorded in `schema_migrations`, matching the + Go backend runner. This keeps setup/bootstrap safe for future + migrations that cannot be written as idempotent SQL. + """ + conn.execute( + """ + CREATE TABLE IF NOT EXISTS schema_migrations ( + name TEXT PRIMARY KEY, + applied_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ','now')) + ) + """, + ) + conn.commit() + applied: list[str] = [] sql_files = sorted(migrations_dir.glob("*.sql")) if not sql_files: raise SystemExit(f"no migrations found in {migrations_dir}") for path in sql_files: + if migration_applied(conn, path.name): + continue sql = path.read_text(encoding="utf-8") - conn.executescript(sql) + if NO_TRANSACTION_MARKER in sql: + try: + conn.executescript(sql) + conn.execute("INSERT INTO schema_migrations (name) VALUES (?)", (path.name,)) + conn.commit() + except sqlite3.Error: + conn.rollback() + conn.execute("PRAGMA foreign_keys=ON") + raise + else: + quoted_name = path.name.replace("'", "''") + try: + conn.executescript( + "BEGIN;\n" + f"{sql}\n" + "INSERT INTO schema_migrations (name) " + f"VALUES ('{quoted_name}');\n" + "COMMIT;\n", + ) + except sqlite3.Error: + conn.rollback() + raise applied.append(path.name) - conn.commit() return applied +def migration_applied(conn: sqlite3.Connection, name: str) -> bool: + """Return True when `name` has already been recorded in the ledger.""" + row = conn.execute( + "SELECT 1 FROM schema_migrations WHERE name = ?", + (name,), + ).fetchone() + return row is not None + + def read_env(env_path: Path) -> dict[str, str]: """Parse a `KEY=VALUE` env file. Comments and blanks ignored. Quoted values are stripped of surrounding double quotes.""" @@ -540,7 +585,7 @@ def cmd_apply_migrations(args: argparse.Namespace) -> int: sys.stdout.write( f"\n" - f"v {len(applied)} migration file(s) re-applied (idempotent):\n" + f"v {len(applied)} new migration file(s) applied:\n" + "".join(f" {name}\n" for name in applied) + "\n" ) @@ -587,7 +632,7 @@ def build_parser() -> argparse.ArgumentParser: p_model.add_argument("--ctx-size", type=int, default=None) p_model.set_defaults(func=cmd_set_model) - p_migrate = sub.add_parser("apply-migrations", help="re-apply schema migrations") + p_migrate = sub.add_parser("apply-migrations", help="apply pending schema migrations") p_migrate.set_defaults(func=cmd_apply_migrations) return parser diff --git a/sdk/python/adrian/proto/event_pb2.py b/sdk/python/adrian/proto/event_pb2.py index 696cb1a..42050a9 100644 --- a/sdk/python/adrian/proto/event_pb2.py +++ b/sdk/python/adrian/proto/event_pb2.py @@ -1,13 +1,22 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE # source: event.proto +# Protobuf Python Version: 6.33.5 """Generated protocol buffer code.""" -from google.protobuf.internal import enum_type_wrapper from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection +from google.protobuf import runtime_version as _runtime_version from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +_runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, + 6, + 33, + 5, + '', + 'event.proto' +) # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -16,230 +25,77 @@ from .buf.validate import validate_pb2 as buf_dot_validate_dot_validate__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0b\x65vent.proto\x12\x12\x61\x64rian.core_api.v1\x1a\x1b\x62uf/validate/validate.proto\",\n\x0b\x43hatMessage\x12\x0c\n\x04role\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontent\x18\x02 \x01(\t\";\n\x08ToolCall\x12\x15\n\x04name\x18\x01 \x01(\tB\x07\xbaH\x04r\x02\x10\x01\x12\x0c\n\x04\x61rgs\x18\x02 \x01(\t\x12\n\n\x02id\x18\x03 \x01(\t\"o\n\nTokenUsage\x12\x1e\n\rprompt_tokens\x18\x01 \x01(\x05\x42\x07\xbaH\x04\x1a\x02(\x00\x12\"\n\x11\x63ompletion_tokens\x18\x02 \x01(\x05\x42\x07\xbaH\x04\x1a\x02(\x00\x12\x1d\n\x0ctotal_tokens\x18\x03 \x01(\x05\x42\x07\xbaH\x04\x1a\x02(\x00\"Q\n\x0c\x41gentContext\x12\x10\n\x08\x61gent_id\x18\x01 \x01(\t\x12\x15\n\rsystem_prompt\x18\x02 \x01(\t\x12\x18\n\x10user_instruction\x18\x03 \x01(\t\"\xc0\x01\n\x0bLlmPairData\x12\r\n\x05model\x18\x01 \x01(\t\x12\x31\n\x08messages\x18\x02 \x03(\x0b\x32\x1f.adrian.core_api.v1.ChatMessage\x12\x0e\n\x06output\x18\x03 \x01(\t\x12\x30\n\ntool_calls\x18\x04 \x03(\x0b\x32\x1c.adrian.core_api.v1.ToolCall\x12-\n\x05usage\x18\x05 \x01(\x0b\x32\x1e.adrian.core_api.v1.TokenUsage\"_\n\x0cToolPairData\x12\x1a\n\ttool_name\x18\x01 \x01(\tB\x07\xbaH\x04r\x02\x10\x01\x12\x14\n\x0ctool_call_id\x18\x02 \x01(\t\x12\r\n\x05input\x18\x03 \x01(\t\x12\x0e\n\x06output\x18\x04 \x01(\t\"\xb3\x03\n\x0bPairedEvent\x12\x19\n\x08\x65vent_id\x18\x01 \x01(\tB\x07\xbaH\x04r\x02\x10\x01\x12\x15\n\rinvocation_id\x18\x02 \x01(\t\x12\x1b\n\nsession_id\x18\x03 \x01(\tB\x07\xbaH\x04r\x02\x10\x01\x12\x0e\n\x06run_id\x18\x04 \x01(\t\x12\x15\n\rparent_run_id\x18\x05 \x01(\t\x12\x11\n\ttimestamp\x18\x06 \x01(\t\x12\x37\n\tpair_type\x18\x07 \x01(\x0e\x32\x1c.adrian.core_api.v1.PairTypeB\x06\xbaH\x03\xc8\x01\x01\x12/\n\x05\x61gent\x18\x08 \x01(\x0b\x32 .adrian.core_api.v1.AgentContext\x12\x30\n\x06parent\x18\t \x01(\x0b\x32 .adrian.core_api.v1.AgentContext\x12.\n\x03llm\x18\n \x01(\x0b\x32\x1f.adrian.core_api.v1.LlmPairDataH\x00\x12\x30\n\x04tool\x18\x0b \x01(\x0b\x32 .adrian.core_api.v1.ToolPairDataH\x00\x12\x15\n\rmetadata_json\x18\x14 \x01(\x0c\x42\x06\n\x04\x64\x61ta\"C\n\x10PairedEventBatch\x12/\n\x06\x65vents\x18\x01 \x03(\x0b\x32\x1f.adrian.core_api.v1.PairedEvent\"G\n\tMcpServer\x12\x15\n\x04name\x18\x01 \x01(\tB\x07\xbaH\x04r\x02\x10\x01\x12\x11\n\ttransport\x18\x02 \x01(\t\x12\x10\n\x08\x65ndpoint\x18\x03 \x01(\t\">\n\x0cMcpInventory\x12.\n\x07servers\x18\x01 \x03(\x0b\x32\x1d.adrian.core_api.v1.McpServer\"+\n\x08LLMStack\x12\x10\n\x08provider\x18\x01 \x01(\t\x12\r\n\x05model\x18\x02 \x01(\t\"\x86\x01\n\x0cSessionLogin\x12\x1b\n\nsession_id\x18\x01 \x01(\tB\x07\xbaH\x04r\x02\x10\x01\x12/\n\tllm_stack\x18\x02 \x01(\x0b\x32\x1c.adrian.core_api.v1.LLMStack\x12\x16\n\x0eschema_version\x18\x04 \x01(\rJ\x04\x08\x03\x10\x04R\nblock_mode\"\xcf\x01\n\x0b\x43lientFrame\x12\x31\n\x05login\x18\x01 \x01(\x0b\x32 .adrian.core_api.v1.SessionLoginH\x00\x12<\n\x0cpaired_batch\x18\x03 \x01(\x0b\x32$.adrian.core_api.v1.PairedEventBatchH\x00\x12\x39\n\rmcp_inventory\x18\x04 \x01(\x0b\x32 .adrian.core_api.v1.McpInventoryH\x00\x42\x07\n\x05\x66rameJ\x04\x08\x02\x10\x03R\x05\x62\x61tch\"\x84\x01\n\x0ePolicySnapshot\x12&\n\x04mode\x18\x01 \x01(\x0e\x32\x18.adrian.core_api.v1.Mode\x12\x11\n\tpolicy_m0\x18\x02 \x01(\x08\x12\x11\n\tpolicy_m2\x18\x03 \x01(\x08\x12\x11\n\tpolicy_m3\x18\x04 \x01(\x08\x12\x11\n\tpolicy_m4\x18\x05 \x01(\x08\"*\n\x0cHitlResponse\x12\x1a\n\x12\x63ontinue_execution\x18\x01 \x01(\x08\">\n\x08LoginAck\x12\x32\n\x06policy\x18\x01 \x01(\x0b\x32\".adrian.core_api.v1.PolicySnapshot\"y\n\x0bServerFrame\x12\x31\n\tlogin_ack\x18\x01 \x01(\x0b\x32\x1c.adrian.core_api.v1.LoginAckH\x00\x12.\n\x07verdict\x18\x02 \x01(\x0b\x32\x1b.adrian.core_api.v1.VerdictH\x00\x42\x07\n\x05\x66rame\"\xdd\x01\n\x07Verdict\x12\x19\n\x08\x65vent_id\x18\x01 \x01(\tB\x07\xbaH\x04r\x02\x10\x01\x12\x1b\n\nsession_id\x18\x02 \x01(\tB\x07\xbaH\x04r\x02\x10\x01\x12\x10\n\x08mad_code\x18\x04 \x01(\t\x12\x32\n\x06policy\x18\x06 \x01(\x0b\x32\".adrian.core_api.v1.PolicySnapshot\x12.\n\x04hitl\x18\x07 \x01(\x0b\x32 .adrian.core_api.v1.HitlResponseJ\x04\x08\x03\x10\x04J\x04\x08\x05\x10\x06R\x0e\x63lassificationR\x08\x65scalate*L\n\x08PairType\x12\x19\n\x15PAIR_TYPE_UNSPECIFIED\x10\x00\x12\x11\n\rPAIR_TYPE_LLM\x10\x01\x12\x12\n\x0ePAIR_TYPE_TOOL\x10\x02*K\n\x04Mode\x12\x14\n\x10MODE_UNSPECIFIED\x10\x00\x12\x0e\n\nMODE_ALERT\x10\x01\x12\r\n\tMODE_HITL\x10\x02\x12\x0e\n\nMODE_BLOCK\x10\x03\x62\x06proto3') - -_PAIRTYPE = DESCRIPTOR.enum_types_by_name['PairType'] -PairType = enum_type_wrapper.EnumTypeWrapper(_PAIRTYPE) -_MODE = DESCRIPTOR.enum_types_by_name['Mode'] -Mode = enum_type_wrapper.EnumTypeWrapper(_MODE) -PAIR_TYPE_UNSPECIFIED = 0 -PAIR_TYPE_LLM = 1 -PAIR_TYPE_TOOL = 2 -MODE_UNSPECIFIED = 0 -MODE_ALERT = 1 -MODE_HITL = 2 -MODE_BLOCK = 3 - - -_CHATMESSAGE = DESCRIPTOR.message_types_by_name['ChatMessage'] -_TOOLCALL = DESCRIPTOR.message_types_by_name['ToolCall'] -_TOKENUSAGE = DESCRIPTOR.message_types_by_name['TokenUsage'] -_AGENTCONTEXT = DESCRIPTOR.message_types_by_name['AgentContext'] -_LLMPAIRDATA = DESCRIPTOR.message_types_by_name['LlmPairData'] -_TOOLPAIRDATA = DESCRIPTOR.message_types_by_name['ToolPairData'] -_PAIREDEVENT = DESCRIPTOR.message_types_by_name['PairedEvent'] -_PAIREDEVENTBATCH = DESCRIPTOR.message_types_by_name['PairedEventBatch'] -_MCPSERVER = DESCRIPTOR.message_types_by_name['McpServer'] -_MCPINVENTORY = DESCRIPTOR.message_types_by_name['McpInventory'] -_LLMSTACK = DESCRIPTOR.message_types_by_name['LLMStack'] -_SESSIONLOGIN = DESCRIPTOR.message_types_by_name['SessionLogin'] -_CLIENTFRAME = DESCRIPTOR.message_types_by_name['ClientFrame'] -_POLICYSNAPSHOT = DESCRIPTOR.message_types_by_name['PolicySnapshot'] -_HITLRESPONSE = DESCRIPTOR.message_types_by_name['HitlResponse'] -_LOGINACK = DESCRIPTOR.message_types_by_name['LoginAck'] -_SERVERFRAME = DESCRIPTOR.message_types_by_name['ServerFrame'] -_VERDICT = DESCRIPTOR.message_types_by_name['Verdict'] -ChatMessage = _reflection.GeneratedProtocolMessageType('ChatMessage', (_message.Message,), { - 'DESCRIPTOR' : _CHATMESSAGE, - '__module__' : 'event_pb2' - # @@protoc_insertion_point(class_scope:adrian.core_api.v1.ChatMessage) - }) -_sym_db.RegisterMessage(ChatMessage) - -ToolCall = _reflection.GeneratedProtocolMessageType('ToolCall', (_message.Message,), { - 'DESCRIPTOR' : _TOOLCALL, - '__module__' : 'event_pb2' - # @@protoc_insertion_point(class_scope:adrian.core_api.v1.ToolCall) - }) -_sym_db.RegisterMessage(ToolCall) - -TokenUsage = _reflection.GeneratedProtocolMessageType('TokenUsage', (_message.Message,), { - 'DESCRIPTOR' : _TOKENUSAGE, - '__module__' : 'event_pb2' - # @@protoc_insertion_point(class_scope:adrian.core_api.v1.TokenUsage) - }) -_sym_db.RegisterMessage(TokenUsage) - -AgentContext = _reflection.GeneratedProtocolMessageType('AgentContext', (_message.Message,), { - 'DESCRIPTOR' : _AGENTCONTEXT, - '__module__' : 'event_pb2' - # @@protoc_insertion_point(class_scope:adrian.core_api.v1.AgentContext) - }) -_sym_db.RegisterMessage(AgentContext) - -LlmPairData = _reflection.GeneratedProtocolMessageType('LlmPairData', (_message.Message,), { - 'DESCRIPTOR' : _LLMPAIRDATA, - '__module__' : 'event_pb2' - # @@protoc_insertion_point(class_scope:adrian.core_api.v1.LlmPairData) - }) -_sym_db.RegisterMessage(LlmPairData) - -ToolPairData = _reflection.GeneratedProtocolMessageType('ToolPairData', (_message.Message,), { - 'DESCRIPTOR' : _TOOLPAIRDATA, - '__module__' : 'event_pb2' - # @@protoc_insertion_point(class_scope:adrian.core_api.v1.ToolPairData) - }) -_sym_db.RegisterMessage(ToolPairData) - -PairedEvent = _reflection.GeneratedProtocolMessageType('PairedEvent', (_message.Message,), { - 'DESCRIPTOR' : _PAIREDEVENT, - '__module__' : 'event_pb2' - # @@protoc_insertion_point(class_scope:adrian.core_api.v1.PairedEvent) - }) -_sym_db.RegisterMessage(PairedEvent) - -PairedEventBatch = _reflection.GeneratedProtocolMessageType('PairedEventBatch', (_message.Message,), { - 'DESCRIPTOR' : _PAIREDEVENTBATCH, - '__module__' : 'event_pb2' - # @@protoc_insertion_point(class_scope:adrian.core_api.v1.PairedEventBatch) - }) -_sym_db.RegisterMessage(PairedEventBatch) - -McpServer = _reflection.GeneratedProtocolMessageType('McpServer', (_message.Message,), { - 'DESCRIPTOR' : _MCPSERVER, - '__module__' : 'event_pb2' - # @@protoc_insertion_point(class_scope:adrian.core_api.v1.McpServer) - }) -_sym_db.RegisterMessage(McpServer) - -McpInventory = _reflection.GeneratedProtocolMessageType('McpInventory', (_message.Message,), { - 'DESCRIPTOR' : _MCPINVENTORY, - '__module__' : 'event_pb2' - # @@protoc_insertion_point(class_scope:adrian.core_api.v1.McpInventory) - }) -_sym_db.RegisterMessage(McpInventory) - -LLMStack = _reflection.GeneratedProtocolMessageType('LLMStack', (_message.Message,), { - 'DESCRIPTOR' : _LLMSTACK, - '__module__' : 'event_pb2' - # @@protoc_insertion_point(class_scope:adrian.core_api.v1.LLMStack) - }) -_sym_db.RegisterMessage(LLMStack) - -SessionLogin = _reflection.GeneratedProtocolMessageType('SessionLogin', (_message.Message,), { - 'DESCRIPTOR' : _SESSIONLOGIN, - '__module__' : 'event_pb2' - # @@protoc_insertion_point(class_scope:adrian.core_api.v1.SessionLogin) - }) -_sym_db.RegisterMessage(SessionLogin) - -ClientFrame = _reflection.GeneratedProtocolMessageType('ClientFrame', (_message.Message,), { - 'DESCRIPTOR' : _CLIENTFRAME, - '__module__' : 'event_pb2' - # @@protoc_insertion_point(class_scope:adrian.core_api.v1.ClientFrame) - }) -_sym_db.RegisterMessage(ClientFrame) - -PolicySnapshot = _reflection.GeneratedProtocolMessageType('PolicySnapshot', (_message.Message,), { - 'DESCRIPTOR' : _POLICYSNAPSHOT, - '__module__' : 'event_pb2' - # @@protoc_insertion_point(class_scope:adrian.core_api.v1.PolicySnapshot) - }) -_sym_db.RegisterMessage(PolicySnapshot) - -HitlResponse = _reflection.GeneratedProtocolMessageType('HitlResponse', (_message.Message,), { - 'DESCRIPTOR' : _HITLRESPONSE, - '__module__' : 'event_pb2' - # @@protoc_insertion_point(class_scope:adrian.core_api.v1.HitlResponse) - }) -_sym_db.RegisterMessage(HitlResponse) - -LoginAck = _reflection.GeneratedProtocolMessageType('LoginAck', (_message.Message,), { - 'DESCRIPTOR' : _LOGINACK, - '__module__' : 'event_pb2' - # @@protoc_insertion_point(class_scope:adrian.core_api.v1.LoginAck) - }) -_sym_db.RegisterMessage(LoginAck) - -ServerFrame = _reflection.GeneratedProtocolMessageType('ServerFrame', (_message.Message,), { - 'DESCRIPTOR' : _SERVERFRAME, - '__module__' : 'event_pb2' - # @@protoc_insertion_point(class_scope:adrian.core_api.v1.ServerFrame) - }) -_sym_db.RegisterMessage(ServerFrame) - -Verdict = _reflection.GeneratedProtocolMessageType('Verdict', (_message.Message,), { - 'DESCRIPTOR' : _VERDICT, - '__module__' : 'event_pb2' - # @@protoc_insertion_point(class_scope:adrian.core_api.v1.Verdict) - }) -_sym_db.RegisterMessage(Verdict) - -if _descriptor._USE_C_DESCRIPTORS == False: - - DESCRIPTOR._options = None - _TOOLCALL.fields_by_name['name']._options = None - _TOOLCALL.fields_by_name['name']._serialized_options = b'\272H\004r\002\020\001' - _TOKENUSAGE.fields_by_name['prompt_tokens']._options = None - _TOKENUSAGE.fields_by_name['prompt_tokens']._serialized_options = b'\272H\004\032\002(\000' - _TOKENUSAGE.fields_by_name['completion_tokens']._options = None - _TOKENUSAGE.fields_by_name['completion_tokens']._serialized_options = b'\272H\004\032\002(\000' - _TOKENUSAGE.fields_by_name['total_tokens']._options = None - _TOKENUSAGE.fields_by_name['total_tokens']._serialized_options = b'\272H\004\032\002(\000' - _TOOLPAIRDATA.fields_by_name['tool_name']._options = None - _TOOLPAIRDATA.fields_by_name['tool_name']._serialized_options = b'\272H\004r\002\020\001' - _PAIREDEVENT.fields_by_name['event_id']._options = None - _PAIREDEVENT.fields_by_name['event_id']._serialized_options = b'\272H\004r\002\020\001' - _PAIREDEVENT.fields_by_name['session_id']._options = None - _PAIREDEVENT.fields_by_name['session_id']._serialized_options = b'\272H\004r\002\020\001' - _PAIREDEVENT.fields_by_name['pair_type']._options = None - _PAIREDEVENT.fields_by_name['pair_type']._serialized_options = b'\272H\003\310\001\001' - _MCPSERVER.fields_by_name['name']._options = None - _MCPSERVER.fields_by_name['name']._serialized_options = b'\272H\004r\002\020\001' - _SESSIONLOGIN.fields_by_name['session_id']._options = None - _SESSIONLOGIN.fields_by_name['session_id']._serialized_options = b'\272H\004r\002\020\001' - _VERDICT.fields_by_name['event_id']._options = None - _VERDICT.fields_by_name['event_id']._serialized_options = b'\272H\004r\002\020\001' - _VERDICT.fields_by_name['session_id']._options = None - _VERDICT.fields_by_name['session_id']._serialized_options = b'\272H\004r\002\020\001' - _PAIRTYPE._serialized_start=2285 - _PAIRTYPE._serialized_end=2361 - _MODE._serialized_start=2363 - _MODE._serialized_end=2438 - _CHATMESSAGE._serialized_start=64 - _CHATMESSAGE._serialized_end=108 - _TOOLCALL._serialized_start=110 - _TOOLCALL._serialized_end=169 - _TOKENUSAGE._serialized_start=171 - _TOKENUSAGE._serialized_end=282 - _AGENTCONTEXT._serialized_start=284 - _AGENTCONTEXT._serialized_end=365 - _LLMPAIRDATA._serialized_start=368 - _LLMPAIRDATA._serialized_end=560 - _TOOLPAIRDATA._serialized_start=562 - _TOOLPAIRDATA._serialized_end=657 - _PAIREDEVENT._serialized_start=660 - _PAIREDEVENT._serialized_end=1095 - _PAIREDEVENTBATCH._serialized_start=1097 - _PAIREDEVENTBATCH._serialized_end=1164 - _MCPSERVER._serialized_start=1166 - _MCPSERVER._serialized_end=1237 - _MCPINVENTORY._serialized_start=1239 - _MCPINVENTORY._serialized_end=1301 - _LLMSTACK._serialized_start=1303 - _LLMSTACK._serialized_end=1346 - _SESSIONLOGIN._serialized_start=1349 - _SESSIONLOGIN._serialized_end=1483 - _CLIENTFRAME._serialized_start=1486 - _CLIENTFRAME._serialized_end=1693 - _POLICYSNAPSHOT._serialized_start=1696 - _POLICYSNAPSHOT._serialized_end=1828 - _HITLRESPONSE._serialized_start=1830 - _HITLRESPONSE._serialized_end=1872 - _LOGINACK._serialized_start=1874 - _LOGINACK._serialized_end=1936 - _SERVERFRAME._serialized_start=1938 - _SERVERFRAME._serialized_end=2059 - _VERDICT._serialized_start=2062 - _VERDICT._serialized_end=2283 +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0b\x65vent.proto\x12\x12\x61\x64rian.core_api.v1\x1a\x1b\x62uf/validate/validate.proto\",\n\x0b\x43hatMessage\x12\x0c\n\x04role\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontent\x18\x02 \x01(\t\";\n\x08ToolCall\x12\x15\n\x04name\x18\x01 \x01(\tB\x07\xbaH\x04r\x02\x10\x01\x12\x0c\n\x04\x61rgs\x18\x02 \x01(\t\x12\n\n\x02id\x18\x03 \x01(\t\"o\n\nTokenUsage\x12\x1e\n\rprompt_tokens\x18\x01 \x01(\x05\x42\x07\xbaH\x04\x1a\x02(\x00\x12\"\n\x11\x63ompletion_tokens\x18\x02 \x01(\x05\x42\x07\xbaH\x04\x1a\x02(\x00\x12\x1d\n\x0ctotal_tokens\x18\x03 \x01(\x05\x42\x07\xbaH\x04\x1a\x02(\x00\"Q\n\x0c\x41gentContext\x12\x10\n\x08\x61gent_id\x18\x01 \x01(\t\x12\x15\n\rsystem_prompt\x18\x02 \x01(\t\x12\x18\n\x10user_instruction\x18\x03 \x01(\t\"\xc0\x01\n\x0bLlmPairData\x12\r\n\x05model\x18\x01 \x01(\t\x12\x31\n\x08messages\x18\x02 \x03(\x0b\x32\x1f.adrian.core_api.v1.ChatMessage\x12\x0e\n\x06output\x18\x03 \x01(\t\x12\x30\n\ntool_calls\x18\x04 \x03(\x0b\x32\x1c.adrian.core_api.v1.ToolCall\x12-\n\x05usage\x18\x05 \x01(\x0b\x32\x1e.adrian.core_api.v1.TokenUsage\"_\n\x0cToolPairData\x12\x1a\n\ttool_name\x18\x01 \x01(\tB\x07\xbaH\x04r\x02\x10\x01\x12\x14\n\x0ctool_call_id\x18\x02 \x01(\t\x12\r\n\x05input\x18\x03 \x01(\t\x12\x0e\n\x06output\x18\x04 \x01(\t\"\xb3\x03\n\x0bPairedEvent\x12\x19\n\x08\x65vent_id\x18\x01 \x01(\tB\x07\xbaH\x04r\x02\x10\x01\x12\x15\n\rinvocation_id\x18\x02 \x01(\t\x12\x1b\n\nsession_id\x18\x03 \x01(\tB\x07\xbaH\x04r\x02\x10\x01\x12\x0e\n\x06run_id\x18\x04 \x01(\t\x12\x15\n\rparent_run_id\x18\x05 \x01(\t\x12\x11\n\ttimestamp\x18\x06 \x01(\t\x12\x37\n\tpair_type\x18\x07 \x01(\x0e\x32\x1c.adrian.core_api.v1.PairTypeB\x06\xbaH\x03\xc8\x01\x01\x12/\n\x05\x61gent\x18\x08 \x01(\x0b\x32 .adrian.core_api.v1.AgentContext\x12\x30\n\x06parent\x18\t \x01(\x0b\x32 .adrian.core_api.v1.AgentContext\x12.\n\x03llm\x18\n \x01(\x0b\x32\x1f.adrian.core_api.v1.LlmPairDataH\x00\x12\x30\n\x04tool\x18\x0b \x01(\x0b\x32 .adrian.core_api.v1.ToolPairDataH\x00\x12\x15\n\rmetadata_json\x18\x14 \x01(\x0c\x42\x06\n\x04\x64\x61ta\"C\n\x10PairedEventBatch\x12/\n\x06\x65vents\x18\x01 \x03(\x0b\x32\x1f.adrian.core_api.v1.PairedEvent\"G\n\tMcpServer\x12\x15\n\x04name\x18\x01 \x01(\tB\x07\xbaH\x04r\x02\x10\x01\x12\x11\n\ttransport\x18\x02 \x01(\t\x12\x10\n\x08\x65ndpoint\x18\x03 \x01(\t\">\n\x0cMcpInventory\x12.\n\x07servers\x18\x01 \x03(\x0b\x32\x1d.adrian.core_api.v1.McpServer\"+\n\x08LLMStack\x12\x10\n\x08provider\x18\x01 \x01(\t\x12\r\n\x05model\x18\x02 \x01(\t\"\x86\x01\n\x0cSessionLogin\x12\x1b\n\nsession_id\x18\x01 \x01(\tB\x07\xbaH\x04r\x02\x10\x01\x12/\n\tllm_stack\x18\x02 \x01(\x0b\x32\x1c.adrian.core_api.v1.LLMStack\x12\x16\n\x0eschema_version\x18\x04 \x01(\rJ\x04\x08\x03\x10\x04R\nblock_mode\"\xcf\x01\n\x0b\x43lientFrame\x12\x31\n\x05login\x18\x01 \x01(\x0b\x32 .adrian.core_api.v1.SessionLoginH\x00\x12<\n\x0cpaired_batch\x18\x03 \x01(\x0b\x32$.adrian.core_api.v1.PairedEventBatchH\x00\x12\x39\n\rmcp_inventory\x18\x04 \x01(\x0b\x32 .adrian.core_api.v1.McpInventoryH\x00\x42\x07\n\x05\x66rameJ\x04\x08\x02\x10\x03R\x05\x62\x61tch\"\xad\x01\n\x0ePolicySnapshot\x12&\n\x04mode\x18\x01 \x01(\x0e\x32\x18.adrian.core_api.v1.Mode\x12\x11\n\tpolicy_m0\x18\x02 \x01(\x08\x12\x11\n\tpolicy_m2\x18\x03 \x01(\x08\x12\x11\n\tpolicy_m3\x18\x04 \x01(\x08\x12\x11\n\tpolicy_m4\x18\x05 \x01(\x08\x12\'\n\x1f\x66\x61il_closed_on_classifier_error\x18\x06 \x01(\x08\"*\n\x0cHitlResponse\x12\x1a\n\x12\x63ontinue_execution\x18\x01 \x01(\x08\">\n\x08LoginAck\x12\x32\n\x06policy\x18\x01 \x01(\x0b\x32\".adrian.core_api.v1.PolicySnapshot\"y\n\x0bServerFrame\x12\x31\n\tlogin_ack\x18\x01 \x01(\x0b\x32\x1c.adrian.core_api.v1.LoginAckH\x00\x12.\n\x07verdict\x18\x02 \x01(\x0b\x32\x1b.adrian.core_api.v1.VerdictH\x00\x42\x07\n\x05\x66rame\"\x90\x02\n\x07Verdict\x12\x19\n\x08\x65vent_id\x18\x01 \x01(\tB\x07\xbaH\x04r\x02\x10\x01\x12\x1b\n\nsession_id\x18\x02 \x01(\tB\x07\xbaH\x04r\x02\x10\x01\x12\x10\n\x08mad_code\x18\x04 \x01(\t\x12\x32\n\x06policy\x18\x06 \x01(\x0b\x32\".adrian.core_api.v1.PolicySnapshot\x12.\n\x04hitl\x18\x07 \x01(\x0b\x32 .adrian.core_api.v1.HitlResponse\x12\x31\n\x06status\x18\x08 \x01(\x0e\x32!.adrian.core_api.v1.VerdictStatusJ\x04\x08\x03\x10\x04J\x04\x08\x05\x10\x06R\x0e\x63lassificationR\x08\x65scalate*L\n\x08PairType\x12\x19\n\x15PAIR_TYPE_UNSPECIFIED\x10\x00\x12\x11\n\rPAIR_TYPE_LLM\x10\x01\x12\x12\n\x0ePAIR_TYPE_TOOL\x10\x02*K\n\x04Mode\x12\x14\n\x10MODE_UNSPECIFIED\x10\x00\x12\x0e\n\nMODE_ALERT\x10\x01\x12\r\n\tMODE_HITL\x10\x02\x12\x0e\n\nMODE_BLOCK\x10\x03*`\n\rVerdictStatus\x12\x1e\n\x1aVERDICT_STATUS_UNSPECIFIED\x10\x00\x12\x15\n\x11VERDICT_STATUS_OK\x10\x01\x12\x18\n\x14VERDICT_STATUS_ERROR\x10\x02\x62\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'event_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + DESCRIPTOR._loaded_options = None + _globals['_TOOLCALL'].fields_by_name['name']._loaded_options = None + _globals['_TOOLCALL'].fields_by_name['name']._serialized_options = b'\272H\004r\002\020\001' + _globals['_TOKENUSAGE'].fields_by_name['prompt_tokens']._loaded_options = None + _globals['_TOKENUSAGE'].fields_by_name['prompt_tokens']._serialized_options = b'\272H\004\032\002(\000' + _globals['_TOKENUSAGE'].fields_by_name['completion_tokens']._loaded_options = None + _globals['_TOKENUSAGE'].fields_by_name['completion_tokens']._serialized_options = b'\272H\004\032\002(\000' + _globals['_TOKENUSAGE'].fields_by_name['total_tokens']._loaded_options = None + _globals['_TOKENUSAGE'].fields_by_name['total_tokens']._serialized_options = b'\272H\004\032\002(\000' + _globals['_TOOLPAIRDATA'].fields_by_name['tool_name']._loaded_options = None + _globals['_TOOLPAIRDATA'].fields_by_name['tool_name']._serialized_options = b'\272H\004r\002\020\001' + _globals['_PAIREDEVENT'].fields_by_name['event_id']._loaded_options = None + _globals['_PAIREDEVENT'].fields_by_name['event_id']._serialized_options = b'\272H\004r\002\020\001' + _globals['_PAIREDEVENT'].fields_by_name['session_id']._loaded_options = None + _globals['_PAIREDEVENT'].fields_by_name['session_id']._serialized_options = b'\272H\004r\002\020\001' + _globals['_PAIREDEVENT'].fields_by_name['pair_type']._loaded_options = None + _globals['_PAIREDEVENT'].fields_by_name['pair_type']._serialized_options = b'\272H\003\310\001\001' + _globals['_MCPSERVER'].fields_by_name['name']._loaded_options = None + _globals['_MCPSERVER'].fields_by_name['name']._serialized_options = b'\272H\004r\002\020\001' + _globals['_SESSIONLOGIN'].fields_by_name['session_id']._loaded_options = None + _globals['_SESSIONLOGIN'].fields_by_name['session_id']._serialized_options = b'\272H\004r\002\020\001' + _globals['_VERDICT'].fields_by_name['event_id']._loaded_options = None + _globals['_VERDICT'].fields_by_name['event_id']._serialized_options = b'\272H\004r\002\020\001' + _globals['_VERDICT'].fields_by_name['session_id']._loaded_options = None + _globals['_VERDICT'].fields_by_name['session_id']._serialized_options = b'\272H\004r\002\020\001' + _globals['_PAIRTYPE']._serialized_start=2377 + _globals['_PAIRTYPE']._serialized_end=2453 + _globals['_MODE']._serialized_start=2455 + _globals['_MODE']._serialized_end=2530 + _globals['_VERDICTSTATUS']._serialized_start=2532 + _globals['_VERDICTSTATUS']._serialized_end=2628 + _globals['_CHATMESSAGE']._serialized_start=64 + _globals['_CHATMESSAGE']._serialized_end=108 + _globals['_TOOLCALL']._serialized_start=110 + _globals['_TOOLCALL']._serialized_end=169 + _globals['_TOKENUSAGE']._serialized_start=171 + _globals['_TOKENUSAGE']._serialized_end=282 + _globals['_AGENTCONTEXT']._serialized_start=284 + _globals['_AGENTCONTEXT']._serialized_end=365 + _globals['_LLMPAIRDATA']._serialized_start=368 + _globals['_LLMPAIRDATA']._serialized_end=560 + _globals['_TOOLPAIRDATA']._serialized_start=562 + _globals['_TOOLPAIRDATA']._serialized_end=657 + _globals['_PAIREDEVENT']._serialized_start=660 + _globals['_PAIREDEVENT']._serialized_end=1095 + _globals['_PAIREDEVENTBATCH']._serialized_start=1097 + _globals['_PAIREDEVENTBATCH']._serialized_end=1164 + _globals['_MCPSERVER']._serialized_start=1166 + _globals['_MCPSERVER']._serialized_end=1237 + _globals['_MCPINVENTORY']._serialized_start=1239 + _globals['_MCPINVENTORY']._serialized_end=1301 + _globals['_LLMSTACK']._serialized_start=1303 + _globals['_LLMSTACK']._serialized_end=1346 + _globals['_SESSIONLOGIN']._serialized_start=1349 + _globals['_SESSIONLOGIN']._serialized_end=1483 + _globals['_CLIENTFRAME']._serialized_start=1486 + _globals['_CLIENTFRAME']._serialized_end=1693 + _globals['_POLICYSNAPSHOT']._serialized_start=1696 + _globals['_POLICYSNAPSHOT']._serialized_end=1869 + _globals['_HITLRESPONSE']._serialized_start=1871 + _globals['_HITLRESPONSE']._serialized_end=1913 + _globals['_LOGINACK']._serialized_start=1915 + _globals['_LOGINACK']._serialized_end=1977 + _globals['_SERVERFRAME']._serialized_start=1979 + _globals['_SERVERFRAME']._serialized_end=2100 + _globals['_VERDICT']._serialized_start=2103 + _globals['_VERDICT']._serialized_end=2375 # @@protoc_insertion_point(module_scope) diff --git a/sdk/python/adrian/proto/event_pb2.pyi b/sdk/python/adrian/proto/event_pb2.pyi index 0343546..6c9ca74 100644 --- a/sdk/python/adrian/proto/event_pb2.pyi +++ b/sdk/python/adrian/proto/event_pb2.pyi @@ -12,10 +12,10 @@ import builtins as _builtins import sys import typing as _typing -if sys.version_info >= (3, 10): - from typing import TypeAlias as _TypeAlias +if sys.version_info >= (3, 11): + from typing import TypeAlias as _TypeAlias, Never as _Never else: - from typing_extensions import TypeAlias as _TypeAlias + from typing_extensions import TypeAlias as _TypeAlias, Never as _Never DESCRIPTOR: _descriptor.FileDescriptor @@ -74,6 +74,28 @@ MODE_BLOCK: Mode.ValueType # 3 """Server forwards every verdict; the SDK enforces per the policy snapshot.""" Global___Mode: _TypeAlias = Mode # noqa: Y015 +class _VerdictStatus: + ValueType = _typing.NewType("ValueType", _builtins.int) + V: _TypeAlias = ValueType # noqa: Y015 + +class _VerdictStatusEnumTypeWrapper(_enum_type_wrapper._EnumTypeWrapper[_VerdictStatus.ValueType], _builtins.type): + DESCRIPTOR: _descriptor.EnumDescriptor + VERDICT_STATUS_UNSPECIFIED: _VerdictStatus.ValueType # 0 + VERDICT_STATUS_OK: _VerdictStatus.ValueType # 1 + VERDICT_STATUS_ERROR: _VerdictStatus.ValueType # 2 + +class VerdictStatus(_VerdictStatus, metaclass=_VerdictStatusEnumTypeWrapper): + """VerdictStatus says whether a Verdict came from a completed classifier + decision or represents a classifier failure. ERROR verdicts carry no + classifier-produced MAD code; policy decides whether they fail open + or fail closed. + """ + +VERDICT_STATUS_UNSPECIFIED: VerdictStatus.ValueType # 0 +VERDICT_STATUS_OK: VerdictStatus.ValueType # 1 +VERDICT_STATUS_ERROR: VerdictStatus.ValueType # 2 +Global___VerdictStatus: _TypeAlias = VerdictStatus # noqa: Y015 + @_typing.final class ChatMessage(_message.Message): """ChatMessage represents a conversation message with a string role.""" @@ -92,8 +114,11 @@ class ChatMessage(_message.Message): role: _builtins.str = ..., content: _builtins.str = ..., ) -> None: ... + _HasFieldArgType: _TypeAlias = _Never # noqa: Y015 + def HasField(self, field_name: _HasFieldArgType) -> _builtins.bool: ... _ClearFieldArgType: _TypeAlias = _typing.Literal["content", b"content", "role", b"role"] # noqa: Y015 def ClearField(self, field_name: _ClearFieldArgType) -> None: ... + def WhichOneof(self, oneof_group: _Never) -> None: ... Global___ChatMessage: _TypeAlias = ChatMessage # noqa: Y015 @@ -119,8 +144,11 @@ class ToolCall(_message.Message): args: _builtins.str = ..., id: _builtins.str = ..., ) -> None: ... + _HasFieldArgType: _TypeAlias = _Never # noqa: Y015 + def HasField(self, field_name: _HasFieldArgType) -> _builtins.bool: ... _ClearFieldArgType: _TypeAlias = _typing.Literal["args", b"args", "id", b"id", "name", b"name"] # noqa: Y015 def ClearField(self, field_name: _ClearFieldArgType) -> None: ... + def WhichOneof(self, oneof_group: _Never) -> None: ... Global___ToolCall: _TypeAlias = ToolCall # noqa: Y015 @@ -146,8 +174,11 @@ class TokenUsage(_message.Message): completion_tokens: _builtins.int = ..., total_tokens: _builtins.int = ..., ) -> None: ... + _HasFieldArgType: _TypeAlias = _Never # noqa: Y015 + def HasField(self, field_name: _HasFieldArgType) -> _builtins.bool: ... _ClearFieldArgType: _TypeAlias = _typing.Literal["completion_tokens", b"completion_tokens", "prompt_tokens", b"prompt_tokens", "total_tokens", b"total_tokens"] # noqa: Y015 def ClearField(self, field_name: _ClearFieldArgType) -> None: ... + def WhichOneof(self, oneof_group: _Never) -> None: ... Global___TokenUsage: _TypeAlias = TokenUsage # noqa: Y015 @@ -178,8 +209,11 @@ class AgentContext(_message.Message): system_prompt: _builtins.str = ..., user_instruction: _builtins.str = ..., ) -> None: ... + _HasFieldArgType: _TypeAlias = _Never # noqa: Y015 + def HasField(self, field_name: _HasFieldArgType) -> _builtins.bool: ... _ClearFieldArgType: _TypeAlias = _typing.Literal["agent_id", b"agent_id", "system_prompt", b"system_prompt", "user_instruction", b"user_instruction"] # noqa: Y015 def ClearField(self, field_name: _ClearFieldArgType) -> None: ... + def WhichOneof(self, oneof_group: _Never) -> None: ... Global___AgentContext: _TypeAlias = AgentContext # noqa: Y015 @@ -225,6 +259,7 @@ class LlmPairData(_message.Message): def HasField(self, field_name: _HasFieldArgType) -> _builtins.bool: ... _ClearFieldArgType: _TypeAlias = _typing.Literal["messages", b"messages", "model", b"model", "output", b"output", "tool_calls", b"tool_calls", "usage", b"usage"] # noqa: Y015 def ClearField(self, field_name: _ClearFieldArgType) -> None: ... + def WhichOneof(self, oneof_group: _Never) -> None: ... Global___LlmPairData: _TypeAlias = LlmPairData # noqa: Y015 @@ -258,8 +293,11 @@ class ToolPairData(_message.Message): input: _builtins.str = ..., output: _builtins.str = ..., ) -> None: ... + _HasFieldArgType: _TypeAlias = _Never # noqa: Y015 + def HasField(self, field_name: _HasFieldArgType) -> _builtins.bool: ... _ClearFieldArgType: _TypeAlias = _typing.Literal["input", b"input", "output", b"output", "tool_call_id", b"tool_call_id", "tool_name", b"tool_name"] # noqa: Y015 def ClearField(self, field_name: _ClearFieldArgType) -> None: ... + def WhichOneof(self, oneof_group: _Never) -> None: ... Global___ToolPairData: _TypeAlias = ToolPairData # noqa: Y015 @@ -357,8 +395,11 @@ class PairedEventBatch(_message.Message): *, events: _abc.Iterable[Global___PairedEvent] | None = ..., ) -> None: ... + _HasFieldArgType: _TypeAlias = _Never # noqa: Y015 + def HasField(self, field_name: _HasFieldArgType) -> _builtins.bool: ... _ClearFieldArgType: _TypeAlias = _typing.Literal["events", b"events"] # noqa: Y015 def ClearField(self, field_name: _ClearFieldArgType) -> None: ... + def WhichOneof(self, oneof_group: _Never) -> None: ... Global___PairedEventBatch: _TypeAlias = PairedEventBatch # noqa: Y015 @@ -391,8 +432,11 @@ class McpServer(_message.Message): transport: _builtins.str = ..., endpoint: _builtins.str = ..., ) -> None: ... + _HasFieldArgType: _TypeAlias = _Never # noqa: Y015 + def HasField(self, field_name: _HasFieldArgType) -> _builtins.bool: ... _ClearFieldArgType: _TypeAlias = _typing.Literal["endpoint", b"endpoint", "name", b"name", "transport", b"transport"] # noqa: Y015 def ClearField(self, field_name: _ClearFieldArgType) -> None: ... + def WhichOneof(self, oneof_group: _Never) -> None: ... Global___McpServer: _TypeAlias = McpServer # noqa: Y015 @@ -413,8 +457,11 @@ class McpInventory(_message.Message): *, servers: _abc.Iterable[Global___McpServer] | None = ..., ) -> None: ... + _HasFieldArgType: _TypeAlias = _Never # noqa: Y015 + def HasField(self, field_name: _HasFieldArgType) -> _builtins.bool: ... _ClearFieldArgType: _TypeAlias = _typing.Literal["servers", b"servers"] # noqa: Y015 def ClearField(self, field_name: _ClearFieldArgType) -> None: ... + def WhichOneof(self, oneof_group: _Never) -> None: ... Global___McpInventory: _TypeAlias = McpInventory # noqa: Y015 @@ -436,8 +483,11 @@ class LLMStack(_message.Message): provider: _builtins.str = ..., model: _builtins.str = ..., ) -> None: ... + _HasFieldArgType: _TypeAlias = _Never # noqa: Y015 + def HasField(self, field_name: _HasFieldArgType) -> _builtins.bool: ... _ClearFieldArgType: _TypeAlias = _typing.Literal["model", b"model", "provider", b"provider"] # noqa: Y015 def ClearField(self, field_name: _ClearFieldArgType) -> None: ... + def WhichOneof(self, oneof_group: _Never) -> None: ... Global___LLMStack: _TypeAlias = LLMStack # noqa: Y015 @@ -469,6 +519,7 @@ class SessionLogin(_message.Message): def HasField(self, field_name: _HasFieldArgType) -> _builtins.bool: ... _ClearFieldArgType: _TypeAlias = _typing.Literal["llm_stack", b"llm_stack", "schema_version", b"schema_version", "session_id", b"session_id"] # noqa: Y015 def ClearField(self, field_name: _ClearFieldArgType) -> None: ... + def WhichOneof(self, oneof_group: _Never) -> None: ... Global___SessionLogin: _TypeAlias = SessionLogin # noqa: Y015 @@ -519,6 +570,9 @@ class PolicySnapshot(_message.Message): Per-MAD-code booleans say whether the active mode's behaviour fires on that code. False means "treat this code as silent regardless of mode". + fail_closed_on_classifier_error controls ERROR verdicts and BLOCK-mode + SDK verdict timeouts. The default false value preserves fail-open + availability when talking to older backends. """ DESCRIPTOR: _descriptor.Descriptor @@ -528,11 +582,13 @@ class PolicySnapshot(_message.Message): POLICY_M2_FIELD_NUMBER: _builtins.int POLICY_M3_FIELD_NUMBER: _builtins.int POLICY_M4_FIELD_NUMBER: _builtins.int + FAIL_CLOSED_ON_CLASSIFIER_ERROR_FIELD_NUMBER: _builtins.int mode: Global___Mode.ValueType policy_m0: _builtins.bool policy_m2: _builtins.bool policy_m3: _builtins.bool policy_m4: _builtins.bool + fail_closed_on_classifier_error: _builtins.bool def __init__( self, *, @@ -541,9 +597,13 @@ class PolicySnapshot(_message.Message): policy_m2: _builtins.bool = ..., policy_m3: _builtins.bool = ..., policy_m4: _builtins.bool = ..., + fail_closed_on_classifier_error: _builtins.bool = ..., ) -> None: ... - _ClearFieldArgType: _TypeAlias = _typing.Literal["mode", b"mode", "policy_m0", b"policy_m0", "policy_m2", b"policy_m2", "policy_m3", b"policy_m3", "policy_m4", b"policy_m4"] # noqa: Y015 + _HasFieldArgType: _TypeAlias = _Never # noqa: Y015 + def HasField(self, field_name: _HasFieldArgType) -> _builtins.bool: ... + _ClearFieldArgType: _TypeAlias = _typing.Literal["fail_closed_on_classifier_error", b"fail_closed_on_classifier_error", "mode", b"mode", "policy_m0", b"policy_m0", "policy_m2", b"policy_m2", "policy_m3", b"policy_m3", "policy_m4", b"policy_m4"] # noqa: Y015 def ClearField(self, field_name: _ClearFieldArgType) -> None: ... + def WhichOneof(self, oneof_group: _Never) -> None: ... Global___PolicySnapshot: _TypeAlias = PolicySnapshot # noqa: Y015 @@ -564,8 +624,11 @@ class HitlResponse(_message.Message): *, continue_execution: _builtins.bool = ..., ) -> None: ... + _HasFieldArgType: _TypeAlias = _Never # noqa: Y015 + def HasField(self, field_name: _HasFieldArgType) -> _builtins.bool: ... _ClearFieldArgType: _TypeAlias = _typing.Literal["continue_execution", b"continue_execution"] # noqa: Y015 def ClearField(self, field_name: _ClearFieldArgType) -> None: ... + def WhichOneof(self, oneof_group: _Never) -> None: ... Global___HitlResponse: _TypeAlias = HitlResponse # noqa: Y015 @@ -594,6 +657,7 @@ class LoginAck(_message.Message): def HasField(self, field_name: _HasFieldArgType) -> _builtins.bool: ... _ClearFieldArgType: _TypeAlias = _typing.Literal["policy", b"policy"] # noqa: Y015 def ClearField(self, field_name: _ClearFieldArgType) -> None: ... + def WhichOneof(self, oneof_group: _Never) -> None: ... Global___LoginAck: _TypeAlias = LoginAck # noqa: Y015 @@ -645,12 +709,22 @@ class Verdict(_message.Message): MAD_CODE_FIELD_NUMBER: _builtins.int POLICY_FIELD_NUMBER: _builtins.int HITL_FIELD_NUMBER: _builtins.int + STATUS_FIELD_NUMBER: _builtins.int event_id: _builtins.str """The event_id of the PairedEvent being classified.""" session_id: _builtins.str """Session identifier for routing.""" mad_code: _builtins.str - """MAD code the classifier returned (e.g. "M0", "M2_C", "M4_a"). Empty string for benign.""" + """MAD code the classifier returned (e.g. "M0", "M2_C", "M4_a"). + Empty string means no MAD code was produced, such as for a + VerdictStatus.ERROR classifier failure. Benign classifier success + is represented by status OK with mad_code "M0". + """ + status: Global___VerdictStatus.ValueType + """Status of the classifier result. OK means mad_code carries a normal + classifier decision. ERROR means classification did not complete and + mad_code is empty; fail-open/fail-closed behaviour comes from policy. + """ @_builtins.property def policy(self) -> Global___PolicySnapshot: """Org's effective execution-mode policy at the time of this verdict. @@ -673,10 +747,12 @@ class Verdict(_message.Message): mad_code: _builtins.str = ..., policy: Global___PolicySnapshot | None = ..., hitl: Global___HitlResponse | None = ..., + status: Global___VerdictStatus.ValueType = ..., ) -> None: ... _HasFieldArgType: _TypeAlias = _typing.Literal["hitl", b"hitl", "policy", b"policy"] # noqa: Y015 def HasField(self, field_name: _HasFieldArgType) -> _builtins.bool: ... - _ClearFieldArgType: _TypeAlias = _typing.Literal["event_id", b"event_id", "hitl", b"hitl", "mad_code", b"mad_code", "policy", b"policy", "session_id", b"session_id"] # noqa: Y015 + _ClearFieldArgType: _TypeAlias = _typing.Literal["event_id", b"event_id", "hitl", b"hitl", "mad_code", b"mad_code", "policy", b"policy", "session_id", b"session_id", "status", b"status"] # noqa: Y015 def ClearField(self, field_name: _ClearFieldArgType) -> None: ... + def WhichOneof(self, oneof_group: _Never) -> None: ... Global___Verdict: _TypeAlias = Verdict # noqa: Y015 From ab967129540cd88f700f35f2eab19bdbbf02ab10 Mon Sep 17 00:00:00 2001 From: Muhammad Usman <85793641+Muhammad-usman92@users.noreply.github.com> Date: Sun, 14 Jun 2026 14:05:59 +0500 Subject: [PATCH 2/4] Update CLA.md --- CLA.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CLA.md b/CLA.md index 6060782..ef770d2 100644 --- a/CLA.md +++ b/CLA.md @@ -55,5 +55,6 @@ To accept this Agreement, open a pull request that adds an entry to the table be | _example placeholder_ | _@example_ | _2026-01-01_ | | Dhrit Timinkumar Patel | @d180 | 2026-05-20 | | Adarsh Tiwari | @adarsh9977 | 2026-05-22 | +| Muhammad usman | @Muhammad-usman92 | 2026-06-11 | Once a CLA-bot (cla-assistant.io or equivalent) is wired up, this manual table will be replaced by the bot's status check on each pull request. Existing signatures in this table remain valid; the bot reads from a separate signers list. From b42f5e776ace128e0b5305c8600e69bd775b913d Mon Sep 17 00:00:00 2001 From: Muhammad-usman92 Date: Mon, 15 Jun 2026 14:15:14 +0500 Subject: [PATCH 3/4] Enforce classifier-error fail-closed policy --- CLA.md | 1 + README.md | 8 + backend/internal/api/handlers_events.go | 15 +- backend/internal/api/handlers_reviews.go | 25 +- backend/internal/api/handlers_stats.go | 26 +- backend/internal/api/handlers_test.go | 215 ++++++++++++++++- backend/internal/api/handlers_verdicts.go | 2 + backend/internal/db/migrate_test.go | 101 ++++++++ .../internal/notifications/discord_test.go | 54 +++++ backend/internal/notifications/dispatcher.go | 4 +- backend/internal/store/events.go | 10 + backend/internal/store/hitl.go | 25 +- backend/internal/store/stats.go | 26 +- backend/internal/store/verdicts.go | 9 +- backend/internal/ws/handler.go | 55 ++++- backend/internal/ws/handler_test.go | 223 +++++++++++++++++- backend/internal/ws/helpers.go | 4 +- backend/internal/ws/session.go | 2 + docs/ARCHITECTURE.md | 11 +- frontend/app/(dashboard)/events/page.tsx | 41 +++- frontend/app/(dashboard)/page.tsx | 12 +- frontend/app/(dashboard)/reviews/page.tsx | 26 +- .../sessions/[session_id]/page.tsx | 7 +- frontend/app/(dashboard)/settings/page.tsx | 24 +- frontend/lib/utils.ts | 22 ++ sdk/python/adrian/__init__.py | 32 ++- sdk/python/adrian/config.py | 8 +- sdk/python/adrian/handler.py | 1 + sdk/python/adrian/types.py | 7 +- sdk/python/adrian/ws.py | 37 ++- sdk/python/tests/test_block_mode.py | 170 +++++++++++++ sdk/python/tests/test_exec_modes.py | 56 +++++ sdk/python/tests/test_handler.py | 46 ++++ sdk/python/tests/test_ws.py | 20 ++ 34 files changed, 1228 insertions(+), 97 deletions(-) diff --git a/CLA.md b/CLA.md index 6060782..04e244f 100644 --- a/CLA.md +++ b/CLA.md @@ -55,5 +55,6 @@ To accept this Agreement, open a pull request that adds an entry to the table be | _example placeholder_ | _@example_ | _2026-01-01_ | | Dhrit Timinkumar Patel | @d180 | 2026-05-20 | | Adarsh Tiwari | @adarsh9977 | 2026-05-22 | +| Muhammad Usman | @Muhammad-usman92 | 2026-06-15 | Once a CLA-bot (cla-assistant.io or equivalent) is wired up, this manual table will be replaced by the bot's status check on each pull request. Existing signatures in this table remain valid; the bot reads from a separate signers list. diff --git a/README.md b/README.md index 04ea25d..7398755 100644 --- a/README.md +++ b/README.md @@ -139,6 +139,14 @@ Adrian supports entirely offline, data sovereign deployments using just a handfu Use the same `adrian.init` snippet as in the [Quickstart](#quickstart) above. The SDK defaults to `ws://localhost:8080/ws`, so a self-hosted setup needs nothing more than the API key - drop the `ws_url=` line. +### Classifier error policy + +Adrian records classifier outages, malformed classifier responses, and unparseable classifier output as `verdict_status=error` with `mad_code=""`. These are operational classifier errors, not benign `M0` findings and not synthetic malicious activity. + +The default policy remains availability-first: classifier errors fail open. In **Settings -> Policy**, enable **Fail closed on classifier error** to make BLOCK-mode tool calls return blocked responses when the classifier cannot produce a verdict. In HITL mode, actionable classifier errors are sent to the review queue and held until an operator approves or rejects them. + +Fail-closed classifier-error enforcement requires the Python SDK version shipped with this repository update. Older SDKs ignore the additive protobuf `status` and policy fields, see an empty MAD code, and continue fail-open even when the dashboard toggle is enabled. + To [reset the admin password](https://docs.adrian.secureagentics.ai/reference/backend#reset-the-admin-password), [change the model](https://docs.adrian.secureagentics.ai/reference/backend#switch-the-local-gguf) and much more check out the dedicated [Docs site](https://docs.adrian.secureagentics.ai/). ## Why Adrian is different diff --git a/backend/internal/api/handlers_events.go b/backend/internal/api/handlers_events.go index b5b1729..f3f0ab3 100644 --- a/backend/internal/api/handlers_events.go +++ b/backend/internal/api/handlers_events.go @@ -72,13 +72,18 @@ func (s *Server) handleListEvents(w http.ResponseWriter, r *http.Request) { since = t } } + if status := q.Get("verdict_status"); status != "" && !validVerdictStatus(status) { + writeError(w, http.StatusBadRequest, "invalid verdict_status") + return + } filters := store.EventFilters{ - Since: since, - AgentID: q.Get("agent_id"), - SessionID: q.Get("session_id"), - EventType: q.Get("event_type"), - MinMAD: q.Get("min_mad"), + Since: since, + AgentID: q.Get("agent_id"), + SessionID: q.Get("session_id"), + EventType: q.Get("event_type"), + MinMAD: q.Get("min_mad"), + VerdictStatus: q.Get("verdict_status"), } rows, total, err := s.store.ListEvents(r.Context(), filters, pg.PerPage, pg.Offset) diff --git a/backend/internal/api/handlers_reviews.go b/backend/internal/api/handlers_reviews.go index 2ecd782..e8d8847 100644 --- a/backend/internal/api/handlers_reviews.go +++ b/backend/internal/api/handlers_reviews.go @@ -39,6 +39,7 @@ type reviewDetail struct { reviewSummary EventPayload json.RawMessage `json:"event_payload,omitempty"` Classification string `json:"classification,omitempty"` + Reasoning string `json:"reasoning,omitempty"` } type reviewResolveResponse struct { @@ -49,8 +50,14 @@ type reviewResolveResponse struct { func (s *Server) handleListReviews(w http.ResponseWriter, r *http.Request) { pg := parsePagination(r) - status := r.URL.Query().Get("status") - rows, total, err := s.store.ListHitlQueue(r.Context(), status, pg.PerPage, pg.Offset) + q := r.URL.Query() + status := q.Get("status") + verdictStatus := q.Get("verdict_status") + if verdictStatus != "" && !validVerdictStatus(verdictStatus) { + writeError(w, http.StatusBadRequest, "invalid verdict_status") + return + } + rows, total, err := s.store.ListHitlQueue(r.Context(), status, verdictStatus, pg.PerPage, pg.Offset) if err != nil { writeError(w, http.StatusInternalServerError, "query failed") return @@ -81,6 +88,7 @@ func (s *Server) handleGetReview(w http.ResponseWriter, r *http.Request) { resp := reviewDetail{ reviewSummary: reviewToSummary(&row.HitlReview), Classification: row.Classification, + Reasoning: row.Reasoning, } if row.EventPayloadJSON != "" { resp.EventPayload = json.RawMessage(row.EventPayloadJSON) @@ -128,7 +136,7 @@ func (s *Server) resolveReview(w http.ResponseWriter, r *http.Request, status st EventId: row.EventID, SessionId: row.SessionID, MadCode: row.MADCode, - Status: pb.VerdictStatus_VERDICT_STATUS_OK, + Status: reviewVerdictStatusProto(row.VerdictStatus), Policy: s.policySnapshotProto(pol), Hitl: &pb.HitlResponse{ContinueExecution: continueExec}, }}, @@ -165,3 +173,14 @@ func reviewToSummary(r *store.HitlReview) reviewSummary { } return out } + +func reviewVerdictStatusProto(status string) pb.VerdictStatus { + switch status { + case "error": + return pb.VerdictStatus_VERDICT_STATUS_ERROR + case "ok": + return pb.VerdictStatus_VERDICT_STATUS_OK + default: + return pb.VerdictStatus_VERDICT_STATUS_UNSPECIFIED + } +} diff --git a/backend/internal/api/handlers_stats.go b/backend/internal/api/handlers_stats.go index 8787369..d6ff5e8 100644 --- a/backend/internal/api/handlers_stats.go +++ b/backend/internal/api/handlers_stats.go @@ -6,12 +6,13 @@ package api import "net/http" type overviewResponse struct { - TotalEvents int `json:"total_events"` - FlaggedVerdicts int `json:"flagged_verdicts"` - PendingReviews int `json:"pending_reviews"` - ActiveAgents int `json:"active_agents"` - VerdictsByMAD map[string]int `json:"verdicts_by_mad"` - Window string `json:"window"` + TotalEvents int `json:"total_events"` + FlaggedVerdicts int `json:"flagged_verdicts"` + ClassifierErrors int `json:"classifier_errors"` + PendingReviews int `json:"pending_reviews"` + ActiveAgents int `json:"active_agents"` + VerdictsByMAD map[string]int `json:"verdicts_by_mad"` + Window string `json:"window"` } type activityBucketEntry struct { @@ -31,12 +32,13 @@ func (s *Server) handleStatsOverview(w http.ResponseWriter, r *http.Request) { return } writeJSON(w, http.StatusOK, overviewResponse{ - TotalEvents: o.TotalEvents, - FlaggedVerdicts: o.FlaggedVerdicts, - PendingReviews: o.PendingReviews, - ActiveAgents: o.ActiveAgents, - VerdictsByMAD: o.VerdictsByMAD, - Window: "24h", + TotalEvents: o.TotalEvents, + FlaggedVerdicts: o.FlaggedVerdicts, + ClassifierErrors: o.ClassifierErrors, + PendingReviews: o.PendingReviews, + ActiveAgents: o.ActiveAgents, + VerdictsByMAD: o.VerdictsByMAD, + Window: "24h", }) } diff --git a/backend/internal/api/handlers_test.go b/backend/internal/api/handlers_test.go index 0b2c209..77c2560 100644 --- a/backend/internal/api/handlers_test.go +++ b/backend/internal/api/handlers_test.go @@ -14,6 +14,7 @@ import ( "time" "github.com/google/uuid" + "google.golang.org/protobuf/proto" _ "modernc.org/sqlite" "github.com/secureagentics/Adrian/backend/internal/api" @@ -415,7 +416,7 @@ func TestProfileNameValidation(t *testing.T) { func TestStatsOverview(t *testing.T) { srv, db, _, cookie := newTestServerLoggedIn(t) - // Seed: 3 events on 2 agents, 2 verdicts (one M0, one M3), + // Seed: 3 events on 2 agents, 3 verdicts (M0, M3, classifier error), // 1 pending review, 1 agents row with last_seen recent. if _, err := db.Exec( `INSERT INTO agents (id, agent_id, last_seen) VALUES (?, 'a1', datetime('now'))`, @@ -441,6 +442,13 @@ func TestStatsOverview(t *testing.T) { t.Fatalf("seed verdict: %v", err) } } + if _, err := db.Exec( + `INSERT INTO verdicts (id, event_id, session_id, mad_code, classification, verdict_status) + VALUES (?, ?, 'sess-stats', '', 'error', 'error')`, + uuid.NewString(), uuid.NewString(), + ); err != nil { + t.Fatalf("seed error verdict: %v", err) + } if _, err := db.Exec( `INSERT INTO hitl_queue (id, event_id, session_id, mad_code) VALUES (?, ?, 'sess-stats', 'M3')`, uuid.NewString(), uuid.NewString(), @@ -459,6 +467,9 @@ func TestStatsOverview(t *testing.T) { if int(data["flagged_verdicts"].(float64)) != 1 { t.Errorf("flagged_verdicts = %v, want 1 (only M3.b counts)", data["flagged_verdicts"]) } + if int(data["classifier_errors"].(float64)) != 1 { + t.Errorf("classifier_errors = %v, want 1", data["classifier_errors"]) + } if int(data["pending_reviews"].(float64)) != 1 { t.Errorf("pending_reviews = %v, want 1", data["pending_reviews"]) } @@ -466,8 +477,10 @@ func TestStatsOverview(t *testing.T) { t.Errorf("active_agents = %v, want 1", data["active_agents"]) } dist := data["verdicts_by_mad"].(map[string]any) - if int(dist["M0"].(float64)) != 1 || int(dist["M3"].(float64)) != 1 { - t.Errorf("verdicts_by_mad = %v, want M0=1 M3=1", dist) + if int(dist["M0"].(float64)) != 1 || + int(dist["M3"].(float64)) != 1 || + int(dist["error"].(float64)) != 1 { + t.Errorf("verdicts_by_mad = %v, want M0=1 M3=1 error=1", dist) } } @@ -523,6 +536,9 @@ func TestListVerdictsIncludesStatusAndFiltersError(t *testing.T) { if row["classification"] != "error" || row["verdict_status"] != "error" { t.Errorf("verdict row = %v, want classification/status error", row) } + if row["reasoning"] != "classifier failed" { + t.Errorf("reasoning = %v, want classifier failed", row["reasoning"]) + } } // ----------------------------------------------------------------- @@ -632,6 +648,139 @@ func TestApproveReviewPublishesToSubscriber(t *testing.T) { } } +func TestApproveErrorReviewPublishesErrorStatus(t *testing.T) { + srv, db, hub, cookie := newTestServerWithHub(t) + + const sessID = "sess-hitl-error" + eventID := uuid.NewString() + verdictID := uuid.NewString() + queueID := uuid.NewString() + + if _, err := db.Exec( + `INSERT INTO events (id, session_id, agent_id, event_type, run_id, payload) + VALUES (?, ?, 'agent-h', 'llm', 'r1', '{}')`, + eventID, sessID, + ); err != nil { + t.Fatalf("seed event: %v", err) + } + if _, err := db.Exec( + `INSERT INTO verdicts (id, event_id, session_id, mad_code, classification, verdict_status, reasoning) + VALUES (?, ?, ?, '', 'error', 'error', 'classifier failure: boom')`, + verdictID, eventID, sessID, + ); err != nil { + t.Fatalf("seed verdict: %v", err) + } + if _, err := db.Exec( + `INSERT INTO hitl_queue (id, event_id, verdict_id, session_id, mad_code) + VALUES (?, ?, ?, ?, '')`, + queueID, eventID, verdictID, sessID, + ); err != nil { + t.Fatalf("seed hitl_queue: %v", err) + } + + detailResp := getReq(t, srv, cookie, "/api/reviews/"+queueID) + if detailResp.StatusCode != http.StatusOK { + t.Fatalf("detail status = %d, want 200", detailResp.StatusCode) + } + detail := decodeBody(t, detailResp)["data"].(map[string]any) + if detail["reasoning"] != "classifier failure: boom" { + t.Errorf("detail.reasoning = %v, want classifier failure cause", detail["reasoning"]) + } + + ch, dereg, err := hub.Register(sessID, "test-owner") + if err != nil { + t.Fatalf("Register: %v", err) + } + defer dereg() + + resp := postJSON(t, srv, cookie, "/api/reviews/"+queueID+"/approve", map[string]any{}) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + + select { + case buf := <-ch: + var frame pb.ServerFrame + if err := proto.Unmarshal(buf, &frame); err != nil { + t.Fatalf("unmarshal frame: %v", err) + } + verdict := frame.GetVerdict() + if verdict == nil { + t.Fatalf("expected Verdict, got %T", frame.Frame) + } + if verdict.GetStatus() != pb.VerdictStatus_VERDICT_STATUS_ERROR { + t.Fatalf("status = %v, want ERROR", verdict.GetStatus()) + } + if verdict.GetMadCode() != "" { + t.Fatalf("mad_code = %q, want empty", verdict.GetMadCode()) + } + if verdict.GetHitl() == nil || !verdict.GetHitl().GetContinueExecution() { + t.Fatalf("expected approve to continue execution") + } + case <-time.After(time.Second): + t.Fatal("subscriber never received the resolution frame") + } +} + +func TestListReviewsFiltersByVerdictStatus(t *testing.T) { + srv, db, _, cookie := newTestServerWithHub(t) + + const sessID = "sess-review-filter" + okEventID := uuid.NewString() + errorEventID := uuid.NewString() + okVerdictID := uuid.NewString() + errorVerdictID := uuid.NewString() + okQueueID := uuid.NewString() + errorQueueID := uuid.NewString() + + if _, err := db.Exec( + `INSERT INTO events (id, session_id, agent_id, event_type, run_id, payload) + VALUES (?, ?, 'agent-h', 'llm', 'r-ok', '{}'), + (?, ?, 'agent-h', 'llm', 'r-error', '{}')`, + okEventID, sessID, + errorEventID, sessID, + ); err != nil { + t.Fatalf("seed events: %v", err) + } + if _, err := db.Exec( + `INSERT INTO verdicts (id, event_id, session_id, mad_code, classification, verdict_status) + VALUES (?, ?, ?, 'M3', 'block', 'ok'), + (?, ?, ?, '', 'error', 'error')`, + okVerdictID, okEventID, sessID, + errorVerdictID, errorEventID, sessID, + ); err != nil { + t.Fatalf("seed verdicts: %v", err) + } + if _, err := db.Exec( + `INSERT INTO hitl_queue (id, event_id, verdict_id, session_id, mad_code) + VALUES (?, ?, ?, ?, 'M3'), + (?, ?, ?, ?, '')`, + okQueueID, okEventID, okVerdictID, sessID, + errorQueueID, errorEventID, errorVerdictID, sessID, + ); err != nil { + t.Fatalf("seed hitl_queue: %v", err) + } + + resp := getReq(t, srv, cookie, "/api/reviews?status=pending&verdict_status=error") + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + data := decodeBody(t, resp)["data"].(map[string]any) + if int(data["total"].(float64)) != 1 { + t.Fatalf("total = %v, want 1", data["total"]) + } + reviews := data["reviews"].([]any) + row := reviews[0].(map[string]any) + if row["id"] != errorQueueID || row["verdict_status"] != "error" { + t.Fatalf("filtered review = %v, want only classifier-error review %q", row, errorQueueID) + } + + resp = getReq(t, srv, cookie, "/api/reviews?verdict_status=bogus") + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("invalid verdict_status status = %d, want 400", resp.StatusCode) + } +} + func TestApproveReviewNoSubscriberStillResolves(t *testing.T) { srv, db, _, cookie := newTestServerWithHub(t) @@ -1050,6 +1199,66 @@ func TestEventsMinMADFilterUsesLatestVerdict(t *testing.T) { } } +func TestEventsVerdictStatusFilterUsesLatestVerdict(t *testing.T) { + srv, db, _, cookie := newTestServerLoggedIn(t) + + const sid = "sess-verdict-status" + + eOK := uuid.NewString() + if _, err := db.Exec( + `INSERT INTO events (id, session_id, agent_id, event_type, run_id, payload) + VALUES (?, ?, 'agent-ok', 'tool', 'r1', '{}')`, + eOK, sid, + ); err != nil { + t.Fatalf("seed ok event: %v", err) + } + if _, err := db.Exec( + `INSERT INTO verdicts (id, event_id, session_id, mad_code, classification, verdict_status, created_at) + VALUES (?, ?, ?, '', 'error', 'error', datetime('now', '-2 seconds')), + (?, ?, ?, 'M0', 'benign', 'ok', datetime('now', '-1 seconds'))`, + uuid.NewString(), eOK, sid, + uuid.NewString(), eOK, sid, + ); err != nil { + t.Fatalf("seed ok verdicts: %v", err) + } + + eError := uuid.NewString() + if _, err := db.Exec( + `INSERT INTO events (id, session_id, agent_id, event_type, run_id, payload) + VALUES (?, ?, 'agent-error', 'llm', 'r2', '{}')`, + eError, sid, + ); err != nil { + t.Fatalf("seed error event: %v", err) + } + if _, err := db.Exec( + `INSERT INTO verdicts (id, event_id, session_id, mad_code, classification, verdict_status, created_at) + VALUES (?, ?, ?, 'M0', 'benign', 'ok', datetime('now', '-2 seconds')), + (?, ?, ?, '', 'error', 'error', datetime('now', '-1 seconds'))`, + uuid.NewString(), eError, sid, + uuid.NewString(), eError, sid, + ); err != nil { + t.Fatalf("seed error verdicts: %v", err) + } + + resp := getReq(t, srv, cookie, "/api/events?session_id="+sid+"&verdict_status=error") + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + data := decodeBody(t, resp)["data"].(map[string]any) + if int(data["total"].(float64)) != 1 { + t.Errorf("verdict_status=error total = %v, want 1", data["total"]) + } + events := data["events"].([]any) + if len(events) != 1 || events[0].(map[string]any)["id"] != eError { + t.Errorf("verdict_status=error events = %v, want only event %q", events, eError) + } + + resp = getReq(t, srv, cookie, "/api/events?session_id="+sid+"&verdict_status=bogus") + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("invalid verdict_status status = %d, want 400", resp.StatusCode) + } +} + // ----------------------------------------------------------------- // MCP servers // ----------------------------------------------------------------- diff --git a/backend/internal/api/handlers_verdicts.go b/backend/internal/api/handlers_verdicts.go index 3c75e57..273517c 100644 --- a/backend/internal/api/handlers_verdicts.go +++ b/backend/internal/api/handlers_verdicts.go @@ -17,6 +17,7 @@ type verdictResponse struct { MADCode string `json:"mad_code"` Classification string `json:"classification"` VerdictStatus string `json:"verdict_status"` + Reasoning string `json:"reasoning,omitempty"` LatencyMS *int64 `json:"latency_ms,omitempty"` TokensUsed int32 `json:"tokens_used"` CreatedAt string `json:"created_at"` @@ -78,6 +79,7 @@ func verdictRowToResponse(r *store.VerdictListRow) verdictResponse { MADCode: r.MADCode, Classification: r.Classification, VerdictStatus: r.VerdictStatus, + Reasoning: r.Reasoning, LatencyMS: r.LatencyMS, TokensUsed: r.TokensUsed, CreatedAt: r.CreatedAt.UTC().Format("2006-01-02T15:04:05.000Z"), diff --git a/backend/internal/db/migrate_test.go b/backend/internal/db/migrate_test.go index b257698..7e90499 100644 --- a/backend/internal/db/migrate_test.go +++ b/backend/internal/db/migrate_test.go @@ -8,6 +8,8 @@ import ( "testing" "testing/fstest" + "github.com/secureagentics/Adrian/backend/migrations" + _ "modernc.org/sqlite" ) @@ -112,6 +114,105 @@ COMMIT;`), } } +func TestEmbeddedMigration002UpgradesPopulatedDB(t *testing.T) { + conn := openTestDB(t) + defer conn.Close() + + initialSQL, err := migrations.Files.ReadFile("001_initial_schema.sql") + if err != nil { + t.Fatalf("read 001 migration: %v", err) + } + applied, err := applyMigrations(conn, fstest.MapFS{ + "001_initial_schema.sql": {Data: initialSQL}, + }) + if err != nil { + t.Fatalf("apply 001 migration: %v", err) + } + if got, want := applied, []string{"001_initial_schema.sql"}; len(got) != len(want) || got[0] != want[0] { + t.Fatalf("applied initial migrations = %v, want %v", got, want) + } + + if _, err := conn.Exec(` +INSERT INTO events (id, session_id, event_type, payload) +VALUES ('evt-populated', 'sess-populated', 'llm', '{}'); +INSERT INTO verdicts (id, event_id, session_id, mad_code, classification, reasoning) +VALUES ('verdict-populated', 'evt-populated', 'sess-populated', 'M4_a', 'block', 'seed'); +INSERT INTO hitl_queue (id, event_id, verdict_id, session_id, mad_code) +VALUES ('review-populated', 'evt-populated', 'verdict-populated', 'sess-populated', 'M4_a'); +`); err != nil { + t.Fatalf("seed populated database: %v", err) + } + + applied, err = applyMigrations(conn, migrations.Files) + if err != nil { + t.Fatalf("apply embedded migrations: %v", err) + } + if got, want := applied, []string{"002_verdict_status_policy.sql"}; len(got) != len(want) || got[0] != want[0] { + t.Fatalf("applied upgrade migrations = %v, want %v", got, want) + } + + var failClosed int + if err := conn.QueryRow(`SELECT fail_closed_on_classifier_error FROM policies WHERE id = 1`).Scan(&failClosed); err != nil { + t.Fatalf("query policy flag: %v", err) + } + if failClosed != 0 { + t.Fatalf("fail_closed_on_classifier_error = %d, want 0", failClosed) + } + + var madCode, classification, verdictStatus string + if err := conn.QueryRow(` +SELECT mad_code, classification, verdict_status +FROM verdicts WHERE id = 'verdict-populated' +`).Scan(&madCode, &classification, &verdictStatus); err != nil { + t.Fatalf("query upgraded verdict: %v", err) + } + if madCode != "M4_a" || classification != "block" || verdictStatus != "ok" { + t.Fatalf("upgraded verdict = (%q, %q, %q), want (M4_a, block, ok)", madCode, classification, verdictStatus) + } + + if _, err := conn.Exec(` +INSERT INTO verdicts (id, event_id, session_id, mad_code, classification, verdict_status, reasoning) +VALUES ('verdict-error', 'evt-populated', 'sess-populated', '', 'error', 'error', 'classifier failure: test'); +`); err != nil { + t.Fatalf("insert classifier-error verdict after upgrade: %v", err) + } + + var reviewVerdictID string + if err := conn.QueryRow(`SELECT verdict_id FROM hitl_queue WHERE id = 'review-populated'`).Scan(&reviewVerdictID); err != nil { + t.Fatalf("query preserved hitl_queue row: %v", err) + } + if reviewVerdictID != "verdict-populated" { + t.Fatalf("preserved hitl_queue verdict_id = %q, want verdict-populated", reviewVerdictID) + } + + for _, name := range []string{"idx_verdicts_event_id", "idx_verdicts_session_id", "idx_verdicts_created_at"} { + var seen int + if err := conn.QueryRow(`SELECT count(*) FROM sqlite_master WHERE type = 'index' AND name = ?`, name).Scan(&seen); err != nil { + t.Fatalf("query index %s: %v", name, err) + } + if seen != 1 { + t.Fatalf("index %s count = %d, want 1", name, seen) + } + } + + rows, err := conn.Query(`PRAGMA foreign_key_check`) + if err != nil { + t.Fatalf("foreign_key_check: %v", err) + } + defer rows.Close() + if rows.Next() { + t.Fatal("foreign_key_check returned violations after 002 migration") + } + + applied, err = applyMigrations(conn, migrations.Files) + if err != nil { + t.Fatalf("second embedded apply: %v", err) + } + if len(applied) != 0 { + t.Fatalf("second embedded apply = %v, want no migrations", applied) + } +} + func openTestDB(t *testing.T) *sql.DB { t.Helper() conn, err := sql.Open("sqlite", "file:migratetest?mode=memory&cache=shared") diff --git a/backend/internal/notifications/discord_test.go b/backend/internal/notifications/discord_test.go index c05a9ab..5530f39 100644 --- a/backend/internal/notifications/discord_test.go +++ b/backend/internal/notifications/discord_test.go @@ -5,13 +5,19 @@ package notifications import ( "context" + "database/sql" "encoding/json" "io" "net/http" "net/http/httptest" "strings" + "sync/atomic" "testing" "time" + + "github.com/google/uuid" + "github.com/secureagentics/Adrian/backend/internal/store" + _ "modernc.org/sqlite" ) func TestValidateDiscordWebhookURL(t *testing.T) { @@ -120,6 +126,54 @@ func TestSendNonDiscordURLRejected(t *testing.T) { } } +func TestDispatcherSkipsEmptyMADCode(t *testing.T) { + var posts int32 + mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&posts, 1) + w.WriteHeader(http.StatusNoContent) + })) + defer mock.Close() + + origHosts := allowedHosts + allowedHosts = []string{mock.URL + "/"} + defer func() { allowedHosts = origHosts }() + + db, err := sql.Open("sqlite", "file:notifications?mode=memory&cache=shared") + if err != nil { + t.Fatalf("open sqlite: %v", err) + } + defer db.Close() + if _, err := db.Exec(` +CREATE TABLE webhooks ( + id TEXT PRIMARY KEY, + platform TEXT NOT NULL DEFAULT 'discord', + webhook_url TEXT NOT NULL, + alert_type TEXT NOT NULL, + enabled INTEGER NOT NULL DEFAULT 1, + installed_by_user_id TEXT, + created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ','now')), + updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ','now')) +); +`); err != nil { + t.Fatalf("create webhooks: %v", err) + } + st := store.New(db) + if err := st.CreateWebhook(context.Background(), uuid.NewString(), mock.URL+"/api/webhooks/1/tok", "all", ""); err != nil { + t.Fatalf("create webhook: %v", err) + } + + d := NewDispatcher(st, "https://dash.example") + d.fanout(context.Background(), VerdictNotification{ + EventID: "ev-error", + SessionID: "sess-error", + MADCode: "", + Classification: "error", + }) + if got := atomic.LoadInt32(&posts); got != 0 { + t.Fatalf("webhook posts = %d, want 0 for empty MAD code", got) + } +} + func TestSendRespectsContextDeadline(t *testing.T) { // Server that holds the response open longer than the client's // context allows. The handler exits when r.Context() is cancelled diff --git a/backend/internal/notifications/dispatcher.go b/backend/internal/notifications/dispatcher.go index 99b56cc..692e4f1 100644 --- a/backend/internal/notifications/dispatcher.go +++ b/backend/internal/notifications/dispatcher.go @@ -69,7 +69,9 @@ func (d *Dispatcher) Run(ctx context.Context) { // would mean state outside SQLite). func (d *Dispatcher) fanout(ctx context.Context, vn VerdictNotification) { if vn.MADCode == "" || strings.HasPrefix(vn.MADCode, "M0") { - // Benign verdicts don't fan out; webhooks are for flagged events. + // Empty MAD codes (classifier errors) and M0 benign verdicts do + // not fan out; these webhooks are for real flagged MAD findings. + // Operational outage alerts should be a separate alert type. return } hooks, err := d.store.ListWebhooks(ctx, true) diff --git a/backend/internal/store/events.go b/backend/internal/store/events.go index c7c71d9..722d1cf 100644 --- a/backend/internal/store/events.go +++ b/backend/internal/store/events.go @@ -64,6 +64,9 @@ type EventFilters struct { // Lets the dashboard surface flagged events that didn't trigger a // HITL hold (post-execution tool pairs, tool_call-less LLM pairs). MinMAD string + // VerdictStatus restricts to events whose latest verdict has this status. + // Accepts "ok" or "error"; empty = no filter. + VerdictStatus string } // InsertEvent persists one paired event and reports whether a new row @@ -247,6 +250,13 @@ func eventsWhere(f EventFilters) (string, []any) { args = append(args, t) } } + if f.VerdictStatus != "" { + parts = append(parts, "EXISTS (SELECT 1 FROM verdicts v "+ + "WHERE v.event_id = e.id "+ + "AND v.created_at = (SELECT max(v2.created_at) FROM verdicts v2 WHERE v2.event_id = e.id) "+ + "AND v.verdict_status = ?)") + args = append(args, f.VerdictStatus) + } return strings.Join(parts, " AND "), args } diff --git a/backend/internal/store/hitl.go b/backend/internal/store/hitl.go index 7aa3019..abae609 100644 --- a/backend/internal/store/hitl.go +++ b/backend/internal/store/hitl.go @@ -7,6 +7,7 @@ import ( "context" "database/sql" "errors" + "strings" "time" "github.com/google/uuid" @@ -47,28 +48,40 @@ func (s *Store) InsertHitlQueue(ctx context.Context, eventID, verdictID, session return err } -// ListHitlQueue returns rows in the requested status (default 'pending'), -// newest first, paginated. -func (s *Store) ListHitlQueue(ctx context.Context, status string, perPage, offset int) ([]*HitlReview, int, error) { +// ListHitlQueue returns rows in the requested review status (default +// 'pending') and optional verdict status, newest first, paginated. +func (s *Store) ListHitlQueue(ctx context.Context, status, verdictStatus string, perPage, offset int) ([]*HitlReview, int, error) { if status == "" { status = "pending" } + where := []string{"q.status = ?"} + args := []any{status} + if verdictStatus != "" { + where = append(where, "COALESCE(v.verdict_status, 'ok') = ?") + args = append(args, verdictStatus) + } + whereSQL := strings.Join(where, " AND ") + var total int if err := s.db.QueryRowContext(ctx, - `SELECT count(*) FROM hitl_queue WHERE status = ?`, status, + `SELECT count(*) + FROM hitl_queue q + LEFT JOIN verdicts v ON v.id = q.verdict_id + WHERE `+whereSQL, args..., ).Scan(&total); err != nil { return nil, 0, err } + queryArgs := append(append([]any{}, args...), perPage, offset) rows, err := s.db.QueryContext(ctx, `SELECT q.id, q.event_id, COALESCE(q.verdict_id, ''), COALESCE(q.session_id, ''), q.mad_code, COALESCE(v.verdict_status, 'ok'), q.status, COALESCE(q.reviewed_by, ''), COALESCE(q.reviewed_at, ''), q.created_at FROM hitl_queue q LEFT JOIN verdicts v ON v.id = q.verdict_id - WHERE q.status = ? + WHERE `+whereSQL+` ORDER BY q.created_at DESC LIMIT ? OFFSET ?`, - status, perPage, offset) + queryArgs...) if err != nil { return nil, 0, err } diff --git a/backend/internal/store/stats.go b/backend/internal/store/stats.go index 5d3ab47..abe40b3 100644 --- a/backend/internal/store/stats.go +++ b/backend/internal/store/stats.go @@ -10,11 +10,12 @@ import ( // Overview is the 24h summary the dashboard home renders. type Overview struct { - TotalEvents int - FlaggedVerdicts int - PendingReviews int - ActiveAgents int - VerdictsByMAD map[string]int + TotalEvents int + FlaggedVerdicts int + ClassifierErrors int + PendingReviews int + ActiveAgents int + VerdictsByMAD map[string]int } // ActivityBucket is one bin in the time-series response. @@ -35,15 +36,25 @@ func (s *Store) StatsOverview(ctx context.Context) (*Overview, error) { return nil, err } - // Flagged = anything other than M0/empty, i.e. an actual MAD code. + // Flagged = real non-M0 MAD findings. Classifier errors are tracked + // separately below so outages do not inflate security-finding totals. if err := s.db.QueryRowContext(ctx, `SELECT count(*) FROM verdicts WHERE created_at >= datetime('now', ?) + AND verdict_status = 'ok' AND mad_code != '' AND mad_code NOT LIKE 'M0%'`, window, ).Scan(&o.FlaggedVerdicts); err != nil { return nil, err } + if err := s.db.QueryRowContext(ctx, + `SELECT count(*) FROM verdicts + WHERE created_at >= datetime('now', ?) + AND verdict_status = 'error'`, window, + ).Scan(&o.ClassifierErrors); err != nil { + return nil, err + } + if err := s.db.QueryRowContext(ctx, `SELECT count(*) FROM hitl_queue WHERE status = 'pending'`, ).Scan(&o.PendingReviews); err != nil { @@ -59,7 +70,8 @@ func (s *Store) StatsOverview(ctx context.Context) (*Overview, error) { rows, err := s.db.QueryContext(ctx, `SELECT CASE - WHEN mad_code LIKE 'M0%' OR mad_code = '' THEN 'M0' + WHEN verdict_status = 'error' THEN 'error' + WHEN mad_code LIKE 'M0%' THEN 'M0' WHEN mad_code LIKE 'M2%' THEN 'M2' WHEN mad_code LIKE 'M3%' THEN 'M3' WHEN mad_code LIKE 'M4%' THEN 'M4' diff --git a/backend/internal/store/verdicts.go b/backend/internal/store/verdicts.go index 027490a..beaa21e 100644 --- a/backend/internal/store/verdicts.go +++ b/backend/internal/store/verdicts.go @@ -33,6 +33,7 @@ type VerdictListRow struct { MADCode string Classification string VerdictStatus string + Reasoning string LatencyMS *int64 TokensUsed int32 CreatedAt time.Time @@ -76,7 +77,7 @@ func (s *Store) ListVerdicts(ctx context.Context, f VerdictFilters, perPage, off args = append(args, perPage, offset) rows, err := s.db.QueryContext(ctx, `SELECT id, event_id, session_id, mad_code, classification, verdict_status, - latency_ms, tokens_used, created_at + COALESCE(reasoning, ''), latency_ms, tokens_used, created_at FROM verdicts WHERE `+where+` ORDER BY created_at DESC @@ -92,7 +93,7 @@ func (s *Store) ListVerdicts(ctx context.Context, f VerdictFilters, perPage, off var latency sql.NullInt64 var createdAt string if err := rows.Scan(&r.ID, &r.EventID, &r.SessionID, &r.MADCode, &r.Classification, &r.VerdictStatus, - &latency, &r.TokensUsed, &createdAt); err != nil { + &r.Reasoning, &latency, &r.TokensUsed, &createdAt); err != nil { return nil, 0, err } if latency.Valid { @@ -109,14 +110,14 @@ func (s *Store) ListVerdicts(ctx context.Context, f VerdictFilters, perPage, off func (s *Store) GetVerdictByEventID(ctx context.Context, eventID string) (*VerdictListRow, error) { row := s.db.QueryRowContext(ctx, `SELECT id, event_id, session_id, mad_code, classification, verdict_status, - latency_ms, tokens_used, created_at + COALESCE(reasoning, ''), latency_ms, tokens_used, created_at FROM verdicts WHERE event_id = ? ORDER BY created_at DESC LIMIT 1`, eventID) r := &VerdictListRow{} var latency sql.NullInt64 var createdAt string if err := row.Scan(&r.ID, &r.EventID, &r.SessionID, &r.MADCode, &r.Classification, &r.VerdictStatus, - &latency, &r.TokensUsed, &createdAt); err != nil { + &r.Reasoning, &latency, &r.TokensUsed, &createdAt); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, ErrNotFound } diff --git a/backend/internal/ws/handler.go b/backend/internal/ws/handler.go index a40706b..39c4db4 100644 --- a/backend/internal/ws/handler.go +++ b/backend/internal/ws/handler.go @@ -303,6 +303,11 @@ func persistAndClassify(ctx context.Context, sess *session, st *store.Store, cla } verdict, err := classifier.Classify(ctx, ev, agentProfileID) if err != nil { + if ctx.Err() != nil { + slog.InfoContext(ctx, "ws.classify_cancelled", + "error", err, "event_id", ev.EventId) + return nil + } slog.WarnContext(ctx, "ws.classifier_failure", "error", err, "event_id", ev.EventId) reasoning := "classifier failure: " + err.Error() @@ -354,6 +359,10 @@ func persistAndClassify(ctx context.Context, sess *session, st *store.Store, cla } func dispatchVerdict(ctx context.Context, sess *session, st *store.Store, hub *Hub, ev *pb.PairedEvent, snap *pb.PolicySnapshot, verdictID, madCode, verdictStatus string) error { + if verdictStatus == "error" { + return dispatchErrorVerdict(ctx, sess, st, hub, ev, snap, verdictID, madCode) + } + // Mode-gated dispatch: // alert: persist verdict, do NOT notify the SDK (dashboard-only). // hitl + in-scope + actionable: persist + queue for human review, @@ -361,7 +370,7 @@ func dispatchVerdict(ctx context.Context, sess *session, st *store.Store, hub *H // hitl + in-scope + non-actionable: forward (review would be a // no-op for the operator since the SDK never blocks on it). // hitl + out-of-scope: forward (no review queued for this code). - // block: forward all verdicts; SDK is the enforcement point. + // block: forward all OK verdicts; SDK is the enforcement point. inScope := shouldFanOut(snap, madCode) switch snap.GetMode() { case pb.Mode_MODE_ALERT: @@ -391,6 +400,38 @@ func dispatchVerdict(ctx context.Context, sess *session, st *store.Store, hub *H return nil } + publishVerdict(ctx, sess, hub, ev, snap, madCode, verdictStatus) + return nil +} + +func dispatchErrorVerdict(ctx context.Context, sess *session, st *store.Store, hub *Hub, ev *pb.PairedEvent, snap *pb.PolicySnapshot, verdictID, madCode string) error { + switch snap.GetMode() { + case pb.Mode_MODE_ALERT: + return nil + case pb.Mode_MODE_BLOCK: + publishVerdict(ctx, sess, hub, ev, snap, madCode, "error") + return nil + case pb.Mode_MODE_HITL: + if snap.GetFailClosedOnClassifierError() && isActionable(ev) { + if err := st.InsertHitlQueue(ctx, ev.EventId, verdictID, sess.sessionID, madCode); err != nil { + slog.ErrorContext(ctx, "hitl.insert_failed_fallback_publish", + "error", err, "event_id", ev.EventId, "verdict_id", verdictID) + publishVerdict(ctx, sess, hub, ev, snap, madCode, "error") + } + return nil + } + publishVerdict(ctx, sess, hub, ev, snap, madCode, "error") + return nil + default: + slog.WarnContext(ctx, "ws.unknown_mode_dropping_verdict", + "mode", snap.GetMode().String(), "event_id", ev.EventId) + return nil + } +} + +func publishVerdict(ctx context.Context, sess *session, hub *Hub, ev *pb.PairedEvent, snap *pb.PolicySnapshot, madCode, verdictStatus string) { + warnOldSDKClassifierErrorCompatibility(ctx, sess, ev, snap, verdictStatus) + out := &pb.ServerFrame{ Frame: &pb.ServerFrame_Verdict{ Verdict: &pb.Verdict{ @@ -406,7 +447,17 @@ func dispatchVerdict(ctx context.Context, sess *session, st *store.Store, hub *H slog.WarnContext(ctx, "ws.publish_dropped", "event_id", ev.EventId, "session_id", sess.sessionID) } - return nil +} + +func warnOldSDKClassifierErrorCompatibility(ctx context.Context, sess *session, ev *pb.PairedEvent, snap *pb.PolicySnapshot, verdictStatus string) { + if verdictStatus != "error" || !snap.GetFailClosedOnClassifierError() || sess.warnedClassifierErrorCompatibility { + return + } + sess.warnedClassifierErrorCompatibility = true + slog.WarnContext(ctx, "ws.classifier_error_fail_closed_requires_updated_sdk", + "event_id", ev.EventId, + "session_id", sess.sessionID, + "message", "old SDKs ignore classifier-error status and policy fields, so fail-closed enforcement requires the updated SDK") } func verdictStatusProto(status string) pb.VerdictStatus { diff --git a/backend/internal/ws/handler_test.go b/backend/internal/ws/handler_test.go index cf70d9c..473600a 100644 --- a/backend/internal/ws/handler_test.go +++ b/backend/internal/ws/handler_test.go @@ -254,6 +254,101 @@ func TestClassifierFailurePersistsAndPublishesErrorVerdict(t *testing.T) { } } +func TestClassifierFailureAlertPersistsWithoutPublish(t *testing.T) { + db, conn := classifierFailureConn(t, "alert", false) + + eventID := uuid.NewString() + if err := sendPairedEvent(conn, classifierFailureToolEvent(eventID, "classifier-failure-alert")); err != nil { + t.Fatalf("send paired_batch: %v", err) + } + + if err := expectNoServerFrame(conn, 250*time.Millisecond); err == nil { + t.Fatal("expected no SDK verdict in alert mode") + } + assertStoredErrorVerdict(t, db, eventID) +} + +func TestClassifierFailureHitlFailClosedQueuesActionable(t *testing.T) { + db, conn := classifierFailureConn(t, "hitl", true) + + eventID := uuid.NewString() + if err := sendPairedEvent(conn, classifierFailureActionableEvent(eventID, "classifier-failure-hitl")); err != nil { + t.Fatalf("send paired_batch: %v", err) + } + + if err := expectNoServerFrame(conn, 250*time.Millisecond); err == nil { + t.Fatal("expected actionable fail-closed ERROR verdict to be held for HITL") + } + assertStoredErrorVerdict(t, db, eventID) + + var queued int + if err := db.QueryRow( + `SELECT count(*) FROM hitl_queue h + JOIN verdicts v ON v.id = h.verdict_id + WHERE h.event_id = ? AND h.mad_code = '' AND v.verdict_status = 'error'`, + eventID, + ).Scan(&queued); err != nil { + t.Fatalf("query hitl_queue: %v", err) + } + if queued != 1 { + t.Fatalf("queued error reviews = %d, want 1", queued) + } +} + +func TestClassifierFailureHitlFailClosedNonActionablePublishes(t *testing.T) { + db, conn := classifierFailureConn(t, "hitl", true) + + eventID := uuid.NewString() + if err := sendPairedEvent(conn, classifierFailureToolEvent(eventID, "classifier-failure-hitl-nonactionable")); err != nil { + t.Fatalf("send paired_batch: %v", err) + } + + frame, err := readServerFrame(conn) + if err != nil { + t.Fatalf("read verdict: %v", err) + } + if got := frame.GetVerdict().GetStatus(); got != bpb.VerdictStatus_VERDICT_STATUS_ERROR { + t.Fatalf("pushed status = %v, want ERROR", got) + } + assertStoredErrorVerdict(t, db, eventID) + + var queued int + if err := db.QueryRow(`SELECT count(*) FROM hitl_queue WHERE event_id = ?`, eventID).Scan(&queued); err != nil { + t.Fatalf("query hitl_queue: %v", err) + } + if queued != 0 { + t.Fatalf("queued reviews = %d, want 0", queued) + } +} + +func TestClassifierFailureHitlQueueFailureFallsBackToPublish(t *testing.T) { + db, conn := classifierFailureConn(t, "hitl", true) + if _, err := db.Exec(` +CREATE TRIGGER fail_hitl_insert +BEFORE INSERT ON hitl_queue +BEGIN + SELECT RAISE(FAIL, 'forced hitl insert failure'); +END; +`); err != nil { + t.Fatalf("create hitl failure trigger: %v", err) + } + + eventID := uuid.NewString() + if err := sendPairedEvent(conn, classifierFailureActionableEvent(eventID, "classifier-failure-hitl-fallback")); err != nil { + t.Fatalf("send paired_batch: %v", err) + } + + frame, err := readServerFrame(conn) + if err != nil { + t.Fatalf("read verdict: %v", err) + } + verdict := frame.GetVerdict() + if verdict.GetStatus() != bpb.VerdictStatus_VERDICT_STATUS_ERROR || verdict.GetMadCode() != "" { + t.Fatalf("pushed verdict = (%q, %v), want ('', ERROR)", verdict.GetMadCode(), verdict.GetStatus()) + } + assertStoredErrorVerdict(t, db, eventID) +} + func TestDuplicateEventRetryKeepsWSOpen(t *testing.T) { db := openInMemoryDB(t) t.Cleanup(func() { _ = db.Close() }) @@ -648,6 +743,120 @@ type fakeClassifier struct { calls *int32 } +func classifierFailureConn(t *testing.T, mode string, failClosed bool) (*sql.DB, *websocket.Conn) { + t.Helper() + + db := openInMemoryDB(t) + t.Cleanup(func() { _ = db.Close() }) + + st := store.New(db) + plaintextKey := "adr_local_test_key_classifier_failure_" + uuid.NewString() + keyHash := sha256Hex(plaintextKey) + insertAPIKey(t, db, keyHash) + + failClosedInt := 0 + if failClosed { + failClosedInt = 1 + } + if _, err := db.Exec( + `UPDATE policies SET mode = ?, fail_closed_on_classifier_error = ? WHERE id = 1`, + mode, failClosedInt, + ); err != nil { + t.Fatalf("set policy: %v", err) + } + + llm := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "classifier exploded", http.StatusInternalServerError) + })) + t.Cleanup(llm.Close) + classifier := engine.NewHTTPClient(llm.URL, "test-key", "test-model", nil, nil) + + mux := http.NewServeMux() + mux.Handle("/ws", ws.AuthMiddleware(st)(ws.NewHandler(st, classifier, ws.NewHub(), nil, nil))) + srv := httptest.NewServer(mux) + t.Cleanup(srv.Close) + + wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + "/ws" + header := http.Header{"Authorization": {"Bearer " + plaintextKey}} + conn, _, err := websocket.DefaultDialer.Dial(wsURL, header) + if err != nil { + t.Fatalf("dial: %v", err) + } + t.Cleanup(func() { _ = conn.Close() }) + + if err := writeProto(conn, &bpb.ClientFrame{ + Frame: &bpb.ClientFrame_Login{Login: &bpb.SessionLogin{ + SessionId: "classifier-failure-sess-" + uuid.NewString(), SchemaVersion: 2, + }}, + }); err != nil { + t.Fatalf("send login: %v", err) + } + if _, err := readServerFrame(conn); err != nil { + t.Fatalf("read login_ack: %v", err) + } + return db, conn +} + +func classifierFailureToolEvent(eventID, sessionID string) *bpb.PairedEvent { + return &bpb.PairedEvent{ + EventId: eventID, SessionId: sessionID, + RunId: "run-classifier-failure", + PairType: bpb.PairType_PAIR_TYPE_TOOL, + Agent: &bpb.AgentContext{AgentId: "failure-agent"}, + Data: &bpb.PairedEvent_Tool{Tool: &bpb.ToolPairData{ + ToolName: "noop", ToolCallId: "tc-classifier-failure", Input: "{}", Output: "ok", + }}, + } +} + +func classifierFailureActionableEvent(eventID, sessionID string) *bpb.PairedEvent { + return &bpb.PairedEvent{ + EventId: eventID, SessionId: sessionID, + RunId: "run-classifier-failure", + PairType: bpb.PairType_PAIR_TYPE_LLM, + Agent: &bpb.AgentContext{AgentId: "failure-agent"}, + Data: &bpb.PairedEvent_Llm{Llm: &bpb.LlmPairData{ + Model: "test-model", + Output: "calling tool", + ToolCalls: []*bpb.ToolCall{{ + Name: "noop", Id: "tc-classifier-failure", Args: "{}", + }}, + }}, + } +} + +func sendPairedEvent(conn *websocket.Conn, ev *bpb.PairedEvent) error { + return writeProto(conn, &bpb.ClientFrame{ + Frame: &bpb.ClientFrame_PairedBatch{PairedBatch: &bpb.PairedEventBatch{ + Events: []*bpb.PairedEvent{ev}, + }}, + }) +} + +func expectNoServerFrame(conn *websocket.Conn, timeout time.Duration) error { + if err := conn.SetReadDeadline(time.Now().Add(timeout)); err != nil { + return err + } + _, _, err := conn.ReadMessage() + _ = conn.SetReadDeadline(time.Time{}) + return err +} + +func assertStoredErrorVerdict(t *testing.T, db *sql.DB, eventID string) { + t.Helper() + var madCode, classification, verdictStatus string + if err := db.QueryRow( + `SELECT mad_code, classification, verdict_status FROM verdicts WHERE event_id = ?`, + eventID, + ).Scan(&madCode, &classification, &verdictStatus); err != nil { + t.Fatalf("query verdict: %v", err) + } + if madCode != "" || classification != "error" || verdictStatus != "error" { + t.Fatalf("stored verdict = (%q, %q, %q), want ('', error, error)", + madCode, classification, verdictStatus) + } +} + func (f *fakeClassifier) Classify(_ context.Context, _ *bpb.PairedEvent, _ string) (*engine.Verdict, error) { if f.calls != nil { atomic.AddInt32(f.calls, 1) @@ -736,7 +945,8 @@ func statusOrZero(r *http.Response) int { } // testSchema is the minimum subset of 001_initial_schema.sql the WS -// handler exercises (api_keys, policies, events, verdicts, mcp_servers). +// handler exercises (api_keys, policies, events, verdicts, mcp_servers, +// hitl_queue). // Embedding the full migration file here would couple the test to the // migration's evolution. const testSchema = ` @@ -798,4 +1008,15 @@ CREATE TABLE agents ( last_seen TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ','now')), metadata TEXT NOT NULL DEFAULT '{}' ); +CREATE TABLE hitl_queue ( + id TEXT PRIMARY KEY, + event_id TEXT NOT NULL UNIQUE, + verdict_id TEXT, + session_id TEXT, + mad_code TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'pending', + reviewed_by TEXT, + reviewed_at TEXT, + created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ','now')) +); ` diff --git a/backend/internal/ws/helpers.go b/backend/internal/ws/helpers.go index 0b7e42b..4f4d283 100644 --- a/backend/internal/ws/helpers.go +++ b/backend/internal/ws/helpers.go @@ -45,8 +45,8 @@ func isActionable(ev *pb.PairedEvent) bool { return llm != nil && len(llm.ToolCalls) > 0 } -// shouldFanOut decides whether a verdict's MAD code is in scope for -// the active policy. False for codes outside the M0/M2/M3/M4 set +// shouldFanOut decides whether an OK verdict's MAD code is in scope +// for the active policy. False for codes outside the M0/M2/M3/M4 set // (defensive: an unrecognised code drops rather than panics) and for // MAD families whose policy_mX flag is unset. // diff --git a/backend/internal/ws/session.go b/backend/internal/ws/session.go index cb251a3..4409e6c 100644 --- a/backend/internal/ws/session.go +++ b/backend/internal/ws/session.go @@ -15,6 +15,8 @@ type session struct { llmProvider string llmModel string loggedIn bool + + warnedClassifierErrorCompatibility bool } // agentProfileID returns the bound agent_profile_id (or nil if the diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md index 5e0228f..2c5ee14 100644 --- a/docs/ARCHITECTURE.md +++ b/docs/ARCHITECTURE.md @@ -22,8 +22,9 @@ | HTTP POST to ADRIAN_LLM_URL (OpenAI | | compatible chat-completions), strip | | blocks, parse M-code. On | -| classifier error, fail-open with synthetic | -| M0 / benign + WARN log. | +| classifier error, return an error to WS | +| ingest. WS persists verdict_status=error, | +| mad_code="", and routes by policy. | | | | | v | | internal/store SQLite (WAL) writes: events, verdicts, | @@ -47,9 +48,9 @@ +--------------------------------------------------------------------+ | adrian-frontend (Next.js container) | | Login and force-change-password, agent profiles, API keys, | -| policy editor (singleton mode + per-MAD-code toggles), HITL | -| review queue, events and verdicts feeds (REST poll), webhook | -| configuration (Discord). | +| policy editor (singleton mode, per-MAD-code toggles, classifier | +| error fail-closed flag), HITL review queue, events and verdicts | +| feeds (REST poll), webhook configuration (Discord). | +--------------------------------------------------------------------+ +--------------------------------------------------------------------+ diff --git a/frontend/app/(dashboard)/events/page.tsx b/frontend/app/(dashboard)/events/page.tsx index 5ef9376..caad527 100644 --- a/frontend/app/(dashboard)/events/page.tsx +++ b/frontend/app/(dashboard)/events/page.tsx @@ -7,7 +7,7 @@ import { AlertExplanation } from '@/components/alert-explanation' import { Badge } from '@/components/badge' import { JsonBlock } from '@/components/json-block' import { Pagination } from '@/components/pagination' -import { madBadgeColor, timeAgo } from '@/lib/utils' +import { isClassifierErrorVerdict, timeAgo, verdictBadgeColor, verdictBadgeLabel } from '@/lib/utils' import { TimeRange, sinceForRange, TimeRangeSelect } from '@/components/time-range' type EventRow = { @@ -28,14 +28,23 @@ type VerdictInfo = { id: string mad_code: string classification: string - latency_ms: number + verdict_status: string + reasoning?: string + latency_ms?: number created_at: string } | null export default function EventsPage() { const [data, setData] = useState<{ events: EventRow[]; total: number }>({ events: [], total: 0 }) const [page, setPage] = useState(1) - const [filters, setFilters] = useState({ event_type: '', session_id: '', min_mad: '' }) + const [filters, setFilters] = useState(() => ({ + event_type: '', + session_id: '', + min_mad: '', + verdict_status: typeof window === 'undefined' + ? '' + : new URLSearchParams(window.location.search).get('verdict_status') || '', + })) const [range, setRange] = useState('24h') const [customSince, setCustomSince] = useState('') const [expanded, setExpanded] = useState(null) @@ -47,6 +56,7 @@ export default function EventsPage() { if (filters.event_type) params.set('event_type', filters.event_type) if (filters.session_id) params.set('session_id', filters.session_id) if (filters.min_mad) params.set('min_mad', filters.min_mad) + if (filters.verdict_status) params.set('verdict_status', filters.verdict_status) if (since) params.set('since', since) api(`/api/events?${params}`) .then(r => setData(r.data || { events: [], total: 0 })) @@ -91,6 +101,14 @@ export default function EventsPage() { + No verdict recorded for this event yet.

) : (
- - - Latency: {verdict.latency_ms}ms - + + {verdict.latency_ms !== undefined && ( + + Latency: {verdict.latency_ms}ms + + )}
)} - {verdict && verdict.mad_code !== 'M0' && ( + {verdict && isClassifierErrorVerdict(verdict) && verdict.reasoning && ( +

+ {verdict.reasoning} +

+ )} + {verdict && !isClassifierErrorVerdict(verdict) && verdict.mad_code !== 'M0' && (
diff --git a/frontend/app/(dashboard)/page.tsx b/frontend/app/(dashboard)/page.tsx index 1b2978f..4b1e4dc 100644 --- a/frontend/app/(dashboard)/page.tsx +++ b/frontend/app/(dashboard)/page.tsx @@ -8,6 +8,7 @@ import { madBadgeColor } from '@/lib/utils' type Overview = { total_events: number flagged_verdicts: number + classifier_errors: number pending_reviews: number active_agents: number verdicts_by_mad: Record @@ -49,9 +50,10 @@ export default function OverviewPage() { -
+
+
@@ -99,19 +101,21 @@ export default function OverviewPage() {

Verdict mix

{overview && Object.values(overview.verdicts_by_mad).some(v => v > 0) ? (
    - {(['M0', 'M2', 'M3', 'M4'] as const).map(family => { + {(['M0', 'M2', 'M3', 'M4', 'error'] as const).map(family => { const count = overview.verdicts_by_mad[family] || 0 const total = Object.values(overview.verdicts_by_mad).reduce((a, b) => a + b, 0) const pct = total ? (count / total) * 100 : 0 return (
  • - {family} + + {family === 'error' ? 'Classifier error' : family} + {count}
    diff --git a/frontend/app/(dashboard)/reviews/page.tsx b/frontend/app/(dashboard)/reviews/page.tsx index 0d3f65c..4b4af44 100644 --- a/frontend/app/(dashboard)/reviews/page.tsx +++ b/frontend/app/(dashboard)/reviews/page.tsx @@ -6,7 +6,7 @@ import { api } from '@/lib/api' import { AlertExplanation } from '@/components/alert-explanation' import { Badge } from '@/components/badge' import { JsonBlock } from '@/components/json-block' -import { madBadgeColor, timeAgo } from '@/lib/utils' +import { isClassifierErrorVerdict, timeAgo, verdictBadgeColor, verdictBadgeLabel } from '@/lib/utils' type ReviewSummary = { id: string @@ -14,6 +14,7 @@ type ReviewSummary = { verdict_id: string session_id: string mad_code: string + verdict_status: string status: string created_at: string } @@ -21,6 +22,7 @@ type ReviewSummary = { type ReviewDetail = ReviewSummary & { event_payload?: any classification?: string + reasoning?: string } export default function ReviewsPage() { @@ -93,7 +95,7 @@ export default function ReviewsPage() {

    Nothing waiting on you

    - When policy mode is HITL and a flagged verdict lands in scope, the SDK pauses and the event appears here. + When policy mode is HITL and a flagged verdict or fail-closed classifier error lands in scope, the SDK pauses and the event appears here.

    ) : ( @@ -109,7 +111,7 @@ export default function ReviewsPage() { }`} >
    - + {timeAgo(r.created_at)}
    @@ -126,7 +128,7 @@ export default function ReviewsPage() { ) : (
    - +
    - + {isClassifierErrorVerdict(detail) ? ( +
    +

    Classifier error

    +

    + The classifier did not return a MAD code. Approving resumes the paused SDK action; rejecting returns a blocked tool response. +

    + {detail.reasoning && ( +

    + {detail.reasoning} +

    + )} +
    + ) : ( + + )}

    Event payload

    diff --git a/frontend/app/(dashboard)/sessions/[session_id]/page.tsx b/frontend/app/(dashboard)/sessions/[session_id]/page.tsx index 806f519..0898bfc 100644 --- a/frontend/app/(dashboard)/sessions/[session_id]/page.tsx +++ b/frontend/app/(dashboard)/sessions/[session_id]/page.tsx @@ -5,12 +5,13 @@ import { useParams } from 'next/navigation' import { api } from '@/lib/api' import { Badge } from '@/components/badge' import { JsonBlock } from '@/components/json-block' -import { madBadgeColor, timeAgo } from '@/lib/utils' +import { verdictBadgeColor, verdictBadgeLabel, timeAgo } from '@/lib/utils' type Verdict = { id: string mad_code: string classification: string + verdict_status: string } type Entry = { @@ -92,7 +93,7 @@ export default function SessionTimelinePage() {
    {entry.verdict && ( - + )} {timeAgo(entry.created_at)}
    @@ -111,7 +112,7 @@ export default function SessionTimelinePage() {

    Verdict

    - +
    )} diff --git a/frontend/app/(dashboard)/settings/page.tsx b/frontend/app/(dashboard)/settings/page.tsx index 9f32b4d..5bacd5f 100644 --- a/frontend/app/(dashboard)/settings/page.tsx +++ b/frontend/app/(dashboard)/settings/page.tsx @@ -80,6 +80,7 @@ function PolicyTab() { policy_m2: !!policy.policy_m2, policy_m3: !!policy.policy_m3, policy_m4: !!policy.policy_m4, + fail_closed_on_classifier_error: !!policy.fail_closed_on_classifier_error, }), }) setStatus('saved') @@ -139,12 +140,27 @@ function PolicyTab() {
    )} +
    +

    + Classifier failure handling +

    + setPolicy({ ...policy, fail_closed_on_classifier_error: v })} + /> +

    + Older SDK versions ignore this flag. Update agents to the SDK version + shipped with this dashboard before relying on fail-closed enforcement. +

    +
    +
    NOTE - - Saving disconnects every active SDK session for this org so each one - reconnects with the new policy on its next event. Events already - in-flight at the moment of save are classified against the previous - policy - expect brief discrepancies for a few seconds after a change. + Mode changes apply when an SDK reconnects. The classifier-error + fail-closed flag is included on future verdicts, so BLOCK-mode timeout + decisions refresh after the next verdict snapshot reaches the SDK.
    diff --git a/frontend/lib/utils.ts b/frontend/lib/utils.ts index b538032..9d2f512 100644 --- a/frontend/lib/utils.ts +++ b/frontend/lib/utils.ts @@ -13,6 +13,13 @@ const PILL_NEUTRAL = 'bg-surface-raised text-ink-3 border-surface-border' const PILL_SOFT = 'bg-surface-overlay text-ink-2 border-surface-border' const PILL_MID = 'bg-ink-2 text-surface-raised border-ink-2' const PILL_STRONG = 'bg-ink text-surface-raised border-ink' +const PILL_ERROR = 'bg-danger/20 text-danger border-danger/40' + +export type VerdictDisplayInput = { + mad_code?: string + classification?: string + verdict_status?: string +} | null | undefined export function madBadgeColor(code: string): string { if (!code) return PILL_NEUTRAL @@ -22,6 +29,21 @@ export function madBadgeColor(code: string): string { return PILL_NEUTRAL } +export function isClassifierErrorVerdict(verdict: VerdictDisplayInput): boolean { + if (!verdict) return false + return verdict.verdict_status === 'error' || (verdict.classification === 'error' && !verdict.mad_code) +} + +export function verdictBadgeLabel(verdict: VerdictDisplayInput): string { + if (isClassifierErrorVerdict(verdict)) return 'Classifier error' + return verdict?.mad_code || 'No MAD code' +} + +export function verdictBadgeColor(verdict: VerdictDisplayInput): string { + if (isClassifierErrorVerdict(verdict)) return PILL_ERROR + return madBadgeColor(verdict?.mad_code || '') +} + export function classificationBadgeColor(cls: string): string { switch (cls) { case 'BLOCK': return PILL_STRONG diff --git a/sdk/python/adrian/__init__.py b/sdk/python/adrian/__init__.py index 03b7fd4..3fbeea5 100644 --- a/sdk/python/adrian/__init__.py +++ b/sdk/python/adrian/__init__.py @@ -193,10 +193,11 @@ def init( session_id: Session identifier. Falls back to ``ADRIAN_SESSION_ID``, then to a per-cwd persistent UUID. See :mod:`adrian.session_persistence`. - block_timeout: Max seconds to wait for a verdict in ``MODE_BLOCK`` - before fail-open. Ignored in ``MODE_ALERT`` (no wait) and - ``MODE_HITL`` (wait indefinitely). Falls back to - ``ADRIAN_BLOCK_TIMEOUT``. + block_timeout: Max seconds to wait for a verdict in ``MODE_BLOCK``. + Timeout handling follows the server policy's + ``fail_closed_on_classifier_error`` flag. Ignored in + ``MODE_ALERT`` (no wait) and ``MODE_HITL`` (wait indefinitely). + Falls back to ``ADRIAN_BLOCK_TIMEOUT``. on_event: Callback for every paired event. on_verdict: Callback for every verdict. on_block: Callback for BLOCK-tier verdicts (M3 / M4). Notification @@ -814,13 +815,17 @@ def _should_halt(verdict: pb.Verdict) -> bool: """Decide whether a verdict should halt tool execution. HITL resolutions override everything: ``continue_execution=False`` - means halt, ``True`` means continue. Otherwise the per-MAD policy - bool is the sole scope authority, if the verdict's tier is + means halt, ``True`` means continue. Classifier ERROR verdicts + follow ``fail_closed_on_classifier_error``. Otherwise the per-MAD + policy bool is the sole scope authority: if the verdict's tier is in-scope, halt; if not, continue. """ if verdict.HasField("hitl"): return not verdict.hitl.continue_execution + if verdict.status == pb.VERDICT_STATUS_ERROR: + return bool(verdict.policy.fail_closed_on_classifier_error) + mad_prefix = verdict.mad_code[:2] in_scope = { "M0": verdict.policy.policy_m0, @@ -861,7 +866,7 @@ def _patch_tool_node() -> None: In block mode, the async patch waits for the preceding LLM's verdict before executing tools. On BLOCK (unless overridden by ``on_block``) it returns synthetic ``ToolMessage`` responses instead of running the - tools. On timeout it fails open. + tools. On timeout it follows ``fail_closed_on_classifier_error``. """ try: from langgraph.prebuilt import ToolNode @@ -937,12 +942,25 @@ async def patched_ainvoke( # producing event_id to wait on, so let the tool run. return await original_ainvoke(self, input, config=config, **kwargs) + if tool_call_id not in ws._tool_call_id_to_event_id: # pyright: ignore[reportPrivateUsage] + # Unknown / evicted correlation: there is no producing LLM + # event to gate, so this remains fail-open even when + # classifier-error fail-closed is enabled. + return await original_ainvoke(self, input, config=config, **kwargs) + cfg = _get_config() timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) verdict = await ws.wait_for_tool_call_verdict(tool_call_id, timeout) if verdict is None: + if ws._mode == pb.MODE_BLOCK and ws.fail_closed_on_classifier_error(): # pyright: ignore[reportPrivateUsage] + logger.warning( + "verdict timeout for tool_call_id=%s, fail-closed", + tool_call_id, + ) + return _build_blocked_response(tool_calls) + logger.warning( "verdict timeout for tool_call_id=%s, fail-open", tool_call_id, diff --git a/sdk/python/adrian/config.py b/sdk/python/adrian/config.py index 238b54c..d4a0895 100644 --- a/sdk/python/adrian/config.py +++ b/sdk/python/adrian/config.py @@ -18,7 +18,8 @@ """Callback invoked for every verdict received. Accepts a ``VerdictContext`` with full event metadata. May be sync or -async. Fires for every MAD code the server forwards (M0 / M2 / M3 / M4). +async. Fires for every forwarded verdict, including classifier-error +verdicts whose ``mad_code`` is empty. """ type OnBlockCallback = ( @@ -97,8 +98,9 @@ class AdrianConfig: ws_url: WebSocket URL for the Adrian server. ``None`` disables the WebSocket handler. block_timeout: Max seconds to wait for a verdict in ``MODE_BLOCK`` - before fail-open. Ignored in ``MODE_ALERT`` (no wait) and - ``MODE_HITL`` (wait indefinitely). + before applying the server's classifier-error timeout policy. + Ignored in ``MODE_ALERT`` (no wait) and ``MODE_HITL`` (wait + indefinitely). on_event: Callback for every paired event. on_verdict: Callback for every verdict. on_block: Callback for BLOCK-tier verdicts (M3 / M4). diff --git a/sdk/python/adrian/handler.py b/sdk/python/adrian/handler.py index 3289041..99aa7be 100644 --- a/sdk/python/adrian/handler.py +++ b/sdk/python/adrian/handler.py @@ -141,6 +141,7 @@ async def handle_verdict(self, verdict: pb.Verdict) -> None: run_id=record.run_id, parent_run_id=record.parent_run_id, policy=verdict.policy, + status=verdict.status, mad_code=verdict.mad_code, hitl=hitl, ) diff --git a/sdk/python/adrian/types.py b/sdk/python/adrian/types.py index 13c1506..082b1c4 100644 --- a/sdk/python/adrian/types.py +++ b/sdk/python/adrian/types.py @@ -219,9 +219,11 @@ class VerdictContext: event_data: Original event payload TypedDict. run_id: LangChain run ID. parent_run_id: Parent run ID if nested, or ``None``. - mad_code: MAD policy code the classifier returned + status: Classifier result status. ``VERDICT_STATUS_ERROR`` means + the classifier did not produce a MAD code; ``mad_code`` is empty. + mad_code: MAD policy code the classifier returned on OK verdicts (e.g. ``"M0"``, ``"M2_C"``, ``"M4_a"``). Empty string - when no code is set (benign). + means no classifier-produced MAD code exists. policy: Org's effective execution-mode policy at the moment this verdict was decided. Carries the mode (alert / block / hitl) and per-MAD-code scope booleans. @@ -238,6 +240,7 @@ class VerdictContext: run_id: str parent_run_id: str | None policy: pb.PolicySnapshot + status: int = pb.VERDICT_STATUS_UNSPECIFIED mad_code: str = "" hitl: pb.HitlResponse | None = None diff --git a/sdk/python/adrian/ws.py b/sdk/python/adrian/ws.py index 1ab5df4..e1948f9 100644 --- a/sdk/python/adrian/ws.py +++ b/sdk/python/adrian/ws.py @@ -226,9 +226,10 @@ def __init__( self._model = "" # Server-supplied execution-mode policy. Populated when the # first ServerFrame{login_ack} arrives after each (re)connect. - # ``policy_active()`` and ``block_timeout()`` read this state - # to decide whether the patched ToolNode should wait for a - # verdict and how long. + # ``policy_active()``, ``block_timeout()``, and + # ``fail_closed_on_classifier_error()`` read this state to + # decide whether the patched ToolNode should wait for a verdict + # and how to handle classifier failures/timeouts. self._mode: int = pb.MODE_UNSPECIFIED self._policy: pb.PolicySnapshot | None = None # Set the first time a ``ServerFrame{login_ack}`` is applied. @@ -259,7 +260,8 @@ def __init__( # with the matching ``Verdict`` proto. Futures survive a # disconnect: a late verdict after reconnect still resolves # the wait; if none arrives, ``wait_for_verdict``'s timeout - # produces a natural fail-open in BLOCK mode. + # returns None and the patched ToolNode applies the current + # fail-open/fail-closed classifier-error policy. self._pending_verdicts: dict[str, asyncio.Future[pb.Verdict]] = {} # Maps LLM pair run_id → event_id so a subsequent tool call can # look up the verdict by its parent_run_id (the LLM's run_id). @@ -319,8 +321,8 @@ def policy_active(self) -> bool: def block_timeout(self, kwarg_default: float) -> float | None: """Effective per-tool-call wait timeout for the active mode. - - ``MODE_BLOCK``: ``kwarg_default`` (typically 30s), fail-open - if the server doesn't classify in time. + - ``MODE_BLOCK``: ``kwarg_default`` (typically 30s). Timeout + handling follows ``fail_closed_on_classifier_error``. - ``MODE_HITL``: ``None``, wait indefinitely for human review. - ``MODE_ALERT`` / unset: ``0``, caller short-circuits before registering a future. @@ -332,6 +334,13 @@ def block_timeout(self, kwarg_default: float) -> float | None: else: return 0 + def fail_closed_on_classifier_error(self) -> bool: + """Whether classifier errors/timeouts should halt tool execution.""" + return bool( + self._policy is not None + and self._policy.fail_closed_on_classifier_error + ) + # -- EventHandler protocol -- async def on_paired_event(self, event: PairedEvent) -> None: @@ -687,12 +696,13 @@ def _on_login_ack(self, ack: pb.LoginAck) -> None: self._login_ack_received.set() logger.info( "LoginAck received: mode=%s policy_m0=%s policy_m2=%s " - "policy_m3=%s policy_m4=%s", + "policy_m3=%s policy_m4=%s fail_closed_on_classifier_error=%s", pb.Mode.Name(ack.policy.mode), ack.policy.policy_m0, ack.policy.policy_m2, ack.policy.policy_m3, ack.policy.policy_m4, + ack.policy.fail_closed_on_classifier_error, ) if self._on_login_ack_cb is not None: @@ -717,10 +727,17 @@ async def _on_verdict_frame(self, verdict: pb.Verdict) -> None: owns the cleanup: its ``finally`` pops the entry after the await returns. """ + if verdict.HasField("policy"): + # Keep the policy snapshot fresh for BLOCK-mode timeout + # decisions. Execution mode remains login-fixed for this + # release; hot-switching mode mid-session is out of scope. + self._policy = verdict.policy + logger.info( - "Verdict received: event_id=%s mad_code=%s mode=%s hitl=%s", + "Verdict received: event_id=%s mad_code=%s status=%s mode=%s hitl=%s", verdict.event_id, verdict.mad_code or "-", + pb.VerdictStatus.Name(verdict.status), pb.Mode.Name(verdict.policy.mode), verdict.HasField("hitl"), ) @@ -935,9 +952,9 @@ async def wait_for_verdict( """Wait for a verdict for ``event_id``. ``timeout`` is mode-derived (see :meth:`block_timeout`): - a positive float for ``MODE_BLOCK`` (fail-open at timeout), + a positive float for ``MODE_BLOCK`` (caller applies policy at timeout), ``None`` for ``MODE_HITL`` (wait indefinitely). Returns the - verdict, or ``None`` on timeout (fail-open). + verdict, or ``None`` on timeout. Cleans up the ``_pending_verdicts`` entry on either path: ``_on_verdict_frame`` only resolves the future, the dict diff --git a/sdk/python/tests/test_block_mode.py b/sdk/python/tests/test_block_mode.py index 0d1c352..1234a09 100644 --- a/sdk/python/tests/test_block_mode.py +++ b/sdk/python/tests/test_block_mode.py @@ -35,6 +35,7 @@ def _apply_mode( policy_m2: bool = False, policy_m3: bool = False, policy_m4: bool = False, + fail_closed_on_classifier_error: bool = False, ) -> pb.PolicySnapshot: """Drive the mode/policy state as if a LoginAck had arrived.""" policy = pb.PolicySnapshot( @@ -43,6 +44,7 @@ def _apply_mode( policy_m2=policy_m2, policy_m3=policy_m3, policy_m4=policy_m4, + fail_closed_on_classifier_error=fail_closed_on_classifier_error, ) ws._mode = mode ws._policy = policy @@ -261,6 +263,174 @@ def _real_tool(x: str) -> str: assert captured == ["hi"] + async def test_timeout_fail_closed_blocks_tool(self, tmp_path: Path) -> None: + captured: list[str] = [] + + def _real_tool(x: str) -> str: + """Real tool stub for block-mode tests.""" + captured.append(x) + + return x + + adrian.init( + api_key="k", + log_file=str(tmp_path / "events.jsonl"), + auto_instrument=True, + ws_url="ws://x", + block_timeout=0.05, + ) + + ws = adrian._ws_client + assert ws is not None + _apply_mode( + ws, + pb.MODE_BLOCK, + policy_m4=True, + fail_closed_on_classifier_error=True, + ) + ws._connected.set() + ws._tool_call_id_to_event_id["tc-1"] = "llm-evt" + + tool_node = ToolNode([_real_tool]) + ai = AIMessage( + content="", + tool_calls=[{"id": "tc-1", "name": "_real_tool", "args": {"x": "hi"}}], + ) + state: dict[str, Any] = {"messages": [ai]} + + result = await tool_node.ainvoke(state, config=_runtime_config()) # pyright: ignore[reportUnknownMemberType] + + assert captured == [] + assert "BLOCKED" in result["messages"][0].content + + async def test_error_verdict_fail_open_runs_tool(self, tmp_path: Path) -> None: + captured: list[str] = [] + + def _real_tool(x: str) -> str: + """Real tool stub for block-mode tests.""" + captured.append(x) + + return x + + adrian.init( + api_key="k", + log_file=str(tmp_path / "events.jsonl"), + auto_instrument=True, + ws_url="ws://x", + block_timeout=1.0, + ) + + ws = adrian._ws_client + assert ws is not None + policy = _apply_mode(ws, pb.MODE_BLOCK, fail_closed_on_classifier_error=False) + ws._connected.set() + ws._tool_call_id_to_event_id["tc-1"] = "llm-evt" + + fut = ws.register_pending("llm-evt") + fut.set_result( + pb.Verdict( + event_id="llm-evt", + status=pb.VERDICT_STATUS_ERROR, + mad_code="", + policy=policy, + ), + ) + + tool_node = ToolNode([_real_tool]) + ai = AIMessage( + content="", + tool_calls=[{"id": "tc-1", "name": "_real_tool", "args": {"x": "hi"}}], + ) + state: dict[str, Any] = {"messages": [ai]} + + await tool_node.ainvoke(state, config=_runtime_config()) # pyright: ignore[reportUnknownMemberType] + + assert captured == ["hi"] + + async def test_error_verdict_fail_closed_blocks_tool( + self, + tmp_path: Path, + ) -> None: + captured: list[str] = [] + + def _real_tool(x: str) -> str: + """Real tool stub for block-mode tests.""" + captured.append(x) + + return x + + adrian.init( + api_key="k", + log_file=str(tmp_path / "events.jsonl"), + auto_instrument=True, + ws_url="ws://x", + block_timeout=1.0, + ) + + ws = adrian._ws_client + assert ws is not None + policy = _apply_mode(ws, pb.MODE_BLOCK, fail_closed_on_classifier_error=True) + ws._connected.set() + ws._tool_call_id_to_event_id["tc-1"] = "llm-evt" + + fut = ws.register_pending("llm-evt") + fut.set_result( + pb.Verdict( + event_id="llm-evt", + status=pb.VERDICT_STATUS_ERROR, + mad_code="", + policy=policy, + ), + ) + + tool_node = ToolNode([_real_tool]) + ai = AIMessage( + content="", + tool_calls=[{"id": "tc-1", "name": "_real_tool", "args": {"x": "hi"}}], + ) + state: dict[str, Any] = {"messages": [ai]} + + result = await tool_node.ainvoke(state, config=_runtime_config()) # pyright: ignore[reportUnknownMemberType] + + assert captured == [] + assert "BLOCKED" in result["messages"][0].content + + async def test_unknown_tool_call_stays_fail_open_when_fail_closed( + self, + tmp_path: Path, + ) -> None: + captured: list[str] = [] + + def _real_tool(x: str) -> str: + """Real tool stub for block-mode tests.""" + captured.append(x) + + return x + + adrian.init( + api_key="k", + log_file=str(tmp_path / "events.jsonl"), + auto_instrument=True, + ws_url="ws://x", + block_timeout=0.05, + ) + + ws = adrian._ws_client + assert ws is not None + _apply_mode(ws, pb.MODE_BLOCK, fail_closed_on_classifier_error=True) + ws._connected.set() + + tool_node = ToolNode([_real_tool]) + ai = AIMessage( + content="", + tool_calls=[{"id": "tc-1", "name": "_real_tool", "args": {"x": "hi"}}], + ) + state: dict[str, Any] = {"messages": [ai]} + + await tool_node.ainvoke(state, config=_runtime_config()) # pyright: ignore[reportUnknownMemberType] + + assert captured == ["hi"] + class TestModeAlert: async def test_alert_mode_skips_wait(self, tmp_path: Path) -> None: diff --git a/sdk/python/tests/test_exec_modes.py b/sdk/python/tests/test_exec_modes.py index 1ea8ae1..ff95700 100644 --- a/sdk/python/tests/test_exec_modes.py +++ b/sdk/python/tests/test_exec_modes.py @@ -168,6 +168,62 @@ async def test_out_of_scope_continues_without_hitl( assert captured == ["hi"] + async def test_error_review_approve_continues(self, tmp_path: Path) -> None: + """ERROR verdict + HITL approve still resumes the tool.""" + captured: list[str] = [] + ws = _init_with_ws(tmp_path) + policy = pb.PolicySnapshot( + mode=pb.MODE_HITL, + fail_closed_on_classifier_error=True, + ) + _apply_mode(ws, policy) + ws._tool_call_id_to_event_id["tc-1"] = "evt-1" + + verdict = pb.Verdict( + event_id="evt-1", + status=pb.VERDICT_STATUS_ERROR, + mad_code="", + policy=policy, + ) + verdict.hitl.continue_execution = True + ws.register_pending("evt-1").set_result(verdict) + + result = await ToolNode([_stub_tool(captured)]).ainvoke( # pyright: ignore[reportUnknownMemberType] + _ainvoke_state(), + config=_runtime_config(), + ) + + assert captured == ["hi"] + assert "BLOCKED" not in result["messages"][0].content + + async def test_error_review_reject_halts(self, tmp_path: Path) -> None: + """ERROR verdict + HITL reject blocks the tool.""" + captured: list[str] = [] + ws = _init_with_ws(tmp_path) + policy = pb.PolicySnapshot( + mode=pb.MODE_HITL, + fail_closed_on_classifier_error=True, + ) + _apply_mode(ws, policy) + ws._tool_call_id_to_event_id["tc-1"] = "evt-1" + + verdict = pb.Verdict( + event_id="evt-1", + status=pb.VERDICT_STATUS_ERROR, + mad_code="", + policy=policy, + ) + verdict.hitl.continue_execution = False + ws.register_pending("evt-1").set_result(verdict) + + result = await ToolNode([_stub_tool(captured)]).ainvoke( # pyright: ignore[reportUnknownMemberType] + _ainvoke_state(), + config=_runtime_config(), + ) + + assert captured == [] + assert "BLOCKED" in result["messages"][0].content + # ------------------------------------------------------------------ # Stray HITL resolution + protocol error diff --git a/sdk/python/tests/test_handler.py b/sdk/python/tests/test_handler.py index cc421ec..b5fa09e 100644 --- a/sdk/python/tests/test_handler.py +++ b/sdk/python/tests/test_handler.py @@ -12,6 +12,8 @@ from adrian.handler import AdrianCallbackHandler, extract_model_name from adrian.hooks import HookRegistry from adrian.pairing import EventPairBuffer +from adrian.proto import event_pb2 as pb +from adrian.types import EventRecord, VerdictContext from langchain_core.messages import ( AIMessage, BaseMessage, # noqa: TC002 @@ -230,3 +232,47 @@ def test_kwargs_model_name(self) -> None: def test_empty_dict(self) -> None: assert extract_model_name({}) == "unknown" + + +class TestVerdictCallbacks: + async def test_error_verdict_populates_status_without_mad_callbacks(self) -> None: + seen: list[VerdictContext] = [] + blocked: list[VerdictContext] = [] + audited: list[VerdictContext] = [] + + handler = AdrianCallbackHandler( + pair_buffer=EventPairBuffer(), + context_tracker=AgentContextTracker(), + hooks=HookRegistry(), + config=AdrianConfig( + on_verdict=seen.append, + on_block=blocked.append, + on_audit=audited.append, + ), + ) + handler._event_map["evt-error"] = EventRecord( # pyright: ignore[reportPrivateUsage] + event_type="llm", + data={ + "output": "tool call", + "tool_calls": [], + "usage": None, + }, + run_id="run-1", + parent_run_id=None, + ) + + await handler.handle_verdict( + pb.Verdict( + event_id="evt-error", + session_id="sess-1", + status=pb.VERDICT_STATUS_ERROR, + mad_code="", + policy=pb.PolicySnapshot(fail_closed_on_classifier_error=True), + ), + ) + + assert len(seen) == 1 + assert seen[0].status == pb.VERDICT_STATUS_ERROR + assert seen[0].mad_code == "" + assert blocked == [] + assert audited == [] diff --git a/sdk/python/tests/test_ws.py b/sdk/python/tests/test_ws.py index c28762e..6e79ac7 100644 --- a/sdk/python/tests/test_ws.py +++ b/sdk/python/tests/test_ws.py @@ -299,6 +299,26 @@ async def __anext__(self) -> bytes: assert resolved.event_id == "evt-1" assert resolved.mad_code == "M4_a" + async def test_verdict_frame_refreshes_policy_without_switching_mode(self) -> None: + client = WebSocketClient("ws://x", "s", api_key="k") + client._mode = pb.MODE_ALERT + client._policy = pb.PolicySnapshot( + mode=pb.MODE_ALERT, + fail_closed_on_classifier_error=False, + ) + + verdict = pb.Verdict( + event_id="evt-1", + policy=pb.PolicySnapshot( + mode=pb.MODE_BLOCK, + fail_closed_on_classifier_error=True, + ), + ) + await client._on_verdict_frame(verdict) + + assert client._mode == pb.MODE_ALERT + assert client.fail_closed_on_classifier_error() is True + # ------------------------------------------------------------------ # Block-mode primitives From 974dfcd2f87d0a75ee383890fe8a5cbd11ff074a Mon Sep 17 00:00:00 2001 From: Muhammad-usman92 Date: Mon, 15 Jun 2026 21:00:57 +0500 Subject: [PATCH 4/4] Make migration 002 recovery safe --- backend/internal/db/migrate.go | 158 ++++++++++++++++++++++++ backend/internal/db/migrate_test.go | 184 ++++++++++++++++++++++++++++ scripts/setup.py | 135 ++++++++++++++++++-- 3 files changed, 469 insertions(+), 8 deletions(-) diff --git a/backend/internal/db/migrate.go b/backend/internal/db/migrate.go index 0c03576..aa251a9 100644 --- a/backend/internal/db/migrate.go +++ b/backend/internal/db/migrate.go @@ -54,6 +54,16 @@ func applyMigrations(db *sql.DB, fsys fs.FS) ([]string, error) { if alreadyApplied { continue } + reconciled, appliedRecovery, err := reconcileMigration002(db, name) + if err != nil { + return nil, err + } + if reconciled { + if appliedRecovery { + applied = append(applied, name) + } + continue + } body, err := fs.ReadFile(fsys, name) if err != nil { @@ -103,3 +113,151 @@ func migrationApplied(db *sql.DB, name string) (bool, error) { } return false, fmt.Errorf("lookup migration %s: %w", name, err) } + +func reconcileMigration002(db *sql.DB, name string) (bool, bool, error) { + if name != "002_verdict_status_policy.sql" { + return false, false, nil + } + + hasPolicyColumn, err := tableHasColumn(db, "policies", "fail_closed_on_classifier_error") + if err != nil { + return false, false, err + } + hasVerdictStatus, err := tableHasColumn(db, "verdicts", "verdict_status") + if err != nil { + return false, false, err + } + allowsErrorClassification, err := tableSQLContains(db, "verdicts", "'error'") + if err != nil { + return false, false, err + } + + if hasPolicyColumn && hasVerdictStatus && allowsErrorClassification { + if _, err := db.Exec(`INSERT OR IGNORE INTO schema_migrations (name) VALUES (?)`, name); err != nil { + return false, false, fmt.Errorf("record recovered %s: %w", name, err) + } + return true, false, nil + } + + if hasPolicyColumn { + if _, err := db.Exec(migration002VerdictsRecoverySQL); err != nil { + _, _ = db.Exec("ROLLBACK") + _, _ = db.Exec("PRAGMA foreign_keys=ON") + return false, false, fmt.Errorf("recover %s verdicts: %w", name, err) + } + if _, err := db.Exec(`INSERT INTO schema_migrations (name) VALUES (?)`, name); err != nil { + return false, false, fmt.Errorf("record recovered %s: %w", name, err) + } + return true, true, nil + } + + if hasVerdictStatus && allowsErrorClassification { + if _, err := db.Exec(migration002PolicyColumnSQL); err != nil { + return false, false, fmt.Errorf("recover %s policy column: %w", name, err) + } + if _, err := db.Exec(`INSERT INTO schema_migrations (name) VALUES (?)`, name); err != nil { + return false, false, fmt.Errorf("record recovered %s: %w", name, err) + } + return true, true, nil + } + return false, false, nil +} + +func tableHasColumn(db *sql.DB, table, column string) (bool, error) { + rows, err := db.Query(`SELECT name FROM pragma_table_info(?)`, table) + if err != nil { + return false, fmt.Errorf("inspect %s columns: %w", table, err) + } + defer rows.Close() + + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + return false, err + } + if name == column { + return true, nil + } + } + return false, rows.Err() +} + +func tableSQLContains(db *sql.DB, table, needle string) (bool, error) { + var sqlText string + err := db.QueryRow(`SELECT sql FROM sqlite_master WHERE type = 'table' AND name = ?`, table).Scan(&sqlText) + if err == sql.ErrNoRows { + return false, nil + } + if err != nil { + return false, fmt.Errorf("inspect %s schema: %w", table, err) + } + return strings.Contains(sqlText, needle), nil +} + +const migration002PolicyColumnSQL = ` +ALTER TABLE policies + ADD COLUMN fail_closed_on_classifier_error INTEGER NOT NULL DEFAULT 0 + CHECK (fail_closed_on_classifier_error IN (0,1)); +` + +const migration002VerdictsRecoverySQL = ` +PRAGMA foreign_keys=OFF; + +BEGIN; + +DROP TABLE IF EXISTS verdicts_new; + +CREATE TABLE verdicts_new ( + id TEXT PRIMARY KEY, + event_id TEXT NOT NULL REFERENCES events(id) ON DELETE CASCADE, + session_id TEXT NOT NULL, + agent_profile_id TEXT REFERENCES agent_profiles(id) ON DELETE SET NULL, + mad_code TEXT NOT NULL, + classification TEXT NOT NULL CHECK (classification IN ('benign','notify','block','error')), + verdict_status TEXT NOT NULL DEFAULT 'ok' + CHECK (verdict_status IN ('ok','error')), + reasoning TEXT, + latency_ms INTEGER, + tokens_used INTEGER NOT NULL DEFAULT 0, + created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ','now')) +); + +INSERT INTO verdicts_new ( + id, + event_id, + session_id, + agent_profile_id, + mad_code, + classification, + verdict_status, + reasoning, + latency_ms, + tokens_used, + created_at +) +SELECT + id, + event_id, + session_id, + agent_profile_id, + mad_code, + classification, + 'ok', + reasoning, + latency_ms, + tokens_used, + created_at +FROM verdicts; + +DROP TABLE verdicts; +ALTER TABLE verdicts_new RENAME TO verdicts; + +CREATE INDEX IF NOT EXISTS idx_verdicts_event_id ON verdicts(event_id); +CREATE INDEX IF NOT EXISTS idx_verdicts_session_id ON verdicts(session_id); +CREATE INDEX IF NOT EXISTS idx_verdicts_created_at ON verdicts(created_at); + +COMMIT; + +PRAGMA foreign_key_check; +PRAGMA foreign_keys=ON; +` diff --git a/backend/internal/db/migrate_test.go b/backend/internal/db/migrate_test.go index b257698..4cbec9f 100644 --- a/backend/internal/db/migrate_test.go +++ b/backend/internal/db/migrate_test.go @@ -8,6 +8,8 @@ import ( "testing" "testing/fstest" + "github.com/secureagentics/Adrian/backend/migrations" + _ "modernc.org/sqlite" ) @@ -112,6 +114,100 @@ COMMIT;`), } } +func TestEmbeddedMigration002UpgradesPopulatedPre002DB(t *testing.T) { + conn := openTestDB(t) + defer conn.Close() + + applyEmbedded001Only(t, conn) + seedPre002VerdictAndReview(t, conn) + + applied, err := applyMigrations(conn, migrations.Files) + if err != nil { + t.Fatalf("apply embedded migrations: %v", err) + } + if got, want := applied, []string{"002_verdict_status_policy.sql"}; len(got) != len(want) || got[0] != want[0] { + t.Fatalf("applied migrations = %v, want %v", got, want) + } + + assert002SchemaAndData(t, conn) + + applied, err = applyMigrations(conn, migrations.Files) + if err != nil { + t.Fatalf("second embedded apply: %v", err) + } + if len(applied) != 0 { + t.Fatalf("second embedded apply = %v, want no migrations", applied) + } +} + +func TestEmbeddedMigration002RecordsCompletedSchemaAfterCrashBeforeLedger(t *testing.T) { + conn := openTestDB(t) + defer conn.Close() + + applyEmbedded001Only(t, conn) + seedPre002VerdictAndReview(t, conn) + + body, err := migrations.Files.ReadFile("002_verdict_status_policy.sql") + if err != nil { + t.Fatalf("read embedded 002 migration: %v", err) + } + if _, err := conn.Exec(string(body)); err != nil { + t.Fatalf("simulate completed 002 migration without ledger record: %v", err) + } + if migrationWasRecorded(t, conn, "002_verdict_status_policy.sql") { + t.Fatal("test setup unexpectedly recorded 002 migration") + } + + applied, err := applyMigrations(conn, migrations.Files) + if err != nil { + t.Fatalf("recover after completed 002 without ledger: %v", err) + } + if len(applied) != 0 { + t.Fatalf("recovery applied migrations = %v, want none", applied) + } + if !migrationWasRecorded(t, conn, "002_verdict_status_policy.sql") { + t.Fatal("recovery did not record completed 002 migration") + } + + assert002SchemaAndData(t, conn) +} + +func TestEmbeddedMigration002RecoversAfterPolicyColumnAddedBeforeLedger(t *testing.T) { + conn := openTestDB(t) + defer conn.Close() + + applyEmbedded001Only(t, conn) + seedPre002VerdictAndReview(t, conn) + + if _, err := conn.Exec(migration002PolicyColumnSQL); err != nil { + t.Fatalf("simulate partial 002 policy-column migration: %v", err) + } + if migrationWasRecorded(t, conn, "002_verdict_status_policy.sql") { + t.Fatal("test setup unexpectedly recorded 002 migration") + } + + applied, err := applyMigrations(conn, migrations.Files) + if err != nil { + t.Fatalf("recover after partial 002 without ledger: %v", err) + } + if got, want := applied, []string{"002_verdict_status_policy.sql"}; len(got) != len(want) || got[0] != want[0] { + t.Fatalf("recovery applied migrations = %v, want %v", got, want) + } + if !migrationWasRecorded(t, conn, "002_verdict_status_policy.sql") { + t.Fatal("recovery did not record completed 002 migration") + } + + assert002SchemaAndData(t, conn) + + applied, err = applyMigrations(conn, migrations.Files) + if err != nil { + t.Fatalf("second embedded apply after recovery: %v", err) + } + if len(applied) != 0 { + t.Fatalf("second embedded apply after recovery = %v, want no migrations", applied) + } +} + func openTestDB(t *testing.T) *sql.DB { t.Helper() conn, err := sql.Open("sqlite", "file:migratetest?mode=memory&cache=shared") @@ -121,6 +217,94 @@ func openTestDB(t *testing.T) *sql.DB { return conn } +func applyEmbedded001Only(t *testing.T, conn *sql.DB) { + t.Helper() + initialSQL, err := migrations.Files.ReadFile("001_initial_schema.sql") + if err != nil { + t.Fatalf("read embedded 001 migration: %v", err) + } + applied, err := applyMigrations(conn, fstest.MapFS{ + "001_initial_schema.sql": {Data: initialSQL}, + }) + if err != nil { + t.Fatalf("apply embedded 001 migration: %v", err) + } + if got, want := applied, []string{"001_initial_schema.sql"}; len(got) != len(want) || got[0] != want[0] { + t.Fatalf("applied initial migrations = %v, want %v", got, want) + } +} + +func seedPre002VerdictAndReview(t *testing.T, conn *sql.DB) { + t.Helper() + if _, err := conn.Exec(` +INSERT INTO events (id, session_id, event_type, payload) +VALUES ('evt-populated', 'sess-populated', 'llm', '{}'); +INSERT INTO verdicts (id, event_id, session_id, mad_code, classification, reasoning) +VALUES ('verdict-populated', 'evt-populated', 'sess-populated', 'M4_a', 'block', 'seed'); +INSERT INTO hitl_queue (id, event_id, verdict_id, session_id, mad_code) +VALUES ('review-populated', 'evt-populated', 'verdict-populated', 'sess-populated', 'M4_a'); +`); err != nil { + t.Fatalf("seed populated pre-002 database: %v", err) + } +} + +func assert002SchemaAndData(t *testing.T, conn *sql.DB) { + t.Helper() + + var failClosed int + if err := conn.QueryRow(`SELECT fail_closed_on_classifier_error FROM policies WHERE id = 1`).Scan(&failClosed); err != nil { + t.Fatalf("query policy flag: %v", err) + } + if failClosed != 0 { + t.Fatalf("fail_closed_on_classifier_error = %d, want 0", failClosed) + } + + var madCode, classification, verdictStatus string + if err := conn.QueryRow(` +SELECT mad_code, classification, verdict_status +FROM verdicts WHERE id = 'verdict-populated' +`).Scan(&madCode, &classification, &verdictStatus); err != nil { + t.Fatalf("query upgraded verdict: %v", err) + } + if madCode != "M4_a" || classification != "block" || verdictStatus != "ok" { + t.Fatalf("upgraded verdict = (%q, %q, %q), want (M4_a, block, ok)", madCode, classification, verdictStatus) + } + + if _, err := conn.Exec(` +INSERT INTO verdicts (id, event_id, session_id, mad_code, classification, verdict_status, reasoning) +VALUES ('verdict-error', 'evt-populated', 'sess-populated', '', 'error', 'error', 'classifier failure: test'); +`); err != nil { + t.Fatalf("insert classifier-error verdict after upgrade: %v", err) + } + + var reviewVerdictID string + if err := conn.QueryRow(`SELECT verdict_id FROM hitl_queue WHERE id = 'review-populated'`).Scan(&reviewVerdictID); err != nil { + t.Fatalf("query preserved hitl_queue row: %v", err) + } + if reviewVerdictID != "verdict-populated" { + t.Fatalf("preserved hitl_queue verdict_id = %q, want verdict-populated", reviewVerdictID) + } + + for _, name := range []string{"idx_verdicts_event_id", "idx_verdicts_session_id", "idx_verdicts_created_at"} { + var seen int + if err := conn.QueryRow(`SELECT count(*) FROM sqlite_master WHERE type = 'index' AND name = ?`, name).Scan(&seen); err != nil { + t.Fatalf("query index %s: %v", name, err) + } + if seen != 1 { + t.Fatalf("index %s count = %d, want 1", name, seen) + } + } + + rows, err := conn.Query(`PRAGMA foreign_key_check`) + if err != nil { + t.Fatalf("foreign_key_check: %v", err) + } + defer rows.Close() + if rows.Next() { + t.Fatal("foreign_key_check returned violations after 002 migration") + } +} + func migrationWasRecorded(t *testing.T, conn *sql.DB, name string) bool { t.Helper() var seen int diff --git a/scripts/setup.py b/scripts/setup.py index 86bf47e..3d2d2a5 100644 --- a/scripts/setup.py +++ b/scripts/setup.py @@ -132,6 +132,11 @@ def apply_migrations(conn: sqlite3.Connection, migrations_dir: Path) -> list[str for path in sql_files: if migration_applied(conn, path.name): continue + reconciled, applied_recovery = reconcile_migration_002(conn, path.name) + if reconciled: + if applied_recovery: + applied.append(path.name) + continue sql = path.read_text(encoding="utf-8") if NO_TRANSACTION_MARKER in sql: try: @@ -143,15 +148,10 @@ def apply_migrations(conn: sqlite3.Connection, migrations_dir: Path) -> list[str conn.execute("PRAGMA foreign_keys=ON") raise else: - quoted_name = path.name.replace("'", "''") try: - conn.executescript( - "BEGIN;\n" - f"{sql}\n" - "INSERT INTO schema_migrations (name) " - f"VALUES ('{quoted_name}');\n" - "COMMIT;\n", - ) + conn.executescript("BEGIN;\n" + sql + "\n") + conn.execute("INSERT INTO schema_migrations (name) VALUES (?)", (path.name,)) + conn.commit() except sqlite3.Error: conn.rollback() raise @@ -168,6 +168,125 @@ def migration_applied(conn: sqlite3.Connection, name: str) -> bool: return row is not None +def reconcile_migration_002(conn: sqlite3.Connection, name: str) -> tuple[bool, bool]: + """Recover 002 if it completed or stopped after adding the policy column.""" + if name != "002_verdict_status_policy.sql": + return False, False + + has_policy_column = table_has_column(conn, "policies", "fail_closed_on_classifier_error") + has_verdict_status = table_has_column(conn, "verdicts", "verdict_status") + allows_error_classification = table_sql_contains(conn, "verdicts", "'error'") + + if has_policy_column and has_verdict_status and allows_error_classification: + conn.execute("INSERT OR IGNORE INTO schema_migrations (name) VALUES (?)", (name,)) + conn.commit() + return True, False + + if has_policy_column: + try: + conn.executescript(MIGRATION_002_VERDICTS_RECOVERY_SQL) + conn.execute("INSERT INTO schema_migrations (name) VALUES (?)", (name,)) + conn.commit() + return True, True + except sqlite3.Error: + conn.rollback() + conn.execute("PRAGMA foreign_keys=ON") + raise + + if has_verdict_status and allows_error_classification: + conn.execute(MIGRATION_002_POLICY_COLUMN_SQL) + conn.execute("INSERT INTO schema_migrations (name) VALUES (?)", (name,)) + conn.commit() + return True, True + + return False, False + + +def table_has_column(conn: sqlite3.Connection, table: str, column: str) -> bool: + return any( + row[0] == column + for row in conn.execute("SELECT name FROM pragma_table_info(?)", (table,)) + ) + + +def table_sql_contains(conn: sqlite3.Connection, table: str, needle: str) -> bool: + row = conn.execute( + "SELECT sql FROM sqlite_master WHERE type = 'table' AND name = ?", + (table,), + ).fetchone() + return row is not None and needle in row[0] + + +MIGRATION_002_POLICY_COLUMN_SQL = """ +ALTER TABLE policies + ADD COLUMN fail_closed_on_classifier_error INTEGER NOT NULL DEFAULT 0 + CHECK (fail_closed_on_classifier_error IN (0,1)) +""" + + +MIGRATION_002_VERDICTS_RECOVERY_SQL = """ +PRAGMA foreign_keys=OFF; + +BEGIN; + +DROP TABLE IF EXISTS verdicts_new; + +CREATE TABLE verdicts_new ( + id TEXT PRIMARY KEY, + event_id TEXT NOT NULL REFERENCES events(id) ON DELETE CASCADE, + session_id TEXT NOT NULL, + agent_profile_id TEXT REFERENCES agent_profiles(id) ON DELETE SET NULL, + mad_code TEXT NOT NULL, + classification TEXT NOT NULL CHECK (classification IN ('benign','notify','block','error')), + verdict_status TEXT NOT NULL DEFAULT 'ok' + CHECK (verdict_status IN ('ok','error')), + reasoning TEXT, + latency_ms INTEGER, + tokens_used INTEGER NOT NULL DEFAULT 0, + created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ','now')) +); + +INSERT INTO verdicts_new ( + id, + event_id, + session_id, + agent_profile_id, + mad_code, + classification, + verdict_status, + reasoning, + latency_ms, + tokens_used, + created_at +) +SELECT + id, + event_id, + session_id, + agent_profile_id, + mad_code, + classification, + 'ok', + reasoning, + latency_ms, + tokens_used, + created_at +FROM verdicts; + +DROP TABLE verdicts; +ALTER TABLE verdicts_new RENAME TO verdicts; + +CREATE INDEX IF NOT EXISTS idx_verdicts_event_id ON verdicts(event_id); +CREATE INDEX IF NOT EXISTS idx_verdicts_session_id ON verdicts(session_id); +CREATE INDEX IF NOT EXISTS idx_verdicts_created_at ON verdicts(created_at); + +COMMIT; + +PRAGMA foreign_key_check; +PRAGMA foreign_keys=ON; +""" + + def read_env(env_path: Path) -> dict[str, str]: """Parse a `KEY=VALUE` env file. Comments and blanks ignored. Quoted values are stripped of surrounding double quotes."""