Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 68 additions & 16 deletions src/obsidian_rag/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,13 @@
help="Ollama API URL (only used with --provider ollama)")
@click.option("--lmstudio-url", default=None,
help="LM Studio API URL (only used with --provider lmstudio)")
@click.option("--ollama-api-key", default=None,
help="Bearer token for Ollama (only used with --provider ollama)")
@click.option("--lmstudio-api-key", default=None,
help="Bearer token for LM Studio (only used with --provider lmstudio)")
@click.option("--model", default=None, help="Override embedding model name")
@click.pass_context
def main(ctx, vault, data, provider, ollama_url, lmstudio_url, model):
def main(ctx, vault, data, provider, ollama_url, lmstudio_url, ollama_api_key, lmstudio_api_key, model):
"""Obsidian RAG - Semantic search for your Obsidian vault."""
ctx.ensure_object(dict)

Expand All @@ -51,6 +55,8 @@ def main(ctx, vault, data, provider, ollama_url, lmstudio_url, model):
ctx.obj["provider"] = provider or config.provider
ctx.obj["ollama_url"] = ollama_url or config.ollama_url
ctx.obj["lmstudio_url"] = lmstudio_url or config.lmstudio_url
ctx.obj["ollama_api_key"] = ollama_api_key or config.get_ollama_api_key()
ctx.obj["lmstudio_api_key"] = lmstudio_api_key or config.get_lmstudio_api_key()
ctx.obj["model"] = model # None means use provider default
ctx.obj["config"] = config

Expand Down Expand Up @@ -105,16 +111,20 @@ def setup():
)
config.ollama_url = ollama_url

# Optional Bearer token
if click.confirm("\nDoes your Ollama instance require a Bearer token?", default=False):
config.ollama_api_key = click.prompt("Enter Bearer token", hide_input=False)

# Verify connection and get available models
click.echo("Checking Ollama server...", nl=False)
server_running = is_ollama_running(ollama_url)
server_running = is_ollama_running(ollama_url, api_key=config.ollama_api_key)

if server_running:
click.echo(" ✓ connected")

# Fetch available embedding models
click.echo("Fetching available embedding models...", nl=False)
available_models = get_ollama_models(ollama_url)
available_models = get_ollama_models(ollama_url, api_key=config.ollama_api_key)

if available_models:
click.echo(f" found {len(available_models)}")
Expand Down Expand Up @@ -149,16 +159,20 @@ def setup():
)
config.lmstudio_url = lmstudio_url

# Optional Bearer token
if click.confirm("\nDoes your LM Studio instance require a Bearer token?", default=False):
config.lmstudio_api_key = click.prompt("Enter Bearer token", hide_input=False)

# Verify connection and get available models
click.echo("Checking LM Studio server...", nl=False)
server_running = is_lmstudio_running(lmstudio_url)
server_running = is_lmstudio_running(lmstudio_url, api_key=config.lmstudio_api_key)

if server_running:
click.echo(" ✓ connected")

# Fetch available embedding models
click.echo("Fetching available embedding models...", nl=False)
available_models = get_lmstudio_models(lmstudio_url)
available_models = get_lmstudio_models(lmstudio_url, api_key=config.lmstudio_api_key)

if available_models:
click.echo(f" found {len(available_models)}")
Expand Down Expand Up @@ -226,13 +240,15 @@ def setup():
embedder = create_embedder(
provider="ollama",
model=config.ollama_model,
base_url=config.ollama_url
base_url=config.ollama_url,
api_key=config.get_ollama_api_key(),
)
else: # lmstudio
embedder = create_embedder(
provider="lmstudio",
model=config.lmstudio_model,
base_url=config.lmstudio_url
base_url=config.lmstudio_url,
api_key=config.get_lmstudio_api_key(),
)

