Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 105 additions & 8 deletions src/vaultctl/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import base64
import builtins
import json
import re
Expand Down Expand Up @@ -29,7 +30,7 @@
from .search import MAX_PATTERN_LENGTH, SearchMatch, filter_keys, search_values
from .types import detect_entry_type, get_entry_fields, get_field_value
from .vault import VaultError, decrypt_vault, edit_vault, encrypt_vault
from .yaml_util import dump_yaml
from .yaml_util import clean_multiline_value, dump_yaml


def _format_value(value: Any) -> str:
Expand Down Expand Up @@ -365,9 +366,17 @@ def _print_context_results(matches: list[SearchMatch], *, show_match: bool) -> N
@click.argument("key")
@click.option("--field", default=None, help="Access a specific field of a structured entry.")
@click.option("--json", "output_json", is_flag=True, default=False, help="Output as JSON.")
@click.option("--raw", is_flag=True, default=False, help="Output raw value without headers or formatting.")
@click.option("--base64", "output_base64", is_flag=True, default=False, help="Output value as base64-encoded string.")
@pass_ctx
def get(vctx: VaultContext, key: str, field: str | None, output_json: bool) -> None:
def get(vctx: VaultContext, key: str, field: str | None, output_json: bool, raw: bool, output_base64: bool) -> None:
"""Show the value of a vault key."""
# Validate mutually exclusive output flags
output_flags = sum([output_json, raw, output_base64])
if output_flags > 1:
click.echo("Error: --json, --raw, and --base64 are mutually exclusive.", err=True)
sys.exit(1)

try:
data = decrypt_vault(vctx.config.vault_file, vctx.password)
except VaultError as exc:
Expand All @@ -388,6 +397,10 @@ def get(vctx: VaultContext, key: str, field: str | None, output_json: bool) -> N
sys.exit(1)
if output_json:
click.echo(json.dumps(field_val, indent=2, ensure_ascii=False))
elif raw:
_output_raw(field_val)
elif output_base64:
_output_base64_encoded(field_val)
else:
click.echo(_format_value(field_val))
return
Expand All @@ -396,6 +409,14 @@ def get(vctx: VaultContext, key: str, field: str | None, output_json: bool) -> N
click.echo(json.dumps(value, indent=2, ensure_ascii=False))
return

if raw:
_output_raw(value)
return

if output_base64:
_output_base64_encoded(value)
return

entry_type = detect_entry_type(value)
if isinstance(value, dict):
click.echo(f"Type: {entry_type}")
Expand All @@ -405,11 +426,37 @@ def get(vctx: VaultContext, key: str, field: str | None, output_json: bool) -> N
click.echo(value, nl=not isinstance(value, str) or not value.endswith("\n"))


def _output_raw(value: Any) -> None:
"""Output a value in raw mode: cleaned multiline string, no headers."""
text = _format_value(value)
if isinstance(value, str):
text = clean_multiline_value(text)
# Write to stdout directly, no extra newline (clean_multiline_value ensures trailing \n)
click.get_text_stream("stdout").write(text if text.endswith("\n") else text + "\n")


def _output_base64_encoded(value: Any) -> None:
"""Output a value as a single base64-encoded line."""
text = _format_value(value)
if isinstance(value, str):
text = clean_multiline_value(text)
encoded = base64.b64encode(text.encode("utf-8")).decode("ascii")
click.echo(encoded)


@main.command()
@click.argument("key")
@click.argument("value", required=False)
@click.option("--prompt", "use_prompt", is_flag=True, help="Enter value interactively.")
@click.option("--file", "from_file", type=click.Path(exists=True), help="Read value from file.")
@click.option("--base64", "from_base64", default=None, help="Set value from base64-encoded string.")
@click.option(
"--base64-file",
"from_base64_file",
type=click.Path(),
default=None,
help="Read base64-encoded value from file (use '-' for stdin).",
)
@click.option("--backup/--no-backup", default=True, help="Save previous value as <key>_previous.")
@click.option("--expires", default=None, help="Expiry date (YYYY-MM-DD) for vault-keys.yml.")
@click.option("--force", is_flag=True, default=False, help="Skip confirmation prompts.")
Expand All @@ -420,12 +467,14 @@ def set(
value: str | None,
use_prompt: bool,
from_file: str | None,
from_base64: str | None,
from_base64_file: str | None,
backup: bool,
expires: str | None,
force: bool,
) -> None:
"""Set a vault key."""
value = _resolve_set_value(value, use_prompt, from_file, key)
value = _resolve_set_value(value, use_prompt, from_file, from_base64, from_base64_file, key)

try:
data = decrypt_vault(vctx.config.vault_file, vctx.password)
Expand Down Expand Up @@ -686,22 +735,70 @@ def _print_detection_json(results: list[DetectionResult]) -> None:
click.echo(json.dumps(items, indent=2))


