diff --git a/sqlit/domains/query/store/history.py b/sqlit/domains/query/store/history.py index fce77328..06f5ca47 100644 --- a/sqlit/domains/query/store/history.py +++ b/sqlit/domains/query/store/history.py @@ -15,16 +15,20 @@ class QueryHistoryEntry: query: str timestamp: str # ISO format connection_name: str + database: str = "" # Active database when query was executed is_starred: bool = False # Computed at load time, not persisted is_starred_only: bool = False # True if only in starred store, not in history def to_dict(self) -> dict: """Convert to dictionary for JSON serialization.""" - return { + d: dict = { "query": self.query, "timestamp": self.timestamp, "connection_name": self.connection_name, } + if self.database: + d["database"] = self.database + return d @classmethod def from_dict(cls, data: dict) -> QueryHistoryEntry: @@ -33,6 +37,7 @@ def from_dict(cls, data: dict) -> QueryHistoryEntry: query=data["query"], timestamp=data["timestamp"], connection_name=data["connection_name"], + database=data.get("database", ""), ) @@ -88,7 +93,7 @@ def load_all(self) -> list[QueryHistoryEntry]: except (KeyError, TypeError): return [] - def save_query(self, connection_name: str, query: str) -> None: + def save_query(self, connection_name: str, query: str, database: str = "") -> None: """Save a query to history. If the exact query already exists for this connection, updates its timestamp. @@ -97,6 +102,7 @@ def save_query(self, connection_name: str, query: str) -> None: Args: connection_name: Name of the connection. query: SQL query text. + database: Active database when the query was executed. """ all_entries = self._load_all_entries() query_stripped = query.strip() @@ -106,6 +112,8 @@ def save_query(self, connection_name: str, query: str) -> None: for entry in all_entries: if entry.get("connection_name") == connection_name and entry.get("query", "").strip() == query_stripped: entry["timestamp"] = now + if database: + entry["database"] = database break else: # Add new entry @@ -113,6 +121,7 @@ def save_query(self, connection_name: str, query: str) -> None: query=query_stripped, timestamp=now, connection_name=connection_name, + database=database, ) all_entries.append(new_entry.to_dict()) diff --git a/sqlit/domains/query/store/memory.py b/sqlit/domains/query/store/memory.py index e482c8c3..99fdef48 100644 --- a/sqlit/domains/query/store/memory.py +++ b/sqlit/domains/query/store/memory.py @@ -39,7 +39,7 @@ def load_for_connection(self, connection_name: str) -> list[QueryHistoryEntry]: def load_all(self) -> list[QueryHistoryEntry]: return [QueryHistoryEntry.from_dict(entry) for entry in self._entries] - def save_query(self, connection_name: str, query: str) -> None: + def save_query(self, connection_name: str, query: str, database: str = "") -> None: query_stripped = query.strip() now = datetime.now().isoformat() @@ -47,15 +47,18 @@ def save_query(self, connection_name: str, query: str) -> None: for entry in self._entries: if entry.get("connection_name") == connection_name and entry.get("query", "").strip() == query_stripped: entry["timestamp"] = now + if database: + entry["database"] = database break else: - self._entries.append( - { - "query": query_stripped, - "timestamp": now, - "connection_name": connection_name, - } - ) + d: dict[str, Any] = { + "query": query_stripped, + "timestamp": now, + "connection_name": connection_name, + } + if database: + d["database"] = database + self._entries.append(d) def delete_entry(self, connection_name: str, timestamp: str) -> bool: return False diff --git a/sqlit/domains/query/ui/mixins/query_execution.py b/sqlit/domains/query/ui/mixins/query_execution.py index 5a9e793e..08ba29fe 100644 --- a/sqlit/domains/query/ui/mixins/query_execution.py +++ b/sqlit/domains/query/ui/mixins/query_execution.py @@ -298,8 +298,13 @@ def _should_save_query_history(self: QueryMixinHost, config: Any) -> bool: def _save_query_history(self: QueryMixinHost, config: Any, query: str) -> None: """Save query history only for saved connections.""" + database = getattr(self, "_active_database", None) or "" + if not database: + endpoint = getattr(config, "tcp_endpoint", None) + if endpoint: + database = getattr(endpoint, "database", "") or "" if self._should_save_query_history(config): - self._get_history_store().save_query(config.name, query) + self._get_history_store().save_query(config.name, query, database=database) return self._get_unsaved_history_store().save_query(config.name, query) @@ -376,10 +381,14 @@ def _maybe_run_pending_telescope_query(self: QueryMixinHost) -> None: or self.current_config is None ): return - connection_name, query = pending + connection_name = pending[0] + query = pending[1] + database = pending[2] if len(pending) > 2 else "" if self.current_config.name != connection_name: return self._pending_telescope_query = None + if database: + self._active_database = database self._apply_history_query(query) @property @@ -842,7 +851,8 @@ def _handle_telescope_result(self: QueryMixinHost, result: Any) -> None: if action == "select": query = data.get("query", "") connection_name = data.get("connection_name", "") - self._run_telescope_query(connection_name, query) + database = data.get("database", "") + self._run_telescope_query(connection_name, query, database=database) elif action == "delete": timestamp = data.get("timestamp", "") connection_name = data.get("connection_name", "") @@ -866,7 +876,9 @@ def _handle_telescope_result(self: QueryMixinHost, result: Any) -> None: else: self.action_telescope() - def _run_telescope_query(self: QueryMixinHost, connection_name: str, query: str) -> None: + def _run_telescope_query( + self: QueryMixinHost, connection_name: str, query: str, *, database: str = "" + ) -> None: if not query or not connection_name: return @@ -875,6 +887,12 @@ def _run_telescope_query(self: QueryMixinHost, connection_name: str, query: str) self.notify(f"Connection '{connection_name}' not found", severity="warning") return + # Resolve the active database from when the query was originally executed. + # For old history entries without a database field, try to find one + # from other entries for the same connection. + if not database: + database = self._infer_database_from_history(connection_name) + self._apply_history_query(query) if ( @@ -882,10 +900,14 @@ def _run_telescope_query(self: QueryMixinHost, connection_name: str, query: str) and self.current_config is not None and self.current_config.name == connection_name ): + if database: + self._active_database = database self._pending_telescope_query = None return - self._pending_telescope_query = None + # Store database in the pending tuple so _maybe_run_pending_telescope_query + # can restore it AFTER the connection callback resets _active_database. + self._pending_telescope_query = (connection_name, query, database) self._connect_like_explorer(connection_name, config) def _get_telescope_connection_map(self: QueryMixinHost) -> dict[str, Any]: @@ -926,6 +948,19 @@ def _connect_like_explorer(self: QueryMixinHost, connection_name: str, config: A self.connect_to_server(config) + def _infer_database_from_history(self: QueryMixinHost, connection_name: str) -> str: + """Try to find the most recently used database for a connection from history.""" + try: + history_store = self._get_history_store() + entries = history_store.load_for_connection(connection_name) + for entry in entries: + db = getattr(entry, "database", "") or "" + if db: + return db + except Exception: + pass + return "" + def _format_telescope_connection_label(self: QueryMixinHost, config: Any) -> str: endpoint = getattr(config, "tcp_endpoint", None) if endpoint is None: diff --git a/sqlit/domains/query/ui/screens/query_history.py b/sqlit/domains/query/ui/screens/query_history.py index 2176eaba..01a31e2f 100644 --- a/sqlit/domains/query/ui/screens/query_history.py +++ b/sqlit/domains/query/ui/screens/query_history.py @@ -460,7 +460,9 @@ def _build_action_result(self, action: str, entry: QueryHistoryEntry) -> tuple[s if action == "delete": payload = {"timestamp": entry.timestamp, "connection_name": entry.connection_name} else: - payload = {"query": entry.query, "connection_name": entry.connection_name} + payload: dict[str, str] = {"query": entry.query, "connection_name": entry.connection_name} + if getattr(entry, "database", ""): + payload["database"] = entry.database return action, payload if action == "delete": return action, entry.timestamp diff --git a/sqlit/shared/core/protocols.py b/sqlit/shared/core/protocols.py index fd99dc4e..2505382e 100644 --- a/sqlit/shared/core/protocols.py +++ b/sqlit/shared/core/protocols.py @@ -40,12 +40,13 @@ class HistoryStoreProtocol(Protocol): query history. """ - def save_query(self, connection_name: str, query: str) -> None: + def save_query(self, connection_name: str, query: str, database: str = "") -> None: """Save a query to history. Args: connection_name: Name of the connection. query: The SQL query string. + database: Active database when the query was executed. """ ... diff --git a/tests/ui/mocks.py b/tests/ui/mocks.py index 8c294052..53f6cc7f 100644 --- a/tests/ui/mocks.py +++ b/tests/ui/mocks.py @@ -69,10 +69,13 @@ def __init__(self): def load_for_connection(self, connection_name: str) -> list: return self.entries.get(connection_name, []) - def save_query(self, connection_name: str, query: str) -> None: + def save_query(self, connection_name: str, query: str, database: str = "") -> None: if connection_name not in self.entries: self.entries[connection_name] = [] - self.entries[connection_name].append({"query": query}) + entry: dict = {"query": query} + if database: + entry["database"] = database + self.entries[connection_name].append(entry) def delete_entry(self, connection_name: str, timestamp: str) -> bool: return False diff --git a/tests/ui/test_telescope_fresh_start.py b/tests/ui/test_telescope_fresh_start.py new file mode 100644 index 00000000..3e7d109d --- /dev/null +++ b/tests/ui/test_telescope_fresh_start.py @@ -0,0 +1,102 @@ +"""UI test for telescope on fresh start (no active connection). + +Regression tests: +1. _pending_telescope_query must be set before connecting so the + post-connection callback can pick up the query. +2. When the connection config has no default database but the history + query was run against a specific database, telescope should try to + extract and apply the database context from the query. +""" + +from __future__ import annotations + +import pytest + +from sqlit.domains.query.store.history import QueryHistoryEntry +from sqlit.domains.shell.app.main import SSMSTUI + +from .mocks import ( + MockConnectionStore, + MockSettingsStore, + build_test_services, + create_test_connection, +) + + +class StubHistoryStore: + def __init__(self, entries): + self._entries = list(entries) + + def load_all(self): + return list(self._entries) + + def load_for_connection(self, cn): + return [e for e in self._entries if e.connection_name == cn] + + def delete_entry(self, cn, ts): + return False + + def save_query(self, cn, q): + pass + + +def _make_app(): + saved_conn = create_test_connection("my-server", "sqlite") + entry = QueryHistoryEntry( + query="SELECT * FROM users", + timestamp="2026-01-01T00:00:00", + connection_name="my-server", + ) + services = build_test_services( + connection_store=MockConnectionStore([saved_conn]), + settings_store=MockSettingsStore({"theme": "tokyo-night"}), + history_store=StubHistoryStore([entry]), + ) + return SSMSTUI(services=services), saved_conn + + +class TestTelescopePendingQuery: + """_pending_telescope_query must be set before connecting.""" + + @pytest.mark.asyncio + async def test_pending_query_set_before_connecting(self) -> None: + app, saved_conn = _make_app() + + async with app.run_test(size=(120, 40)) as pilot: + app.connections = [saved_conn] + await pilot.pause() + + assert app.current_connection is None + + app._run_telescope_query("my-server", "SELECT * FROM users") + + pending = getattr(app, "_pending_telescope_query", None) + assert pending is not None, ( + "_pending_telescope_query should be set before connecting, " + "but it was None (cleared prematurely)" + ) + assert pending[0] == "my-server" + assert pending[1] == "SELECT * FROM users" + + @pytest.mark.asyncio + async def test_no_error_after_telescope_select(self) -> None: + app, saved_conn = _make_app() + + async with app.run_test(size=(120, 40)) as pilot: + app.connections = [saved_conn] + await pilot.pause() + + assert app.current_connection is None + + app._handle_telescope_result(( + "select", + {"query": "SELECT * FROM users", "connection_name": "my-server"}, + )) + await pilot.pause(1.0) + + from sqlit.shared.ui.screens.error import ErrorScreen + error_screens = [s for s in app.screen_stack if isinstance(s, ErrorScreen)] + assert not error_screens, "ErrorScreen should not appear after telescope select" + + assert app.current_connection is not None + assert app.query_input.text == "SELECT * FROM users"