store = VectorStore(data_path=config.get_data_path())
Expand Down Expand Up @@ -291,7 +307,9 @@ def setup():
config.data_path or str(get_data_dir()),
config.provider,
config.ollama_url,
None # model
None, # model
config.get_ollama_api_key(),
config.get_lmstudio_api_key(),
)
plist_path.write_text(plist_content)

Expand Down Expand Up @@ -327,6 +345,8 @@ def index(ctx, clear, path_filter):
provider = ctx.obj["provider"]
ollama_url = ctx.obj["ollama_url"]
lmstudio_url = ctx.obj["lmstudio_url"]
ollama_api_key = ctx.obj["ollama_api_key"]
lmstudio_api_key = ctx.obj["lmstudio_api_key"]
config = ctx.obj["config"]

# Get model from CLI override or config file based on provider
Expand All @@ -344,16 +364,19 @@ def index(ctx, clear, path_filter):
click.echo(f"Provider: {provider}")
click.echo(f"Model: {model}")

# Determine the correct base_url based on provider
# Determine the correct base_url and api_key based on provider
if provider == "ollama":
base_url = ollama_url
api_key = ollama_api_key
elif provider == "lmstudio":
base_url = lmstudio_url
api_key = lmstudio_api_key
else:
base_url = None
api_key = None

# Initialize components
embedder = create_embedder(provider=provider, model=model, base_url=base_url)
embedder = create_embedder(provider=provider, model=model, base_url=base_url, api_key=api_key)
store = VectorStore(data_path=data_path)
indexer = VaultIndexer(vault_path=vault_path, embedder=embedder, config=config.indexer)

Expand Down Expand Up @@ -413,6 +436,8 @@ def search(ctx, query, limit, note_type):
provider = ctx.obj["provider"]
ollama_url = ctx.obj["ollama_url"]
lmstudio_url = ctx.obj["lmstudio_url"]
ollama_api_key = ctx.obj["ollama_api_key"]
lmstudio_api_key = ctx.obj["lmstudio_api_key"]
config = ctx.obj["config"]

# Get model from CLI override or config file based on provider
Expand All @@ -425,16 +450,19 @@ def search(ctx, query, limit, note_type):
elif provider == "lmstudio":
model = config.lmstudio_model

# Determine the correct base_url based on provider
# Determine the correct base_url and api_key based on provider
if provider == "ollama":
base_url = ollama_url
api_key = ollama_api_key
elif provider == "lmstudio":
base_url = lmstudio_url
api_key = lmstudio_api_key
else:
base_url = None
api_key = None

# Initialize components
embedder = create_embedder(provider=provider, model=model, base_url=base_url)
embedder = create_embedder(provider=provider, model=model, base_url=base_url, api_key=api_key)
store = VectorStore(data_path=data_path)

# Generate query embedding
Expand Down Expand Up @@ -486,6 +514,8 @@ def similar(ctx, note_path, limit):
provider = ctx.obj["provider"]
ollama_url = ctx.obj["ollama_url"]
lmstudio_url = ctx.obj["lmstudio_url"]
ollama_api_key = ctx.obj["ollama_api_key"]
lmstudio_api_key = ctx.obj["lmstudio_api_key"]
config = ctx.obj["config"]

model = ctx.obj["model"]
Expand All @@ -499,12 +529,15 @@ def similar(ctx, note_path, limit):

if provider == "ollama":
base_url = ollama_url
api_key = ollama_api_key
elif provider == "lmstudio":
base_url = lmstudio_url
api_key = lmstudio_api_key
else:
base_url = None
api_key = None

embedder = create_embedder(provider=provider, model=model, base_url=base_url)
embedder = create_embedder(provider=provider, model=model, base_url=base_url, api_key=api_key)
store = VectorStore(data_path=data_path)

click.echo(f"Finding notes similar to: {note_path}\n")
Expand Down Expand Up @@ -552,6 +585,8 @@ def context(ctx, note_path, limit):
provider = ctx.obj["provider"]
ollama_url = ctx.obj["ollama_url"]
lmstudio_url = ctx.obj["lmstudio_url"]
ollama_api_key = ctx.obj["ollama_api_key"]
lmstudio_api_key = ctx.obj["lmstudio_api_key"]
config = ctx.obj["config"]

