Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,12 @@ def post(
on_page_context = validated_data.get("on_page_context")

try:
client = SeerExplorerClient(organization, request.user, is_interactive=True)
client = SeerExplorerClient(
organization,
request.user,
is_interactive=True,
enable_coding=True,
)
if run_id:
# Continue existing conversation
result_run_id = client.continue_run(
Expand Down
93 changes: 93 additions & 0 deletions src/sentry/seer/explorer/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
import time
from typing import Any, Literal

import orjson
Expand Down Expand Up @@ -117,6 +118,28 @@ def execute(cls, organization: Organization, run_id: int) -> None:
on_completion=NotifyOnComplete
)
run_id = client.start_run("Analyze this issue")

# WITH CODE EDITING AND PR CREATION
client = SeerExplorerClient(
organization,
user,
enable_coding=True, # Enable code editing tools
)

run_id = client.start_run("Fix the null pointer exception in auth.py")
state = client.get_run(run_id, blocking=True)

# Check if agent made code changes and if they need to be pushed
has_changes, is_synced = state.has_code_changes()
if has_changes and not is_synced:
# Push changes to PR (creates new PR or updates existing)
state = client.push_changes(run_id)

# Get PR info for each repo
for repo_name in state.get_file_patches_by_repo().keys():
pr_state = state.get_pr_state(repo_name)
if pr_state and pr_state.pr_url:
print(f"PR created: {pr_state.pr_url}")
```

Args:
Expand All @@ -128,6 +151,7 @@ def execute(cls, organization: Organization, run_id: int) -> None:
on_completion_hook: Optional `ExplorerOnCompletionHook` class to call when the agent completes. The hook's execute() method receives the organization and run ID. This is called whether or not the agent was successful. Hook classes must be module-level (not nested classes).
intelligence_level: Optionally set the intelligence level of the agent. Higher intelligence gives better result quality at the cost of significantly higher latency and cost.
is_interactive: Enable full interactive, human-like features of the agent. Only enable if you support *all* available interactions in Seer. An example use of this is the explorer chat in Sentry UI.
enable_coding: Enable code editing tools. When disabled, the agent cannot make code changes. Default is False.
"""

def __init__(
Expand All @@ -140,6 +164,7 @@ def __init__(
on_completion_hook: type[ExplorerOnCompletionHook] | None = None,
intelligence_level: Literal["low", "medium", "high"] = "medium",
is_interactive: bool = False,
enable_coding: bool = False,
):
self.organization = organization
self.user = user
Expand All @@ -149,6 +174,7 @@ def __init__(
self.category_key = category_key
self.category_value = category_value
self.is_interactive = is_interactive
self.enable_coding = enable_coding

# Validate that category_key and category_value are provided together
if category_key == "" or category_value == "":
Expand Down Expand Up @@ -198,6 +224,7 @@ def start_run(
"user_org_context": collect_user_org_context(self.user, self.organization),
"intelligence_level": self.intelligence_level,
"is_interactive": self.is_interactive,
"enable_coding": self.enable_coding,
}

# Add artifact key and schema if provided
Expand Down Expand Up @@ -273,6 +300,7 @@ def continue_run(
"insert_index": insert_index,
"on_page_context": on_page_context,
"is_interactive": self.is_interactive,
"enable_coding": self.enable_coding,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is set in init so maybe we don't need this?
Do we expect users to update client.enable_coding to disable/enable coding after the run starts? Maybe we can have an optional param for this fx to pass the new value of the flag?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is following the same pattern as is_interactive where if users change the flag on the client, then future iterations in the run will update the flag. Otherwise, this doesn't have any impact

}

# Add artifact key and schema if provided
Expand Down Expand Up @@ -384,3 +412,68 @@ def get_runs(

runs = [ExplorerRun(**run) for run in result.get("data", [])]
return runs

def push_changes(
self,
run_id: int,
repo_name: str | None = None,
poll_interval: float = 2.0,
poll_timeout: float = 120.0,
) -> SeerRunState:
"""
Push code changes to PR(s) and wait for completion.

Creates new PRs or updates existing ones with current file patches.
Polls until all PR operations complete.

Args:
run_id: The run ID
repo_name: Specific repo to push, or None for all repos with changes
poll_interval: Seconds between polls
poll_timeout: Maximum seconds to wait

Returns:
SeerRunState: Final state with PR info

Raises:
TimeoutError: If polling exceeds timeout
requests.HTTPError: If the Seer API request fails
"""
# Trigger PR creation
path = "/v1/automation/explorer/update"
payload = {
"run_id": run_id,
"payload": {
"type": "create_pr",
"repo_name": repo_name,
},
}
body = orjson.dumps(payload, option=orjson.OPT_NON_STR_KEYS)
response = requests.post(
f"{settings.SEER_AUTOFIX_URL}{path}",
data=body,
headers={
"content-type": "application/json;charset=utf-8",
**sign_with_seer_secret(body),
},
)
response.raise_for_status()

# Poll until PR creation completes
start_time = time.time()

while True:
state = fetch_run_status(run_id, self.organization)

# Check if any PRs are still being created
any_creating = any(
pr.pr_creation_status == "creating" for pr in state.repo_pr_states.values()
)

if not any_creating:
return state
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Polling loop exits early if PR state not yet initialized

The push_changes polling loop will exit immediately if repo_pr_states is empty or doesn't contain any PR with pr_creation_status == "creating". Since any() on an empty iterable returns False, if the server hasn't yet populated repo_pr_states with the new PR being created by the time the first fetch_run_status call returns, the method will return prematurely with stale data instead of waiting for the PR creation to complete.

Fix in Cursor Fix in Web

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wrong, the update endpoint call won't return until the state is updated already with "creating"


if time.time() - start_time > poll_timeout:
raise TimeoutError(f"PR creation timed out after {poll_timeout}s")

time.sleep(poll_interval)
91 changes: 91 additions & 0 deletions src/sentry/seer/explorer/client_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,46 @@ class Config:
extra = "allow"


class FilePatch(BaseModel):
"""A file patch from code editing."""

path: str
type: Literal["A", "M", "D"] # A=add, M=modify, D=delete
added: int
removed: int

class Config:
extra = "allow"


class ExplorerFilePatch(BaseModel):
"""A file patch associated with a repository."""

repo_name: str
patch: FilePatch

class Config:
extra = "allow"


class RepoPRState(BaseModel):
"""PR state for a single repository."""

repo_name: str
branch_name: str | None = None
pr_number: int | None = None
pr_url: str | None = None
pr_id: int | None = None
commit_sha: str | None = None
pr_creation_status: Literal["creating", "completed", "error"] | None = None
pr_creation_error: str | None = None
title: str | None = None
description: str | None = None

class Config:
extra = "allow"


class MemoryBlock(BaseModel):
"""A block in the Explorer agent's conversation/memory."""

Expand All @@ -52,6 +92,10 @@ class MemoryBlock(BaseModel):
timestamp: str
loading: bool = False
artifacts: list[Artifact] = []
file_patches: list[ExplorerFilePatch] = []
pr_commit_shas: dict[str, str] | None = (
None # repository name -> commit SHA. Used to track which commit was associated with each repo's PR at the time this block was created.
)

class Config:
extra = "allow"
Expand All @@ -76,6 +120,7 @@ class SeerRunState(BaseModel):
status: Literal["processing", "completed", "error", "awaiting_user_input"]
updated_at: str
pending_user_input: PendingUserInput | None = None
repo_pr_states: dict[str, RepoPRState] = {}

class Config:
extra = "allow"
Expand Down Expand Up @@ -110,6 +155,52 @@ def get_artifact(self, key: str, schema: type[T]) -> T | None:
return None
return schema.parse_obj(artifact.data)

def get_file_patches_by_repo(self) -> dict[str, list[ExplorerFilePatch]]:
"""Get file patches grouped by repository."""
by_repo: dict[str, list[ExplorerFilePatch]] = {}
for block in self.blocks:
for fp in block.file_patches:
if fp.repo_name not in by_repo:
by_repo[fp.repo_name] = []
by_repo[fp.repo_name].append(fp)
return by_repo

def get_pr_state(self, repo_name: str) -> RepoPRState | None:
"""Get PR state for a specific repository."""
return self.repo_pr_states.get(repo_name)

def _is_repo_synced(self, repo_name: str) -> bool:
"""Check if PR for a repo is in sync with latest changes."""
pr_state = self.repo_pr_states.get(repo_name)
if not pr_state or not pr_state.commit_sha:
return False # No PR yet = not synced

# Find last block with patches for this repo
for block in reversed(self.blocks):
if any(fp.repo_name == repo_name for fp in block.file_patches):
block_sha = (block.pr_commit_shas or {}).get(repo_name)
return block_sha == pr_state.commit_sha
return True # No patches found = synced

def has_code_changes(self) -> tuple[bool, bool]:
"""
Check if there are code changes and if all have been pushed to PRs.

Returns:
(has_changes, all_changes_pushed):
- has_changes: True if any file patches exist
- all_changes_pushed: True if the current state of changes across all repos have all been pushed to PRs.
"""
patches_by_repo = self.get_file_patches_by_repo()
has_changes = len(patches_by_repo) > 0

if not has_changes:
return (False, True)

# Check if all repos with changes are synced
all_changes_pushed = all(self._is_repo_synced(repo) for repo in patches_by_repo.keys())
return (has_changes, all_changes_pushed)


class CustomToolDefinition(BaseModel):
"""Definition of a custom tool to be sent to Seer."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ def test_post_new_conversation_calls_client(self, mock_client_class: MagicMock):
assert response.data == {"run_id": 456}

# Verify client was called correctly
mock_client_class.assert_called_once_with(self.organization, ANY, is_interactive=True)
mock_client_class.assert_called_once_with(
self.organization, ANY, is_interactive=True, enable_coding=True
)
mock_client.start_run.assert_called_once_with(
prompt="What is this error about?", on_page_context=None
)
Expand All @@ -90,7 +92,9 @@ def test_post_continue_conversation_calls_client(self, mock_client_class: MagicM
assert response.data == {"run_id": 789}

# Verify client was called correctly
mock_client_class.assert_called_once_with(self.organization, ANY, is_interactive=True)
mock_client_class.assert_called_once_with(
self.organization, ANY, is_interactive=True, enable_coding=True
)
mock_client.continue_run.assert_called_once_with(
run_id=789, prompt="Follow up question", insert_index=2, on_page_context=None
)
Expand Down
Loading
Loading