diff --git a/README.md b/README.md index 7d59d92..0b8a79f 100644 --- a/README.md +++ b/README.md @@ -144,7 +144,7 @@ See [ARCHITECTURE.md](ARCHITECTURE.md) for the deep dive (file inventory, design | Flag | Default | Description | |---|---|---| -| `--model` | auto-detect | LLM for SQL generation (e.g. `gpt-4o`, `claude-sonnet-4-6`, `anthropic/claude-opus-4-7`) | +| `--model` | auto-detect | LLM for SQL generation (e.g. `gpt-4o`, `claude-sonnet-4-6`, `anthropic/claude-opus-4-7`, `ollama/llama3`) | | `--selector-model` | same as `--model` | LLM for the table-selector step. **A cheaper model is recommended** (e.g. `gpt-4o-mini`) | | `--top-k` | 50 | TF-IDF candidates passed to the LLM selector | | `--select` | 15 | Tables the LLM selector picks from those candidates | @@ -158,8 +158,11 @@ See [ARCHITECTURE.md](ARCHITECTURE.md) for the deep dive (file inventory, design |---|---| | `OPENAI_API_KEY` | Use OpenAI as the LLM provider | | `ANTHROPIC_API_KEY` | Use Anthropic as the LLM provider | +| `OLLAMA_BASE_URL` | Override the Ollama server URL (default `http://localhost:11434`) | -If both are set, Anthropic is preferred. Override either with `--model anthropic/` or `--model openai/`. +If both API keys are set, Anthropic is preferred. Override with `--model anthropic/`, +`--model openai/`, or `--model ollama/`. Ollama uses the local OpenAI-compatible +endpoint and does not require an API key. --- diff --git a/src/promptquery/llm.py b/src/promptquery/llm.py index de30afa..87a5e14 100644 --- a/src/promptquery/llm.py +++ b/src/promptquery/llm.py @@ -1,8 +1,11 @@ from __future__ import annotations +import json import os import re from abc import ABC, abstractmethod +from urllib import error as urlerror +from urllib import request class LLMError(RuntimeError): @@ -86,6 +89,43 @@ def generate(self, system: str, user: str) -> str: return response.choices[0].message.content or "" +class OllamaClient(LLMClient): + name = "ollama" + + def __init__(self, model: str = "llama3", base_url: str | None = None): + self.model = model + endpoint = base_url or os.environ.get("OLLAMA_BASE_URL") or "http://localhost:11434" + self.base_url = endpoint.rstrip("/") + + def generate(self, system: str, user: str) -> str: + payload = json.dumps({ + "model": self.model, + "messages": [ + {"role": "system", "content": system}, + {"role": "user", "content": user}, + ], + "temperature": 0, + }).encode("utf-8") + req = request.Request( + f"{self.base_url}/v1/chat/completions", + data=payload, + headers={"Content-Type": "application/json"}, + method="POST", + ) + try: + with request.urlopen(req, timeout=120) as response: + body = response.read() + except urlerror.URLError as e: + raise LLMError(f"ollama request failed: {e}") from e + + try: + data = json.loads(body.decode("utf-8")) + content = data["choices"][0]["message"]["content"] + except (KeyError, IndexError, TypeError, json.JSONDecodeError) as e: + raise LLMError("ollama response missing choices[0].message.content") from e + return content or "" + + _SQL_BLOCK_RE = re.compile(r"```(?:sql)?\s*(.*?)```", re.DOTALL | re.IGNORECASE) @@ -105,14 +145,18 @@ def make_client(model_spec: str | None = None) -> LLMClient: return AnthropicClient(model=provider) if provider.startswith("gpt") or provider.startswith("o1") or provider.startswith("o3"): return OpenAIClient(model=provider) + if provider == "ollama": + return OllamaClient() raise LLMError( f"Cannot infer provider from model {provider!r}. " - "Use 'anthropic/' or 'openai/'." + "Use 'anthropic/', 'openai/', or 'ollama/'." ) if provider == "anthropic": return AnthropicClient(model=model) if provider == "openai": return OpenAIClient(model=model) + if provider == "ollama": + return OllamaClient(model=model) raise LLMError(f"Unknown provider: {provider!r}") if os.environ.get("ANTHROPIC_API_KEY"): diff --git a/tests/test_llm.py b/tests/test_llm.py new file mode 100644 index 0000000..ef96561 --- /dev/null +++ b/tests/test_llm.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +import json +from urllib import error as urlerror + +import pytest + +from promptquery.llm import LLMError, OllamaClient, make_client + + +class _Response: + def __init__(self, payload: dict): + self._body = json.dumps(payload).encode("utf-8") + + def __enter__(self): + return self + + def __exit__(self, *_args): + return False + + def read(self) -> bytes: + return self._body + + +def test_ollama_client_posts_to_openai_compatible_endpoint(monkeypatch): + seen = {} + + def fake_urlopen(req, timeout): + seen["url"] = req.full_url + seen["timeout"] = timeout + seen["headers"] = dict(req.header_items()) + seen["payload"] = json.loads(req.data.decode("utf-8")) + return _Response({ + "choices": [ + {"message": {"content": "```sql\nSELECT 1\n```"}}, + ], + }) + + monkeypatch.setattr("promptquery.llm.request.urlopen", fake_urlopen) + + client = OllamaClient(model="llama3.1", base_url="http://localhost:11434/") + + assert client.generate("system prompt", "user question") == "```sql\nSELECT 1\n```" + assert seen["url"] == "http://localhost:11434/v1/chat/completions" + assert seen["timeout"] == 120 + assert seen["headers"]["Content-type"] == "application/json" + assert seen["payload"] == { + "model": "llama3.1", + "messages": [ + {"role": "system", "content": "system prompt"}, + {"role": "user", "content": "user question"}, + ], + "temperature": 0, + } + + +def test_ollama_client_uses_env_base_url(monkeypatch): + seen = {} + + def fake_urlopen(req, timeout): + seen["url"] = req.full_url + return _Response({"choices": [{"message": {"content": "SELECT 1"}}]}) + + monkeypatch.setenv("OLLAMA_BASE_URL", "http://ollama.test:11434/") + monkeypatch.setattr("promptquery.llm.request.urlopen", fake_urlopen) + + assert OllamaClient().generate("system", "user") == "SELECT 1" + assert seen["url"] == "http://ollama.test:11434/v1/chat/completions" + + +def test_ollama_client_wraps_http_errors(monkeypatch): + def fake_urlopen(req, timeout): + raise urlerror.URLError("connection refused") + + monkeypatch.setattr("promptquery.llm.request.urlopen", fake_urlopen) + + with pytest.raises(LLMError, match="ollama request failed"): + OllamaClient().generate("system", "user") + + +def test_ollama_client_rejects_malformed_response(monkeypatch): + def fake_urlopen(req, timeout): + return _Response({"choices": []}) + + monkeypatch.setattr("promptquery.llm.request.urlopen", fake_urlopen) + + with pytest.raises(LLMError, match="choices"): + OllamaClient().generate("system", "user") + + +def test_make_client_routes_ollama_models(): + client = make_client("ollama/llama3") + + assert isinstance(client, OllamaClient) + assert client.model == "llama3"