Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions sqlit/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,8 @@ def main() -> int:
description=f"{schema.display_name} connection options",
)
add_schema_arguments(provider_parser, schema, include_name=True, name_required=True)
provider_parser.add_argument("--password-command", dest="password_command", help="Shell command to retrieve the database password")
provider_parser.add_argument("--ssh-password-command", dest="ssh_password_command", help="Shell command to retrieve the SSH password")

edit_parser = conn_subparsers.add_parser("edit", help="Edit an existing connection")
edit_parser.add_argument("connection_name", help="Name of connection to edit")
Expand All @@ -528,6 +530,8 @@ def main() -> int:
help="Authentication type (SQL Server only)",
)
edit_parser.add_argument("--file-path", help="Database file path (SQLite only)")
edit_parser.add_argument("--password-command", dest="password_command", help="Shell command to retrieve the database password")
edit_parser.add_argument("--ssh-password-command", dest="ssh_password_command", help="Shell command to retrieve the SSH password")

delete_parser = conn_subparsers.add_parser("delete", help="Delete a connection")
delete_parser.add_argument("connection_name", help="Name of connection to delete")
Expand All @@ -542,6 +546,8 @@ def main() -> int:
description=f"{schema.display_name} connection options",
)
add_schema_arguments(provider_parser, schema, include_name=True, name_required=False)
provider_parser.add_argument("--password-command", dest="password_command", help="Shell command to retrieve the database password")
provider_parser.add_argument("--ssh-password-command", dest="ssh_password_command", help="Shell command to retrieve the SSH password")

query_parser = subparsers.add_parser("query", help="Execute a SQL query")
query_parser.add_argument("--connection", "-c", required=True, help="Connection name to use")
Expand Down
14 changes: 14 additions & 0 deletions sqlit/domains/connections/app/connection_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
from typing import Any, Protocol

from sqlit.domains.connections.domain.config import ConnectionConfig
from sqlit.domains.connections.domain.password_command import (
PasswordCommandError,
run_password_command,
)
from sqlit.domains.connections.domain.passwords import needs_db_password, needs_ssh_password
from sqlit.shared.app import AppServices

Expand Down Expand Up @@ -49,10 +53,20 @@ def populate_credentials_if_missing(self, config: ConnectionConfig) -> None:
password = service.get_password(config.name)
if password is not None:
endpoint.password = password
elif endpoint.password_command:
try:
endpoint.password = run_password_command(endpoint.password_command)
except PasswordCommandError:
pass
if config.tunnel and config.tunnel.password is None:
ssh_password = service.get_ssh_password(config.name)
if ssh_password is not None:
config.tunnel.password = ssh_password
elif config.tunnel.password_command:
try:
config.tunnel.password = run_password_command(config.tunnel.password_command)
except PasswordCommandError:
pass

def start(self, config: ConnectionConfig, on_ready: Any) -> None:
"""Start the connection flow, prompting for missing passwords as needed."""
Expand Down
8 changes: 8 additions & 0 deletions sqlit/domains/connections/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,14 @@ def cmd_connection_edit(args: Any, *, services: AppServices | None = None) -> in
if args.password is not None:
endpoint.password = args.password

password_command = getattr(args, "password_command", None)
if password_command is not None and endpoint:
endpoint.password_command = password_command or None

ssh_password_command = getattr(args, "ssh_password_command", None)
if ssh_password_command is not None and conn.tunnel:
conn.tunnel.password_command = ssh_password_command or None

file_path = getattr(args, "file_path", None)
if file_path is not None:
if conn.file_endpoint:
Expand Down
12 changes: 11 additions & 1 deletion sqlit/domains/connections/cli/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def build_connection_config_from_args(
}

# Fields where None means "not set" vs "" means "explicitly empty"
nullable_fields = {"password", "ssh_password"}
nullable_fields = {"password", "ssh_password", "password_command", "ssh_password_command"}