model = ctx.obj["model"]
Expand All @@ -565,12 +600,15 @@ def context(ctx, note_path, limit):

if provider == "ollama":
base_url = ollama_url
api_key = ollama_api_key
elif provider == "lmstudio":
base_url = lmstudio_url
api_key = lmstudio_api_key
else:
base_url = None
api_key = None

embedder = create_embedder(provider=provider, model=model, base_url=base_url)
embedder = create_embedder(provider=provider, model=model, base_url=base_url, api_key=api_key)
store = VectorStore(data_path=data_path)

click.echo(f"Getting context for: {note_path}\n")
Expand Down Expand Up @@ -634,6 +672,8 @@ def watch(ctx, debounce):
provider = ctx.obj["provider"]
ollama_url = ctx.obj["ollama_url"]
lmstudio_url = ctx.obj["lmstudio_url"]
ollama_api_key = ctx.obj["ollama_api_key"]
lmstudio_api_key = ctx.obj["lmstudio_api_key"]
model = ctx.obj["model"]

click.echo(f"Watching vault: {vault_path}")
Expand All @@ -648,6 +688,8 @@ def watch(ctx, debounce):
provider=provider,
ollama_url=ollama_url,
lmstudio_url=lmstudio_url,
ollama_api_key=ollama_api_key,
lmstudio_api_key=lmstudio_api_key,
model=model,
debounce_delay=debounce,
)
Expand Down Expand Up @@ -698,7 +740,7 @@ def _uninstall_wrapper_script():
wrapper_path.unlink()


