diff --git a/sqlit/cli.py b/sqlit/cli.py index 702c769e..61c162a9 100644 --- a/sqlit/cli.py +++ b/sqlit/cli.py @@ -511,6 +511,8 @@ def main() -> int: description=f"{schema.display_name} connection options", ) add_schema_arguments(provider_parser, schema, include_name=True, name_required=True) + provider_parser.add_argument("--password-command", dest="password_command", help="Shell command to retrieve the database password") + provider_parser.add_argument("--ssh-password-command", dest="ssh_password_command", help="Shell command to retrieve the SSH password") edit_parser = conn_subparsers.add_parser("edit", help="Edit an existing connection") edit_parser.add_argument("connection_name", help="Name of connection to edit") @@ -528,6 +530,8 @@ def main() -> int: help="Authentication type (SQL Server only)", ) edit_parser.add_argument("--file-path", help="Database file path (SQLite only)") + edit_parser.add_argument("--password-command", dest="password_command", help="Shell command to retrieve the database password") + edit_parser.add_argument("--ssh-password-command", dest="ssh_password_command", help="Shell command to retrieve the SSH password") delete_parser = conn_subparsers.add_parser("delete", help="Delete a connection") delete_parser.add_argument("connection_name", help="Name of connection to delete") @@ -542,6 +546,8 @@ def main() -> int: description=f"{schema.display_name} connection options", ) add_schema_arguments(provider_parser, schema, include_name=True, name_required=False) + provider_parser.add_argument("--password-command", dest="password_command", help="Shell command to retrieve the database password") + provider_parser.add_argument("--ssh-password-command", dest="ssh_password_command", help="Shell command to retrieve the SSH password") query_parser = subparsers.add_parser("query", help="Execute a SQL query") query_parser.add_argument("--connection", "-c", required=True, help="Connection name to use") diff --git a/sqlit/domains/connections/app/connection_flow.py b/sqlit/domains/connections/app/connection_flow.py index cdbb6ad8..88c16dfa 100644 --- a/sqlit/domains/connections/app/connection_flow.py +++ b/sqlit/domains/connections/app/connection_flow.py @@ -7,6 +7,10 @@ from typing import Any, Protocol from sqlit.domains.connections.domain.config import ConnectionConfig +from sqlit.domains.connections.domain.password_command import ( + PasswordCommandError, + run_password_command, +) from sqlit.domains.connections.domain.passwords import needs_db_password, needs_ssh_password from sqlit.shared.app import AppServices @@ -49,10 +53,20 @@ def populate_credentials_if_missing(self, config: ConnectionConfig) -> None: password = service.get_password(config.name) if password is not None: endpoint.password = password + elif endpoint.password_command: + try: + endpoint.password = run_password_command(endpoint.password_command) + except PasswordCommandError: + pass if config.tunnel and config.tunnel.password is None: ssh_password = service.get_ssh_password(config.name) if ssh_password is not None: config.tunnel.password = ssh_password + elif config.tunnel.password_command: + try: + config.tunnel.password = run_password_command(config.tunnel.password_command) + except PasswordCommandError: + pass def start(self, config: ConnectionConfig, on_ready: Any) -> None: """Start the connection flow, prompting for missing passwords as needed.""" diff --git a/sqlit/domains/connections/cli/commands.py b/sqlit/domains/connections/cli/commands.py index 3299b48f..ffc08b42 100644 --- a/sqlit/domains/connections/cli/commands.py +++ b/sqlit/domains/connections/cli/commands.py @@ -253,6 +253,14 @@ def cmd_connection_edit(args: Any, *, services: AppServices | None = None) -> in if args.password is not None: endpoint.password = args.password + password_command = getattr(args, "password_command", None) + if password_command is not None and endpoint: + endpoint.password_command = password_command or None + + ssh_password_command = getattr(args, "ssh_password_command", None) + if ssh_password_command is not None and conn.tunnel: + conn.tunnel.password_command = ssh_password_command or None + file_path = getattr(args, "file_path", None) if file_path is not None: if conn.file_endpoint: diff --git a/sqlit/domains/connections/cli/helpers.py b/sqlit/domains/connections/cli/helpers.py index 03dba680..6d5d3255 100644 --- a/sqlit/domains/connections/cli/helpers.py +++ b/sqlit/domains/connections/cli/helpers.py @@ -99,7 +99,7 @@ def build_connection_config_from_args( } # Fields where None means "not set" vs "" means "explicitly empty" - nullable_fields = {"password", "ssh_password"} + nullable_fields = {"password", "ssh_password", "password_command", "ssh_password_command"} for field in schema.fields: value = raw_values.get(field.name, "") @@ -115,6 +115,14 @@ def build_connection_config_from_args( config_values[field.name] = value + # Pick up password_command / ssh_password_command from CLI args (not schema fields) + password_command = getattr(args, "password_command", None) + if password_command: + config_values["password_command"] = password_command + ssh_password_command = getattr(args, "ssh_password_command", None) + if ssh_password_command: + config_values["ssh_password_command"] = ssh_password_command + if "port" in config_values and not config_values["port"]: config_values["port"] = schema.default_port or "" @@ -134,6 +142,7 @@ def build_connection_config_from_args( "database": config_values.pop("database", ""), "username": config_values.pop("username", ""), "password": config_values.pop("password", None), + "password_command": config_values.pop("password_command", None), } ssh_enabled = config_values.pop("ssh_enabled", False) if ssh_enabled: @@ -144,6 +153,7 @@ def build_connection_config_from_args( "username": config_values.pop("ssh_username", ""), "auth_type": config_values.pop("ssh_auth_type", "key"), "password": config_values.pop("ssh_password", None), + "password_command": config_values.pop("ssh_password_command", None), "key_path": config_values.pop("ssh_key_path", ""), } else: diff --git a/sqlit/domains/connections/cli/prompts.py b/sqlit/domains/connections/cli/prompts.py index 636a0aa1..786840c7 100644 --- a/sqlit/domains/connections/cli/prompts.py +++ b/sqlit/domains/connections/cli/prompts.py @@ -3,21 +3,66 @@ from __future__ import annotations import getpass +import sys from sqlit.domains.connections.domain.config import ConnectionConfig -from sqlit.domains.connections.domain.passwords import needs_db_password, needs_ssh_password +from sqlit.domains.connections.domain.password_command import ( + PasswordCommandError, + run_password_command, +) + + +def _needs_ssh_prompt(config: ConnectionConfig) -> bool: + """Check if SSH password is still missing (ignoring password_command).""" + if not config.tunnel or not config.tunnel.enabled: + return False + if config.tunnel.auth_type != "password": + return False + return config.tunnel.password is None + + +def _needs_db_prompt(config: ConnectionConfig) -> bool: + """Check if DB password is still missing (ignoring password_command).""" + from sqlit.domains.connections.providers.metadata import is_file_based, requires_auth + + if is_file_based(config.db_type): + return False + if not requires_auth(config.db_type): + return False + auth_type = config.get_option("auth_type") + if auth_type in ("ad_default", "ad_integrated", "windows"): + return False + endpoint = config.tcp_endpoint + return bool(endpoint and endpoint.password is None) def prompt_for_password(config: ConnectionConfig) -> ConnectionConfig: """Prompt for passwords if they are not set (None).""" new_config = config - if needs_ssh_password(config): - ssh_password = getpass.getpass(f"SSH password for '{config.name}': ") - new_config = new_config.with_tunnel(password=ssh_password) + # SSH password + if config.tunnel and config.tunnel.password is None: + if config.tunnel.password_command: + try: + ssh_password = run_password_command(config.tunnel.password_command) + new_config = new_config.with_tunnel(password=ssh_password) + except PasswordCommandError as exc: + print(f"Warning: SSH password command failed: {exc}", file=sys.stderr) + if _needs_ssh_prompt(new_config): + ssh_password = getpass.getpass(f"SSH password for '{config.name}': ") + new_config = new_config.with_tunnel(password=ssh_password) - if needs_db_password(config): - db_password = getpass.getpass(f"Password for '{config.name}': ") - new_config = new_config.with_endpoint(password=db_password) + # DB password + endpoint = config.tcp_endpoint + if endpoint and endpoint.password is None: + if endpoint.password_command: + try: + db_password = run_password_command(endpoint.password_command) + new_config = new_config.with_endpoint(password=db_password) + except PasswordCommandError as exc: + print(f"Warning: password command failed: {exc}", file=sys.stderr) + if _needs_db_prompt(new_config): + db_password = getpass.getpass(f"Password for '{config.name}': ") + new_config = new_config.with_endpoint(password=db_password) return new_config diff --git a/sqlit/domains/connections/domain/config.py b/sqlit/domains/connections/domain/config.py index 01e96706..9ecf6657 100644 --- a/sqlit/domains/connections/domain/config.py +++ b/sqlit/domains/connections/domain/config.py @@ -100,6 +100,7 @@ class TcpEndpoint: database: str = "" username: str = "" password: str | None = None + password_command: str | None = None kind: str = "tcp" @@ -117,6 +118,7 @@ class TunnelConfig: username: str = "" auth_type: str = "key" # key|password password: str | None = None + password_command: str | None = None key_path: str = "" @@ -165,6 +167,7 @@ def from_dict(cls, data: Mapping[str, Any]) -> ConnectionConfig: database=str(endpoint_data.get("database", "")), username=str(endpoint_data.get("username", "")), password=endpoint_data.get("password", None), + password_command=endpoint_data.get("password_command", None), ) else: file_path = payload.pop("file_path", None) @@ -179,6 +182,7 @@ def from_dict(cls, data: Mapping[str, Any]) -> ConnectionConfig: database=str(payload.pop("database", "")), username=str(payload.pop("username", "")), password=payload.pop("password", None), + password_command=payload.pop("password_command", None), ) tunnel = None @@ -193,6 +197,7 @@ def from_dict(cls, data: Mapping[str, Any]) -> ConnectionConfig: username=str(tunnel_data.get("username", "")), auth_type=str(tunnel_data.get("auth_type", "key")), password=tunnel_data.get("password", None), + password_command=tunnel_data.get("password_command", None), key_path=str(tunnel_data.get("key_path", "")), ) else: @@ -202,6 +207,7 @@ def from_dict(cls, data: Mapping[str, Any]) -> ConnectionConfig: ssh_username = str(payload.pop("ssh_username", "")) ssh_auth_type = str(payload.pop("ssh_auth_type", "key")) ssh_password = payload.pop("ssh_password", None) + ssh_password_command = payload.pop("ssh_password_command", None) ssh_key_path = str(payload.pop("ssh_key_path", "")) enabled_flag = str(ssh_enabled).lower() if ssh_enabled is not None else "" @@ -213,6 +219,7 @@ def from_dict(cls, data: Mapping[str, Any]) -> ConnectionConfig: username=ssh_username, auth_type=ssh_auth_type or "key", password=ssh_password, + password_command=ssh_password_command, key_path=ssh_key_path, ) @@ -281,6 +288,7 @@ def to_form_values(self) -> dict[str, Any]: "database": self.endpoint.database, "username": self.endpoint.username, "password": self.endpoint.password, + "password_command": self.endpoint.password_command, } ) @@ -293,6 +301,7 @@ def to_form_values(self) -> dict[str, Any]: "ssh_username": self.tunnel.username, "ssh_auth_type": self.tunnel.auth_type, "ssh_password": self.tunnel.password, + "ssh_password_command": self.tunnel.password_command, "ssh_key_path": self.tunnel.key_path, } ) @@ -319,7 +328,7 @@ def to_dict(self, *, include_passwords: bool = True) -> dict[str, Any]: "path": self.endpoint.path, } else: - data["endpoint"] = { + endpoint_dict: dict[str, Any] = { "kind": "tcp", "host": self.endpoint.host, "port": self.endpoint.port, @@ -327,9 +336,12 @@ def to_dict(self, *, include_passwords: bool = True) -> dict[str, Any]: "username": self.endpoint.username, "password": self.endpoint.password if include_passwords else None, } + if self.endpoint.password_command: + endpoint_dict["password_command"] = self.endpoint.password_command + data["endpoint"] = endpoint_dict if self.tunnel and self.tunnel.enabled: - data["tunnel"] = { + tunnel_dict: dict[str, Any] = { "enabled": True, "host": self.tunnel.host, "port": self.tunnel.port, @@ -338,6 +350,9 @@ def to_dict(self, *, include_passwords: bool = True) -> dict[str, Any]: "password": self.tunnel.password if include_passwords else None, "key_path": self.tunnel.key_path, } + if self.tunnel.password_command: + tunnel_dict["password_command"] = self.tunnel.password_command + data["tunnel"] = tunnel_dict else: data["tunnel"] = {"enabled": False} diff --git a/sqlit/domains/connections/domain/password_command.py b/sqlit/domains/connections/domain/password_command.py new file mode 100644 index 00000000..864e9260 --- /dev/null +++ b/sqlit/domains/connections/domain/password_command.py @@ -0,0 +1,35 @@ +"""Run a shell command to retrieve a password.""" + +from __future__ import annotations + +import subprocess + + +class PasswordCommandError(Exception): + """Raised when a password command fails.""" + + +def run_password_command(command: str, *, timeout: int = 30) -> str: + """Run a shell command and return its stripped stdout as the password.""" + try: + result = subprocess.run( + command, + shell=True, + capture_output=True, + text=True, + timeout=timeout, + ) + except FileNotFoundError as exc: + raise PasswordCommandError(f"Command not found: {exc}") from exc + except subprocess.TimeoutExpired as exc: + raise PasswordCommandError( + f"Password command timed out after {timeout}s: {command}" + ) from exc + + if result.returncode != 0: + stderr = result.stderr.strip() + raise PasswordCommandError( + f"Password command failed (exit {result.returncode}): {stderr}" + ) + + return result.stdout.strip() diff --git a/sqlit/domains/connections/domain/passwords.py b/sqlit/domains/connections/domain/passwords.py index 40957b82..62cedf8d 100644 --- a/sqlit/domains/connections/domain/passwords.py +++ b/sqlit/domains/connections/domain/passwords.py @@ -19,7 +19,11 @@ def needs_db_password(config: ConnectionConfig) -> bool: return False endpoint = config.tcp_endpoint - return bool(endpoint and endpoint.password is None) + if not endpoint or endpoint.password is not None: + return False + if endpoint.password_command: + return False + return True def needs_ssh_password(config: ConnectionConfig) -> bool: @@ -30,4 +34,8 @@ def needs_ssh_password(config: ConnectionConfig) -> bool: if config.tunnel.auth_type != "password": return False - return config.tunnel.password is None + if config.tunnel.password is not None: + return False + if config.tunnel.password_command: + return False + return True diff --git a/tests/helpers.py b/tests/helpers.py index 50746f60..2447096a 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -42,6 +42,7 @@ def __call__(self, **kwargs: Any) -> _ConnectionConfig: database=str(kwargs.pop("database", "")), username=str(kwargs.pop("username", "")), password=kwargs.pop("password", None), + password_command=kwargs.pop("password_command", None), ) if tunnel is None: @@ -51,6 +52,7 @@ def __call__(self, **kwargs: Any) -> _ConnectionConfig: ssh_username = str(kwargs.pop("ssh_username", "")) ssh_auth_type = str(kwargs.pop("ssh_auth_type", "key")) ssh_password = kwargs.pop("ssh_password", None) + ssh_password_command = kwargs.pop("ssh_password_command", None) ssh_key_path = str(kwargs.pop("ssh_key_path", "")) enabled_flag = str(ssh_enabled).lower() if ssh_enabled is not None else "" @@ -62,6 +64,7 @@ def __call__(self, **kwargs: Any) -> _ConnectionConfig: username=ssh_username, auth_type=ssh_auth_type or "key", password=ssh_password, + password_command=ssh_password_command, key_path=ssh_key_path, ) diff --git a/tests/test_password_prompts.py b/tests/test_password_prompts.py index c0b4b988..4a638574 100644 --- a/tests/test_password_prompts.py +++ b/tests/test_password_prompts.py @@ -134,6 +134,28 @@ def test_mssql_windows_auth_with_empty_password_no_prompt(self) -> None: assert not needs_db_password(config) + def test_password_command_set_does_not_need_prompt(self) -> None: + config = ConnectionConfig( + name="test", + db_type="postgresql", + server="localhost", + username="user", + password=None, + password_command="echo secret", + ) + assert not needs_db_password(config) + + def test_no_password_no_command_needs_prompt(self) -> None: + config = ConnectionConfig( + name="test", + db_type="postgresql", + server="localhost", + username="user", + password=None, + ) + assert needs_db_password(config) + + class TestNeedsSshPassword: """Test needs_ssh_password helper function.""" @@ -203,6 +225,34 @@ def test_ssh_password_auth_with_stored_password_does_not_need_prompt(self) -> No assert not needs_ssh_password(config) + def test_ssh_password_command_set_does_not_need_prompt(self) -> None: + config = ConnectionConfig( + name="test", + db_type="postgresql", + server="localhost", + ssh_enabled=True, + ssh_auth_type="password", + ssh_host="bastion", + ssh_username="user", + ssh_password=None, + ssh_password_command="echo sshpw", + ) + assert not needs_ssh_password(config) + + def test_ssh_no_password_no_command_needs_prompt(self) -> None: + config = ConnectionConfig( + name="test", + db_type="postgresql", + server="localhost", + ssh_enabled=True, + ssh_auth_type="password", + ssh_host="bastion", + ssh_username="user", + ssh_password=None, + ) + assert needs_ssh_password(config) + + class TestCliPromptForPassword: """Test CLI prompt_for_password function.""" @@ -356,6 +406,82 @@ def test_original_config_not_modified(self) -> None: assert result is not original +class TestPasswordCommandPrompt: + """Tests for password_command integration in CLI prompts.""" + + @patch("sqlit.domains.connections.cli.prompts.getpass.getpass") + @patch("sqlit.domains.connections.cli.prompts.run_password_command", return_value="cmd_password") + def test_password_command_resolves_db_password(self, mock_run: MagicMock, mock_getpass: MagicMock) -> None: + config = ConnectionConfig( + name="mydb", + db_type="postgresql", + server="localhost", + username="user", + password=None, + password_command="op read op://vault/item/pw", + ) + result = prompt_for_password(config) + mock_run.assert_called_once_with("op read op://vault/item/pw") + mock_getpass.assert_not_called() + assert result.password == "cmd_password" + + @patch("sqlit.domains.connections.cli.prompts.getpass.getpass") + @patch("sqlit.domains.connections.cli.prompts.run_password_command", return_value="ssh_cmd_pw") + def test_password_command_resolves_ssh_password(self, mock_run: MagicMock, mock_getpass: MagicMock) -> None: + config = ConnectionConfig( + name="mydb", + db_type="postgresql", + server="localhost", + username="user", + password="stored", + ssh_enabled=True, + ssh_auth_type="password", + ssh_host="bastion", + ssh_username="sshuser", + ssh_password=None, + ssh_password_command="echo sshpw", + ) + result = prompt_for_password(config) + mock_run.assert_called_once_with("echo sshpw") + mock_getpass.assert_not_called() + assert result.ssh_password == "ssh_cmd_pw" + + @patch("sqlit.domains.connections.cli.prompts.getpass.getpass", return_value="fallback") + @patch("sqlit.domains.connections.cli.prompts.run_password_command") + def test_password_command_failure_falls_back_to_getpass(self, mock_run: MagicMock, mock_getpass: MagicMock) -> None: + from sqlit.domains.connections.domain.password_command import PasswordCommandError + + mock_run.side_effect = PasswordCommandError("command failed") + config = ConnectionConfig( + name="mydb", + db_type="postgresql", + server="localhost", + username="user", + password=None, + password_command="bad-cmd", + ) + result = prompt_for_password(config) + mock_run.assert_called_once() + mock_getpass.assert_called_once() + assert result.password == "fallback" + + @patch("sqlit.domains.connections.cli.prompts.getpass.getpass") + @patch("sqlit.domains.connections.cli.prompts.run_password_command") + def test_explicit_password_skips_command(self, mock_run: MagicMock, mock_getpass: MagicMock) -> None: + config = ConnectionConfig( + name="mydb", + db_type="postgresql", + server="localhost", + username="user", + password="explicit", + password_command="echo should-not-run", + ) + result = prompt_for_password(config) + mock_run.assert_not_called() + mock_getpass.assert_not_called() + assert result.password == "explicit" + + class TestPasswordPromptIntegration: """Integration tests for the full password prompt flow.""" diff --git a/tests/unit/test_connection_config_from_dict.py b/tests/unit/test_connection_config_from_dict.py index 2d719837..86a64860 100644 --- a/tests/unit/test_connection_config_from_dict.py +++ b/tests/unit/test_connection_config_from_dict.py @@ -68,3 +68,140 @@ def test_from_dict_folder_path_normalized() -> None: config = ConnectionConfig.from_dict(data) assert config.folder_path == "Potato/Ninja" + + +def test_from_dict_endpoint_password_command() -> None: + data = { + "name": "pc-test", + "db_type": "postgresql", + "endpoint": { + "kind": "tcp", + "host": "localhost", + "port": "5432", + "database": "db", + "username": "user", + "password": None, + "password_command": "op read op://vault/item/password", + }, + } + config = ConnectionConfig.from_dict(data) + assert config.tcp_endpoint is not None + assert config.tcp_endpoint.password_command == "op read op://vault/item/password" + + +def test_from_dict_tunnel_password_command() -> None: + data = { + "name": "pc-test", + "db_type": "postgresql", + "endpoint": {"kind": "tcp", "host": "localhost", "port": "5432", "database": "db", "username": "user"}, + "tunnel": { + "enabled": True, + "host": "bastion", + "port": "22", + "username": "sshuser", + "auth_type": "password", + "password": None, + "password_command": "bw get password ssh-bastion", + }, + } + config = ConnectionConfig.from_dict(data) + assert config.tunnel is not None + assert config.tunnel.password_command == "bw get password ssh-bastion" + + +def test_from_dict_legacy_ssh_password_command() -> None: + data = { + "name": "legacy-pc", + "db_type": "postgresql", + "server": "localhost", + "port": "5432", + "database": "db", + "username": "user", + "password_command": "echo dbpass", + "ssh_enabled": True, + "ssh_host": "bastion", + "ssh_password_command": "echo sshpass", + } + config = ConnectionConfig.from_dict(data) + assert config.tcp_endpoint is not None + assert config.tcp_endpoint.password_command == "echo dbpass" + assert config.tunnel is not None + assert config.tunnel.password_command == "echo sshpass" + + +def test_to_dict_includes_password_command() -> None: + config = ConnectionConfig.from_dict({ + "name": "t", + "db_type": "postgresql", + "endpoint": { + "kind": "tcp", + "host": "h", + "port": "5432", + "database": "d", + "username": "u", + "password_command": "echo pw", + }, + }) + d = config.to_dict() + assert d["endpoint"]["password_command"] == "echo pw" + + +def test_to_dict_omits_password_command_when_none() -> None: + config = ConnectionConfig.from_dict({ + "name": "t", + "db_type": "postgresql", + "endpoint": {"kind": "tcp", "host": "h", "port": "5432", "database": "d", "username": "u"}, + }) + d = config.to_dict() + assert "password_command" not in d["endpoint"] + + +def test_round_trip_password_command() -> None: + original = { + "name": "rt", + "db_type": "postgresql", + "endpoint": { + "kind": "tcp", + "host": "h", + "port": "5432", + "database": "d", + "username": "u", + "password": None, + "password_command": "vault kv get -field=pw secret/db", + }, + "tunnel": { + "enabled": True, + "host": "bastion", + "port": "22", + "username": "ssh", + "auth_type": "password", + "password": None, + "password_command": "echo sshpw", + }, + } + config = ConnectionConfig.from_dict(original) + d = config.to_dict() + config2 = ConnectionConfig.from_dict(d) + assert config2.tcp_endpoint is not None + assert config2.tcp_endpoint.password_command == "vault kv get -field=pw secret/db" + assert config2.tunnel is not None + assert config2.tunnel.password_command == "echo sshpw" + + +def test_to_dict_include_passwords_false_keeps_password_command() -> None: + config = ConnectionConfig.from_dict({ + "name": "t", + "db_type": "postgresql", + "endpoint": { + "kind": "tcp", + "host": "h", + "port": "5432", + "database": "d", + "username": "u", + "password": "secret", + "password_command": "echo pw", + }, + }) + d = config.to_dict(include_passwords=False) + assert d["endpoint"]["password"] is None + assert d["endpoint"]["password_command"] == "echo pw" diff --git a/tests/unit/test_connection_flow.py b/tests/unit/test_connection_flow.py new file mode 100644 index 00000000..33f09b0d --- /dev/null +++ b/tests/unit/test_connection_flow.py @@ -0,0 +1,48 @@ +"""Tests for connection flow password_command integration.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +from sqlit.domains.connections.app.connection_flow import ConnectionFlow +from tests.helpers import ConnectionConfig + + +class TestPopulateCredentialsPasswordCommand: + def _make_flow(self, *, keyring_password: str | None = None) -> ConnectionFlow: + services = MagicMock() + services.credentials_service.get_password.return_value = keyring_password + services.credentials_service.get_ssh_password.return_value = None + return ConnectionFlow(services=services) + + @patch("sqlit.domains.connections.app.connection_flow.run_password_command", return_value="cmd_pw") + def test_runs_password_command_when_keyring_empty(self, mock_run: MagicMock) -> None: + flow = self._make_flow(keyring_password=None) + config = ConnectionConfig( + name="test", + db_type="postgresql", + server="localhost", + username="user", + password=None, + password_command="echo cmd_pw", + ) + flow.populate_credentials_if_missing(config) + mock_run.assert_called_once_with("echo cmd_pw") + assert config.tcp_endpoint is not None + assert config.tcp_endpoint.password == "cmd_pw" + + @patch("sqlit.domains.connections.app.connection_flow.run_password_command") + def test_keyring_password_wins_over_command(self, mock_run: MagicMock) -> None: + flow = self._make_flow(keyring_password="keyring_pw") + config = ConnectionConfig( + name="test", + db_type="postgresql", + server="localhost", + username="user", + password=None, + password_command="echo cmd_pw", + ) + flow.populate_credentials_if_missing(config) + mock_run.assert_not_called() + assert config.tcp_endpoint is not None + assert config.tcp_endpoint.password == "keyring_pw" diff --git a/tests/unit/test_password_command.py b/tests/unit/test_password_command.py new file mode 100644 index 00000000..8b78e056 --- /dev/null +++ b/tests/unit/test_password_command.py @@ -0,0 +1,52 @@ +"""Tests for password command utility.""" + +from __future__ import annotations + +import subprocess +from unittest.mock import patch + +import pytest + +from sqlit.domains.connections.domain.password_command import ( + PasswordCommandError, + run_password_command, +) + + +class TestRunPasswordCommand: + def test_returns_stripped_stdout(self) -> None: + mock_result = subprocess.CompletedProcess( + args="cmd", returncode=0, stdout=" secret\n", stderr="" + ) + with patch("subprocess.run", return_value=mock_result): + assert run_password_command("cmd") == "secret" + + def test_raises_on_nonzero_exit(self) -> None: + mock_result = subprocess.CompletedProcess( + args="cmd", returncode=1, stdout="", stderr="access denied" + ) + with patch("subprocess.run", return_value=mock_result): + with pytest.raises(PasswordCommandError, match="exit 1.*access denied"): + run_password_command("cmd") + + def test_raises_on_timeout(self) -> None: + with patch( + "subprocess.run", + side_effect=subprocess.TimeoutExpired("cmd", 30), + ): + with pytest.raises(PasswordCommandError, match="timed out"): + run_password_command("cmd") + + def test_raises_on_file_not_found(self) -> None: + with patch( + "subprocess.run", + side_effect=FileNotFoundError("No such file"), + ): + with pytest.raises(PasswordCommandError, match="Command not found"): + run_password_command("cmd") + + def test_real_echo_command(self) -> None: + assert run_password_command("echo secret") == "secret" + + def test_real_echo_strips_whitespace(self) -> None: + assert run_password_command("echo ' hello '") == "hello"