Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ See [ARCHITECTURE.md](ARCHITECTURE.md) for the deep dive (file inventory, design

| Flag | Default | Description |
|---|---|---|
| `--model` | auto-detect | LLM for SQL generation (e.g. `gpt-4o`, `claude-sonnet-4-6`, `anthropic/claude-opus-4-7`) |
| `--model` | auto-detect | LLM for SQL generation (e.g. `gpt-4o`, `claude-sonnet-4-6`, `anthropic/claude-opus-4-7`, `ollama/llama3`) |
| `--selector-model` | same as `--model` | LLM for the table-selector step. **A cheaper model is recommended** (e.g. `gpt-4o-mini`) |
| `--top-k` | 50 | TF-IDF candidates passed to the LLM selector |
| `--select` | 15 | Tables the LLM selector picks from those candidates |
Expand All @@ -158,8 +158,11 @@ See [ARCHITECTURE.md](ARCHITECTURE.md) for the deep dive (file inventory, design
|---|---|
| `OPENAI_API_KEY` | Use OpenAI as the LLM provider |
| `ANTHROPIC_API_KEY` | Use Anthropic as the LLM provider |
| `OLLAMA_BASE_URL` | Override the Ollama server URL (default `http://localhost:11434`) |

If both are set, Anthropic is preferred. Override either with `--model anthropic/<name>` or `--model openai/<name>`.
If both API keys are set, Anthropic is preferred. Override with `--model anthropic/<name>`,
`--model openai/<name>`, or `--model ollama/<name>`. Ollama uses the local OpenAI-compatible
endpoint and does not require an API key.

---

Expand Down
46 changes: 45 additions & 1 deletion src/promptquery/llm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from __future__ import annotations

import json
import os
import re
from abc import ABC, abstractmethod
from urllib import error as urlerror
from urllib import request


class LLMError(RuntimeError):
Expand Down Expand Up @@ -86,6 +89,43 @@ def generate(self, system: str, user: str) -> str:
return response.choices[0].message.content or ""


class OllamaClient(LLMClient):
name = "ollama"

def __init__(self, model: str = "llama3", base_url: str | None = None):
self.model = model
endpoint = base_url or os.environ.get("OLLAMA_BASE_URL") or "http://localhost:11434"
self.base_url = endpoint.rstrip("/")

def generate(self, system: str, user: str) -> str:
payload = json.dumps({
"model": self.model,
"messages": [
{"role": "system", "content": system},
{"role": "user", "content": user},
],
"temperature": 0,
}).encode("utf-8")
req = request.Request(
f"{self.base_url}/v1/chat/completions",
data=payload,
headers={"Content-Type": "application/json"},
method="POST",
)
try:
with request.urlopen(req, timeout=120) as response:
body = response.read()
except urlerror.URLError as e:
raise LLMError(f"ollama request failed: {e}") from e

try:
data = json.loads(body.decode("utf-8"))
content = data["choices"][0]["message"]["content"]
except (KeyError, IndexError, TypeError, json.JSONDecodeError) as e:
raise LLMError("ollama response missing choices[0].message.content") from e
return content or ""


_SQL_BLOCK_RE = re.compile(r"```(?:sql)?\s*(.*?)```", re.DOTALL | re.IGNORECASE)


Expand All @@ -105,14 +145,18 @@ def make_client(model_spec: str | None = None) -> LLMClient:
return AnthropicClient(model=provider)
if provider.startswith("gpt") or provider.startswith("o1") or provider.startswith("o3"):
return OpenAIClient(model=provider)
if provider == "ollama":
return OllamaClient()
raise LLMError(
f"Cannot infer provider from model {provider!r}. "
"Use 'anthropic/<model>' or 'openai/<model>'."
"Use 'anthropic/<model>', 'openai/<model>', or 'ollama/<model>'."
)
if provider == "anthropic":
return AnthropicClient(model=model)
if provider == "openai":
return OpenAIClient(model=model)
if provider == "ollama":
return OllamaClient(model=model)
raise LLMError(f"Unknown provider: {provider!r}")

if os.environ.get("ANTHROPIC_API_KEY"):
Expand Down
95 changes: 95 additions & 0 deletions tests/test_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from __future__ import annotations

import json
from urllib import error as urlerror

import pytest

from promptquery.llm import LLMError, OllamaClient, make_client


class _Response:
def __init__(self, payload: dict):
self._body = json.dumps(payload).encode("utf-8")

def __enter__(self):
return self

def __exit__(self, *_args):
return False

def read(self) -> bytes:
return self._body


def test_ollama_client_posts_to_openai_compatible_endpoint(monkeypatch):
seen = {}

def fake_urlopen(req, timeout):
seen["url"] = req.full_url
seen["timeout"] = timeout
seen["headers"] = dict(req.header_items())
seen["payload"] = json.loads(req.data.decode("utf-8"))
return _Response({
"choices": [
{"message": {"content": "```sql\nSELECT 1\n```"}},
],
})

monkeypatch.setattr("promptquery.llm.request.urlopen", fake_urlopen)

client = OllamaClient(model="llama3.1", base_url="http://localhost:11434/")

assert client.generate("system prompt", "user question") == "```sql\nSELECT 1\n```"
assert seen["url"] == "http://localhost:11434/v1/chat/completions"
assert seen["timeout"] == 120
assert seen["headers"]["Content-type"] == "application/json"
assert seen["payload"] == {
"model": "llama3.1",
"messages": [
{"role": "system", "content": "system prompt"},
{"role": "user", "content": "user question"},
],
"temperature": 0,
}


def test_ollama_client_uses_env_base_url(monkeypatch):
seen = {}

def fake_urlopen(req, timeout):
seen["url"] = req.full_url
return _Response({"choices": [{"message": {"content": "SELECT 1"}}]})

monkeypatch.setenv("OLLAMA_BASE_URL", "http://ollama.test:11434/")
monkeypatch.setattr("promptquery.llm.request.urlopen", fake_urlopen)

assert OllamaClient().generate("system", "user") == "SELECT 1"
assert seen["url"] == "http://ollama.test:11434/v1/chat/completions"


def test_ollama_client_wraps_http_errors(monkeypatch):
def fake_urlopen(req, timeout):
raise urlerror.URLError("connection refused")

monkeypatch.setattr("promptquery.llm.request.urlopen", fake_urlopen)

with pytest.raises(LLMError, match="ollama request failed"):
OllamaClient().generate("system", "user")


def test_ollama_client_rejects_malformed_response(monkeypatch):
def fake_urlopen(req, timeout):
return _Response({"choices": []})

monkeypatch.setattr("promptquery.llm.request.urlopen", fake_urlopen)

with pytest.raises(LLMError, match="choices"):
OllamaClient().generate("system", "user")


def test_make_client_routes_ollama_models():
client = make_client("ollama/llama3")

assert isinstance(client, OllamaClient)
assert client.model == "llama3"