Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions sqlit/domains/query/store/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,20 @@ class QueryHistoryEntry:
query: str
timestamp: str # ISO format
connection_name: str
database: str = "" # Active database when query was executed
is_starred: bool = False # Computed at load time, not persisted
is_starred_only: bool = False # True if only in starred store, not in history

def to_dict(self) -> dict:
"""Convert to dictionary for JSON serialization."""
return {
d: dict = {
"query": self.query,
"timestamp": self.timestamp,
"connection_name": self.connection_name,
}
if self.database:
d["database"] = self.database
return d

@classmethod
def from_dict(cls, data: dict) -> QueryHistoryEntry:
Expand All @@ -33,6 +37,7 @@ def from_dict(cls, data: dict) -> QueryHistoryEntry:
query=data["query"],
timestamp=data["timestamp"],
connection_name=data["connection_name"],
database=data.get("database", ""),
)


Expand Down Expand Up @@ -88,7 +93,7 @@ def load_all(self) -> list[QueryHistoryEntry]:
except (KeyError, TypeError):
return []

def save_query(self, connection_name: str, query: str) -> None:
def save_query(self, connection_name: str, query: str, database: str = "") -> None:
"""Save a query to history.

If the exact query already exists for this connection, updates its timestamp.
Expand All @@ -97,6 +102,7 @@ def save_query(self, connection_name: str, query: str) -> None:
Args:
connection_name: Name of the connection.
query: SQL query text.
database: Active database when the query was executed.
"""
all_entries = self._load_all_entries()
query_stripped = query.strip()
Expand All @@ -106,13 +112,16 @@ def save_query(self, connection_name: str, query: str) -> None:
for entry in all_entries:
if entry.get("connection_name") == connection_name and entry.get("query", "").strip() == query_stripped:
entry["timestamp"] = now
if database:
entry["database"] = database
break
else:
# Add new entry
new_entry = QueryHistoryEntry(
query=query_stripped,
timestamp=now,
connection_name=connection_name,
database=database,
)
all_entries.append(new_entry.to_dict())

Expand Down
19 changes: 11 additions & 8 deletions sqlit/domains/query/store/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,26 @@ def load_for_connection(self, connection_name: str) -> list[QueryHistoryEntry]:
def load_all(self) -> list[QueryHistoryEntry]:
return [QueryHistoryEntry.from_dict(entry) for entry in self._entries]

def save_query(self, connection_name: str, query: str) -> None:
def save_query(self, connection_name: str, query: str, database: str = "") -> None:
query_stripped = query.strip()
now = datetime.now().isoformat()

# Check if query already exists
for entry in self._entries:
if entry.get("connection_name") == connection_name and entry.get("query", "").strip() == query_stripped:
entry["timestamp"] = now
if database:
entry["database"] = database
break
else:
self._entries.append(
{
"query": query_stripped,
"timestamp": now,
"connection_name": connection_name,
}
)
d: dict[str, Any] = {
"query": query_stripped,
"timestamp": now,
"connection_name": connection_name,
}
if database:
d["database"] = database
self._entries.append(d)

def delete_entry(self, connection_name: str, timestamp: str) -> bool:
return False
Expand Down
45 changes: 40 additions & 5 deletions sqlit/domains/query/ui/mixins/query_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,13 @@ def _should_save_query_history(self: QueryMixinHost, config: Any) -> bool:

def _save_query_history(self: QueryMixinHost, config: Any, query: str) -> None:
"""Save query history only for saved connections."""
database = getattr(self, "_active_database", None) or ""
if not database:
endpoint = getattr(config, "tcp_endpoint", None)
if endpoint:
database = getattr(endpoint, "database", "") or ""
if self._should_save_query_history(config):
self._get_history_store().save_query(config.name, query)
self._get_history_store().save_query(config.name, query, database=database)
return
self._get_unsaved_history_store().save_query(config.name, query)

Expand Down Expand Up @@ -376,10 +381,14 @@ def _maybe_run_pending_telescope_query(self: QueryMixinHost) -> None:
or self.current_config is None
):
return
connection_name, query = pending
connection_name = pending[0]
query = pending[1]
database = pending[2] if len(pending) > 2 else ""
if self.current_config.name != connection_name:
return
self._pending_telescope_query = None
if database:
self._active_database = database
self._apply_history_query(query)

@property
Expand Down Expand Up @@ -842,7 +851,8 @@ def _handle_telescope_result(self: QueryMixinHost, result: Any) -> None:
if action == "select":
query = data.get("query", "")
connection_name = data.get("connection_name", "")
self._run_telescope_query(connection_name, query)
database = data.get("database", "")
self._run_telescope_query(connection_name, query, database=database)
elif action == "delete":
timestamp = data.get("timestamp", "")
connection_name = data.get("connection_name", "")
Expand All @@ -866,7 +876,9 @@ def _handle_telescope_result(self: QueryMixinHost, result: Any) -> None:
else:
self.action_telescope()

def _run_telescope_query(self: QueryMixinHost, connection_name: str, query: str) -> None:
def _run_telescope_query(
self: QueryMixinHost, connection_name: str, query: str, *, database: str = ""
) -> None:
if not query or not connection_name:
return

Expand All @@ -875,17 +887,27 @@ def _run_telescope_query(self: QueryMixinHost, connection_name: str, query: str)
self.notify(f"Connection '{connection_name}' not found", severity="warning")
return

# Resolve the active database from when the query was originally executed.
# For old history entries without a database field, try to find one
# from other entries for the same connection.
if not database:
database = self._infer_database_from_history(connection_name)

self._apply_history_query(query)