def _resolve_set_value(value: str | None, use_prompt: bool, from_file: str | None, key: str) -> str:
"""Resolve the value for a set operation from the three input modes.
def _resolve_set_value(
value: str | None,
use_prompt: bool,
from_file: str | None,
from_base64: str | None,
from_base64_file: str | None,
key: str,
) -> str:
"""Resolve the value for a set operation from the available input modes.

Returns the resolved value string. Calls ``sys.exit(1)`` if no value
can be determined.
can be determined or if multiple input modes are specified.
"""

# Count how many input sources are provided
sources = sum(
[
value is not None,
use_prompt,
from_file is not None,
from_base64 is not None,
from_base64_file is not None,
]
)
if sources > 1:
click.echo(
"Error: Specify only one of: <value>, --prompt, --file, --base64, --base64-file.",
err=True,
)
sys.exit(1)

if use_prompt:
resolved: str = click.prompt(f"Value for {key}", hide_input=True)
if not resolved:
click.echo("Error: Empty value.", err=True)
sys.exit(1)
return resolved

if from_file:
return Path(from_file).read_text(encoding="utf-8")
return clean_multiline_value(Path(from_file).read_text(encoding="utf-8"))

if from_base64 is not None:
try:
return base64.b64decode(from_base64).decode("utf-8")
except Exception:
click.echo("Error: Invalid base64 input.", err=True)
sys.exit(1)

if from_base64_file is not None:
try:
if from_base64_file == "-":
raw = sys.stdin.read().strip()
else:
raw = Path(from_base64_file).read_text(encoding="utf-8").strip()
return base64.b64decode(raw).decode("utf-8")
except Exception:
click.echo("Error: Invalid base64 input.", err=True)
sys.exit(1)

if value is None:
click.echo("Error: No value provided. Use <value>, --prompt or --file.", err=True)
click.echo(
"Error: No value provided. Use <value>, --prompt, --file, --base64, or --base64-file.",
err=True,
)
sys.exit(1)
return value

Expand Down
11 changes: 11 additions & 0 deletions src/vaultctl/yaml_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,14 @@ def load_yaml_text(text: str) -> dict[str, Any]:
def dump_yaml_text(data: dict[str, Any]) -> str:
"""Serialize *data* to a YAML string with stable formatting."""
return yaml.dump(data, default_flow_style=False, allow_unicode=True, sort_keys=True)


def clean_multiline_value(value: str) -> str:
"""Strip trailing whitespace per line and ensure exactly one trailing newline.