for field in schema.fields:
value = raw_values.get(field.name, "")
Expand All @@ -115,6 +115,14 @@ def build_connection_config_from_args(

config_values[field.name] = value

# Pick up password_command / ssh_password_command from CLI args (not schema fields)
password_command = getattr(args, "password_command", None)
if password_command:
config_values["password_command"] = password_command
ssh_password_command = getattr(args, "ssh_password_command", None)
if ssh_password_command:
config_values["ssh_password_command"] = ssh_password_command

if "port" in config_values and not config_values["port"]:
config_values["port"] = schema.default_port or ""

Expand All @@ -134,6 +142,7 @@ def build_connection_config_from_args(
"database": config_values.pop("database", ""),
"username": config_values.pop("username", ""),
"password": config_values.pop("password", None),
"password_command": config_values.pop("password_command", None),
}
ssh_enabled = config_values.pop("ssh_enabled", False)
if ssh_enabled:
Expand All @@ -144,6 +153,7 @@ def build_connection_config_from_args(
"username": config_values.pop("ssh_username", ""),
"auth_type": config_values.pop("ssh_auth_type", "key"),
"password": config_values.pop("ssh_password", None),
"password_command": config_values.pop("ssh_password_command", None),
"key_path": config_values.pop("ssh_key_path", ""),
}
else:
Expand Down
59 changes: 52 additions & 7 deletions sqlit/domains/connections/cli/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,66 @@
from __future__ import annotations

import getpass
import sys

from sqlit.domains.connections.domain.config import ConnectionConfig
from sqlit.domains.connections.domain.passwords import needs_db_password, needs_ssh_password
from sqlit.domains.connections.domain.password_command import (
PasswordCommandError,
run_password_command,
)


def _needs_ssh_prompt(config: ConnectionConfig) -> bool:
"""Check if SSH password is still missing (ignoring password_command)."""
if not config.tunnel or not config.tunnel.enabled:
return False
if config.tunnel.auth_type != "password":
return False
return config.tunnel.password is None


def _needs_db_prompt(config: ConnectionConfig) -> bool:
"""Check if DB password is still missing (ignoring password_command)."""
from sqlit.domains.connections.providers.metadata import is_file_based, requires_auth

if is_file_based(config.db_type):
return False
if not requires_auth(config.db_type):
return False
auth_type = config.get_option("auth_type")
if auth_type in ("ad_default", "ad_integrated", "windows"):
return False
endpoint = config.tcp_endpoint
return bool(endpoint and endpoint.password is None)


def prompt_for_password(config: ConnectionConfig) -> ConnectionConfig:
"""Prompt for passwords if they are not set (None)."""
new_config = config

if needs_ssh_password(config):
ssh_password = getpass.getpass(f"SSH password for '{config.name}': ")
new_config = new_config.with_tunnel(password=ssh_password)
# SSH password
if config.tunnel and config.tunnel.password is None:
if config.tunnel.password_command:
try:
ssh_password = run_password_command(config.tunnel.password_command)
new_config = new_config.with_tunnel(password=ssh_password)
except PasswordCommandError as exc:
print(f"Warning: SSH password command failed: {exc}", file=sys.stderr)
if _needs_ssh_prompt(new_config):
ssh_password = getpass.getpass(f"SSH password for '{config.name}': ")
new_config = new_config.with_tunnel(password=ssh_password)

if needs_db_password(config):
db_password = getpass.getpass(f"Password for '{config.name}': ")
new_config = new_config.with_endpoint(password=db_password)
# DB password
endpoint = config.tcp_endpoint
if endpoint and endpoint.password is None:
if endpoint.password_command:
try:
db_password = run_password_command(endpoint.password_command)
new_config = new_config.with_endpoint(password=db_password)
except PasswordCommandError as exc:
print(f"Warning: password command failed: {exc}", file=sys.stderr)
if _needs_db_prompt(new_config):
db_password = getpass.getpass(f"Password for '{config.name}': ")
new_config = new_config.with_endpoint(password=db_password)

return new_config
19 changes: 17 additions & 2 deletions sqlit/domains/connections/domain/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class TcpEndpoint:
database: str = ""
username: str = ""
password: str | None = None
password_command: str | None = None
kind: str = "tcp"


Expand All @@ -117,6 +118,7 @@ class TunnelConfig:
username: str = ""
auth_type: str = "key" # key|password
password: str | None = None
password_command: str | None = None
key_path: str = ""


Expand Down Expand Up @@ -165,6 +167,7 @@ def from_dict(cls, data: Mapping[str, Any]) -> ConnectionConfig:
database=str(endpoint_data.get("database", "")),
username=str(endpoint_data.get("username", "")),
password=endpoint_data.get("password", None),
password_command=endpoint_data.get("password_command", None),
)
else:
file_path = payload.pop("file_path", None)
Expand All @@ -179,6 +182,7 @@ def from_dict(cls, data: Mapping[str, Any]) -> ConnectionConfig:
database=str(payload.pop("database", "")),
username=str(payload.pop("username", "")),
password=payload.pop("password", None),
password_command=payload.pop("password_command", None),
)

tunnel = None
Expand All @@ -193,6 +197,7 @@ def from_dict(cls, data: Mapping[str, Any]) -> ConnectionConfig:
username=str(tunnel_data.get("username", "")),
auth_type=str(tunnel_data.get("auth_type", "key")),
password=tunnel_data.get("password", None),
password_command=tunnel_data.get("password_command", None),
key_path=str(tunnel_data.get("key_path", "")),
)
else:
Expand All @@ -202,6 +207,7 @@ def from_dict(cls, data: Mapping[str, Any]) -> ConnectionConfig:
ssh_username = str(payload.pop("ssh_username", ""))
ssh_auth_type = str(payload.pop("ssh_auth_type", "key"))
ssh_password = payload.pop("ssh_password", None)
ssh_password_command = payload.pop("ssh_password_command", None)
ssh_key_path = str(payload.pop("ssh_key_path", ""))

