diff --git a/src/sentry/seer/endpoints/organization_seer_explorer_chat.py b/src/sentry/seer/endpoints/organization_seer_explorer_chat.py index a2384a548f3fcd..c2c2ff65c01283 100644 --- a/src/sentry/seer/endpoints/organization_seer_explorer_chat.py +++ b/src/sentry/seer/endpoints/organization_seer_explorer_chat.py @@ -112,7 +112,12 @@ def post( on_page_context = validated_data.get("on_page_context") try: - client = SeerExplorerClient(organization, request.user, is_interactive=True) + client = SeerExplorerClient( + organization, + request.user, + is_interactive=True, + enable_coding=True, + ) if run_id: # Continue existing conversation result_run_id = client.continue_run( diff --git a/src/sentry/seer/explorer/client.py b/src/sentry/seer/explorer/client.py index 1fb2ae8f7e6461..386b902b873c73 100644 --- a/src/sentry/seer/explorer/client.py +++ b/src/sentry/seer/explorer/client.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import time from typing import Any, Literal import orjson @@ -117,6 +118,28 @@ def execute(cls, organization: Organization, run_id: int) -> None: on_completion=NotifyOnComplete ) run_id = client.start_run("Analyze this issue") + + # WITH CODE EDITING AND PR CREATION + client = SeerExplorerClient( + organization, + user, + enable_coding=True, # Enable code editing tools + ) + + run_id = client.start_run("Fix the null pointer exception in auth.py") + state = client.get_run(run_id, blocking=True) + + # Check if agent made code changes and if they need to be pushed + has_changes, is_synced = state.has_code_changes() + if has_changes and not is_synced: + # Push changes to PR (creates new PR or updates existing) + state = client.push_changes(run_id) + + # Get PR info for each repo + for repo_name in state.get_file_patches_by_repo().keys(): + pr_state = state.get_pr_state(repo_name) + if pr_state and pr_state.pr_url: + print(f"PR created: {pr_state.pr_url}") ``` Args: @@ -128,6 +151,7 @@ def execute(cls, organization: Organization, run_id: int) -> None: on_completion_hook: Optional `ExplorerOnCompletionHook` class to call when the agent completes. The hook's execute() method receives the organization and run ID. This is called whether or not the agent was successful. Hook classes must be module-level (not nested classes). intelligence_level: Optionally set the intelligence level of the agent. Higher intelligence gives better result quality at the cost of significantly higher latency and cost. is_interactive: Enable full interactive, human-like features of the agent. Only enable if you support *all* available interactions in Seer. An example use of this is the explorer chat in Sentry UI. + enable_coding: Enable code editing tools. When disabled, the agent cannot make code changes. Default is False. """ def __init__( @@ -140,6 +164,7 @@ def __init__( on_completion_hook: type[ExplorerOnCompletionHook] | None = None, intelligence_level: Literal["low", "medium", "high"] = "medium", is_interactive: bool = False, + enable_coding: bool = False, ): self.organization = organization self.user = user @@ -149,6 +174,7 @@ def __init__( self.category_key = category_key self.category_value = category_value self.is_interactive = is_interactive + self.enable_coding = enable_coding # Validate that category_key and category_value are provided together if category_key == "" or category_value == "": @@ -198,6 +224,7 @@ def start_run( "user_org_context": collect_user_org_context(self.user, self.organization), "intelligence_level": self.intelligence_level, "is_interactive": self.is_interactive, + "enable_coding": self.enable_coding, } # Add artifact key and schema if provided @@ -273,6 +300,7 @@ def continue_run( "insert_index": insert_index, "on_page_context": on_page_context, "is_interactive": self.is_interactive, + "enable_coding": self.enable_coding, } # Add artifact key and schema if provided @@ -384,3 +412,68 @@ def get_runs( runs = [ExplorerRun(**run) for run in result.get("data", [])] return runs + + def push_changes( + self, + run_id: int, + repo_name: str | None = None, + poll_interval: float = 2.0, + poll_timeout: float = 120.0, + ) -> SeerRunState: + """ + Push code changes to PR(s) and wait for completion. + + Creates new PRs or updates existing ones with current file patches. + Polls until all PR operations complete. + + Args: + run_id: The run ID + repo_name: Specific repo to push, or None for all repos with changes + poll_interval: Seconds between polls + poll_timeout: Maximum seconds to wait + + Returns: + SeerRunState: Final state with PR info + + Raises: + TimeoutError: If polling exceeds timeout + requests.HTTPError: If the Seer API request fails + """ + # Trigger PR creation + path = "/v1/automation/explorer/update" + payload = { + "run_id": run_id, + "payload": { + "type": "create_pr", + "repo_name": repo_name, + }, + } + body = orjson.dumps(payload, option=orjson.OPT_NON_STR_KEYS) + response = requests.post( + f"{settings.SEER_AUTOFIX_URL}{path}", + data=body, + headers={ + "content-type": "application/json;charset=utf-8", + **sign_with_seer_secret(body), + }, + ) + response.raise_for_status() + + # Poll until PR creation completes + start_time = time.time() + + while True: + state = fetch_run_status(run_id, self.organization) + + # Check if any PRs are still being created + any_creating = any( + pr.pr_creation_status == "creating" for pr in state.repo_pr_states.values() + ) + + if not any_creating: + return state + + if time.time() - start_time > poll_timeout: + raise TimeoutError(f"PR creation timed out after {poll_timeout}s") + + time.sleep(poll_interval) diff --git a/src/sentry/seer/explorer/client_models.py b/src/sentry/seer/explorer/client_models.py index 6458d708e030e5..034e02ed662e7c 100644 --- a/src/sentry/seer/explorer/client_models.py +++ b/src/sentry/seer/explorer/client_models.py @@ -44,6 +44,46 @@ class Config: extra = "allow" +class FilePatch(BaseModel): + """A file patch from code editing.""" + + path: str + type: Literal["A", "M", "D"] # A=add, M=modify, D=delete + added: int + removed: int + + class Config: + extra = "allow" + + +class ExplorerFilePatch(BaseModel): + """A file patch associated with a repository.""" + + repo_name: str + patch: FilePatch + + class Config: + extra = "allow" + + +class RepoPRState(BaseModel): + """PR state for a single repository.""" + + repo_name: str + branch_name: str | None = None + pr_number: int | None = None + pr_url: str | None = None + pr_id: int | None = None + commit_sha: str | None = None + pr_creation_status: Literal["creating", "completed", "error"] | None = None + pr_creation_error: str | None = None + title: str | None = None + description: str | None = None + + class Config: + extra = "allow" + + class MemoryBlock(BaseModel): """A block in the Explorer agent's conversation/memory.""" @@ -52,6 +92,10 @@ class MemoryBlock(BaseModel): timestamp: str loading: bool = False artifacts: list[Artifact] = [] + file_patches: list[ExplorerFilePatch] = [] + pr_commit_shas: dict[str, str] | None = ( + None # repository name -> commit SHA. Used to track which commit was associated with each repo's PR at the time this block was created. + ) class Config: extra = "allow" @@ -76,6 +120,7 @@ class SeerRunState(BaseModel): status: Literal["processing", "completed", "error", "awaiting_user_input"] updated_at: str pending_user_input: PendingUserInput | None = None + repo_pr_states: dict[str, RepoPRState] = {} class Config: extra = "allow" @@ -110,6 +155,52 @@ def get_artifact(self, key: str, schema: type[T]) -> T | None: return None return schema.parse_obj(artifact.data) + def get_file_patches_by_repo(self) -> dict[str, list[ExplorerFilePatch]]: + """Get file patches grouped by repository.""" + by_repo: dict[str, list[ExplorerFilePatch]] = {} + for block in self.blocks: + for fp in block.file_patches: + if fp.repo_name not in by_repo: + by_repo[fp.repo_name] = [] + by_repo[fp.repo_name].append(fp) + return by_repo + + def get_pr_state(self, repo_name: str) -> RepoPRState | None: + """Get PR state for a specific repository.""" + return self.repo_pr_states.get(repo_name) + + def _is_repo_synced(self, repo_name: str) -> bool: + """Check if PR for a repo is in sync with latest changes.""" + pr_state = self.repo_pr_states.get(repo_name) + if not pr_state or not pr_state.commit_sha: + return False # No PR yet = not synced + + # Find last block with patches for this repo + for block in reversed(self.blocks): + if any(fp.repo_name == repo_name for fp in block.file_patches): + block_sha = (block.pr_commit_shas or {}).get(repo_name) + return block_sha == pr_state.commit_sha + return True # No patches found = synced + + def has_code_changes(self) -> tuple[bool, bool]: + """ + Check if there are code changes and if all have been pushed to PRs. + + Returns: + (has_changes, all_changes_pushed): + - has_changes: True if any file patches exist + - all_changes_pushed: True if the current state of changes across all repos have all been pushed to PRs. + """ + patches_by_repo = self.get_file_patches_by_repo() + has_changes = len(patches_by_repo) > 0 + + if not has_changes: + return (False, True) + + # Check if all repos with changes are synced + all_changes_pushed = all(self._is_repo_synced(repo) for repo in patches_by_repo.keys()) + return (has_changes, all_changes_pushed) + class CustomToolDefinition(BaseModel): """Definition of a custom tool to be sent to Seer.""" diff --git a/tests/sentry/seer/endpoints/test_organization_seer_explorer_chat.py b/tests/sentry/seer/endpoints/test_organization_seer_explorer_chat.py index 343186bc7b6cf3..ec565240aca76d 100644 --- a/tests/sentry/seer/endpoints/test_organization_seer_explorer_chat.py +++ b/tests/sentry/seer/endpoints/test_organization_seer_explorer_chat.py @@ -69,7 +69,9 @@ def test_post_new_conversation_calls_client(self, mock_client_class: MagicMock): assert response.data == {"run_id": 456} # Verify client was called correctly - mock_client_class.assert_called_once_with(self.organization, ANY, is_interactive=True) + mock_client_class.assert_called_once_with( + self.organization, ANY, is_interactive=True, enable_coding=True + ) mock_client.start_run.assert_called_once_with( prompt="What is this error about?", on_page_context=None ) @@ -90,7 +92,9 @@ def test_post_continue_conversation_calls_client(self, mock_client_class: MagicM assert response.data == {"run_id": 789} # Verify client was called correctly - mock_client_class.assert_called_once_with(self.organization, ANY, is_interactive=True) + mock_client_class.assert_called_once_with( + self.organization, ANY, is_interactive=True, enable_coding=True + ) mock_client.continue_run.assert_called_once_with( run_id=789, prompt="Follow up question", insert_index=2, on_page_context=None ) diff --git a/tests/sentry/seer/explorer/test_explorer_client.py b/tests/sentry/seer/explorer/test_explorer_client.py index 860c59a9091501..897b81b870cc4d 100644 --- a/tests/sentry/seer/explorer/test_explorer_client.py +++ b/tests/sentry/seer/explorer/test_explorer_client.py @@ -6,7 +6,14 @@ from pydantic import BaseModel from sentry.seer.explorer.client import SeerExplorerClient -from sentry.seer.explorer.client_models import SeerRunState +from sentry.seer.explorer.client_models import ( + ExplorerFilePatch, + FilePatch, + MemoryBlock, + Message, + RepoPRState, + SeerRunState, +) from sentry.seer.models import SeerPermissionError from sentry.testutils.cases import TestCase @@ -582,3 +589,218 @@ class RootCause(BaseModel): root_cause = result.get_artifact("root_cause", RootCause) assert root_cause is not None assert root_cause.cause == "New cause" + + +class TestSeerExplorerClientPushChanges(TestCase): + """Test push_changes method""" + + def setUp(self): + super().setUp() + self.user = self.create_user() + self.organization = self.create_organization(owner=self.user) + + @patch("sentry.seer.explorer.client.has_seer_explorer_access_with_detail") + @patch("sentry.seer.explorer.client.fetch_run_status") + @patch("sentry.seer.explorer.client.requests.post") + def test_push_changes_sends_correct_payload(self, mock_post, mock_fetch, mock_access): + """Test that push_changes sends correct payload""" + mock_access.return_value = (True, None) + mock_post.return_value = MagicMock() + mock_fetch.return_value = SeerRunState( + run_id=123, + blocks=[], + status="completed", + updated_at="2024-01-01T00:00:00Z", + repo_pr_states={ + "owner/repo": RepoPRState( + repo_name="owner/repo", + pr_creation_status="completed", + pr_url="https://github.com/owner/repo/pull/1", + ) + }, + ) + + client = SeerExplorerClient(self.organization, self.user, enable_coding=True) + result = client.push_changes(123, repo_name="owner/repo") + + body = orjson.loads(mock_post.call_args[1]["data"]) + assert body["run_id"] == 123 + assert body["payload"]["type"] == "create_pr" + assert body["payload"]["repo_name"] == "owner/repo" + assert result.repo_pr_states["owner/repo"].pr_url == "https://github.com/owner/repo/pull/1" + + @patch("sentry.seer.explorer.client.has_seer_explorer_access_with_detail") + @patch("sentry.seer.explorer.client.fetch_run_status") + @patch("sentry.seer.explorer.client.requests.post") + @patch("sentry.seer.explorer.client.time.sleep") + def test_push_changes_polls_until_complete( + self, mock_sleep, mock_post, mock_fetch, mock_access + ): + """Test that push_changes polls until PR creation completes""" + mock_access.return_value = (True, None) + mock_post.return_value = MagicMock() + + creating_state = SeerRunState( + run_id=123, + blocks=[], + status="completed", + updated_at="2024-01-01T00:00:00Z", + repo_pr_states={ + "owner/repo": RepoPRState(repo_name="owner/repo", pr_creation_status="creating") + }, + ) + completed_state = SeerRunState( + run_id=123, + blocks=[], + status="completed", + updated_at="2024-01-01T00:00:00Z", + repo_pr_states={ + "owner/repo": RepoPRState(repo_name="owner/repo", pr_creation_status="completed") + }, + ) + mock_fetch.side_effect = [creating_state, completed_state] + + client = SeerExplorerClient(self.organization, self.user, enable_coding=True) + result = client.push_changes(123) + + assert mock_fetch.call_count == 2 + assert mock_sleep.call_count == 1 + assert result.repo_pr_states["owner/repo"].pr_creation_status == "completed" + + @patch("sentry.seer.explorer.client.has_seer_explorer_access_with_detail") + @patch("sentry.seer.explorer.client.fetch_run_status") + @patch("sentry.seer.explorer.client.requests.post") + @patch("sentry.seer.explorer.client.time.time") + def test_push_changes_timeout(self, mock_time, mock_post, mock_fetch, mock_access): + """Test that push_changes raises TimeoutError after timeout""" + mock_access.return_value = (True, None) + mock_post.return_value = MagicMock() + mock_fetch.return_value = SeerRunState( + run_id=123, + blocks=[], + status="completed", + updated_at="2024-01-01T00:00:00Z", + repo_pr_states={ + "owner/repo": RepoPRState(repo_name="owner/repo", pr_creation_status="creating") + }, + ) + mock_time.side_effect = [0, 0, 200] # Exceeds 120s timeout + + client = SeerExplorerClient(self.organization, self.user, enable_coding=True) + with pytest.raises(TimeoutError, match="PR creation timed out"): + client.push_changes(123, poll_timeout=120.0) + + +class TestSeerRunStateCodeChanges(TestCase): + """Test SeerRunState helper methods for code changes""" + + def test_has_code_changes_no_patches(self): + """Test has_code_changes with no patches returns (False, True)""" + state = SeerRunState( + run_id=123, + blocks=[ + MemoryBlock( + id="block-1", + message=Message(role="assistant", content="Hello"), + timestamp="2024-01-01T00:00:00Z", + ) + ], + status="completed", + updated_at="2024-01-01T00:00:00Z", + ) + + has_changes, is_synced = state.has_code_changes() + assert has_changes is False + assert is_synced is True + + def test_has_code_changes_unsynced(self): + """Test has_code_changes with patches but no PR""" + state = SeerRunState( + run_id=123, + blocks=[ + MemoryBlock( + id="block-1", + message=Message(role="assistant", content="Fixed"), + timestamp="2024-01-01T00:00:00Z", + file_patches=[ + ExplorerFilePatch( + repo_name="owner/repo", + patch=FilePatch(path="file.py", type="M", added=10, removed=5), + ) + ], + ) + ], + status="completed", + updated_at="2024-01-01T00:00:00Z", + ) + + has_changes, is_synced = state.has_code_changes() + assert has_changes is True + assert is_synced is False + + def test_has_code_changes_synced(self): + """Test has_code_changes when changes are synced to PR""" + state = SeerRunState( + run_id=123, + blocks=[ + MemoryBlock( + id="block-1", + message=Message(role="assistant", content="Fixed"), + timestamp="2024-01-01T00:00:00Z", + file_patches=[ + ExplorerFilePatch( + repo_name="owner/repo", + patch=FilePatch(path="file.py", type="M", added=10, removed=5), + ) + ], + pr_commit_shas={"owner/repo": "abc123"}, + ) + ], + status="completed", + updated_at="2024-01-01T00:00:00Z", + repo_pr_states={ + "owner/repo": RepoPRState( + repo_name="owner/repo", + commit_sha="abc123", + pr_creation_status="completed", + ) + }, + ) + + has_changes, is_synced = state.has_code_changes() + assert has_changes is True + assert is_synced is True + + def test_get_file_patches_by_repo(self): + """Test get_file_patches_by_repo groups patches correctly""" + state = SeerRunState( + run_id=123, + blocks=[ + MemoryBlock( + id="block-1", + message=Message(role="assistant", content="Fixed"), + timestamp="2024-01-01T00:00:00Z", + file_patches=[ + ExplorerFilePatch( + repo_name="owner/repo1", + patch=FilePatch(path="file1.py", type="M", added=10, removed=5), + ), + ExplorerFilePatch( + repo_name="owner/repo2", + patch=FilePatch(path="file2.py", type="A", added=20, removed=0), + ), + ExplorerFilePatch( + repo_name="owner/repo1", + patch=FilePatch(path="file3.py", type="M", added=5, removed=2), + ), + ], + ) + ], + status="completed", + updated_at="2024-01-01T00:00:00Z", + ) + + result = state.get_file_patches_by_repo() + assert len(result) == 2 + assert len(result["owner/repo1"]) == 2 + assert len(result["owner/repo2"]) == 1