if (
self.current_connection is not None
and self.current_config is not None
and self.current_config.name == connection_name
):
if database:
self._active_database = database
self._pending_telescope_query = None
return

self._pending_telescope_query = None
# Store database in the pending tuple so _maybe_run_pending_telescope_query
# can restore it AFTER the connection callback resets _active_database.
self._pending_telescope_query = (connection_name, query, database)
self._connect_like_explorer(connection_name, config)

def _get_telescope_connection_map(self: QueryMixinHost) -> dict[str, Any]:
Expand Down Expand Up @@ -926,6 +948,19 @@ def _connect_like_explorer(self: QueryMixinHost, connection_name: str, config: A

self.connect_to_server(config)

def _infer_database_from_history(self: QueryMixinHost, connection_name: str) -> str:
"""Try to find the most recently used database for a connection from history."""
try:
history_store = self._get_history_store()
entries = history_store.load_for_connection(connection_name)
for entry in entries:
db = getattr(entry, "database", "") or ""
if db:
return db
except Exception:
pass
return ""

def _format_telescope_connection_label(self: QueryMixinHost, config: Any) -> str:
endpoint = getattr(config, "tcp_endpoint", None)
if endpoint is None:
Expand Down
4 changes: 3 additions & 1 deletion sqlit/domains/query/ui/screens/query_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,9 @@ def _build_action_result(self, action: str, entry: QueryHistoryEntry) -> tuple[s
if action == "delete":
payload = {"timestamp": entry.timestamp, "connection_name": entry.connection_name}
else:
payload = {"query": entry.query, "connection_name": entry.connection_name}
payload: dict[str, str] = {"query": entry.query, "connection_name": entry.connection_name}
if getattr(entry, "database", ""):
payload["database"] = entry.database
return action, payload
if action == "delete":
return action, entry.timestamp
Expand Down
3 changes: 2 additions & 1 deletion sqlit/shared/core/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,13 @@ class HistoryStoreProtocol(Protocol):
query history.
"""

def save_query(self, connection_name: str, query: str) -> None:
def save_query(self, connection_name: str, query: str, database: str = "") -> None:
"""Save a query to history.

Args:
connection_name: Name of the connection.
query: The SQL query string.
database: Active database when the query was executed.
"""
...

Expand Down
7 changes: 5 additions & 2 deletions tests/ui/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,13 @@ def __init__(self):
def load_for_connection(self, connection_name: str) -> list:
return self.entries.get(connection_name, [])

def save_query(self, connection_name: str, query: str) -> None:
def save_query(self, connection_name: str, query: str, database: str = "") -> None:
if connection_name not in self.entries:
self.entries[connection_name] = []
self.entries[connection_name].append({"query": query})
entry: dict = {"query": query}
if database:
entry["database"] = database
self.entries[connection_name].append(entry)

def delete_entry(self, connection_name: str, timestamp: str) -> bool:
return False
Expand Down
102 changes: 102 additions & 0 deletions tests/ui/test_telescope_fresh_start.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""UI test for telescope on fresh start (no active connection).

Regression tests:
1. _pending_telescope_query must be set before connecting so the
post-connection callback can pick up the query.
2. When the connection config has no default database but the history
query was run against a specific database, telescope should try to
extract and apply the database context from the query.
"""

from __future__ import annotations

import pytest

from sqlit.domains.query.store.history import QueryHistoryEntry
from sqlit.domains.shell.app.main import SSMSTUI

from .mocks import (
MockConnectionStore,
MockSettingsStore,
build_test_services,
create_test_connection,
)


class StubHistoryStore:
def __init__(self, entries):
self._entries = list(entries)

def load_all(self):
return list(self._entries)

def load_for_connection(self, cn):
return [e for e in self._entries if e.connection_name == cn]

def delete_entry(self, cn, ts):
return False

def save_query(self, cn, q):
pass


def _make_app():
saved_conn = create_test_connection("my-server", "sqlite")
entry = QueryHistoryEntry(
query="SELECT * FROM users",
timestamp="2026-01-01T00:00:00",
connection_name="my-server",
)
services = build_test_services(
connection_store=MockConnectionStore([saved_conn]),
settings_store=MockSettingsStore({"theme": "tokyo-night"}),
history_store=StubHistoryStore([entry]),
)
return SSMSTUI(services=services), saved_conn


class TestTelescopePendingQuery:
"""_pending_telescope_query must be set before connecting."""

@pytest.mark.asyncio
async def test_pending_query_set_before_connecting(self) -> None:
app, saved_conn = _make_app()

async with app.run_test(size=(120, 40)) as pilot:
app.connections = [saved_conn]
await pilot.pause()

assert app.current_connection is None

app._run_telescope_query("my-server", "SELECT * FROM users")

pending = getattr(app, "_pending_telescope_query", None)
assert pending is not None, (
"_pending_telescope_query should be set before connecting, "
"but it was None (cleared prematurely)"
)
assert pending[0] == "my-server"
assert pending[1] == "SELECT * FROM users"

@pytest.mark.asyncio
async def test_no_error_after_telescope_select(self) -> None:
app, saved_conn = _make_app()

async with app.run_test(size=(120, 40)) as pilot:
app.connections = [saved_conn]
await pilot.pause()

assert app.current_connection is None

app._handle_telescope_result((
"select",
{"query": "SELECT * FROM users", "connection_name": "my-server"},
))
await pilot.pause(1.0)

from sqlit.shared.ui.screens.error import ErrorScreen
error_screens = [s for s in app.screen_stack if isinstance(s, ErrorScreen)]
assert not error_screens, "ErrorScreen should not appear after telescope select"

assert app.current_connection is not None
assert app.query_input.text == "SELECT * FROM users"
Loading