enabled_flag = str(ssh_enabled).lower() if ssh_enabled is not None else ""
Expand All @@ -213,6 +219,7 @@ def from_dict(cls, data: Mapping[str, Any]) -> ConnectionConfig:
username=ssh_username,
auth_type=ssh_auth_type or "key",
password=ssh_password,
password_command=ssh_password_command,
key_path=ssh_key_path,
)

Expand Down Expand Up @@ -281,6 +288,7 @@ def to_form_values(self) -> dict[str, Any]:
"database": self.endpoint.database,
"username": self.endpoint.username,
"password": self.endpoint.password,
"password_command": self.endpoint.password_command,
}
)

Expand All @@ -293,6 +301,7 @@ def to_form_values(self) -> dict[str, Any]:
"ssh_username": self.tunnel.username,
"ssh_auth_type": self.tunnel.auth_type,
"ssh_password": self.tunnel.password,
"ssh_password_command": self.tunnel.password_command,
"ssh_key_path": self.tunnel.key_path,
}
)
Expand All @@ -319,17 +328,20 @@ def to_dict(self, *, include_passwords: bool = True) -> dict[str, Any]:
"path": self.endpoint.path,
}
else:
data["endpoint"] = {
endpoint_dict: dict[str, Any] = {
"kind": "tcp",
"host": self.endpoint.host,
"port": self.endpoint.port,
"database": self.endpoint.database,
"username": self.endpoint.username,
"password": self.endpoint.password if include_passwords else None,
}
if self.endpoint.password_command:
endpoint_dict["password_command"] = self.endpoint.password_command
data["endpoint"] = endpoint_dict

if self.tunnel and self.tunnel.enabled:
data["tunnel"] = {
tunnel_dict: dict[str, Any] = {
"enabled": True,
"host": self.tunnel.host,
"port": self.tunnel.port,
Expand All @@ -338,6 +350,9 @@ def to_dict(self, *, include_passwords: bool = True) -> dict[str, Any]:
"password": self.tunnel.password if include_passwords else None,
"key_path": self.tunnel.key_path,
}
if self.tunnel.password_command:
tunnel_dict["password_command"] = self.tunnel.password_command
data["tunnel"] = tunnel_dict
else:
data["tunnel"] = {"enabled": False}

Expand Down
35 changes: 35 additions & 0 deletions sqlit/domains/connections/domain/password_command.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Run a shell command to retrieve a password."""

from __future__ import annotations

import subprocess


class PasswordCommandError(Exception):
"""Raised when a password command fails."""


def run_password_command(command: str, *, timeout: int = 30) -> str:
"""Run a shell command and return its stripped stdout as the password."""
try:
result = subprocess.run(
command,
shell=True,
capture_output=True,
text=True,
timeout=timeout,
)
except FileNotFoundError as exc:
raise PasswordCommandError(f"Command not found: {exc}") from exc
except subprocess.TimeoutExpired as exc:
raise PasswordCommandError(
f"Password command timed out after {timeout}s: {command}"
) from exc

if result.returncode != 0:
stderr = result.stderr.strip()
raise PasswordCommandError(
f"Password command failed (exit {result.returncode}): {stderr}"
)

return result.stdout.strip()
12 changes: 10 additions & 2 deletions sqlit/domains/connections/domain/passwords.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ def needs_db_password(config: ConnectionConfig) -> bool:
return False

endpoint = config.tcp_endpoint
return bool(endpoint and endpoint.password is None)
if not endpoint or endpoint.password is not None:
return False
if endpoint.password_command:
return False
return True


def needs_ssh_password(config: ConnectionConfig) -> bool:
Expand All @@ -30,4 +34,8 @@ def needs_ssh_password(config: ConnectionConfig) -> bool:
if config.tunnel.auth_type != "password":
return False

return config.tunnel.password is None
if config.tunnel.password is not None:
return False
if config.tunnel.password_command:
return False
return True
3 changes: 3 additions & 0 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __call__(self, **kwargs: Any) -> _ConnectionConfig:
database=str(kwargs.pop("database", "")),
username=str(kwargs.pop("username", "")),
password=kwargs.pop("password", None),
password_command=kwargs.pop("password_command", None),
)

if tunnel is None:
Expand All @@ -51,6 +52,7 @@ def __call__(self, **kwargs: Any) -> _ConnectionConfig:
ssh_username = str(kwargs.pop("ssh_username", ""))
ssh_auth_type = str(kwargs.pop("ssh_auth_type", "key"))
ssh_password = kwargs.pop("ssh_password", None)
ssh_password_command = kwargs.pop("ssh_password_command", None)
ssh_key_path = str(kwargs.pop("ssh_key_path", ""))

enabled_flag = str(ssh_enabled).lower() if ssh_enabled is not None else ""
Expand All @@ -62,6 +64,7 @@ def __call__(self, **kwargs: Any) -> _ConnectionConfig:
username=ssh_username,
auth_type=ssh_auth_type or "key",
password=ssh_password,
password_command=ssh_password_command,
key_path=ssh_key_path,
)

Expand Down
Loading
Loading