diff --git a/docs/adapters/frameworks-agentforce.md b/docs/adapters/frameworks-agentforce.md new file mode 100644 index 0000000..4b7aceb --- /dev/null +++ b/docs/adapters/frameworks-agentforce.md @@ -0,0 +1,267 @@ +# Salesforce Agentforce framework adapter + +`layerlens.instrument.adapters.frameworks.agentforce.AgentForceAdapter` +imports Salesforce Agentforce session traces from Data Cloud DMOs and +emits them as LayerLens canonical events. The adapter package also ships +companion modules for the Agent API REST surface, the Pub/Sub Platform +Events stream, the Einstein Trust Layer policy importer, and an +LLM evaluator that runs LayerLens graders against captured sessions. + +This adapter is **import-mode** rather than runtime monkey-patching: it +authenticates against a Salesforce org via OAuth 2.0 JWT Bearer and runs +SOQL queries against the AgentForce DMO objects to backfill trace data. +Salesforce Agentforce itself is a remote multi-tenant service, not a +Python library, so there is no framework SDK to instrument in-process. + +## Install + +```bash +pip install 'layerlens[agentforce]' +``` + +The `[agentforce]` extra pulls `requests>=2.28` (used by the JWT Bearer +flow, the SOQL HTTP transport, the Agent API REST client, and the CometD +Pub/Sub fallback). The Salesforce credentials must be provisioned +out-of-band (Connected App + private key + permitted user — see +[OAuth setup](#oauth-setup) below). + +## Quick start + +```python +from layerlens.instrument.adapters.frameworks.agentforce import ( + AgentForceAdapter, + SalesforceCredentials, +) +from layerlens.instrument.transport.sink_http import HttpEventSink + +credentials = SalesforceCredentials( + client_id="3MVG9...", + username="agent-importer@example.com", + private_key="env:SALESFORCE_PRIVATE_KEY", # or file path or raw PEM + instance_url="https://example.my.salesforce.com", +) + +sink = HttpEventSink(adapter_name="salesforce_agentforce") +adapter = AgentForceAdapter(credentials=credentials) +adapter.add_sink(sink) +adapter.connect() # JWT flow runs here + +result = adapter.import_sessions( + start_date="2026-04-01", + end_date="2026-04-25", + limit=100, +) +print( + f"Imported {result.events_generated} events " + f"from {result.sessions_imported} sessions" +) + +adapter.disconnect() +sink.close() +``` + +A fully runnable, mocked end-to-end sample lives in +[`samples/instrument/agentforce/`](../../samples/instrument/agentforce/). + +## What's wrapped + +This adapter does not monkey-patch anything in process. It calls SOQL +against the following Data Cloud DMO objects: + +| DMO object | Purpose | +|----------------------------------|------------------------------------------| +| `AIAgentSession` | Top-level session record | +| `AIAgentSessionParticipant` | Agents + users in the session | +| `AIAgentInteraction` | Turns within a session | +| `AIAgentInteractionStep` | Individual steps inside an interaction | +| `AIAgentInteractionMessage` | Raw input / output messages | + +Each row is normalized via `AgentForceNormalizer` and emitted through +the adapter's `emit_dict_event` pipeline (which honors the +`CaptureConfig` filter and circuit-breaker state). + +Companion modules in the same package: + +| Module | What it does | +|---------------------------|-------------------------------------------------------| +| `auth.py` | OAuth 2.0 JWT Bearer flow + SOQL HTTP client | +| `client.py` | Agent API REST client (real-time session capture) | +| `events.py` | Platform Events subscriber (gRPC + CometD fallback) | +| `mapper.py` | Agent API session → LayerLens event mapper | +| `trust_layer.py` | Einstein Trust Layer policy import / YAML emission | +| `llm_eval.py` | `EinsteinEvaluator` — A/B prompt + model comparison | + +## Events emitted + +| Event | Layer | When | +|----------------------|--------|--------------------------------------------------------| +| `agent.lifecycle` | L1 | Per `AIAgentSession` start / end. | +| `agent.identity` | L1 | Per `AIAgentSessionParticipant`. | +| `agent.interaction` | L1 | Per `AIAgentInteraction`. | +| `agent.input` | L1 | Per `AIAgentInteractionMessage` with role=user. | +| `agent.output` | L1 | Per `AIAgentInteractionMessage` with role=agent. | +| `model.invoke` | L3 | Per `LLMExecutionStep` from `AIAgentInteractionStep`. | +| `tool.call` | L5a | Per `ActionInvocationStep` / `FunctionStep`. | +| `environment.config` | L4a | Per topic classification (Agent API path). | +| `agent.state.change` | L1 | Per Agent API session start / end (live mapper). | +| `policy.violation` | cross | Per Einstein Trust Layer policy hit. | +| `agent.handoff` | L4a | Per escalation (Agent API mapper). | + +Each emitted event from the importer path includes `_identity` (the +Salesforce record `Id`) and `_timestamp` (record `LastModifiedDate`) for +re-import idempotency. + +## OAuth setup + +The adapter authenticates with Salesforce via the +[OAuth 2.0 JWT Bearer flow][oauth-jwt]. This is the supported +server-to-server flow for backfill agents — no interactive user login +or refresh-token rotation is needed. + +[oauth-jwt]: https://help.salesforce.com/s/articleView?id=sf.remoteaccess_oauth_jwt_flow.htm&type=5 + +### 1. Create a Connected App in Salesforce + +In your Salesforce org: **Setup → App Manager → New Connected App**. +Configure: + +- **Connected App Name**: `LayerLens AgentForce Importer` +- **API (Enable OAuth Settings)**: ✅ +- **Use digital signatures**: ✅ — upload your public-key X.509 certificate +- **Selected OAuth Scopes**: + - `Manage user data via APIs (api)` + - `Perform requests at any time (refresh_token, offline_access)` + - `Access Agentforce Service APIs (agentforce_api)` (if available in + your edition; otherwise `api` is sufficient for SOQL DMO reads) +- **Require Secret for Web Server Flow**: ✅ +- **Callback URL**: any placeholder (e.g. `https://login.salesforce.com/`) + — JWT Bearer flow does not actually use this. + +Save and copy the **Consumer Key** — that's your `client_id`. + +### 2. Generate a key pair + +```bash +openssl req -x509 -nodes -newkey rsa:2048 \ + -keyout layerlens-agentforce.key \ + -out layerlens-agentforce.crt \ + -days 365 -subj "/CN=layerlens-agentforce" +``` + +Upload the `.crt` to the Connected App. Keep the `.key` secret. + +### 3. Pre-authorize the integration user + +**Setup → Connected Apps → Manage → Edit Policies**: + +- **Permitted Users**: `Admin approved users are pre-authorized` +- Add a profile or permission set that includes the integration user. + The integration user must have read access to the AgentForce DMOs + (`AIAgentSession*`). + +### 4. Configure the SDK + +Pass the credentials via `SalesforceCredentials`. The `private_key` +field accepts three forms: + +| Form | Example | +|-----------------------|--------------------------------------| +| `env:NAME` reference | `env:SF_PRIVATE_KEY_PEM` | +| Filesystem path | `/etc/secrets/layerlens-agentforce.key` | +| Inline PEM string | `-----BEGIN PRIVATE KEY-----\n...\n` | + +```python +from layerlens.instrument.adapters.frameworks.agentforce import ( + SalesforceCredentials, +) + +credentials = SalesforceCredentials( + client_id="3MVG9...", # Connected App Consumer Key + username="layerlens-agentforce@example.com", + private_key="env:SF_PRIVATE_KEY_PEM", + instance_url="https://example.my.salesforce.com", +) +``` + +The `SalesforceConnection.authenticate()` call constructs and signs the +JWT with `RS256` and exchanges it at +`https://${instance_url}/services/oauth2/token` for an access token. +Tokens are cached in-memory for ~1 hour and refreshed automatically. + +## Salesforce specifics + +- **Token lifetime**: ~2 hours, treated as 1 hour to leave room for + clock drift. The adapter re-authenticates automatically when the + cached token expires before the next operation. +- **Rate limits**: a warning is logged when the API daily limit + consumption passes 80%. Salesforce returns the consumption in the + `Sforce-Limit-Info` response header. +- **Incremental sync**: pass `last_import_timestamp` to + `import_sessions(...)` to fetch only records modified since a + watermark. +- **Batch size**: configurable via the `batch_size` constructor arg + (default 200; the SOQL `IN` clause maximum is 2000). +- **SOQL injection**: every parent ID interpolated into the `WHERE … IN + (…)` clause is validated against the `^[a-zA-Z0-9]{15}(?:[a-zA-Z0-9]{3})?$` + Salesforce ID regex before splicing. Date / timestamp parameters are + validated against ISO 8601 regexes. + +## Capture config + +```python +from layerlens.instrument.adapters._base import CaptureConfig + +# Recommended for compliance backfills. +adapter = AgentForceAdapter( + credentials=credentials, + capture_config=CaptureConfig.standard(), +) + +# Strip raw message bodies, keep only structural events. +adapter = AgentForceAdapter( + credentials=credentials, + capture_config=CaptureConfig( + l1_agent_io=True, + l4a_environment_config=True, + capture_content=False, + ), +) +``` + +## BYOK + +Salesforce manages its own model keys (Einstein Trust Layer abstracts +the provider). The adapter does not own model API keys. The Salesforce +credentials themselves are intended to live in atlas-app's +`byok_credentials` table once M1.B ships — see `docs/adapters/byok.md`. + +## Trust Layer round-trip + +`TrustLayerImporter` exports the org's Einstein Trust Layer policy as +LayerLens YAML so the same guardrails can be re-evaluated outside the +Salesforce control plane: + +```python +from layerlens.instrument.adapters.frameworks.agentforce import ( + SalesforceConnection, + TrustLayerImporter, +) + +connection = SalesforceConnection(credentials=credentials) +connection.authenticate() +config, yaml_str = TrustLayerImporter(connection).import_and_convert( + policy_name="agentforce_trust_layer", +) +print(yaml_str) +``` + +The legacy alias `to_stratix_policy(...)` is retained for compatibility +with the original `stratix.*` adapter package and emits a +`DeprecationWarning`; new code should call `to_layerlens_policy(...)` +directly. + +## Replay + +`adapter.serialize_for_replay()` returns a `ReplayableTrace` with all +events captured during the current `import_sessions` call. Replay is a +re-emit operation: the adapter does not re-query Salesforce. diff --git a/pyproject.toml b/pyproject.toml index ae6d1dc..9440440 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,17 @@ classifiers = [ [project.optional-dependencies] cli = ["click>=8.0.0"] +# --- Instrument layer: framework adapters --- +# Adding any extra below MUST keep the default `pip install layerlens` +# install set unchanged. The Salesforce Agentforce adapter is import-mode +# only (it talks to a remote REST surface, not an in-process Python SDK) +# so the extra resolves to the HTTP transport plus the JWT signing +# library used by the OAuth 2.0 JWT Bearer flow. +agentforce = [ + "requests>=2.28", + "PyJWT[crypto]>=2.8", +] + [project.urls] Homepage = "https://github.com/LayerLens/stratix-python" Repository = "https://github.com/LayerLens/stratix-python" diff --git a/samples/instrument/agentforce/README.md b/samples/instrument/agentforce/README.md new file mode 100644 index 0000000..eb2a6a2 --- /dev/null +++ b/samples/instrument/agentforce/README.md @@ -0,0 +1,58 @@ +# Salesforce Agentforce sample + +Runnable end-to-end sample for the +`layerlens.instrument.adapters.frameworks.agentforce` adapter. + +The sample is **fully mocked** — it makes no network calls to either +Salesforce or LayerLens. It exists to demonstrate the API surface and act +as a smoke test that the `[agentforce]` extra installs cleanly. + +## Install + +```bash +pip install 'layerlens[agentforce]' +``` + +The `[agentforce]` extra pulls in `requests>=2.28` (the JWT Bearer flow +and SOQL HTTP transport). + +## Run + +```bash +python -m samples.instrument.agentforce.main +``` + +You should see four labeled flows print to stdout: + +* `[backfill]` — SOQL session backfill via the Data Cloud DMO importer. +* `[live]` — Synchronous Agent API request / response capture. +* `[trust-layer]` — Einstein Trust Layer export to LayerLens YAML policy. +* `[evaluator]` — Einstein evaluator offline behavior (logs the + zero-score fallback when no LayerLens API key is configured). + +The sample exits 0 on success. + +## Live Salesforce auth (optional) + +If you have a Salesforce Connected App with the JWT Bearer flow +configured, set these environment variables before running and the +sample will additionally exercise a live `connect()` against the org: + +```bash +export SALESFORCE_CLIENT_ID="3MVG9..." +export SALESFORCE_USERNAME="agent-importer@example.com" +export SALESFORCE_PRIVATE_KEY="env:SF_PRIVATE_KEY_PEM" # or a file path / raw PEM +export SALESFORCE_INSTANCE_URL="https://example.my.salesforce.com" +``` + +`SALESFORCE_PRIVATE_KEY` accepts three forms: + +| Form | Example | +|------|---------| +| `env:NAME` reference | `env:SF_PRIVATE_KEY_PEM` | +| Filesystem path | `/etc/secrets/sf-jwt.pem` | +| Inline PEM string | `-----BEGIN PRIVATE KEY-----\n...\n-----END PRIVATE KEY-----` | + +See `docs/adapters/frameworks-agentforce.md` for the OAuth Connected +App setup, the Trust Layer policy round-trip, and the full event taxonomy +the adapter emits. diff --git a/samples/instrument/agentforce/__init__.py b/samples/instrument/agentforce/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/samples/instrument/agentforce/main.py b/samples/instrument/agentforce/main.py new file mode 100644 index 0000000..29e78de --- /dev/null +++ b/samples/instrument/agentforce/main.py @@ -0,0 +1,333 @@ +"""Runnable sample: drive the Salesforce Agentforce adapter end-to-end. + +This sample is **fully mocked** — both the Salesforce REST surface and the +LayerLens telemetry sink are stubbed in-process. It demonstrates:: + + 1. Adapter construction with explicit ``CaptureConfig``. + 2. Three import-shaped flows: + - SOQL session backfill (Data Cloud DMOs). + - Agent API live capture (synchronous request / response). + - Einstein Trust Layer policy export. + 3. Event routing through ``BaseAdapter`` → recording sink. + 4. Clean shutdown with summary. + +Run:: + + pip install 'layerlens[agentforce]' + python -m samples.instrument.agentforce.main + +If the optional ``SALESFORCE_*`` environment variables are present, the +sample additionally exercises a single ``connect()`` call against the live +Salesforce org via the JWT Bearer flow. Otherwise the sample stays in +mock-only mode and exits with code 0. + +Required environment for the smoke run: + +* (none — the sample exits cleanly without any env vars) + +Optional environment for the live auth check: + +* ``SALESFORCE_CLIENT_ID`` — Connected App consumer key. +* ``SALESFORCE_USERNAME`` — Salesforce user the JWT is issued for. +* ``SALESFORCE_PRIVATE_KEY`` — PEM-encoded private key (or + ``env:VARNAME`` reference, or a filesystem path). +* ``SALESFORCE_INSTANCE_URL`` — your org's My Domain URL + (e.g. ``https://example.my.salesforce.com``). +""" + +from __future__ import annotations + +import os +import sys +from typing import Any +from unittest import mock + +from layerlens.instrument.adapters._base import CaptureConfig +from layerlens.instrument.adapters.frameworks.agentforce import ( + AgentApiClient, + AgentForceAdapter, + EinsteinEvaluator, + SalesforceAuthError, + SalesforceConnection, + SalesforceCredentials, + TrustLayerImporter, +) +from layerlens.instrument.adapters.frameworks.agentforce.models import ( + TrustLayerConfig, + TrustLayerGuardrail, +) + + +class _RecordingSink: + """Stand-in for an HTTP / OTLP sink — records every event in-process.""" + + def __init__(self) -> None: + self.events: list[tuple[str, dict[str, Any]]] = [] + + def emit(self, *args: Any, **kwargs: Any) -> None: # noqa: ARG002 + if len(args) == 2 and isinstance(args[0], str): + self.events.append((args[0], args[1])) + + +def _have_salesforce_env() -> bool: + return all( + os.environ.get(name) + for name in ( + "SALESFORCE_CLIENT_ID", + "SALESFORCE_USERNAME", + "SALESFORCE_PRIVATE_KEY", + ) + ) + + +def _mock_credentials() -> SalesforceCredentials: + """Build credentials that do NOT require a real Salesforce org.""" + creds = SalesforceCredentials( + client_id="3MVG9SampleConnectedAppKey0000000", + username="sample-importer@example.com", + private_key="-----BEGIN PRIVATE KEY-----\nMIISample\n-----END PRIVATE KEY-----\n", + instance_url="https://example.my.salesforce.com", + ) + creds.access_token = "00DSAMPLE!AQ.TOKEN" + creds.token_expiry = 9_999_999_999.0 # not expired + return creds + + +def _mock_connection() -> SalesforceConnection: + conn = SalesforceConnection(credentials=_mock_credentials()) + conn.instance_url = "https://example.my.salesforce.com" + return conn + + +# --------------------------------------------------------------------------- +# Flow 1 — SOQL session backfill (Data Cloud DMO import) +# --------------------------------------------------------------------------- + + +def _flow_session_backfill(sink: _RecordingSink) -> int: + """Import a synthetic AgentForce session via the SOQL importer path.""" + adapter = AgentForceAdapter( + stratix=sink, + connection=_mock_connection(), + capture_config=CaptureConfig.full(), + ) + adapter.connect() + + # Replace the connection.query with fixture rows to simulate the SOQL + # responses the importer would receive from a real Salesforce org. + session_row = { + "Id": "0XxSAMPLE00001A", + "StartTimestamp": "2026-04-25T10:00:00Z", + "EndTimestamp": "2026-04-25T10:00:30Z", + "AiAgentChannelTypeId": "Web", + "AiAgentSessionEndType": "Completed", + "VoiceCallId": None, + "MessagingSessionId": None, + "PreviousSessionId": None, + } + participant_row = { + "Id": "1XxSAMPLE00001B", + "AiAgentSessionId": session_row["Id"], + "AiAgentTypeId": "EinsteinServiceAgent", + "AiAgentApiName": "Service_Agent", + "AiAgentVersionApiName": "v1", + "ParticipantId": "user-1", + "AiAgentSessionParticipantRoleId": "Agent", + } + interaction_row = { + "Id": "2XxSAMPLE00001C", + "AiAgentSessionId": session_row["Id"], + "AiAgentInteractionTypeId": "Conversation", + "TelemetryTraceId": "trace-1", + "TelemetryTraceSpanId": "span-1", + "TopicApiName": "Order_Status", + "AttributeText": '{"intent":"check_order"}', + "PrevInteractionId": None, + } + step_row = { + "Id": "3XxSAMPLE00001D", + "AiAgentInteractionId": interaction_row["Id"], + "AiAgentInteractionStepTypeId": "ActionInvocationStep", + "InputValueText": '{"order_id":"O-123"}', + "OutputValueText": '{"status":"shipped"}', + "ErrorMessageText": None, + "GenerationId": None, + "GenAiGatewayRequestId": None, + "GenAiGatewayResponseId": None, + "Name": "lookup_order", + "TelemetryTraceSpanId": "span-2", + } + + fixture_responses = [ + [session_row], + [participant_row], + [interaction_row], + [step_row], + [], # no AIAgentInteractionMessage rows + ] + with mock.patch.object( + adapter._importer._connection, # type: ignore[union-attr] + "query", + side_effect=fixture_responses, + ): + result = adapter.import_sessions(start_date="2026-04-25") + + print( + f"[backfill] imported {result.sessions_imported} session, " + f"{result.events_generated} events emitted" + ) + adapter.disconnect() + return 0 + + +# --------------------------------------------------------------------------- +# Flow 2 — Agent API live capture (request / response) +# --------------------------------------------------------------------------- + + +def _flow_live_capture() -> int: + """Drive a live Agent API session through the mocked REST surface.""" + + class _R: + status_code = 200 + headers: dict[str, str] = {} + + def __init__(self, payload: dict[str, Any]) -> None: + self._payload = payload + + def json(self) -> dict[str, Any]: + return self._payload + + def raise_for_status(self) -> None: + return None + + create_resp = _R({"sessionId": "session-1", "createdAt": "2026-04-25T10:00:00Z"}) + send_resp = _R( + { + "messages": [ + {"id": "m1", "text": "Your order shipped on 2026-04-24."}, + ], + "topic": "Order_Status", + "actions": [ + {"name": "lookup_order", "parameters": {"id": "O-123"}, "result": "shipped"}, + ], + "guardrailResults": [ + {"name": "toxicity", "triggered": False, "message": "clean"}, + ], + } + ) + end_resp = _R({}) + + client = AgentApiClient(connection=_mock_connection()) + with mock.patch("requests.post", side_effect=[create_resp, send_resp]), mock.patch( + "requests.delete", return_value=end_resp + ): + session = client.create_session(agent_name="Service_Agent") + message = client.send_message(session.session_id, "Where is my order?") + client.end_session(session.session_id) + + print(f"[live] session={session.session_id} agent_response={message!r}") + return 0 + + +# --------------------------------------------------------------------------- +# Flow 3 — Einstein Trust Layer policy export +# --------------------------------------------------------------------------- + + +def _flow_trust_layer_export() -> int: + """Convert a Trust Layer config into LayerLens YAML policy.""" + importer = TrustLayerImporter(connection=_mock_connection()) + cfg = TrustLayerConfig( + guardrails=[ + TrustLayerGuardrail(name="toxicity_detection", type="toxicity"), + TrustLayerGuardrail(name="pii_detection", type="pii", threshold=0.9), + ], + zero_data_retention=True, + audit_trail_enabled=True, + ) + yaml_str = importer.to_layerlens_policy(cfg, policy_name="sample_policy") + first_lines = "\n".join(yaml_str.splitlines()[:6]) + print("[trust-layer] generated policy YAML (first 6 lines):") + print(first_lines) + return 0 + + +# --------------------------------------------------------------------------- +# Flow 4 — Einstein evaluator (graceful offline fallback) +# --------------------------------------------------------------------------- + + +def _flow_evaluator_offline() -> int: + """Show the offline behavior of the evaluator (no LayerLens client).""" + evaluator = EinsteinEvaluator() + results = evaluator.evaluate_completions( + session_ids=["0XxSAMPLE00001A"], + graders=["relevance", "faithfulness", "safety"], + ) + for r in results: + print( + f"[evaluator] session={r.session_id} composite={r.composite_score} " + f"scores={r.scores}" + ) + return 0 + + +# --------------------------------------------------------------------------- +# Optional: live JWT auth check (only if SALESFORCE_* env vars present) +# --------------------------------------------------------------------------- + + +def _flow_live_auth_check() -> int: + creds = SalesforceCredentials( + client_id=os.environ["SALESFORCE_CLIENT_ID"], + username=os.environ["SALESFORCE_USERNAME"], + private_key=os.environ["SALESFORCE_PRIVATE_KEY"], + instance_url=os.environ.get( + "SALESFORCE_INSTANCE_URL", + "https://login.salesforce.com", + ), + ) + adapter = AgentForceAdapter(credentials=creds, capture_config=CaptureConfig.standard()) + try: + adapter.connect() + print("[live-auth] AgentForce adapter authenticated against Salesforce.") + except SalesforceAuthError as exc: + print(f"[live-auth] Salesforce auth failed: {exc}", file=sys.stderr) + return 1 + finally: + adapter.disconnect() + return 0 + + +def main() -> int: + sink = _RecordingSink() + + rc = _flow_session_backfill(sink) + if rc: + return rc + + rc = _flow_live_capture() + if rc: + return rc + + rc = _flow_trust_layer_export() + if rc: + return rc + + rc = _flow_evaluator_offline() + if rc: + return rc + + print(f"[summary] sink recorded {len(sink.events)} events across the backfill flow") + + if _have_salesforce_env(): + rc = _flow_live_auth_check() + if rc: + return rc + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/layerlens/instrument/adapters/frameworks/__init__.py b/src/layerlens/instrument/adapters/frameworks/__init__.py new file mode 100644 index 0000000..3718d80 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/__init__.py @@ -0,0 +1,26 @@ +"""Framework adapters for the LayerLens Instrument layer. + +Each framework adapter wraps an agent / chain framework's lifecycle to +intercept agent runs, model invocations, tool calls, state changes, and +handoffs, emitting events through the LayerLens telemetry pipeline. + +Adapters available (loaded on demand — importing this package does NOT +import any framework SDK): + +* ``agentforce`` — Salesforce Agentforce (auth, client, event mapping) + +Usage:: + + # Lazy import — does not pull in framework dependencies until used. + from layerlens.instrument.adapters.frameworks.agentforce import ( + AgentForceAdapter, + SalesforceCredentials, + ) + +The package is intentionally empty so that ``import +layerlens.instrument.adapters.frameworks`` never fails because of an +absent framework SDK. Each per-framework subpackage handles its own +optional dependency surface. +""" + +from __future__ import annotations diff --git a/src/layerlens/instrument/adapters/frameworks/agentforce/__init__.py b/src/layerlens/instrument/adapters/frameworks/agentforce/__init__.py new file mode 100644 index 0000000..658507e --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/agentforce/__init__.py @@ -0,0 +1,68 @@ +""" +LayerLens Salesforce Agentforce Adapter. + +Full-featured adapter for Salesforce Agentforce agent evaluation: + +- Session trace import via Data Cloud SOQL (batch / incremental) +- Agent API REST client (real-time session capture) +- Platform Events subscriber (gRPC Pub/Sub for near-real-time) +- Einstein Trust Layer policy import +- LLM evaluation scenarios (completions, A/B testing, model comparison) + +DMO Objects (Data Cloud): + +- ``AIAgentSession`` +- ``AIAgentSessionParticipant`` +- ``AIAgentInteraction`` +- ``AIAgentInteractionStep`` +- ``AIAgentInteractionMessage`` + +Install:: + + pip install 'layerlens[agentforce]' +""" + +from __future__ import annotations + +from layerlens.instrument.adapters.frameworks.agentforce.auth import ( + NormalizationError, + SalesforceAuthError, + SalesforceConnection, + SalesforceQueryError, + SalesforceCredentials, +) +from layerlens.instrument.adapters.frameworks.agentforce.client import AgentApiClient +from layerlens.instrument.adapters.frameworks.agentforce.events import PlatformEventSubscriber +from layerlens.instrument.adapters.frameworks.agentforce.mapper import AgentApiMapper +from layerlens.instrument.adapters.frameworks.agentforce.adapter import AgentForceAdapter +from layerlens.instrument.adapters.frameworks.agentforce.importer import ImportResult, AgentForceImporter +from layerlens.instrument.adapters.frameworks.agentforce.llm_eval import EinsteinEvaluator +from layerlens.instrument.adapters.frameworks.agentforce.normalizer import AgentForceNormalizer +from layerlens.instrument.adapters.frameworks.agentforce.trust_layer import TrustLayerImporter + +__all__ = [ + # Core adapter + "AgentForceAdapter", + # Auth + "SalesforceAuthError", + "SalesforceConnection", + "SalesforceCredentials", + "SalesforceQueryError", + "NormalizationError", + # Import + "AgentForceImporter", + "AgentForceNormalizer", + "ImportResult", + # Agent API + "AgentApiClient", + "AgentApiMapper", + # Trust Layer + "TrustLayerImporter", + # Platform Events + "PlatformEventSubscriber", + # Evaluation + "EinsteinEvaluator", +] + +# Registry lazy-loading convention +ADAPTER_CLASS = AgentForceAdapter diff --git a/src/layerlens/instrument/adapters/frameworks/agentforce/adapter.py b/src/layerlens/instrument/adapters/frameworks/agentforce/adapter.py new file mode 100644 index 0000000..6dcebc2 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/agentforce/adapter.py @@ -0,0 +1,190 @@ +""" +AgentForce Adapter + +BaseAdapter-compliant wrapper for AgentForce trace import. +Provides lifecycle management, circuit breaker protection, +CaptureConfig filtering, and health reporting. +""" + +from __future__ import annotations + +import uuid +import logging +from typing import Any + +from layerlens.instrument.adapters._base.adapter import ( + AdapterInfo, + BaseAdapter, + AdapterHealth, + AdapterStatus, + ReplayableTrace, + AdapterCapability, +) +from layerlens.instrument.adapters._base.capture import CaptureConfig +from layerlens.instrument.adapters._base.pydantic_compat import PydanticCompat +from layerlens.instrument.adapters.frameworks.agentforce.auth import ( + SalesforceAuthError, + SalesforceConnection, + SalesforceCredentials, +) +from layerlens.instrument.adapters.frameworks.agentforce.importer import ImportResult, AgentForceImporter +from layerlens.instrument.adapters.frameworks.agentforce.normalizer import AgentForceNormalizer + +logger = logging.getLogger(__name__) + + +class AgentForceAdapter(BaseAdapter): + """ + BaseAdapter wrapper for AgentForce trace import. + + Provides the standard LayerLens adapter lifecycle + (connect / disconnect / health_check) around the AgentForce importer, + routing imported events through the BaseAdapter circuit breaker and + CaptureConfig pipeline. + + Usage:: + + adapter = AgentForceAdapter(stratix=stratix, credentials=credentials) + adapter.connect() + result = adapter.import_sessions(start_date="2026-02-21") + adapter.disconnect() + """ + + FRAMEWORK = "salesforce_agentforce" + VERSION = "0.1.0" + # ``frameworks/agentforce/models.py`` line 17 imports + # ``from pydantic import Field, BaseModel`` only — both names exist + # identically under v1 and v2. No v2-only decorators + # (field_validator/model_validator) appear anywhere in the + # agentforce subpackage. Salesforce Agentforce itself is a remote + # REST API, not a Python library, so there is no framework-side + # Pydantic dependency to constrain. + requires_pydantic = PydanticCompat.V1_OR_V2 + + def __init__( + self, + stratix: Any | None = None, + capture_config: CaptureConfig | None = None, + credentials: SalesforceCredentials | None = None, + connection: SalesforceConnection | None = None, + batch_size: int = 200, + ) -> None: + super().__init__(stratix=stratix, capture_config=capture_config) + self._credentials = credentials + self._connection = connection + self._normalizer = AgentForceNormalizer() + self._importer: AgentForceImporter | None = None + self._batch_size = batch_size + + def connect(self) -> None: + """Authenticate with Salesforce and prepare the importer.""" + if self._connection is None: + if self._credentials is None: + raise SalesforceAuthError("Either 'credentials' or 'connection' must be provided") + self._connection = SalesforceConnection(credentials=self._credentials) + + if self._credentials and self._credentials.is_expired: + self._connection.authenticate() + + self._importer = AgentForceImporter( + connection=self._connection, + normalizer=self._normalizer, + batch_size=self._batch_size, + ) + + self._connected = True + self._status = AdapterStatus.HEALTHY + logger.info("AgentForce adapter connected") + + def disconnect(self) -> None: + """Disconnect and release resources.""" + self._importer = None + self._connected = False + self._status = AdapterStatus.DISCONNECTED + logger.info("AgentForce adapter disconnected") + + def health_check(self) -> AdapterHealth: + """Return adapter health, including Salesforce connection status.""" + message = None + if self._connection and self._credentials and self._credentials.is_expired: + message = "Salesforce token expired, will re-authenticate on next operation" + + return AdapterHealth( + status=self._status, + framework_name=self.FRAMEWORK, + adapter_version=self.VERSION, + message=message, + error_count=self._error_count, + circuit_open=self._circuit_open, + ) + + def get_adapter_info(self) -> AdapterInfo: + return AdapterInfo( + name="AgentForceAdapter", + version=self.VERSION, + framework=self.FRAMEWORK, + capabilities=[ + AdapterCapability.TRACE_MODELS, + AdapterCapability.TRACE_TOOLS, + ], + description="LayerLens adapter for Salesforce AgentForce trace import", + ) + + def serialize_for_replay(self) -> ReplayableTrace: + return ReplayableTrace( + adapter_name="AgentForceAdapter", + framework=self.FRAMEWORK, + trace_id=str(uuid.uuid4()), + events=list(self._trace_events), + state_snapshots=[], + config={ + "capture_config": self._capture_config.model_dump(), + }, + ) + + def import_sessions( + self, + start_date: str | None = None, + end_date: str | None = None, + agent_type: str | None = None, + channel_type: str | None = None, + limit: int | None = None, + last_import_timestamp: str | None = None, + ) -> ImportResult: + """ + Import AgentForce sessions and emit events through the adapter pipeline. + + Events are routed through ``emit_dict_event()`` for circuit breaker + and CaptureConfig protection. + + Returns: + ImportResult summary. + """ + if not self._connected or not self._importer: + raise RuntimeError("Adapter not connected. Call connect() first.") + + events, result = self._importer.import_sessions( + start_date=start_date, + end_date=end_date, + agent_type=agent_type, + channel_type=channel_type, + limit=limit, + last_import_timestamp=last_import_timestamp, + ) + + # Route each event through BaseAdapter pipeline + emitted = 0 + for event in events: + event_type = event.get("event_type", "") + payload = event.get("payload", {}) + # Add identity and timestamp to payload for downstream consumers + if "identity" in event: + payload["_identity"] = event["identity"] + if "timestamp" in event: + payload["_timestamp"] = event["timestamp"] + + self.emit_dict_event(event_type, payload) + emitted += 1 + + result.events_generated = emitted + return result diff --git a/src/layerlens/instrument/adapters/frameworks/agentforce/auth.py b/src/layerlens/instrument/adapters/frameworks/agentforce/auth.py new file mode 100644 index 0000000..cbebae1 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/agentforce/auth.py @@ -0,0 +1,328 @@ +""" +Salesforce OAuth 2.0 JWT Bearer Authentication + +Implements the JWT Bearer flow for server-to-server authentication +with Salesforce Data Cloud. Includes retry with exponential backoff, +timeouts, and credential masking. +""" + +from __future__ import annotations + +import os +import time +import logging +from typing import Any +from dataclasses import dataclass + +logger = logging.getLogger(__name__) + +# Timeout defaults (seconds) +_AUTH_TIMEOUT = 30 +_QUERY_TIMEOUT = 60 + +# Retry defaults +_MAX_RETRIES = 3 +_RETRY_BASE_DELAY = 1.0 # seconds +_RETRY_MAX_DELAY = 30.0 # seconds + +# Salesforce access token lifetime (conservative; actual is ~2 hours) +_TOKEN_LIFETIME_S = 3600 + +# Rate limit warning threshold (percentage of API limit consumed) +_RATE_LIMIT_WARN_THRESHOLD = 0.8 + + +class SalesforceAuthError(Exception): + """Raised when Salesforce authentication fails.""" + + def __init__(self, message: str, status_code: int | None = None, endpoint: str = "") -> None: + self.status_code = status_code + self.endpoint = endpoint + super().__init__(message) + + +class SalesforceQueryError(Exception): + """Raised when a SOQL query fails.""" + + def __init__(self, message: str, status_code: int | None = None, soql: str = "") -> None: + self.status_code = status_code + self.soql = soql + super().__init__(message) + + +class NormalizationError(Exception): + """Raised when normalization of AgentForce records fails.""" + + pass + + +@dataclass +class SalesforceCredentials: + """Salesforce connection credentials.""" + + client_id: str + username: str + private_key: str # PEM-encoded private key or env var name + instance_url: str = "https://login.salesforce.com" + access_token: str | None = None + token_expiry: float = 0.0 + + @property + def is_expired(self) -> bool: + return time.time() >= self.token_expiry + + def resolve_private_key(self) -> str: + """Resolve the private key from env var, file path, or raw PEM string.""" + key = self.private_key + # Check env var reference + if key.startswith("$") or key.startswith("env:"): + env_name = key.lstrip("$").removeprefix("env:") + resolved = os.environ.get(env_name, "") + if not resolved: + raise SalesforceAuthError( + f"Environment variable '{env_name}' not set for private key" + ) + return resolved + # Check file path + if os.path.isfile(key): + with open(key) as f: + return f.read() + # Assume raw PEM + return key + + def __repr__(self) -> str: + return ( + f"SalesforceCredentials(" + f"client_id='{self.client_id[:8]}...', " + f"username='{self.username}', " + f"instance_url='{self.instance_url}', " + f"private_key='***REDACTED***', " + f"access_token={'***REDACTED***' if self.access_token else 'None'}, " + f"is_expired={self.is_expired})" + ) + + +@dataclass +class SalesforceConnection: + """Active Salesforce connection with retry and timeout support.""" + + credentials: SalesforceCredentials + instance_url: str = "" + api_version: str = "v60.0" + auth_timeout: int = _AUTH_TIMEOUT + query_timeout: int = _QUERY_TIMEOUT + max_retries: int = _MAX_RETRIES + + def authenticate(self) -> None: + """Authenticate using JWT Bearer flow with retry.""" + import jwt + import requests # type: ignore[import-untyped,unused-ignore] + + resolved_key = self.credentials.resolve_private_key() + + # Build JWT + now = int(time.time()) + payload = { + "iss": self.credentials.client_id, + "sub": self.credentials.username, + "aud": self.credentials.instance_url, + "exp": now + 300, + } + token = jwt.encode(payload, resolved_key, algorithm="RS256") + + endpoint = f"{self.credentials.instance_url}/services/oauth2/token" + last_error: Exception | None = None + + for attempt in range(self.max_retries): + try: + response = requests.post( + endpoint, + data={ + "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", + "assertion": token, + }, + timeout=self.auth_timeout, + ) + response.raise_for_status() + data = response.json() + + self.credentials.access_token = data["access_token"] + self.instance_url = data["instance_url"] + self.credentials.token_expiry = now + _TOKEN_LIFETIME_S + logger.info("Authenticated with Salesforce: %s", self.instance_url) + return + except requests.exceptions.Timeout as e: + last_error = e + logger.warning( + "Salesforce auth timeout (attempt %d/%d): %s", + attempt + 1, + self.max_retries, + e, + ) + except requests.exceptions.HTTPError as e: + status = e.response.status_code if e.response is not None else None + # Don't retry 4xx (client errors) except 429 (rate limit) + if status is not None and 400 <= status < 500 and status != 429: + raise SalesforceAuthError( + f"Salesforce authentication failed (HTTP {status}). " + f"Check credentials and re-authenticate using `stratix agentforce connect`." + f" " + f"Endpoint: {endpoint}", + status_code=status, + endpoint=endpoint, + ) from e + last_error = e + logger.warning( + "Salesforce auth HTTP error (attempt %d/%d): %s", + attempt + 1, + self.max_retries, + e, + ) + except requests.exceptions.RequestException as e: + last_error = e + logger.warning( + "Salesforce auth request error (attempt %d/%d): %s", + attempt + 1, + self.max_retries, + e, + ) + + # Exponential backoff + if attempt < self.max_retries - 1: + delay = min( + _RETRY_BASE_DELAY * (2**attempt), + _RETRY_MAX_DELAY, + ) + time.sleep(delay) + + raise SalesforceAuthError( + f"Salesforce authentication failed after {self.max_retries} attempts. " + f"Last error: {last_error}. " + f"Re-authenticate using `stratix agentforce connect`. " + f"Endpoint: {endpoint}", + endpoint=endpoint, + ) + + @staticmethod + def _check_rate_limit(response_headers: dict[str, Any]) -> None: + """Parse Sforce-Limit-Info header and warn if approaching limits. + + Salesforce returns ``Sforce-Limit-Info: api-usage=25/15000`` on every + API response. We log a warning when usage exceeds the configured + threshold so operators can react before hitting hard limits. + """ + limit_info = response_headers.get("Sforce-Limit-Info", "") + if not limit_info: + return + try: + # Format: "api-usage=USED/LIMIT" + usage_part = limit_info.split("=", 1)[1] if "=" in limit_info else "" + if "/" in usage_part: + used_str, total_str = usage_part.split("/", 1) + used, total = int(used_str), int(total_str) + if total > 0 and used / total >= _RATE_LIMIT_WARN_THRESHOLD: + logger.warning( + "Salesforce API rate limit warning: %d/%d (%.0f%%) consumed", + used, + total, + (used / total) * 100, + ) + except (ValueError, IndexError): + # Malformed header — ignore silently + pass + + def query(self, soql: str) -> list[dict[str, Any]]: + """Execute a SOQL query with retry, timeout, and pagination.""" + if self.credentials.is_expired: + self.authenticate() + + import requests + + url = f"{self.instance_url}/services/data/{self.api_version}/query" + headers = { + "Authorization": f"Bearer {self.credentials.access_token}", + "Content-Type": "application/json", + } + + records: list[dict[str, Any]] = [] + params: dict[str, str] | None = {"q": soql} + + while True: + last_error: Exception | None = None + success = False + + for attempt in range(self.max_retries): + try: + response = requests.get( + url, + headers=headers, + params=params, + timeout=self.query_timeout, + ) + response.raise_for_status() + + # Check Salesforce API rate limits + self._check_rate_limit(response.headers) + + data = response.json() + + records.extend(data.get("records", [])) + + # Handle pagination + next_url = data.get("nextRecordsUrl") + if next_url: + url = f"{self.instance_url}{next_url}" + params = None # Pagination URL includes query params + success = True + break + + except requests.exceptions.Timeout as e: + last_error = e + logger.warning( + "Salesforce query timeout (attempt %d/%d)", + attempt + 1, + self.max_retries, + ) + except requests.exceptions.HTTPError as e: + status = e.response.status_code if e.response is not None else None + if status is not None and 400 <= status < 500 and status != 429: + raise SalesforceQueryError( + f"SOQL query failed (HTTP {status})", + status_code=status, + soql=soql[:200], + ) from e + last_error = e + logger.warning( + "Salesforce query HTTP error (attempt %d/%d): %s", + attempt + 1, + self.max_retries, + e, + ) + except requests.exceptions.RequestException as e: + last_error = e + logger.warning( + "Salesforce query request error (attempt %d/%d): %s", + attempt + 1, + self.max_retries, + e, + ) + + if attempt < self.max_retries - 1: + delay = min( + _RETRY_BASE_DELAY * (2**attempt), + _RETRY_MAX_DELAY, + ) + time.sleep(delay) + + if not success: + raise SalesforceQueryError( + f"SOQL query failed after {self.max_retries} attempts. " + f"Last error: {last_error}", + soql=soql[:200], + ) + + # If no next page, we're done + if not data.get("nextRecordsUrl"): + break + + return records diff --git a/src/layerlens/instrument/adapters/frameworks/agentforce/client.py b/src/layerlens/instrument/adapters/frameworks/agentforce/client.py new file mode 100644 index 0000000..b2ab2b5 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/agentforce/client.py @@ -0,0 +1,334 @@ +""" +Salesforce Agent API REST Client + +Provides a typed client for the Salesforce Agent API: +- Session creation and lifecycle management +- Synchronous and streaming message exchange +- Response parsing with action and guardrail extraction + +Reference: https://developer.salesforce.com/docs/ai/agentforce/guide/agent-api.html +""" + +from __future__ import annotations + +import json +import time +import logging +from typing import Any +from collections.abc import Generator + +from layerlens.instrument.adapters.frameworks.agentforce.auth import ( + SalesforceConnection, + SalesforceQueryError, +) +from layerlens.instrument.adapters.frameworks.agentforce.models import ( + AgentApiMessage, + AgentApiSession, +) + +logger = logging.getLogger(__name__) + +# Agent API path prefix +_AGENT_API_PREFIX = "/services/data/{version}/agent" + +# Default timeout for Agent API calls (seconds) +_API_TIMEOUT = 30 + +# Maximum response text length to capture (prevent memory bloat) +_MAX_RESPONSE_LENGTH = 50_000 + + +class AgentApiClient: + """ + REST client for the Salesforce Agent API. + + Wraps session creation, message exchange, and response parsing. + All methods use the authenticated ``SalesforceConnection`` for + token management and retry logic. + + Usage: + client = AgentApiClient(connection=connection) + session = client.create_session(agent_name="Service_Agent") + response = client.send_message(session.session_id, "How do I reset my password?") + client.end_session(session.session_id) + """ + + def __init__( + self, + connection: SalesforceConnection, + api_timeout: int = _API_TIMEOUT, + ) -> None: + self._connection = connection + self._api_timeout = api_timeout + self._base_url = "" + + @property + def base_url(self) -> str: + """Build the Agent API base URL from the connection.""" + if not self._base_url: + instance = self._connection.instance_url + version = self._connection.api_version + self._base_url = f"{instance}{_AGENT_API_PREFIX.format(version=version)}" + return self._base_url + + def create_session( + self, + agent_name: str, + context: dict[str, Any] | None = None, + ) -> AgentApiSession: + """ + Create a new Agent API session. + + Args: + agent_name: Name of the Agentforce agent to connect to. + context: Optional context variables for the session. + + Returns: + AgentApiSession with the session ID and initial state. + + Raises: + SalesforceQueryError: If the API call fails. + """ + import requests # type: ignore[import-untyped,unused-ignore] + + if not agent_name or not agent_name.strip(): + raise ValueError("agent_name must be a non-empty string") + + if self._connection.credentials.is_expired: + self._connection.authenticate() + + url = f"{self.base_url}/sessions" + headers = { + "Authorization": f"Bearer {self._connection.credentials.access_token}", + "Content-Type": "application/json", + } + body: dict[str, Any] = {"agentName": agent_name} + if context: + body["context"] = context + + try: + response = requests.post( + url, + headers=headers, + json=body, + timeout=self._api_timeout, + ) + response.raise_for_status() + data = response.json() + + return AgentApiSession( + session_id=data.get("sessionId", ""), + agent_name=agent_name, + status="active", + created_at=data.get("createdAt"), + ) + except requests.exceptions.RequestException as e: + raise SalesforceQueryError( + f"Failed to create Agent API session: {e}", + status_code=getattr(getattr(e, "response", None), "status_code", None), + ) from e + + def send_message( + self, + session_id: str, + message: str, + stream: bool = False, + ) -> AgentApiMessage | Generator[str, None, None]: + """ + Send a message to an active Agent API session. + + Args: + session_id: The session ID from ``create_session()``. + message: User message text to send. + stream: If True, return a generator of streaming response chunks. + + Returns: + AgentApiMessage with the agent response, or a generator if streaming. + + Raises: + SalesforceQueryError: If the API call fails. + """ + if not session_id or not session_id.strip(): + raise ValueError("session_id must be a non-empty string") + if not message or not message.strip(): + raise ValueError("message must be a non-empty string") + + import requests + + if self._connection.credentials.is_expired: + self._connection.authenticate() + + url = f"{self.base_url}/sessions/{session_id}/messages" + headers = { + "Authorization": f"Bearer {self._connection.credentials.access_token}", + "Content-Type": "application/json", + } + if stream: + headers["Accept"] = "text/event-stream" + + body = {"message": {"text": message}} + + try: + response = requests.post( + url, + headers=headers, + json=body, + timeout=self._api_timeout, + stream=stream, + ) + response.raise_for_status() + + if stream: + return self._stream_response(response) + + return self._parse_message_response(response.json()) + + except requests.exceptions.RequestException as e: + raise SalesforceQueryError( + f"Failed to send Agent API message: {e}", + status_code=getattr(getattr(e, "response", None), "status_code", None), + ) from e + + def end_session(self, session_id: str) -> None: + """ + End an active Agent API session. + + Args: + session_id: The session ID to end. + + Raises: + SalesforceQueryError: If the API call fails. + """ + if not session_id or not session_id.strip(): + raise ValueError("session_id must be a non-empty string") + + import requests + + if self._connection.credentials.is_expired: + self._connection.authenticate() + + url = f"{self.base_url}/sessions/{session_id}" + headers = { + "Authorization": f"Bearer {self._connection.credentials.access_token}", + "Content-Type": "application/json", + } + + try: + response = requests.delete( + url, + headers=headers, + timeout=self._api_timeout, + ) + response.raise_for_status() + except requests.exceptions.RequestException as e: + raise SalesforceQueryError( + f"Failed to end Agent API session: {e}", + status_code=getattr(getattr(e, "response", None), "status_code", None), + ) from e + + def capture_session( + self, + agent_name: str, + messages: list[str], + context: dict[str, Any] | None = None, + ) -> AgentApiSession: + """ + Convenience method: create session, send all messages, end session. + + Returns an ``AgentApiSession`` with all messages and responses. + + Args: + agent_name: Agentforce agent name. + messages: List of user messages to send sequentially. + context: Optional session context. + + Returns: + Complete AgentApiSession with all exchanged messages. + """ + session = self.create_session(agent_name, context) + all_messages: list[AgentApiMessage] = [] + + for msg_text in messages: + # Record user message + all_messages.append(AgentApiMessage(role="user", content=msg_text)) + + # Send and capture response + response = self.send_message(session.session_id, msg_text) + if isinstance(response, AgentApiMessage): + all_messages.append(response) + + self.end_session(session.session_id) + + session.messages = all_messages + session.status = "ended" + session.ended_at = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + return session + + # --- Internal helpers --- + + @staticmethod + def _parse_message_response(data: dict[str, Any]) -> AgentApiMessage: + """Parse a synchronous Agent API message response.""" + messages = data.get("messages", []) + if not messages: + return AgentApiMessage( + role="agent", + content=data.get("text", ""), + timestamp=data.get("timestamp"), + ) + + # Take the last agent message + last = messages[-1] + actions = [] + guardrails = [] + + # Extract actions if present + for action in data.get("actions", []): + actions.append( + { + "name": action.get("name", "unknown"), + "parameters": action.get("parameters", {}), + "result": action.get("result"), + } + ) + + # Extract guardrail results if present + for gr in data.get("guardrailResults", []): + guardrails.append( + { + "name": gr.get("name", "unknown"), + "triggered": gr.get("triggered", False), + "message": gr.get("message"), + } + ) + + return AgentApiMessage( + id=last.get("id"), + role="agent", + content=str(last.get("text", ""))[:_MAX_RESPONSE_LENGTH], + timestamp=last.get("timestamp"), + topic=data.get("topic"), + actions=actions, + guardrail_results=guardrails, + ) + + @staticmethod + def _stream_response(response: Any) -> Generator[str, None, None]: + """Parse a streaming Agent API response (SSE format).""" + try: + for line in response.iter_lines(decode_unicode=True): + if not line: + continue + if line.startswith("data: "): + data_str = line[6:] + if data_str.strip() == "[DONE]": + return + try: + chunk = json.loads(data_str) + text = chunk.get("text", "") + if text: + yield text + except json.JSONDecodeError: + logger.debug("Failed to parse SSE chunk: %s", data_str[:100]) + finally: + response.close() diff --git a/src/layerlens/instrument/adapters/frameworks/agentforce/events.py b/src/layerlens/instrument/adapters/frameworks/agentforce/events.py new file mode 100644 index 0000000..94bccbe --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/agentforce/events.py @@ -0,0 +1,330 @@ +""" +Salesforce Platform Events Subscriber + +Subscribes to Salesforce Platform Events via the gRPC Pub/Sub API +for near-real-time Agentforce session capture. + +Supports: +- gRPC Pub/Sub API subscription (AgentSession__e) +- Automatic reconnection with exponential backoff +- Event replay from a specific replay ID +- Graceful shutdown with pending event flush + +Reference: https://developer.salesforce.com/docs/platform/pub-sub-api/overview +""" + +from __future__ import annotations + +import time +import logging +import threading +from typing import Any +from collections.abc import Callable + +from layerlens.instrument.adapters.frameworks.agentforce.auth import SalesforceConnection +from layerlens.instrument.adapters.frameworks.agentforce.models import AgentSessionEvent + +logger = logging.getLogger(__name__) + +# Default Platform Event channel +_DEFAULT_CHANNEL = "/event/AgentSession__e" + +# Reconnection backoff constants +_RECONNECT_BASE_DELAY = 1.0 +_RECONNECT_MAX_DELAY = 60.0 +_MAX_RECONNECT_ATTEMPTS = 10 + +# Subscriber batch size +_BATCH_SIZE = 100 + + +class PlatformEventSubscriber: + """ + Subscribe to Salesforce Platform Events for real-time Agentforce capture. + + Uses the Salesforce gRPC Pub/Sub API to receive events as they occur, + with automatic reconnection and replay support. + + Usage: + subscriber = PlatformEventSubscriber( + connection=connection, + on_event=handle_event, + ) + subscriber.start() + # ... later ... + subscriber.stop() + """ + + def __init__( + self, + connection: SalesforceConnection, + on_event: Callable[[AgentSessionEvent], None] | None = None, + channel: str = _DEFAULT_CHANNEL, + replay_id: str | None = None, + ) -> None: + """ + Initialize the Platform Events subscriber. + + Args: + connection: Authenticated Salesforce connection. + on_event: Callback invoked for each received event. + channel: Platform Event channel to subscribe to. + replay_id: Optional replay ID to resume from. + """ + self._connection = connection + self._on_event = on_event + self._channel = channel + self._replay_id = replay_id + self._running = False + self._thread: threading.Thread | None = None + self._reconnect_attempts = 0 + self._events_received = 0 + self._last_replay_id: str | None = replay_id + + @property + def is_running(self) -> bool: + """Whether the subscriber is actively listening.""" + return self._running + + @property + def events_received(self) -> int: + """Total events received since start.""" + return self._events_received + + @property + def last_replay_id(self) -> str | None: + """Last processed replay ID (for resume on restart).""" + return self._last_replay_id + + def start(self) -> None: + """ + Start the Platform Events subscriber in a background thread. + + The subscriber will attempt to connect and begin receiving events. + On connection failure, it retries with exponential backoff. + """ + if self._running: + logger.warning("Platform Events subscriber already running") + return + + self._running = True + self._thread = threading.Thread( + target=self._subscribe_loop, + name="stratix-sf-events", + daemon=True, + ) + self._thread.start() + logger.info( + "Platform Events subscriber started on channel: %s", + self._channel, + ) + + def stop(self) -> None: + """ + Stop the Platform Events subscriber. + + Signals the background thread to stop and waits for graceful shutdown. + """ + self._running = False + if self._thread and self._thread.is_alive(): + self._thread.join(timeout=5.0) + self._thread = None + logger.info( + "Platform Events subscriber stopped. Events received: %d", + self._events_received, + ) + + def _subscribe_loop(self) -> None: + """Main subscription loop with reconnection logic.""" + while self._running: + try: + self._subscribe() + except Exception as e: + # ``self._running`` can flip concurrently from ``stop()`` — + # mypy can't see the cross-thread mutation, so it thinks the + # break is unreachable inside ``while self._running:``. It's + # not. + if not self._running: + break # type: ignore[unreachable] + self._reconnect_attempts += 1 + if self._reconnect_attempts > _MAX_RECONNECT_ATTEMPTS: + logger.error( + "Platform Events subscriber exceeded max reconnect attempts (%d). Stopping.", # noqa: E501 + _MAX_RECONNECT_ATTEMPTS, + ) + self._running = False + break + + delay = min( + _RECONNECT_BASE_DELAY * (2 ** (self._reconnect_attempts - 1)), + _RECONNECT_MAX_DELAY, + ) + logger.warning( + "Platform Events connection lost (attempt %d/%d): %s. Retrying in %.1fs.", + self._reconnect_attempts, + _MAX_RECONNECT_ATTEMPTS, + str(e)[:200], + delay, + ) + time.sleep(delay) + + def _subscribe(self) -> None: + """ + Subscribe to the Platform Event channel. + + This method uses HTTP long-polling as a fallback when the gRPC + Pub/Sub API client is not available. For production use with + high-volume events, the gRPC client is recommended. + """ + # Attempt gRPC Pub/Sub API first + try: + self._subscribe_grpc() + return + except ImportError: + logger.info("gRPC Pub/Sub client not available. Falling back to CometD polling.") + except Exception as e: + logger.warning("gRPC subscription failed: %s. Falling back.", e) + + # Fallback: CometD / HTTP long-polling + self._subscribe_cometd() + + def _subscribe_grpc(self) -> None: + """ + Subscribe using the Salesforce gRPC Pub/Sub API. + + Requires the ``grpcio`` and ``avro`` packages. + """ + # Import gRPC dependencies (optional) + import grpc # type: ignore[import-not-found,import-untyped,unused-ignore] # noqa: F401 + + if self._connection.credentials.is_expired: + self._connection.authenticate() + + # gRPC Pub/Sub API endpoint + pubsub_endpoint = self._connection.instance_url.replace("https://", "") + ":443" + + logger.info("Connecting to gRPC Pub/Sub API: %s", pubsub_endpoint) + + # NOTE: Full gRPC stub implementation requires the Salesforce + # pub-sub proto definitions. This is a structural placeholder + # that demonstrates the connection pattern. Production code + # should use the salesforce-pubsub package. + raise NotImplementedError( + "Full gRPC Pub/Sub implementation requires salesforce-pubsub package. " + "Install: pip install salesforce-pubsub" + ) + + def _subscribe_cometd(self) -> None: + """ + Subscribe using CometD long-polling (fallback). + + Uses the Streaming API (/cometd) endpoint for Platform Events. + Lower throughput than gRPC but works without additional dependencies. + """ + import requests # type: ignore[import-untyped,unused-ignore] + + if self._connection.credentials.is_expired: + self._connection.authenticate() + + base_url = self._connection.instance_url + api_version = self._connection.api_version + cometd_url = f"{base_url}/cometd/{api_version.lstrip('v')}" + + headers = { + "Authorization": f"Bearer {self._connection.credentials.access_token}", + "Content-Type": "application/json", + } + + # CometD handshake + handshake_payload = [ + { + "channel": "/meta/handshake", + "version": "1.0", + "supportedConnectionTypes": ["long-polling"], + "minimumVersion": "1.0", + } + ] + + try: + resp = requests.post( + cometd_url, + headers=headers, + json=handshake_payload, + timeout=30, + ) + resp.raise_for_status() + handshake_data = resp.json() + client_id = handshake_data[0].get("clientId") + if not client_id: + raise RuntimeError("CometD handshake failed: no clientId") + + # Subscribe to channel + subscribe_payload = [ + { + "channel": "/meta/subscribe", + "clientId": client_id, + "subscription": self._channel, + } + ] + if self._replay_id: + subscribe_payload[0]["ext"] = { + "replay": {self._channel: self._replay_id}, + } + + resp = requests.post( + cometd_url, + headers=headers, + json=subscribe_payload, + timeout=30, + ) + resp.raise_for_status() + + # Reset reconnect attempts on successful connection + self._reconnect_attempts = 0 + + # Long-polling loop + while self._running: + connect_payload = [ + { + "channel": "/meta/connect", + "clientId": client_id, + "connectionType": "long-polling", + } + ] + resp = requests.post( + cometd_url, + headers=headers, + json=connect_payload, + timeout=120, + ) + resp.raise_for_status() + + for msg in resp.json(): + channel = msg.get("channel", "") + if channel == self._channel: + self._handle_event(msg.get("data", {})) + + except requests.exceptions.RequestException as e: + raise RuntimeError(f"CometD connection error: {e}") from e + + def _handle_event(self, data: dict[str, Any]) -> None: + """Process a received Platform Event.""" + try: + event = AgentSessionEvent( + session_id=data.get("SessionId__c", ""), + agent_name=data.get("AgentName__c"), + topic_name=data.get("TopicName__c"), + actions_taken=data.get("ActionsTaken__c"), + response_text=data.get("ResponseText__c"), + trust_layer_flags=data.get("TrustLayerFlags__c"), + replay_id=str(data.get("event", {}).get("replayId", "")), + ) + + self._events_received += 1 + self._last_replay_id = event.replay_id + + if self._on_event: + self._on_event(event) + + except Exception as e: + logger.warning("Failed to process Platform Event: %s", e) diff --git a/src/layerlens/instrument/adapters/frameworks/agentforce/importer.py b/src/layerlens/instrument/adapters/frameworks/agentforce/importer.py new file mode 100644 index 0000000..d5fe941 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/agentforce/importer.py @@ -0,0 +1,268 @@ +""" +AgentForce Trace Importer. + +Imports AgentForce Session Tracing data from Salesforce Data Cloud +and normalizes it to LayerLens canonical events. + +Supports: +- Batch import (date range filter) +- Incremental import (timestamp-based) +- Session, participant, interaction, step, and message extraction +""" + +from __future__ import annotations + +import re +import logging +from typing import Any +from dataclasses import field, dataclass + +from layerlens.instrument.adapters.frameworks.agentforce.auth import SalesforceConnection +from layerlens.instrument.adapters.frameworks.agentforce.normalizer import AgentForceNormalizer + +# Regex for validating ISO 8601 date strings (YYYY-MM-DD) +_DATE_RE = re.compile(r"^\d{4}-\d{2}-\d{2}$") +# Regex for validating ISO 8601 timestamp strings +_TIMESTAMP_RE = re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}") +# Regex for Salesforce record IDs (exactly 15 or 18 char alphanumeric) +_SF_ID_RE = re.compile(r"^[a-zA-Z0-9]{15}(?:[a-zA-Z0-9]{3})?$") + +logger = logging.getLogger(__name__) + + +@dataclass +class ImportResult: + """Result of an AgentForce import operation.""" + + sessions_imported: int = 0 + participants_imported: int = 0 + interactions_imported: int = 0 + steps_imported: int = 0 + messages_imported: int = 0 + events_generated: int = 0 + errors: list[str] = field(default_factory=list) + + @property + def total_records(self) -> int: + return ( + self.sessions_imported + + self.participants_imported + + self.interactions_imported + + self.steps_imported + + self.messages_imported + ) + + +class AgentForceImporter: + """ + Import AgentForce traces from Salesforce Data Cloud. + + Usage: + connection = SalesforceConnection(credentials) + connection.authenticate() + importer = AgentForceImporter(connection) + events, result = importer.import_sessions( + start_date="2026-02-21", + end_date="2026-02-28", + ) + """ + + def __init__( + self, + connection: SalesforceConnection, + normalizer: AgentForceNormalizer | None = None, + batch_size: int = 200, + ) -> None: + self._connection = connection + self._normalizer = normalizer or AgentForceNormalizer() + self._batch_size = batch_size + + def import_sessions( + self, + start_date: str | None = None, + end_date: str | None = None, + agent_type: str | None = None, # noqa: ARG002 — reserved for SOQL filter wiring + channel_type: str | None = None, # noqa: ARG002 — reserved for SOQL filter wiring + limit: int | None = None, + last_import_timestamp: str | None = None, + ) -> tuple[list[dict[str, Any]], ImportResult]: + """ + Import AgentForce sessions and all related records. + + Args: + start_date: Import sessions starting from this date (ISO 8601) + end_date: Import sessions up to this date (ISO 8601) + agent_type: Filter by agent type (Employee, EinsteinSDR, EinsteinServiceAgent) + channel_type: Filter by channel type + limit: Maximum sessions to import + last_import_timestamp: For incremental sync, only import after this timestamp + + Returns: + Tuple of (list of LayerLens events, ImportResult summary) + """ + result = ImportResult() + all_events: list[dict[str, Any]] = [] + + # Build session query with validated parameters + conditions = [] + if start_date: + self._validate_date(start_date) + conditions.append(f"StartTimestamp >= {start_date}T00:00:00Z") + if end_date: + self._validate_date(end_date) + conditions.append(f"StartTimestamp <= {end_date}T23:59:59Z") + if last_import_timestamp: + self._validate_timestamp(last_import_timestamp) + conditions.append(f"StartTimestamp > {last_import_timestamp}") + + where = f" WHERE {' AND '.join(conditions)}" if conditions else "" + limit_clause = f" LIMIT {limit}" if limit else f" LIMIT {self._batch_size}" + + soql = ( + "SELECT Id, StartTimestamp, EndTimestamp, AiAgentChannelTypeId, " + "AiAgentSessionEndType, VoiceCallId, MessagingSessionId, PreviousSessionId " + f"FROM AIAgentSession{where} ORDER BY StartTimestamp ASC{limit_clause}" + ) + + try: + sessions = self._connection.query(soql) + except Exception as e: + result.errors.append(f"Session query failed: {e}") + return all_events, result + + if not sessions: + return all_events, result + + session_ids = [s["Id"] for s in sessions] + result.sessions_imported = len(sessions) + + # Normalize sessions + for session in sessions: + events = self._normalizer.normalize_session(session) + all_events.extend(events) + + # Import participants + participants = self._query_related( + "AIAgentSessionParticipant", + "AiAgentSessionId", + session_ids, + "Id, AiAgentSessionId, AiAgentTypeId, AiAgentApiName, " + "AiAgentVersionApiName, ParticipantId, AiAgentSessionParticipantRoleId", + result=result, + ) + result.participants_imported = len(participants) + for p in participants: + all_events.append(self._normalizer.normalize_participant(p)) + + # Import interactions + interactions = self._query_related( + "AIAgentInteraction", + "AiAgentSessionId", + session_ids, + "Id, AiAgentSessionId, AiAgentInteractionTypeId, " + "TelemetryTraceId, TelemetryTraceSpanId, TopicApiName, " + "AttributeText, PrevInteractionId", + order_by="Id ASC", + result=result, + ) + result.interactions_imported = len(interactions) + for i in interactions: + all_events.append(self._normalizer.normalize_interaction(i)) + + if interactions: + interaction_ids = [i["Id"] for i in interactions] + + # Import steps + steps = self._query_related( + "AIAgentInteractionStep", + "AiAgentInteractionId", + interaction_ids, + "Id, AiAgentInteractionId, AiAgentInteractionStepTypeId, " + "InputValueText, OutputValueText, ErrorMessageText, " + "GenerationId, GenAiGatewayRequestId, GenAiGatewayResponseId, " + "Name, TelemetryTraceSpanId", + order_by="Id ASC", + result=result, + ) + result.steps_imported = len(steps) + for s in steps: + all_events.append(self._normalizer.normalize_step(s)) + + # Import messages + messages = self._query_related( + "AIAgentInteractionMessage", + "AiAgentInteractionId", + interaction_ids, + "Id, AiAgentInteractionId, AiAgentInteractionMessageTypeId, " + "ContentText, AiAgentInteractionMsgContentTypeId, " + "MessageSentTimestamp, ParentMessageId", + order_by="MessageSentTimestamp ASC", + result=result, + ) + result.messages_imported = len(messages) + for m in messages: + all_events.append(self._normalizer.normalize_message(m)) + + result.events_generated = len(all_events) + logger.info( + "AgentForce import complete: %d sessions, %d events generated", + result.sessions_imported, + result.events_generated, + ) + return all_events, result + + def _query_related( + self, + object_name: str, + foreign_key: str, + parent_ids: list[str], + fields: str, + order_by: str | None = None, + result: ImportResult | None = None, + ) -> list[dict[str, Any]]: + """Query related records in batches to respect SOQL limits.""" + all_records: list[dict[str, Any]] = [] + + # Batch parent IDs to avoid SOQL IN clause limits + for i in range(0, len(parent_ids), self._batch_size): + batch = parent_ids[i : i + self._batch_size] + # Validate IDs to prevent SOQL injection + safe_ids = [self._validate_sf_id(pid) for pid in batch] + ids_str = "', '".join(safe_ids) + soql = f"SELECT {fields} FROM {object_name} WHERE {foreign_key} IN ('{ids_str}')" + if order_by: + soql += f" ORDER BY {order_by}" + + try: + records = self._connection.query(soql) + all_records.extend(records) + except Exception as e: + error_msg = f"Failed to query {object_name}: {e}" + logger.error(error_msg) + if result is not None: + result.errors.append(error_msg) + + return all_records + + @staticmethod + def _validate_date(value: str) -> None: + """Validate an ISO 8601 date string (YYYY-MM-DD).""" + if not _DATE_RE.match(value): + raise ValueError(f"Invalid date format: '{value}'. Expected YYYY-MM-DD.") + + @staticmethod + def _validate_timestamp(value: str) -> None: + """Validate an ISO 8601 timestamp string.""" + if not _TIMESTAMP_RE.match(value): + raise ValueError(f"Invalid timestamp format: '{value}'. Expected ISO 8601.") + + @staticmethod + def _validate_sf_id(value: str) -> str: + """Validate a Salesforce ID format (15 or 18 char alphanumeric). + + Raises: + ValueError: If the value does not match the Salesforce ID format. + """ + if not _SF_ID_RE.match(value): + raise ValueError(f"Invalid Salesforce ID format: {value!r}") + return value diff --git a/src/layerlens/instrument/adapters/frameworks/agentforce/llm_eval.py b/src/layerlens/instrument/adapters/frameworks/agentforce/llm_eval.py new file mode 100644 index 0000000..b838756 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/agentforce/llm_eval.py @@ -0,0 +1,440 @@ +""" +Agentforce LLM Evaluation Scenarios. + +Provides evaluation capabilities beyond agent tracing: + +- Einstein completions evaluation (grading LLM responses) +- Prompt template A/B testing for Agentforce topics +- Model comparison (GPT vs Claude vs Gemini via Atlas) +- CRM outcome ground truth correlation + +These scenarios use imported Agentforce session data as input +and run LayerLens graders to produce evaluation scores. +""" + +from __future__ import annotations + +import os +import logging +from typing import Any +from dataclasses import field, dataclass + +from layerlens.instrument.adapters.frameworks.agentforce.models import EvaluationResult + +logger = logging.getLogger(__name__) + + +def _get_stratix_client() -> Any | None: + """Lazily create a Stratix API client from environment variables.""" + api_url = os.environ.get("LAYERLENS_API_URL") + api_key = os.environ.get("LAYERLENS_API_KEY") + if not api_url or not api_key: + return None + try: + from layerlens import Stratix as StratixClient + + return StratixClient(base_url=api_url, api_key=api_key) + except Exception as exc: + logger.debug("Could not create Stratix client: %s", exc) + return None + + +# Default graders for Agentforce evaluation +_DEFAULT_GRADERS = ["relevance", "faithfulness", "coherence", "safety"] + +# Composite score weights (aligned with Section 4.3 of integration doc) +_DEFAULT_WEIGHTS = { + "topic_accuracy": 0.20, + "action_correctness": 0.25, + "response_quality": 0.20, + "safety_compliance": 0.20, + "crm_outcome": 0.15, +} + + +@dataclass +class ABTestResult: + """Result of an A/B test between two prompt variants.""" + + variant_a_scores: dict[str, float] = field(default_factory=dict) + variant_b_scores: dict[str, float] = field(default_factory=dict) + winner: str = "" + significance: float = 0.0 + sample_size: int = 0 + + +@dataclass +class ModelComparisonResult: + """Result of comparing multiple models on the same test cases.""" + + model_scores: dict[str, dict[str, float]] = field(default_factory=dict) + best_model: str = "" + test_cases_evaluated: int = 0 + + +class EinsteinEvaluator: + """ + Evaluate Agentforce LLM responses using LayerLens graders. + + Operates on imported session data (from + :py:meth:`AgentForceAdapter.import_sessions`) and applies graders to + LLM execution steps, action sequences, and agent responses. + + Usage:: + + evaluator = EinsteinEvaluator(adapter=adapter) + results = evaluator.evaluate_completions( + session_ids=["0Xx..."], + graders=["relevance", "faithfulness"], + ) + """ + + def __init__( + self, + adapter: Any = None, + connection: Any = None, + ) -> None: + """ + Initialize the evaluator. + + Args: + adapter: AgentForceAdapter instance (for session import). + connection: SalesforceConnection (for ground truth queries). + """ + self._adapter = adapter + self._connection = connection + self._client = _get_stratix_client() + + def evaluate_completions( + self, + session_ids: list[str], + graders: list[str] | None = None, + ) -> list[EvaluationResult]: + """ + Evaluate LLM completions from imported Agentforce sessions. + + Extracts LLM execution steps from the session data and runs + the specified graders on each completion. + + Args: + session_ids: Salesforce session IDs to evaluate. + graders: List of grader names (defaults to relevance + faithfulness). + + Returns: + List of EvaluationResult, one per session. + """ + if not session_ids: + return [] + + grader_names = graders or _DEFAULT_GRADERS + results: list[EvaluationResult] = [] + + for session_id in session_ids: + try: + scores = self._evaluate_session(session_id, grader_names) + composite = self._compute_composite_score(scores) + results.append( + EvaluationResult( + session_id=session_id, + scores=scores, + composite_score=composite, + ) + ) + except Exception as e: + logger.warning("Failed to evaluate session %s: %s", session_id, e) + results.append( + EvaluationResult( + session_id=session_id, + errors=[str(e)], + ) + ) + + return results + + def evaluate_topic( + self, + topic: str, # noqa: ARG002 — reserved for topic-filtered import wiring + graders: list[str] | None = None, + limit: int = 100, + ) -> list[EvaluationResult]: + """ + Convenience method: import sessions for a topic and evaluate. + + Combines session import + grading in one call. + + Args: + topic: Agentforce topic name to evaluate. + graders: Grader names to run. + limit: Maximum sessions to evaluate. + + Returns: + List of EvaluationResult for the topic. + """ + if not self._adapter: + raise RuntimeError("Adapter required for evaluate_topic()") + + # Import sessions that match the topic + events, result = self._adapter._importer.import_sessions(limit=limit) + + # Extract session IDs from imported events + session_ids: list[str] = [] + for event in events: + payload = event.get("payload", {}) + sid = payload.get("session_id") + if sid and sid not in session_ids: + session_ids.append(sid) + + return self.evaluate_completions(session_ids[:limit], graders) + + def ab_test_prompts( + self, + topic: str, # noqa: ARG002 — annotates which Agentforce topic is under test + variant_a: str, + variant_b: str, + test_cases: list[dict[str, str]] | None = None, + graders: list[str] | None = None, + ) -> ABTestResult: + """ + A/B test two prompt variants for an Agentforce topic. + + Args: + topic: The Agentforce topic being tested. + variant_a: First prompt instruction text. + variant_b: Second prompt instruction text. + test_cases: List of test inputs (dicts with "input" key). + graders: Grader names to use for scoring. + + Returns: + ABTestResult with per-variant scores and winner. + """ + grader_names = graders or ["relevance", "trajectory_accuracy"] + cases = test_cases or [] + sample_size = len(cases) + + # Score each variant + a_scores = self._score_variant(variant_a, cases, grader_names) + b_scores = self._score_variant(variant_b, cases, grader_names) + + # Determine winner by average score across graders + a_avg = sum(a_scores.values()) / max(len(a_scores), 1) + b_avg = sum(b_scores.values()) / max(len(b_scores), 1) + winner = "variant_a" if a_avg >= b_avg else "variant_b" + + return ABTestResult( + variant_a_scores=a_scores, + variant_b_scores=b_scores, + winner=winner, + significance=abs(a_avg - b_avg), + sample_size=sample_size, + ) + + def compare_models( + self, + topic: str, + models: list[str], + test_cases: list[dict[str, str]] | None = None, + graders: list[str] | None = None, + ) -> ModelComparisonResult: + """ + Compare multiple LLM models for an Agentforce topic. + + Args: + topic: The Agentforce topic to evaluate. + models: List of model names (e.g., ["gpt-5.3", "claude-opus-4-6"]). + test_cases: Test inputs for evaluation. + graders: Grader names to use. + + Returns: + ModelComparisonResult with per-model scores and best model. + """ + grader_names = graders or _DEFAULT_GRADERS + cases = test_cases or [] + model_scores: dict[str, dict[str, float]] = {} + + for model in models: + scores = self._score_model(model, topic, cases, grader_names) + model_scores[model] = scores + + # Determine best model by highest average score + best_model = "" + best_avg = -1.0 + for model, scores in model_scores.items(): + avg = sum(scores.values()) / max(len(scores), 1) + if avg > best_avg: + best_avg = avg + best_model = model + + return ModelComparisonResult( + model_scores=model_scores, + best_model=best_model, + test_cases_evaluated=len(cases), + ) + + def correlate_outcomes( + self, + session_ids: list[str], + outcome_query: str, + evaluation_dimensions: list[str] | None = None, + ) -> list[EvaluationResult]: + """ + Correlate evaluation scores with CRM business outcomes. + + Args: + session_ids: Session IDs to evaluate and correlate. + outcome_query: SOQL query to fetch business outcomes. + evaluation_dimensions: Grader dimensions to include. + + Returns: + EvaluationResult list with ground_truth populated. + """ + dimensions = evaluation_dimensions or _DEFAULT_GRADERS + + # Evaluate sessions + results = self.evaluate_completions(session_ids, dimensions) + + # Fetch ground truth from Salesforce + if self._connection: + try: + outcomes = self._connection.query(outcome_query) + outcome_map = {r.get("CaseId", r.get("Id", "")): r for r in outcomes} + for result in results: + gt = outcome_map.get(result.session_id, {}) + if gt: + result.ground_truth = gt + except Exception as e: + logger.warning("Failed to fetch ground truth: %s", e) + + return results + + # --- Internal helpers --- + + def _evaluate_session( + self, + session_id: str, + grader_names: list[str], + ) -> dict[str, float]: + """Run graders on a single session. Returns grader->score mapping.""" + if self._client: + try: + result = self._client.evaluations.create( + trace_id=session_id, + grader_ids=grader_names, + ) + # result may be a dict or model; normalise to dict + result_dict = result if isinstance(result, dict) else result.model_dump() + return {g: result_dict.get("scores", {}).get(g, 0.0) for g in grader_names} + except Exception as exc: + logger.warning( + "Grader invocation failed for session %s: %s", + session_id, + exc, + ) + + logger.warning( + "No LayerLens client configured — returning 0.0 for session %s. " + "Set LAYERLENS_API_URL and LAYERLENS_API_KEY environment variables.", + session_id, + ) + return dict.fromkeys(grader_names, 0.0) + + def _compute_composite_score( + self, + scores: dict[str, float], + ) -> float | None: + """Compute a weighted composite score from individual grader scores.""" + if not scores: + return None + + total_weight = 0.0 + weighted_sum = 0.0 + + # Map grader names to weight categories + grader_to_category = { + "topic_accuracy": "topic_accuracy", + "tool_correctness": "action_correctness", + "tool_adherence": "action_correctness", + "relevance": "response_quality", + "faithfulness": "response_quality", + "coherence": "response_quality", + "safety": "safety_compliance", + "hallucination": "safety_compliance", + "pii_detection": "safety_compliance", + } + + for grader, score in scores.items(): + category = grader_to_category.get(grader, "response_quality") + weight = _DEFAULT_WEIGHTS.get(category, 0.1) + weighted_sum += score * weight + total_weight += weight + + return weighted_sum / total_weight if total_weight > 0 else None + + def _score_variant( + self, + prompt: str, + test_cases: list[dict[str, str]], + grader_names: list[str], + ) -> dict[str, float]: + """Score a prompt variant across test cases.""" + if not test_cases: + logger.warning("No test cases provided for variant scoring — returning 0.0.") + return dict.fromkeys(grader_names, 0.0) + + if self._client: + try: + aggregated: dict[str, float] = dict.fromkeys(grader_names, 0.0) + for case in test_cases: + result = self._client.evaluations.create( + trace_id=case.get("trace_id", ""), + grader_ids=grader_names, + config={"prompt_override": prompt}, + ) + result_dict = result if isinstance(result, dict) else result.model_dump() + for g in grader_names: + aggregated[g] += result_dict.get("scores", {}).get(g, 0.0) + n = len(test_cases) + return {g: aggregated[g] / n for g in grader_names} + except Exception as exc: + logger.warning("Variant scoring failed: %s", exc) + + logger.warning( + "No LayerLens client configured — returning 0.0 for variant scoring. " + "Set LAYERLENS_API_URL and LAYERLENS_API_KEY environment variables." + ) + return dict.fromkeys(grader_names, 0.0) + + def _score_model( + self, + model: str, + topic: str, # noqa: ARG002 — annotates which Agentforce topic is being scored + test_cases: list[dict[str, str]], + grader_names: list[str], + ) -> dict[str, float]: + """Score a model on test cases.""" + if not test_cases: + logger.warning("No test cases provided for model %s — returning 0.0.", model) + return dict.fromkeys(grader_names, 0.0) + + if self._client: + try: + aggregated: dict[str, float] = dict.fromkeys(grader_names, 0.0) + for case in test_cases: + result = self._client.evaluations.create( + trace_id=case.get("trace_id", ""), + grader_ids=grader_names, + config={"model_override": model}, + ) + result_dict = result if isinstance(result, dict) else result.model_dump() + for g in grader_names: + aggregated[g] += result_dict.get("scores", {}).get(g, 0.0) + n = len(test_cases) + return {g: aggregated[g] / n for g in grader_names} + except Exception as exc: + logger.warning("Model %s scoring failed: %s", model, exc) + + logger.warning( + "No LayerLens client configured — returning 0.0 for model %s. " + "Set LAYERLENS_API_URL and LAYERLENS_API_KEY environment variables.", + model, + ) + return dict.fromkeys(grader_names, 0.0) diff --git a/src/layerlens/instrument/adapters/frameworks/agentforce/mapper.py b/src/layerlens/instrument/adapters/frameworks/agentforce/mapper.py new file mode 100644 index 0000000..fa55577 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/agentforce/mapper.py @@ -0,0 +1,251 @@ +""" +Agent API Session to Stratix Trace Event Mapper + +Maps Agent API session data (from ``client.py``) to Stratix canonical +event types. This is distinct from ``normalizer.py`` which handles +Data Cloud DMO records from SOQL queries. + +Mapping: +- Session creation -> agent.state.change (trace_start) +- User message -> agent.input (L1) +- Agent response -> agent.output (L1) +- Topic classification-> environment.config (L4a) +- Action invocation -> tool.call (L5a) +- Guardrail check -> policy.violation (Cross) +- Escalation -> agent.handoff (Cross) +- Session end -> agent.state.change (trace_end) +""" + +from __future__ import annotations + +import time +import logging +from typing import Any + +from layerlens.instrument.adapters.frameworks.agentforce.models import ( + AgentApiMessage, + AgentApiSession, +) + +logger = logging.getLogger(__name__) + + +class AgentApiMapper: + """ + Maps Agent API sessions to Stratix trace events. + + Each public method returns a list of event dicts compatible with + ``BaseAdapter.emit_dict_event(event_type, payload)``. + """ + + def map_session(self, session: AgentApiSession) -> list[dict[str, Any]]: + """ + Map a complete Agent API session to a sequence of Stratix events. + + Args: + session: Complete AgentApiSession with messages. + + Returns: + Ordered list of ``{event_type, payload}`` dicts. + """ + events: list[dict[str, Any]] = [] + + # Session start + events.append(self.map_session_start(session)) + + # Process each message + seen_topics: set[str] = set() + for msg in session.messages: + if msg.role == "user": + events.append(self.map_user_message(msg, session.session_id)) + elif msg.role == "agent": + events.append(self.map_agent_response(msg, session.session_id)) + + # Topic classification (emit once per topic) + if msg.topic and msg.topic not in seen_topics: + events.append( + self.map_topic_classification( + msg.topic, + session.agent_name or "unknown", + session.session_id, + ) + ) + seen_topics.add(msg.topic) + + # Action invocations + for action in msg.actions: + events.append( + self.map_action_invocation( + action, + session.session_id, + ) + ) + + # Guardrail checks + for gr in msg.guardrail_results: + events.append( + self.map_guardrail_check( + gr, + session.session_id, + ) + ) + + # Session end + events.append(self.map_session_end(session)) + + return events + + def map_session_start(self, session: AgentApiSession) -> dict[str, Any]: + """Map session creation to agent.state.change (trace_start).""" + return { + "event_type": "agent.state.change", + "payload": { + "framework": "salesforce_agentforce", + "event_subtype": "trace_start", + "session_id": session.session_id, + "agent_name": session.agent_name, + "timestamp_ns": _ts_to_ns(session.created_at), + }, + } + + def map_session_end(self, session: AgentApiSession) -> dict[str, Any]: + """Map session end to agent.state.change (trace_end).""" + start_ns = _ts_to_ns(session.created_at) + end_ns = _ts_to_ns(session.ended_at) + duration_ns = end_ns - start_ns if start_ns and end_ns else 0 + + return { + "event_type": "agent.state.change", + "payload": { + "framework": "salesforce_agentforce", + "event_subtype": "trace_end", + "session_id": session.session_id, + "agent_name": session.agent_name, + "duration_ns": duration_ns, + "message_count": len(session.messages), + }, + } + + @staticmethod + def map_user_message( + msg: AgentApiMessage, + session_id: str, + ) -> dict[str, Any]: + """Map a user message to agent.input (L1).""" + return { + "event_type": "agent.input", + "payload": { + "framework": "salesforce_agentforce", + "session_id": session_id, + "content": { + "role": "human", + "message": msg.content, + }, + "timestamp_ns": _ts_to_ns(msg.timestamp), + }, + } + + @staticmethod + def map_agent_response( + msg: AgentApiMessage, + session_id: str, + ) -> dict[str, Any]: + """Map an agent response to agent.output (L1).""" + return { + "event_type": "agent.output", + "payload": { + "framework": "salesforce_agentforce", + "session_id": session_id, + "content": { + "role": "agent", + "message": msg.content, + }, + "timestamp_ns": _ts_to_ns(msg.timestamp), + }, + } + + @staticmethod + def map_topic_classification( + topic: str, + agent_name: str, + session_id: str, + ) -> dict[str, Any]: + """Map topic classification to environment.config (L4a).""" + return { + "event_type": "environment.config", + "payload": { + "framework": "salesforce_agentforce", + "session_id": session_id, + "agent_name": agent_name, + "topic": topic, + "config_type": "topic_classification", + }, + } + + @staticmethod + def map_action_invocation( + action: dict[str, Any], + session_id: str, + ) -> dict[str, Any]: + """Map an Agentforce action to tool.call (L5a).""" + return { + "event_type": "tool.call", + "payload": { + "framework": "salesforce_agentforce", + "session_id": session_id, + "tool_name": action.get("name", "unknown"), + "tool_input": action.get("parameters", {}), + "tool_output": action.get("result"), + "tool_type": "salesforce_action", + }, + } + + @staticmethod + def map_guardrail_check( + guardrail: dict[str, Any], + session_id: str, + ) -> dict[str, Any]: + """Map a guardrail check to policy.violation (Cross-cutting).""" + return { + "event_type": "policy.violation", + "payload": { + "framework": "salesforce_agentforce", + "session_id": session_id, + "guardrail_name": guardrail.get("name", "unknown"), + "triggered": guardrail.get("triggered", False), + "message": guardrail.get("message"), + "source": "einstein_trust_layer", + }, + } + + @staticmethod + def map_escalation( + session_id: str, + from_agent: str, + to_agent: str = "human", + reason: str = "escalation", + ) -> dict[str, Any]: + """Map an escalation to agent.handoff (Cross-cutting).""" + return { + "event_type": "agent.handoff", + "payload": { + "from_agent": from_agent, + "to_agent": to_agent, + "reason": reason, + "framework": "salesforce_agentforce", + "session_id": session_id, + }, + } + + +def _ts_to_ns(ts: str | None) -> int: + """Convert an ISO 8601 timestamp string to nanoseconds since epoch.""" + if not ts: + return time.time_ns() + try: + from datetime import datetime + + dt = datetime.fromisoformat(ts.replace("Z", "+00:00")) + return int(dt.timestamp() * 1_000_000_000) + except (ValueError, TypeError): + return time.time_ns() diff --git a/src/layerlens/instrument/adapters/frameworks/agentforce/models.py b/src/layerlens/instrument/adapters/frameworks/agentforce/models.py new file mode 100644 index 0000000..dab4205 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/agentforce/models.py @@ -0,0 +1,322 @@ +""" +Pydantic models for Salesforce Agentforce data structures. + +Provides type-safe representations of: +- Salesforce DMO records (AIAgentSession, AIAgentInteractionStep, etc.) +- Agent API request/response payloads +- Platform Event payloads +- Trust Layer configuration +- LLM evaluation inputs/outputs +""" + +from __future__ import annotations + +from enum import Enum # Python 3.11+ has StrEnum; using `(str, Enum)` for 3.9/3.10 compat. +from typing import Any, Optional + +from pydantic import Field, BaseModel + +# --------------------------------------------------------------------------- +# Enums +# --------------------------------------------------------------------------- + + +class AgentChannelType(str, Enum): + """Agentforce session channel types.""" + + WEB = "Web" + MESSAGING = "Messaging" + VOICE = "Voice" + SLACK = "Slack" + API = "Api" + + +class SessionEndType(str, Enum): + """How an Agentforce session ended.""" + + COMPLETED = "Completed" + ESCALATED = "Escalated" + TIMED_OUT = "TimedOut" + ERROR = "Error" + ABANDONED = "Abandoned" + + +class StepType(str, Enum): + """Agentforce interaction step types from DMO.""" + + USER_INPUT = "UserInputStep" + LLM_EXECUTION = "LLMExecutionStep" + FUNCTION = "FunctionStep" + ACTION_INVOCATION = "ActionInvocationStep" + + +class ParticipantType(str, Enum): + """Participant roles in an Agentforce session.""" + + AI = "ai" + HUMAN = "human" + + +class AuthFlow(str, Enum): + """Supported Salesforce authentication flows.""" + + JWT_BEARER = "jwt_bearer" + CLIENT_CREDENTIALS = "client_credentials" + NAMED_CREDENTIAL = "named_credential" + + +class CaptureMode(str, Enum): + """Agentforce capture modes.""" + + POLLING = "polling" + REALTIME = "realtime" + HYBRID = "hybrid" + + +# --------------------------------------------------------------------------- +# DMO Record Models +# --------------------------------------------------------------------------- + + +class AgentSession(BaseModel): + """AIAgentSession DMO record.""" + + id: str = Field(alias="Id", description="Salesforce record ID") + start_timestamp: Optional[str] = Field( + default=None, + alias="StartTimestamp", + description="ISO 8601 session start time", + ) + end_timestamp: Optional[str] = Field( + default=None, + alias="EndTimestamp", + description="ISO 8601 session end time", + ) + channel_type: Optional[str] = Field( + default=None, + alias="AiAgentChannelTypeId", + description="Session channel (Web, Messaging, Voice, etc.)", + ) + session_end_type: Optional[str] = Field( + default=None, + alias="AiAgentSessionEndType", + description="How the session ended", + ) + voice_call_id: Optional[str] = Field(default=None, alias="VoiceCallId") + messaging_session_id: Optional[str] = Field(default=None, alias="MessagingSessionId") + previous_session_id: Optional[str] = Field(default=None, alias="PreviousSessionId") + + model_config = {"populate_by_name": True} + + +class AgentParticipant(BaseModel): + """AIAgentSessionParticipant DMO record.""" + + id: str = Field(alias="Id") + session_id: str = Field(alias="AiAgentSessionId") + agent_type: Optional[str] = Field(default=None, alias="AiAgentTypeId") + agent_api_name: Optional[str] = Field(default=None, alias="AiAgentApiName") + agent_version: Optional[str] = Field(default=None, alias="AiAgentVersionApiName") + participant_id: Optional[str] = Field(default=None, alias="ParticipantId") + role: Optional[str] = Field(default=None, alias="AiAgentSessionParticipantRoleId") + + model_config = {"populate_by_name": True} + + +class AgentInteraction(BaseModel): + """AIAgentInteraction DMO record.""" + + id: str = Field(alias="Id") + session_id: str = Field(alias="AiAgentSessionId") + interaction_type: Optional[str] = Field(default=None, alias="AiAgentInteractionTypeId") + telemetry_trace_id: Optional[str] = Field(default=None, alias="TelemetryTraceId") + telemetry_span_id: Optional[str] = Field(default=None, alias="TelemetryTraceSpanId") + topic_api_name: Optional[str] = Field(default=None, alias="TopicApiName") + attribute_text: Optional[str] = Field(default=None, alias="AttributeText") + prev_interaction_id: Optional[str] = Field(default=None, alias="PrevInteractionId") + + model_config = {"populate_by_name": True} + + +class AgentInteractionStep(BaseModel): + """AIAgentInteractionStep DMO record.""" + + id: str = Field(alias="Id") + interaction_id: str = Field(alias="AiAgentInteractionId") + step_type: Optional[str] = Field(default=None, alias="AiAgentInteractionStepTypeId") + input_value: Optional[str] = Field(default=None, alias="InputValueText") + output_value: Optional[str] = Field(default=None, alias="OutputValueText") + error_message: Optional[str] = Field(default=None, alias="ErrorMessageText") + generation_id: Optional[str] = Field(default=None, alias="GenerationId") + gateway_request_id: Optional[str] = Field(default=None, alias="GenAiGatewayRequestId") + gateway_response_id: Optional[str] = Field(default=None, alias="GenAiGatewayResponseId") + name: Optional[str] = Field(default=None, alias="Name") + telemetry_span_id: Optional[str] = Field(default=None, alias="TelemetryTraceSpanId") + start_timestamp: Optional[str] = Field(default=None, alias="StartTimestamp") + end_timestamp: Optional[str] = Field(default=None, alias="EndTimestamp") + + model_config = {"populate_by_name": True} + + +class AgentInteractionMessage(BaseModel): + """AIAgentInteractionMessage DMO record.""" + + id: str = Field(alias="Id") + interaction_id: str = Field(alias="AiAgentInteractionId") + message_type: Optional[str] = Field(default=None, alias="AiAgentInteractionMessageTypeId") + content_text: Optional[str] = Field(default=None, alias="ContentText") + content_type: Optional[str] = Field(default=None, alias="AiAgentInteractionMsgContentTypeId") + sent_timestamp: Optional[str] = Field(default=None, alias="MessageSentTimestamp") + parent_message_id: Optional[str] = Field(default=None, alias="ParentMessageId") + + model_config = {"populate_by_name": True} + + +# --------------------------------------------------------------------------- +# Agent API Models +# --------------------------------------------------------------------------- + + +class AgentApiMessage(BaseModel): + """A message in an Agent API session.""" + + id: Optional[str] = Field(default=None, description="Message ID") + role: str = Field(description="Message role (user, agent, system)") + content: str = Field(description="Message content text") + timestamp: Optional[str] = Field(default=None, description="ISO 8601 timestamp") + topic: Optional[str] = Field(default=None, description="Classified topic name") + actions: list[dict[str, Any]] = Field( + default_factory=list, + description="Actions taken by the agent", + ) + guardrail_results: list[dict[str, Any]] = Field( + default_factory=list, + description="Trust Layer guardrail check results", + ) + + +class AgentApiSession(BaseModel): + """Represents an Agent API session.""" + + session_id: str = Field(description="Salesforce session ID") + agent_name: Optional[str] = Field(default=None, description="Agentforce agent name") + status: str = Field(default="active", description="Session status") + messages: list[AgentApiMessage] = Field( + default_factory=list, + description="Session messages in order", + ) + created_at: Optional[str] = Field(default=None, description="Session creation timestamp") + ended_at: Optional[str] = Field(default=None, description="Session end timestamp") + + +# --------------------------------------------------------------------------- +# Trust Layer Models +# --------------------------------------------------------------------------- + + +class TrustLayerGuardrail(BaseModel): + """Einstein Trust Layer guardrail configuration.""" + + name: str = Field(description="Guardrail name") + type: str = Field(description="Guardrail type (toxicity, pii, custom)") + enabled: bool = Field(default=True, description="Whether the guardrail is active") + threshold: Optional[float] = Field( + default=None, + description="Detection threshold (0.0-1.0)", + ) + action: str = Field( + default="block", + description="Action on violation (block, warn, log)", + ) + + +class TrustLayerConfig(BaseModel): + """Complete Einstein Trust Layer configuration.""" + + guardrails: list[TrustLayerGuardrail] = Field( + default_factory=list, + description="Configured guardrails", + ) + data_masking_enabled: bool = Field( + default=False, + description="Whether PII/PCI masking is active", + ) + zero_data_retention: bool = Field( + default=True, + description="Whether zero data retention is enabled for LLM calls", + ) + audit_trail_enabled: bool = Field( + default=True, + description="Whether audit trail logging is active", + ) + + +# --------------------------------------------------------------------------- +# Platform Event Models +# --------------------------------------------------------------------------- + + +class AgentSessionEvent(BaseModel): + """Platform Event payload for AgentSession__e.""" + + session_id: str = Field(description="Agentforce session ID") + agent_name: Optional[str] = Field(default=None, description="Agent name") + topic_name: Optional[str] = Field(default=None, description="Classified topic") + actions_taken: Optional[str] = Field( + default=None, + description="JSON-encoded actions list", + ) + response_text: Optional[str] = Field( + default=None, + description="Agent response text", + ) + trust_layer_flags: Optional[str] = Field( + default=None, + description="JSON-encoded Trust Layer results", + ) + replay_id: Optional[str] = Field( + default=None, + description="Platform Event replay ID for redelivery", + ) + + +# --------------------------------------------------------------------------- +# Evaluation Models +# --------------------------------------------------------------------------- + + +class EvaluationRequest(BaseModel): + """Request to evaluate Agentforce sessions.""" + + session_ids: list[str] = Field(description="Salesforce session IDs to evaluate") + graders: list[str] = Field( + default_factory=lambda: ["relevance", "faithfulness"], + description="Grader names to run", + ) + include_ground_truth: bool = Field( + default=False, + description="Whether to fetch CRM outcome ground truth", + ) + ground_truth_query: Optional[str] = Field( + default=None, + description="SOQL query for ground truth data", + ) + + +class EvaluationResult(BaseModel): + """Result of evaluating Agentforce sessions.""" + + session_id: str = Field(description="Evaluated session ID") + scores: dict[str, float] = Field( + default_factory=dict, + description="Grader name -> score mapping", + ) + composite_score: Optional[float] = Field( + default=None, + description="Weighted composite quality score", + ) + ground_truth: dict[str, Any] = Field( + default_factory=dict, + description="CRM outcome data if fetched", + ) + errors: list[str] = Field(default_factory=list, description="Evaluation errors") diff --git a/src/layerlens/instrument/adapters/frameworks/agentforce/normalizer.py b/src/layerlens/instrument/adapters/frameworks/agentforce/normalizer.py new file mode 100644 index 0000000..a5553c8 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/agentforce/normalizer.py @@ -0,0 +1,251 @@ +""" +AgentForce DMO to LayerLens Event Normalizer + +Maps AgentForce Data Model Objects to LayerLens canonical event types: +- AIAgentSession → agent.lifecycle (start/end) +- AIAgentSessionParticipant → agent.identity +- AIAgentInteraction → agent.input / agent.output +- AIAgentInteractionStep (UserInputStep) → agent.input (L1) +- AIAgentInteractionStep (LLMExecutionStep) → model.invoke (L3) +- AIAgentInteractionStep (FunctionStep / ActionInvocationStep) → tool.call (L5) +- AIAgentInteractionMessage (Input) → agent.input +- AIAgentInteractionMessage (Output) → agent.output +""" + +from __future__ import annotations + +import json +import logging +from typing import Any +from datetime import datetime + +logger = logging.getLogger(__name__) + +# Step type to LayerLens event type mapping +_STEP_TYPE_MAP = { + "UserInputStep": "agent.input", + "LLMExecutionStep": "model.invoke", + "FunctionStep": "tool.call", + "ActionInvocationStep": "tool.call", +} + + +class AgentForceNormalizer: + """Normalize AgentForce DMO records to LayerLens events.""" + + def normalize_session( + self, + session: dict[str, Any], + ) -> list[dict[str, Any]]: + """Normalize an AIAgentSession to agent.lifecycle start/end events.""" + events = [] + + sf_meta = { + "sf.session.id": session.get("Id"), + "sf.session.channel": session.get("AiAgentChannelTypeId"), + "sf.session.end_type": session.get("AiAgentSessionEndType"), + } + + # Start event + events.append( + { + "event_type": "agent.lifecycle", + "payload": { + "lifecycle_action": "start", + "session_id": session.get("Id"), + "channel_type": session.get("AiAgentChannelTypeId"), + "previous_session_id": session.get("PreviousSessionId"), + "voice_call_id": session.get("VoiceCallId"), + "messaging_session_id": session.get("MessagingSessionId"), + }, + "metadata": sf_meta, + "timestamp": session.get("StartTimestamp"), + } + ) + + # End event (if session has ended) + end_ts = session.get("EndTimestamp") + if end_ts: + events.append( + { + "event_type": "agent.lifecycle", + "payload": { + "lifecycle_action": "end", + "session_id": session.get("Id"), + "session_end_type": session.get("AiAgentSessionEndType"), + "channel_type": session.get("AiAgentChannelTypeId"), + }, + "metadata": sf_meta, + "timestamp": end_ts, + } + ) + + return events + + def normalize_participant( + self, + participant: dict[str, Any], + ) -> dict[str, Any]: + """Normalize an AIAgentSessionParticipant to agent identity metadata.""" + agent_type = participant.get("AiAgentTypeId", "") + is_human = agent_type == "Employee" + + return { + "event_type": "agent.identity", + "payload": { + "participant_type": "human" if is_human else "ai", + "agent_type": agent_type, + "agent_api_name": participant.get("AiAgentApiName"), + "agent_version": participant.get("AiAgentVersionApiName"), + "participant_id": participant.get("ParticipantId"), + "role": participant.get("AiAgentSessionParticipantRoleId"), + "session_id": participant.get("AiAgentSessionId"), + }, + } + + def normalize_interaction( + self, + interaction: dict[str, Any], + ) -> dict[str, Any]: + """Normalize an AIAgentInteraction to a trace span.""" + # Parse AttributeText as JSON if present + attr_text = interaction.get("AttributeText") + attributes = {} + if attr_text: + try: + attributes = json.loads(attr_text) + except (json.JSONDecodeError, TypeError): + attributes = {"raw": attr_text} + + return { + "event_type": "agent.interaction", + "identity": { + "trace_id": interaction.get("TelemetryTraceId"), + "span_id": interaction.get("TelemetryTraceSpanId"), + }, + "payload": { + "interaction_id": interaction.get("Id"), + "interaction_type": interaction.get("AiAgentInteractionTypeId"), + "topic": interaction.get("TopicApiName"), + "attributes": attributes, + "prev_interaction_id": interaction.get("PrevInteractionId"), + "session_id": interaction.get("AiAgentSessionId"), + }, + "metadata": { + "sf.topic.name": interaction.get("TopicApiName"), + "sf.session.id": interaction.get("AiAgentSessionId"), + }, + } + + def normalize_step( + self, + step: dict[str, Any], + ) -> dict[str, Any]: + """Normalize an AIAgentInteractionStep to the appropriate LayerLens event.""" + step_type = step.get("AiAgentInteractionStepTypeId", "") + event_type = _STEP_TYPE_MAP.get(step_type, "tool.call") + + base: dict[str, Any] = { + "event_type": event_type, + "identity": { + "span_id": step.get("TelemetryTraceSpanId"), + }, + } + + # Salesforce metadata passthrough + base["metadata"] = { + "sf.step.name": step.get("Name"), + "sf.step.id": step.get("Id"), + "sf.generation.id": step.get("GenerationId"), + } + + # Extract timing if available + start_ts = step.get("StartTimestamp") + end_ts = step.get("EndTimestamp") + if start_ts: + base["timestamp"] = start_ts + if start_ts and end_ts: + try: + start_dt = datetime.fromisoformat(str(start_ts).replace("Z", "+00:00")) + end_dt = datetime.fromisoformat(str(end_ts).replace("Z", "+00:00")) + base["duration_ms"] = (end_dt - start_dt).total_seconds() * 1000 + except (ValueError, TypeError): + pass + + if event_type == "model.invoke": + base["payload"] = { + "model": { + "provider": "salesforce", + "name": step.get("Name", "unknown"), + "version": "unavailable", + "parameters": {}, + }, + "input_messages": [{"role": "user", "content": step.get("InputValueText", "")}], + "output_message": {"role": "assistant", "content": step.get("OutputValueText", "")}, + "error": step.get("ErrorMessageText"), + "metadata": { + "generation_id": step.get("GenerationId"), + "gateway_request_id": step.get("GenAiGatewayRequestId"), + "gateway_response_id": step.get("GenAiGatewayResponseId"), + }, + } + + elif event_type == "tool.call": + input_text = step.get("InputValueText", "") + output_text = step.get("OutputValueText") + + base["payload"] = { + "tool": { + "name": step.get("Name", "unknown"), + "version": "unavailable", + "integration": "salesforce_agentforce", + }, + "input": _try_parse_json(input_text), + "output": _try_parse_json(output_text) if output_text else None, + "error": step.get("ErrorMessageText"), + } + + else: # agent.input + base["payload"] = { + "content": { + "role": "human", + "message": step.get("InputValueText", ""), + }, + } + + return base + + def normalize_message( + self, + message: dict[str, Any], + ) -> dict[str, Any]: + """Normalize an AIAgentInteractionMessage to agent.input or agent.output.""" + msg_type = message.get("AiAgentInteractionMessageTypeId", "") + event_type = "agent.output" if msg_type == "Output" else "agent.input" + role = "agent" if msg_type == "Output" else "human" + + return { + "event_type": event_type, + "payload": { + "content": { + "role": role, + "message": message.get("ContentText", ""), + "metadata": { + "content_type": message.get("AiAgentInteractionMsgContentTypeId"), + "parent_message_id": message.get("ParentMessageId"), + }, + }, + }, + "timestamp": message.get("MessageSentTimestamp"), + } + + +def _try_parse_json(text: str) -> dict[str, Any]: + """Try to parse text as JSON, falling back to raw string wrapper.""" + if not text: + return {} + try: + result = json.loads(text) + return result if isinstance(result, dict) else {"raw": text} + except (json.JSONDecodeError, TypeError): + return {"raw": text} diff --git a/src/layerlens/instrument/adapters/frameworks/agentforce/trust_layer.py b/src/layerlens/instrument/adapters/frameworks/agentforce/trust_layer.py new file mode 100644 index 0000000..5cc654c --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/agentforce/trust_layer.py @@ -0,0 +1,228 @@ +""" +Einstein Trust Layer Policy Importer. + +Imports Einstein Trust Layer guardrail configuration from Salesforce +and converts it to LayerLens policy-as-code YAML format. + +Supports: + +- Guardrail rules extraction via Metadata API +- Data masking policy import +- Conversion to LayerLens YAML policy DSL +- Round-trip: import, evaluate, export updated policies + +Reference: https://developer.salesforce.com/docs/einstein/genai/guide/trust.html +""" + +from __future__ import annotations + +import logging +import warnings + +from layerlens.instrument.adapters.frameworks.agentforce.auth import SalesforceConnection +from layerlens.instrument.adapters.frameworks.agentforce.models import ( + TrustLayerConfig, + TrustLayerGuardrail, +) + +logger = logging.getLogger(__name__) + +# Metadata API types for Trust Layer config +_TRUST_LAYER_METADATA_TYPES = [ + "GenAiPlugin", + "GenAiFunction", +] + +# Default guardrail names in Einstein Trust Layer +_DEFAULT_GUARDRAILS = [ + "toxicity_detection", + "pii_detection", + "prompt_injection", + "hallucination_detection", +] + + +class TrustLayerImporter: + """ + Import Einstein Trust Layer configuration and convert to Stratix policy. + + Usage:: + + importer = TrustLayerImporter(connection=connection) + config = importer.fetch_config() + yaml_str = importer.to_layerlens_policy(config) + """ + + def __init__(self, connection: SalesforceConnection) -> None: + self._connection = connection + + def fetch_config(self) -> TrustLayerConfig: + """ + Fetch Einstein Trust Layer configuration from Salesforce. + + Queries the Setup metadata to extract guardrail rules, + data masking settings, and audit configuration. + + Returns: + TrustLayerConfig with all detected guardrails and settings. + """ + guardrails: list[TrustLayerGuardrail] = [] + + # Query for GenAI guardrail metadata + try: + records = self._connection.query( + "SELECT DeveloperName, Language, MasterLabel " + "FROM GenAiPlugin " + "WHERE IsDeleted = false " + "ORDER BY DeveloperName ASC" + ) + for record in records: + name = record.get("DeveloperName", "") + if name: + guardrails.append( + TrustLayerGuardrail( + name=name, + type=self._classify_guardrail(name), + enabled=True, + ) + ) + except Exception as e: + logger.warning("Failed to query GenAiPlugin metadata: %s", e) + + # Add default guardrails if none found (Trust Layer has built-in ones) + if not guardrails: + for name in _DEFAULT_GUARDRAILS: + guardrails.append( + TrustLayerGuardrail( + name=name, + type=self._classify_guardrail(name), + enabled=True, + ) + ) + + return TrustLayerConfig( + guardrails=guardrails, + data_masking_enabled=False, # Disabled for agents per SF docs + zero_data_retention=True, + audit_trail_enabled=True, + ) + + def to_layerlens_policy( + self, + config: TrustLayerConfig, + policy_name: str = "agentforce_trust_layer", + policy_version: str = "1.0.0", + ) -> str: + """ + Convert a TrustLayerConfig to LayerLens policy-as-code YAML. + + Args: + config: The Trust Layer configuration to convert. + policy_name: Name for the generated policy. + policy_version: Version string for the policy. + + Returns: + YAML string representing a LayerLens policy document. + """ + rules: list[str] = [] + + for guardrail in config.guardrails: + action = "block" if guardrail.action == "block" else "warn" + threshold = guardrail.threshold if guardrail.threshold is not None else 0.8 + + rule_yaml = ( + f" - name: {guardrail.name}\n" + f' description: "Imported from Einstein Trust Layer: {guardrail.type}"\n' + f" type: {guardrail.type}\n" + f" enabled: {str(guardrail.enabled).lower()}\n" + f" threshold: {threshold}\n" + f" action: {action}\n" + f" source: einstein_trust_layer" + ) + rules.append(rule_yaml) + + rules_block = "\n".join(rules) if rules else " []" + + yaml_output = ( + "# LayerLens Policy - Imported from Einstein Trust Layer\n" + "# Generated by: layerlens.instrument.adapters.frameworks.agentforce.trust_layer\n" + "# Source: Salesforce Einstein Trust Layer\n" + "\n" + "policy:\n" + f" name: {policy_name}\n" + f' version: "{policy_version}"\n' + ' description: "Policy imported from Salesforce Einstein Trust Layer"\n' + " source: salesforce_agentforce\n" + "\n" + "settings:\n" + f" data_masking: {str(config.data_masking_enabled).lower()}\n" + f" zero_data_retention: {str(config.zero_data_retention).lower()}\n" + f" audit_trail: {str(config.audit_trail_enabled).lower()}\n" + "\n" + "rules:\n" + f"{rules_block}\n" + ) + + return yaml_output + + def to_stratix_policy( + self, + config: TrustLayerConfig, + policy_name: str = "agentforce_trust_layer", + policy_version: str = "1.0.0", + ) -> str: + """Deprecated alias for :meth:`to_layerlens_policy`. + + Retained for compatibility with the legacy ``stratix.*`` adapter + package. New code should call :meth:`to_layerlens_policy` directly. + + Args: + config: The Trust Layer configuration to convert. + policy_name: Name for the generated policy. + policy_version: Version string for the policy. + + Returns: + YAML string representing a LayerLens policy document. + """ + warnings.warn( + "TrustLayerImporter.to_stratix_policy is deprecated; " + "use to_layerlens_policy instead.", + DeprecationWarning, + stacklevel=2, + ) + return self.to_layerlens_policy( + config, + policy_name=policy_name, + policy_version=policy_version, + ) + + def import_and_convert( + self, + policy_name: str = "agentforce_trust_layer", + ) -> tuple[TrustLayerConfig, str]: + """ + Convenience method: fetch config and convert to YAML in one call. + + Args: + policy_name: Name for the generated policy. + + Returns: + Tuple of (TrustLayerConfig, YAML string). + """ + config = self.fetch_config() + yaml_str = self.to_layerlens_policy(config, policy_name=policy_name) + return config, yaml_str + + @staticmethod + def _classify_guardrail(name: str) -> str: + """Classify a guardrail name into a guardrail type.""" + name_lower = name.lower() + if "toxic" in name_lower or "harm" in name_lower: + return "toxicity" + if "pii" in name_lower or "mask" in name_lower or "privacy" in name_lower: + return "pii" + if "injection" in name_lower or "jailbreak" in name_lower: + return "prompt_injection" + if "hallucin" in name_lower or "ground" in name_lower: + return "hallucination" + return "custom" diff --git a/tests/instrument/adapters/__init__.py b/tests/instrument/adapters/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/instrument/adapters/frameworks/__init__.py b/tests/instrument/adapters/frameworks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/instrument/adapters/frameworks/test_agentforce.py b/tests/instrument/adapters/frameworks/test_agentforce.py new file mode 100644 index 0000000..6941cee --- /dev/null +++ b/tests/instrument/adapters/frameworks/test_agentforce.py @@ -0,0 +1,761 @@ +"""Unit tests for the Salesforce Agentforce framework adapter. + +Mocked at the SDK shape level — no real Salesforce API or ``requests`` +network call is made. Each test patches ``requests`` (the only +third-party SDK touched at the module boundary) so the adapter, importer, +mapper, normalizer, client, events, evaluator, and trust-layer importer +are all exercised end-to-end against fixture data. + +Coverage: + +* lifecycle (connect / disconnect / health_check / serialize_for_replay) +* SOQL importer with paginated query results + JSON-injection guard +* DMO normalizer for every record type (session, participant, interaction, + step×3 step-types, message) +* Agent API client (create / send / end / capture) +* Agent API mapper (start, user / agent message, topic, action, guardrail, + escalation, end) +* Trust Layer importer (config fetch, YAML emission, deprecation alias) +* Platform Events subscriber (handle_event + reconnect bookkeeping) +* Einstein evaluator (composite score weights + offline-without-client + behavior) +* Lazy-import + default-install guard (importing the package does NOT + pull in ``requests``) +""" + +from __future__ import annotations + +import sys +from typing import Any +from unittest import mock + +import pytest + +from layerlens.instrument.adapters._base import AdapterStatus, CaptureConfig +from layerlens.instrument.adapters.frameworks.agentforce import ( + ADAPTER_CLASS, + ImportResult, + AgentApiClient, + AgentApiMapper, + AgentForceAdapter, + EinsteinEvaluator, + AgentForceImporter, + NormalizationError, + TrustLayerImporter, + SalesforceAuthError, + AgentForceNormalizer, + SalesforceConnection, + SalesforceQueryError, + SalesforceCredentials, + PlatformEventSubscriber, +) +from layerlens.instrument.adapters.frameworks.agentforce.models import ( + AgentApiMessage, + AgentApiSession, + TrustLayerConfig, + AgentSessionEvent, + TrustLayerGuardrail, +) + +# --------------------------------------------------------------------------- +# Test fixtures +# --------------------------------------------------------------------------- + + +class _RecordingStratix: + """Minimal stub that records every event emission.""" + + def __init__(self) -> None: + self.events: list[dict[str, Any]] = [] + + def emit(self, *args: Any, **kwargs: Any) -> None: # noqa: ARG002 + if len(args) == 2 and isinstance(args[0], str): + self.events.append({"event_type": args[0], "payload": args[1]}) + + +class _FakeResponse: + """``requests.Response`` shape with the bits the adapter touches.""" + + def __init__( + self, + json_data: Any = None, + status_code: int = 200, + headers: dict[str, str] | None = None, + text_lines: list[str] | None = None, + ) -> None: + self._json = json_data if json_data is not None else {} + self.status_code = status_code + self.headers = headers or {} + self._lines = text_lines or [] + self._closed = False + + def json(self) -> Any: + return self._json + + def raise_for_status(self) -> None: + if self.status_code >= 400: + import requests # type: ignore[import-untyped] + + err = requests.exceptions.HTTPError(f"HTTP {self.status_code}") + err.response = self + raise err + + def iter_lines(self, decode_unicode: bool = False) -> Any: # noqa: ARG002 + yield from self._lines + + def close(self) -> None: + self._closed = True + + +def _credentials() -> SalesforceCredentials: + creds = SalesforceCredentials( + client_id="3MVG9TestConnectedAppKey00000000", + username="agent-importer@example.com", + private_key="-----BEGIN PRIVATE KEY-----\nMIITestKey\n-----END PRIVATE KEY-----\n", + instance_url="https://example.my.salesforce.com", + ) + creds.access_token = "00DTEST!AQ.TOKEN" + creds.token_expiry = 9_999_999_999.0 # not expired + return creds + + +def _connection() -> SalesforceConnection: + conn = SalesforceConnection(credentials=_credentials()) + conn.instance_url = "https://example.my.salesforce.com" + return conn + + +# --------------------------------------------------------------------------- +# Lazy-import + package surface +# --------------------------------------------------------------------------- + + +def test_adapter_class_export_matches() -> None: + assert ADAPTER_CLASS is AgentForceAdapter + + +def test_package_reexports_full_public_api() -> None: + """All symbols in ``__all__`` are importable from the package root.""" + import layerlens.instrument.adapters.frameworks.agentforce as af + + for name in af.__all__: + assert hasattr(af, name), f"{name!r} declared in __all__ but missing" + + +def test_package_does_not_eagerly_import_requests() -> None: + """Importing the adapter package must not pull in ``requests``.""" + # Drop any prior import so the assertion measures the package itself. + for mod in list(sys.modules): + if mod == "requests" or mod.startswith("requests."): + del sys.modules[mod] + + # Re-import the package fresh. + for mod in list(sys.modules): + if mod.startswith("layerlens.instrument.adapters.frameworks.agentforce"): + del sys.modules[mod] + + import layerlens.instrument.adapters.frameworks.agentforce # noqa: F401 + + assert "requests" not in sys.modules, ( + "agentforce adapter must not import requests at module load time" + ) + + +# --------------------------------------------------------------------------- +# Adapter lifecycle +# --------------------------------------------------------------------------- + + +def test_connect_without_credentials_or_connection_raises() -> None: + adapter = AgentForceAdapter() + with pytest.raises(SalesforceAuthError): + adapter.connect() + + +def test_lifecycle_with_prebuilt_connection() -> None: + adapter = AgentForceAdapter(connection=_connection()) + adapter.connect() + assert adapter.is_connected is True + assert adapter.status == AdapterStatus.HEALTHY + + info = adapter.get_adapter_info() + assert info.framework == "salesforce_agentforce" + assert info.version == AgentForceAdapter.VERSION + + health = adapter.health_check() + assert health.framework_name == "salesforce_agentforce" + assert health.error_count == 0 + + rt = adapter.serialize_for_replay() + assert rt.framework == "salesforce_agentforce" + assert "capture_config" in rt.config + + adapter.disconnect() + assert adapter.is_connected is False + assert adapter.status == AdapterStatus.DISCONNECTED + + +def test_health_message_warns_when_token_expired() -> None: + creds = _credentials() + creds.token_expiry = 0.0 # expired + conn = SalesforceConnection(credentials=creds) + conn.instance_url = "https://example.my.salesforce.com" + + adapter = AgentForceAdapter(credentials=creds, connection=conn) + # Skip authenticate() by pre-populating connection. + adapter._importer = mock.MagicMock() + adapter._connected = True + adapter._status = AdapterStatus.HEALTHY + + health = adapter.health_check() + assert health.message is not None + assert "expired" in health.message.lower() + + +def test_import_sessions_before_connect_raises() -> None: + adapter = AgentForceAdapter() + with pytest.raises(RuntimeError, match="not connected"): + adapter.import_sessions(start_date="2026-04-01") + + +def test_import_sessions_routes_events_through_pipeline() -> None: + stratix = _RecordingStratix() + adapter = AgentForceAdapter( + stratix=stratix, + connection=_connection(), + capture_config=CaptureConfig.full(), + ) + adapter.connect() + + # Replace the importer with a fixture-returning fake. + fake_events = [ + { + "event_type": "agent.lifecycle", + "payload": {"lifecycle_action": "start", "session_id": "0XxAAA"}, + "identity": {"trace_id": "trace-1"}, + "timestamp": "2026-04-01T00:00:00Z", + }, + { + "event_type": "agent.input", + "payload": {"content": {"role": "human", "message": "hi"}}, + }, + ] + fake_result = ImportResult(sessions_imported=1) + adapter._importer = mock.MagicMock() + adapter._importer.import_sessions = mock.MagicMock( + return_value=(fake_events, fake_result), + ) + + result = adapter.import_sessions(start_date="2026-04-01") + assert result.sessions_imported == 1 + assert result.events_generated == 2 + + types = [e["event_type"] for e in stratix.events] + assert types == ["agent.lifecycle", "agent.input"] + # Identity + timestamp passthrough into payload. + assert stratix.events[0]["payload"]["_identity"] == {"trace_id": "trace-1"} + assert stratix.events[0]["payload"]["_timestamp"] == "2026-04-01T00:00:00Z" + + +# --------------------------------------------------------------------------- +# Importer (SOQL → events) +# --------------------------------------------------------------------------- + + +def test_importer_validates_date_format() -> None: + importer = AgentForceImporter(connection=_connection()) + with pytest.raises(ValueError, match="Invalid date format"): + importer.import_sessions(start_date="04/01/2026") + + +def test_importer_validates_timestamp_format() -> None: + importer = AgentForceImporter(connection=_connection()) + with pytest.raises(ValueError, match="Invalid timestamp format"): + importer.import_sessions(last_import_timestamp="2026/04/01 00:00:00") + + +def test_importer_rejects_malformed_salesforce_id() -> None: + importer = AgentForceImporter(connection=_connection()) + with pytest.raises(ValueError, match="Invalid Salesforce ID"): + importer._validate_sf_id("not a real id; DROP TABLE--") + + +def test_importer_runs_full_query_and_normalizes_records() -> None: + importer = AgentForceImporter(connection=_connection(), batch_size=50) + + session_row = { + "Id": "0XxAAAAAAAAAAA1", + "StartTimestamp": "2026-04-01T10:00:00Z", + "EndTimestamp": "2026-04-01T10:05:00Z", + "AiAgentChannelTypeId": "Web", + "AiAgentSessionEndType": "Completed", + "VoiceCallId": None, + "MessagingSessionId": None, + "PreviousSessionId": None, + } + participant_row = { + "Id": "1XxAAAAAAAAAAA1", + "AiAgentSessionId": session_row["Id"], + "AiAgentTypeId": "EinsteinSDR", + "AiAgentApiName": "Sales_Agent", + "AiAgentVersionApiName": "v1", + "ParticipantId": "user-1", + "AiAgentSessionParticipantRoleId": "Agent", + } + interaction_row = { + "Id": "2XxAAAAAAAAAAA1", + "AiAgentSessionId": session_row["Id"], + "AiAgentInteractionTypeId": "Conversation", + "TelemetryTraceId": "trace-1", + "TelemetryTraceSpanId": "span-1", + "TopicApiName": "Lead_Qualification", + "AttributeText": '{"intent":"qualify"}', + "PrevInteractionId": None, + } + step_row = { + "Id": "3XxAAAAAAAAAAA1", + "AiAgentInteractionId": interaction_row["Id"], + "AiAgentInteractionStepTypeId": "LLMExecutionStep", + "InputValueText": "what is the lead source?", + "OutputValueText": "the lead came from a webinar", + "ErrorMessageText": None, + "GenerationId": "gen-1", + "GenAiGatewayRequestId": "req-1", + "GenAiGatewayResponseId": "resp-1", + "Name": "lead_source_step", + "TelemetryTraceSpanId": "span-2", + } + message_row = { + "Id": "4XxAAAAAAAAAAA1", + "AiAgentInteractionId": interaction_row["Id"], + "AiAgentInteractionMessageTypeId": "Output", + "ContentText": "Got it, thanks!", + "AiAgentInteractionMsgContentTypeId": "Text", + "MessageSentTimestamp": "2026-04-01T10:01:00Z", + "ParentMessageId": None, + } + + query_responses = [ + [session_row], + [participant_row], + [interaction_row], + [step_row], + [message_row], + ] + + with mock.patch.object( + importer._connection, + "query", + side_effect=query_responses, + ): + events, result = importer.import_sessions( + start_date="2026-04-01", + end_date="2026-04-02", + ) + + assert result.sessions_imported == 1 + assert result.participants_imported == 1 + assert result.interactions_imported == 1 + assert result.steps_imported == 1 + assert result.messages_imported == 1 + # 2 lifecycle events (start + end) + participant + interaction + step + message + assert result.events_generated == 6 + assert len(events) == 6 + + +def test_importer_records_query_failure_in_result() -> None: + importer = AgentForceImporter(connection=_connection()) + + with mock.patch.object( + importer._connection, + "query", + side_effect=SalesforceQueryError("session query failed", soql=""), + ): + events, result = importer.import_sessions(start_date="2026-04-01") + + assert events == [] + assert result.sessions_imported == 0 + assert result.errors # at least one entry + + +# --------------------------------------------------------------------------- +# Normalizer (DMO → canonical events) +# --------------------------------------------------------------------------- + + +def test_normalizer_session_emits_start_and_end() -> None: + n = AgentForceNormalizer() + events = n.normalize_session( + { + "Id": "0Xx1", + "StartTimestamp": "2026-04-01T00:00:00Z", + "EndTimestamp": "2026-04-01T00:05:00Z", + "AiAgentChannelTypeId": "Voice", + "AiAgentSessionEndType": "Completed", + } + ) + assert [e["payload"]["lifecycle_action"] for e in events] == ["start", "end"] + assert events[0]["event_type"] == "agent.lifecycle" + + +def test_normalizer_participant_marks_human_for_employee() -> None: + n = AgentForceNormalizer() + evt = n.normalize_participant({"AiAgentTypeId": "Employee"}) + assert evt["payload"]["participant_type"] == "human" + + evt = n.normalize_participant({"AiAgentTypeId": "EinsteinSDR"}) + assert evt["payload"]["participant_type"] == "ai" + + +def test_normalizer_step_routes_by_type() -> None: + n = AgentForceNormalizer() + + llm_step = n.normalize_step( + { + "AiAgentInteractionStepTypeId": "LLMExecutionStep", + "Name": "summarize", + "InputValueText": "summarize x", + "OutputValueText": "x summarized", + "StartTimestamp": "2026-04-01T10:00:00Z", + "EndTimestamp": "2026-04-01T10:00:01Z", + } + ) + assert llm_step["event_type"] == "model.invoke" + assert llm_step["payload"]["model"]["provider"] == "salesforce" + assert llm_step["duration_ms"] == pytest.approx(1000.0) + + tool_step = n.normalize_step( + { + "AiAgentInteractionStepTypeId": "ActionInvocationStep", + "Name": "create_case", + "InputValueText": '{"subject":"hi"}', + "OutputValueText": '{"id":"500x"}', + } + ) + assert tool_step["event_type"] == "tool.call" + assert tool_step["payload"]["tool"]["name"] == "create_case" + # JSON parsed. + assert tool_step["payload"]["input"] == {"subject": "hi"} + + user_step = n.normalize_step( + { + "AiAgentInteractionStepTypeId": "UserInputStep", + "InputValueText": "hello", + } + ) + assert user_step["event_type"] == "agent.input" + assert user_step["payload"]["content"]["message"] == "hello" + + +def test_normalizer_interaction_handles_invalid_attribute_json() -> None: + n = AgentForceNormalizer() + evt = n.normalize_interaction( + { + "Id": "2Xx1", + "AiAgentSessionId": "0Xx1", + "AttributeText": "not json {", + } + ) + # Falls back to raw wrapper rather than crashing. + assert evt["payload"]["attributes"] == {"raw": "not json {"} + + +def test_normalizer_message_routes_role_by_type() -> None: + n = AgentForceNormalizer() + + out = n.normalize_message( + { + "AiAgentInteractionMessageTypeId": "Output", + "ContentText": "hi", + } + ) + assert out["event_type"] == "agent.output" + assert out["payload"]["content"]["role"] == "agent" + + inp = n.normalize_message( + { + "AiAgentInteractionMessageTypeId": "Input", + "ContentText": "hi", + } + ) + assert inp["event_type"] == "agent.input" + assert inp["payload"]["content"]["role"] == "human" + + +# --------------------------------------------------------------------------- +# Agent API client + mapper +# --------------------------------------------------------------------------- + + +def test_client_create_session_validates_inputs() -> None: + client = AgentApiClient(connection=_connection()) + with pytest.raises(ValueError): + client.create_session(agent_name="") + + +def test_client_create_send_end_session_round_trip() -> None: + client = AgentApiClient(connection=_connection()) + + create_resp = _FakeResponse( + json_data={"sessionId": "session-1", "createdAt": "2026-04-01T10:00:00Z"}, + ) + send_resp = _FakeResponse( + json_data={ + "messages": [{"id": "m1", "text": "hello back", "timestamp": "2026-04-01T10:00:01Z"}], + "topic": "Greeting", + "actions": [{"name": "noop", "parameters": {}, "result": "ok"}], + "guardrailResults": [{"name": "toxicity", "triggered": False, "message": "clean"}], + }, + ) + end_resp = _FakeResponse(json_data={}) + + with mock.patch("requests.post", side_effect=[create_resp, send_resp]), mock.patch( + "requests.delete", return_value=end_resp + ): + session = client.create_session(agent_name="ServiceAgent") + assert session.session_id == "session-1" + + message = client.send_message(session.session_id, "hi") + assert isinstance(message, AgentApiMessage) + assert message.content == "hello back" + assert message.topic == "Greeting" + assert message.actions[0]["name"] == "noop" + assert message.guardrail_results[0]["name"] == "toxicity" + + client.end_session(session.session_id) + + +def test_client_send_message_validates_inputs() -> None: + client = AgentApiClient(connection=_connection()) + with pytest.raises(ValueError): + client.send_message(session_id="", message="x") + with pytest.raises(ValueError): + client.send_message(session_id="s", message="") + + +def test_client_capture_session_records_full_transcript() -> None: + client = AgentApiClient(connection=_connection()) + + create_resp = _FakeResponse( + json_data={"sessionId": "s-1", "createdAt": "2026-04-01T10:00:00Z"}, + ) + msg_resp_1 = _FakeResponse(json_data={"messages": [{"id": "m1", "text": "hi"}]}) + msg_resp_2 = _FakeResponse(json_data={"messages": [{"id": "m2", "text": "bye"}]}) + end_resp = _FakeResponse(json_data={}) + + with mock.patch( + "requests.post", side_effect=[create_resp, msg_resp_1, msg_resp_2] + ), mock.patch("requests.delete", return_value=end_resp): + session = client.capture_session( + agent_name="ServiceAgent", + messages=["hello", "goodbye"], + ) + + assert session.status == "ended" + # 2 user + 2 agent = 4 messages. + assert len(session.messages) == 4 + assert [m.role for m in session.messages] == ["user", "agent", "user", "agent"] + + +def test_mapper_emits_full_session_event_sequence() -> None: + mapper = AgentApiMapper() + session = AgentApiSession( + session_id="s-1", + agent_name="ServiceAgent", + created_at="2026-04-01T10:00:00Z", + ended_at="2026-04-01T10:00:05Z", + messages=[ + AgentApiMessage(role="user", content="hello"), + AgentApiMessage( + role="agent", + content="hi", + topic="Greeting", + actions=[{"name": "noop", "parameters": {}, "result": "ok"}], + guardrail_results=[ + {"name": "toxicity", "triggered": False, "message": ""}, + ], + ), + ], + ) + events = mapper.map_session(session) + types = [e["event_type"] for e in events] + assert types == [ + "agent.state.change", # session start + "agent.input", # user + "agent.output", # agent + "environment.config", # topic + "tool.call", # action + "policy.violation", # guardrail + "agent.state.change", # session end + ] + + +def test_mapper_session_end_computes_duration() -> None: + mapper = AgentApiMapper() + session = AgentApiSession( + session_id="s", + created_at="2026-04-01T10:00:00Z", + ended_at="2026-04-01T10:00:02Z", + ) + end_event = mapper.map_session_end(session) + # 2 seconds → 2_000_000_000 nanoseconds. + assert end_event["payload"]["duration_ns"] == 2_000_000_000 + + +def test_mapper_escalation() -> None: + evt = AgentApiMapper.map_escalation( + session_id="s-1", + from_agent="bot", + to_agent="human", + reason="user requested", + ) + assert evt["event_type"] == "agent.handoff" + assert evt["payload"]["from_agent"] == "bot" + + +# --------------------------------------------------------------------------- +# Trust Layer +# --------------------------------------------------------------------------- + + +def test_trust_layer_to_layerlens_policy_emits_well_formed_yaml() -> None: + importer = TrustLayerImporter(connection=_connection()) + cfg = TrustLayerConfig( + guardrails=[ + TrustLayerGuardrail(name="toxicity_detection", type="toxicity"), + TrustLayerGuardrail(name="pii_detection", type="pii", threshold=0.9), + ], + ) + yaml_str = importer.to_layerlens_policy(cfg, policy_name="my_policy") + assert "policy:" in yaml_str + assert "name: my_policy" in yaml_str + assert "toxicity_detection" in yaml_str + assert "pii_detection" in yaml_str + assert "threshold: 0.9" in yaml_str + assert "LayerLens Policy" in yaml_str + assert "stratix.sdk" not in yaml_str + + +def test_trust_layer_deprecation_alias_warns_and_returns_same() -> None: + importer = TrustLayerImporter(connection=_connection()) + cfg = TrustLayerConfig(guardrails=[TrustLayerGuardrail(name="x", type="custom")]) + + with pytest.warns(DeprecationWarning, match="to_layerlens_policy"): + legacy = importer.to_stratix_policy(cfg) + canonical = importer.to_layerlens_policy(cfg) + assert legacy == canonical + + +def test_trust_layer_classify_guardrail_buckets_known_names() -> None: + classify = TrustLayerImporter._classify_guardrail + assert classify("toxicity_detection") == "toxicity" + assert classify("pii_mask") == "pii" + assert classify("prompt_injection_guard") == "prompt_injection" + assert classify("hallucination_check") == "hallucination" + assert classify("custom_guard") == "custom" + + +def test_trust_layer_fetch_config_falls_back_to_defaults_on_query_fail() -> None: + importer = TrustLayerImporter(connection=_connection()) + + with mock.patch.object( + importer._connection, + "query", + side_effect=SalesforceQueryError("no perms", soql=""), + ): + cfg = importer.fetch_config() + + # Default guardrails populated when nothing came back. + names = {g.name for g in cfg.guardrails} + assert "toxicity_detection" in names + assert "pii_detection" in names + + +# --------------------------------------------------------------------------- +# Platform Events subscriber +# --------------------------------------------------------------------------- + + +def test_platform_events_handle_event_invokes_callback_and_records_replay_id() -> None: + received: list[AgentSessionEvent] = [] + sub = PlatformEventSubscriber( + connection=_connection(), + on_event=received.append, + channel="/event/AgentSession__e", + ) + + sub._handle_event( + { + "SessionId__c": "0Xx1", + "AgentName__c": "ServiceAgent", + "TopicName__c": "Greeting", + "ActionsTaken__c": "[]", + "ResponseText__c": "hi", + "TrustLayerFlags__c": "{}", + "event": {"replayId": "42"}, + } + ) + + assert len(received) == 1 + assert received[0].session_id == "0Xx1" + assert sub.events_received == 1 + assert sub.last_replay_id == "42" + + +def test_platform_events_default_channel_and_state_flags() -> None: + sub = PlatformEventSubscriber(connection=_connection()) + assert sub.is_running is False + assert sub.events_received == 0 + + +# --------------------------------------------------------------------------- +# Einstein evaluator +# --------------------------------------------------------------------------- + + +def test_evaluator_returns_zero_scores_without_layerlens_client() -> None: + evaluator = EinsteinEvaluator() + # No client configured => graders default to 0.0 (logged). + results = evaluator.evaluate_completions( + session_ids=["0Xx1"], + graders=["relevance", "faithfulness"], + ) + assert len(results) == 1 + assert results[0].scores == {"relevance": 0.0, "faithfulness": 0.0} + assert results[0].composite_score == 0.0 + + +def test_evaluator_composite_score_uses_weight_categories() -> None: + evaluator = EinsteinEvaluator() + composite = evaluator._compute_composite_score( + { + "relevance": 1.0, + "faithfulness": 1.0, + "safety": 1.0, + } + ) + # Three perfect scores collapse to 1.0 regardless of weight choice. + assert composite == pytest.approx(1.0) + + composite_zero = evaluator._compute_composite_score({}) + assert composite_zero is None + + +def test_evaluator_returns_empty_when_no_session_ids() -> None: + assert EinsteinEvaluator().evaluate_completions(session_ids=[]) == [] + + +def test_evaluator_evaluate_topic_requires_adapter() -> None: + with pytest.raises(RuntimeError, match="Adapter required"): + EinsteinEvaluator().evaluate_topic(topic="Lead_Qualification") + + +# --------------------------------------------------------------------------- +# Smoke: NormalizationError surfaces for callers that re-export it +# --------------------------------------------------------------------------- + + +def test_normalization_error_is_distinct_exception() -> None: + err = NormalizationError("bad row") + assert isinstance(err, Exception) + assert "bad row" in str(err)