This is essential for SSH keys and certificates stored in YAML,
where multiline formatting may introduce trailing spaces.
"""
lines = value.splitlines()
cleaned = "\n".join(line.rstrip() for line in lines)
return cleaned.rstrip("\n") + "\n"
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def vault_file(tmp_path):
"username": "deploy",
"password": "d3ploy",
},
"ssh_key": "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEA \nAAAAGnRlc3Qga2V5 \n-----END OPENSSH PRIVATE KEY-----\n",
}
plain = tmp_path / "vault-plain.yml"
encrypted = tmp_path / "vault.yml"
Expand Down
168 changes: 168 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,3 +478,171 @@ def test_search_context_top_level_string(runner, cli_env):
result = runner.invoke(main, ["search", "test_value", "--context"])
assert result.exit_code == 0
assert "test_key" in result.output


# --- get --raw tests ---


def test_get_raw_plain_string(runner, cli_env):
"""--raw outputs a plain string value without any formatting."""
result = runner.invoke(main, ["get", "test_key", "--raw"])
assert result.exit_code == 0
assert result.output == "test_value\n"


def test_get_raw_multiline_strips_trailing_whitespace(runner, cli_env):
"""--raw strips trailing whitespace from each line of a multiline value."""
result = runner.invoke(main, ["get", "ssh_key", "--raw"])
assert result.exit_code == 0
# The test fixture has trailing spaces on some lines
for line in result.output.splitlines():
assert line == line.rstrip(), f"Trailing whitespace found: {line!r}"
assert result.output.endswith("\n")
assert "BEGIN OPENSSH PRIVATE KEY" in result.output


def test_get_raw_with_field(runner, cli_env):
"""--raw --field outputs only the field value, no 'Type:' header."""
result = runner.invoke(main, ["get", "db_creds", "--field", "username", "--raw"])
assert result.exit_code == 0
assert result.output == "admin\n"
assert "Type:" not in result.output


def test_get_raw_structured_entry(runner, cli_env):
"""--raw on a structured entry outputs YAML without 'Type:' header."""
result = runner.invoke(main, ["get", "db_creds", "--raw"])
assert result.exit_code == 0
assert "Type:" not in result.output
assert "username: admin" in result.output


# --- get --base64 tests ---


def test_get_base64_plain_string(runner, cli_env):
"""--base64 outputs the value as base64-encoded string."""
import base64

result = runner.invoke(main, ["get", "test_key", "--base64"])
assert result.exit_code == 0
decoded = base64.b64decode(result.output.strip()).decode("utf-8")
assert "test_value" in decoded


def test_get_base64_multiline(runner, cli_env):
"""--base64 on a multiline value produces a single base64 line."""
import base64

result = runner.invoke(main, ["get", "ssh_key", "--base64"])
assert result.exit_code == 0
# Output should be a single line (base64 encoded)
assert "\n" not in result.output.strip()
decoded = base64.b64decode(result.output.strip()).decode("utf-8")
assert "BEGIN OPENSSH PRIVATE KEY" in decoded
# Decoded value should have no trailing whitespace on lines
for line in decoded.splitlines():
assert line == line.rstrip()


def test_get_base64_with_field(runner, cli_env):
"""--base64 --field outputs the field value base64-encoded."""
import base64

result = runner.invoke(main, ["get", "db_creds", "--field", "password", "--base64"])
assert result.exit_code == 0
decoded = base64.b64decode(result.output.strip()).decode("utf-8")
assert "s3cret" in decoded


# --- mutually exclusive output flags ---


def test_get_mutually_exclusive_flags(runner, cli_env):
"""--json, --raw, and --base64 are mutually exclusive."""
result = runner.invoke(main, ["get", "test_key", "--raw", "--json"])
assert result.exit_code == 1
assert "mutually exclusive" in result.output

result = runner.invoke(main, ["get", "test_key", "--raw", "--base64"])
assert result.exit_code == 1
assert "mutually exclusive" in result.output

result = runner.invoke(main, ["get", "test_key", "--json", "--base64"])
assert result.exit_code == 1
assert "mutually exclusive" in result.output


# --- set --base64 tests ---


def test_set_base64_inline(runner, cli_env):
"""--base64 decodes a base64 value before storing."""
import base64

encoded = base64.b64encode(b"decoded_secret").decode("ascii")
result = runner.invoke(main, ["set", "b64_key", "--base64", encoded, "--force", "--no-backup"])
assert result.exit_code == 0
assert "Added" in result.output

# Verify the decoded value was stored
result = runner.invoke(main, ["get", "b64_key", "--raw"])
assert result.exit_code == 0
assert "decoded_secret" in result.output


def test_set_base64_invalid(runner, cli_env):
"""--base64 with invalid base64 input should fail."""
result = runner.invoke(main, ["set", "b64_key", "--base64", "not-valid-base64!!!", "--force"])
assert result.exit_code == 1
assert "Invalid base64" in result.output


def test_set_base64_file_from_file(runner, cli_env, tmp_path):
"""--base64-file reads base64 from a file and decodes it."""
import base64

secret = "file_based_secret"
b64_file = tmp_path / "encoded.b64"
b64_file.write_text(base64.b64encode(secret.encode()).decode())

result = runner.invoke(main, ["set", "b64f_key", "--base64-file", str(b64_file), "--force", "--no-backup"])
assert result.exit_code == 0

result = runner.invoke(main, ["get", "b64f_key", "--raw"])
assert result.exit_code == 0
assert "file_based_secret" in result.output


def test_set_base64_file_stdin(runner, cli_env):
"""--base64-file - reads base64 from stdin."""
import base64

encoded = base64.b64encode(b"stdin_secret").decode("ascii")
result = runner.invoke(main, ["set", "stdin_key", "--base64-file", "-", "--force", "--no-backup"], input=encoded)
assert result.exit_code == 0

result = runner.invoke(main, ["get", "stdin_key", "--raw"])
assert result.exit_code == 0
assert "stdin_secret" in result.output


def test_set_multiple_sources_rejected(runner, cli_env):
"""Specifying multiple input sources should fail."""
result = runner.invoke(main, ["set", "key", "value", "--base64", "abc", "--force"])
assert result.exit_code == 1
assert "Specify only one" in result.output


def test_set_file_cleans_whitespace(runner, cli_env, tmp_path):
"""--file import applies whitespace cleanup to multiline values."""
key_file = tmp_path / "key.pem"
key_file.write_text("-----BEGIN KEY-----\nline1 \nline2\t\n-----END KEY-----\n")

result = runner.invoke(main, ["set", "clean_key", "--file", str(key_file), "--force", "--no-backup"])
assert result.exit_code == 0

result = runner.invoke(main, ["get", "clean_key", "--raw"])
assert result.exit_code == 0
for line in result.output.splitlines():
assert line == line.rstrip(), f"Trailing whitespace found: {line!r}"
Loading
Loading