def _get_plist_content(vault_path: str, data_path: str, provider: str, ollama_url: str, model: str | None) -> str:
def _get_plist_content(vault_path: str, data_path: str, provider: str, ollama_url: str, model: str | None, ollama_api_key: str | None = None, lmstudio_api_key: str | None = None) -> str:
"""Generate launchd plist content."""
# Use wrapper script for better System Settings appearance
wrapper_path = WRAPPER_SCRIPT_DIR / WRAPPER_SCRIPT_NAME
Expand All @@ -715,6 +757,14 @@ def _get_plist_content(vault_path: str, data_path: str, provider: str, ollama_ur
env_vars += f"""
<key>OBSIDIAN_RAG_OLLAMA_URL</key>
<string>{ollama_url}</string>"""
if ollama_api_key:
env_vars += f"""
<key>OBSIDIAN_RAG_OLLAMA_API_KEY</key>
<string>{ollama_api_key}</string>"""
elif provider == "lmstudio" and lmstudio_api_key:
env_vars += f"""
<key>OBSIDIAN_RAG_LMSTUDIO_API_KEY</key>
<string>{lmstudio_api_key}</string>"""

if model:
env_vars += f"""
Expand Down Expand Up @@ -765,6 +815,8 @@ def install_service(ctx):
data_path = ctx.obj["data"]
provider = ctx.obj["provider"]
ollama_url = ctx.obj["ollama_url"]
ollama_api_key = ctx.obj["ollama_api_key"]
lmstudio_api_key = ctx.obj["lmstudio_api_key"]
model = ctx.obj["model"]

plist_path = LAUNCH_AGENTS_DIR / PLIST_NAME
Expand All @@ -783,7 +835,7 @@ def install_service(ctx):
click.echo(f"Created: {wrapper_path}")

# Write plist
plist_content = _get_plist_content(vault_path, data_path, provider, ollama_url, model)
plist_content = _get_plist_content(vault_path, data_path, provider, ollama_url, model, ollama_api_key, lmstudio_api_key)
plist_path.write_text(plist_content)
click.echo(f"Created: {plist_path}")

Expand Down
21 changes: 21 additions & 0 deletions src/obsidian_rag/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,12 @@ class Config:
# Ollama settings
ollama_url: str = "http://localhost:11434"
ollama_model: str = "nomic-embed-text"
ollama_api_key: Optional[str] = None # Bearer token for protected Ollama instances

# LM Studio settings
lmstudio_url: str = "http://localhost:1234"
lmstudio_model: str = "text-embedding-nomic-embed-text-v1.5"
lmstudio_api_key: Optional[str] = None # Bearer token for protected LM Studio instances

# OpenAI model (optional override)
openai_model: str = "text-embedding-3-small"
Expand All @@ -97,6 +99,14 @@ def get_openai_api_key(self) -> Optional[str]:
"""Get OpenAI API key from config or environment."""
return self.openai_api_key or os.environ.get("OPENAI_API_KEY")

def get_ollama_api_key(self) -> Optional[str]:
"""Get Ollama Bearer token from config or environment."""
return self.ollama_api_key or os.environ.get("OBSIDIAN_RAG_OLLAMA_API_KEY")

def get_lmstudio_api_key(self) -> Optional[str]:
"""Get LM Studio Bearer token from config or environment."""
return self.lmstudio_api_key or os.environ.get("OBSIDIAN_RAG_LMSTUDIO_API_KEY")


def load_config() -> Config:
"""Load configuration from file with environment variable overrides.
Expand Down Expand Up @@ -130,11 +140,13 @@ def load_config() -> Config:
if "ollama" in data:
config.ollama_url = data["ollama"].get("url", config.ollama_url)
config.ollama_model = data["ollama"].get("model", config.ollama_model)
config.ollama_api_key = data["ollama"].get("api_key", config.ollama_api_key)

# LM Studio settings
if "lmstudio" in data:
config.lmstudio_url = data["lmstudio"].get("url", config.lmstudio_url)
config.lmstudio_model = data["lmstudio"].get("model", config.lmstudio_model)
config.lmstudio_api_key = data["lmstudio"].get("api_key", config.lmstudio_api_key)

# Indexer settings
if "indexer" in data:
Expand All @@ -154,6 +166,11 @@ def load_config() -> Config:
config.ollama_url = os.environ["OBSIDIAN_RAG_OLLAMA_URL"]
if os.environ.get("OBSIDIAN_RAG_LMSTUDIO_URL"):
config.lmstudio_url = os.environ["OBSIDIAN_RAG_LMSTUDIO_URL"]
# Bearer token overrides for local providers
if os.environ.get("OBSIDIAN_RAG_OLLAMA_API_KEY"):
config.ollama_api_key = os.environ["OBSIDIAN_RAG_OLLAMA_API_KEY"]
if os.environ.get("OBSIDIAN_RAG_LMSTUDIO_API_KEY"):
config.lmstudio_api_key = os.environ["OBSIDIAN_RAG_LMSTUDIO_API_KEY"]
if os.environ.get("OBSIDIAN_RAG_MODEL"):
if config.provider == "ollama":
config.ollama_model = os.environ["OBSIDIAN_RAG_MODEL"]
Expand Down Expand Up @@ -200,6 +217,8 @@ def save_config(config: Config) -> Path:
ollama_section["url"] = config.ollama_url
if config.ollama_model != "nomic-embed-text":
ollama_section["model"] = config.ollama_model
if config.ollama_api_key:
ollama_section["api_key"] = config.ollama_api_key
if ollama_section:
data["ollama"] = ollama_section

Expand All @@ -210,6 +229,8 @@ def save_config(config: Config) -> Path:
lmstudio_section["url"] = config.lmstudio_url
if config.lmstudio_model != "text-embedding-nomic-embed-text-v1.5":
lmstudio_section["model"] = config.lmstudio_model
if config.lmstudio_api_key:
lmstudio_section["api_key"] = config.lmstudio_api_key
if lmstudio_section:
data["lmstudio"] = lmstudio_section

Expand Down
Loading
Loading