Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion e2e-tests/src/databases/duckdb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

@dataclass(frozen=True)
class DuckdbDB:
datasource_name: str | None = "test duckdb conn"
datasource_name: str | None = "test_duckdb_conn"
datasource_type: str = "duckdb"
database_path: Path | None = None
check_connection: bool = False
Expand Down
2 changes: 1 addition & 1 deletion e2e-tests/src/databases/sqlite_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

@dataclass(frozen=True)
class SqliteDB:
datasource_name: str | None = "test sqlite conn"
datasource_name: str | None = "test_sqlite_conn"
datasource_type: str = "sqlite"
database_path: Path | None = None
check_connection: bool = False
Expand Down
4 changes: 2 additions & 2 deletions e2e-tests/tests/resources/duckdb_introspections.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# ===== test duckdb conn.yaml =====
datasource_id: test duckdb conn.yaml
# ===== test_duckdb_conn.yaml =====
datasource_id: test_duckdb_conn.yaml
datasource_type: duckdb
context_built_at: 2026-02-27 16:38:36.999625
context:
Expand Down
4 changes: 2 additions & 2 deletions e2e-tests/tests/resources/sqlite_introspections.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# ===== test sqlite conn.yaml =====
datasource_id: test sqlite conn.yaml
# ===== test_sqlite_conn.yaml =====
datasource_id: test_sqlite_conn.yaml
datasource_type: sqlite
context_built_at: 2026-02-27 16:15:31.010532
context:
Expand Down
102 changes: 102 additions & 0 deletions src/databao_cli/features/datasource/validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""Validation rules for datasource fields.

These rules are shared between the CLI workflow and the Streamlit UI so
that users get the same feedback regardless of how they create a
datasource.
"""

import ipaddress
import re

# The agent requires source names to match this pattern so they can be
# used unquoted in SQL queries. See databao-agent domain.py.
_SOURCE_NAME_RE = re.compile(r"^[A-Za-z][A-Za-z0-9_]*$")

# Folder segments in the datasource path are more permissive — they only
# need to be valid filesystem names.
_FOLDER_SEGMENT_RE = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9._-]*[a-zA-Z0-9]$|^[a-zA-Z0-9]$")

MAX_DATASOURCE_NAME_LENGTH = 255

MIN_PORT = 1
MAX_PORT = 65535

# Hostname: RFC 952 / RFC 1123 — labels separated by dots.
_HOSTNAME_RE = re.compile(r"^(?!-)[A-Za-z0-9-]{1,63}(?<!-)(\.[A-Za-z0-9-]{1,63})*$")


def validate_datasource_name(name: str) -> str | None:
"""Return an error message if *name* is invalid, or ``None`` if it is OK.

Datasource names may contain forward-slash separators to organise
datasources into folders (e.g. ``resources/my_db``). Folder segments
are validated as filesystem-safe names. The final segment (the actual
source name) must match the agent's ``^[A-Za-z][A-Za-z0-9_]*$``
pattern so it can be used unquoted in SQL queries.
"""
if not name or not name.strip():
return "Datasource name must not be empty."

if len(name) > MAX_DATASOURCE_NAME_LENGTH:
return f"Datasource name must be at most {MAX_DATASOURCE_NAME_LENGTH} characters."

if re.search(r"\s", name):
return "Datasource name must not contain whitespace."

segments = name.split("/")
for segment in segments:
if not segment:
return (
"Datasource name must not contain empty path segments (for example, leading, trailing, or repeated slashes)."
)

# Validate folder segments (all but the last).
for segment in segments[:-1]:
if not _FOLDER_SEGMENT_RE.match(segment):
return (
f"Folder segment '{segment}' may only contain letters, digits, "
"hyphens, underscores, and dots, and must start and end with "
"a letter or digit."
)

# Validate the source name (last segment) against the agent's pattern.
source_name = segments[-1]
if not _SOURCE_NAME_RE.match(source_name):
return "Datasource name must start with a letter and contain only letters, digits, and underscores."

return None


def validate_port(value: str) -> str | None:
"""Return an error message if *value* is not a valid port number."""
try:
port = int(value)
except ValueError:
return "Port must be a number."
if port < MIN_PORT or port > MAX_PORT:
return f"Port must be between {MIN_PORT} and {MAX_PORT}."
return None


def validate_hostname(value: str) -> str | None:
"""Return an error message if *value* is not a valid hostname or IP."""
value = value.strip()
if not value:
return "Hostname must not be empty."
# Allow localhost and IP addresses as-is.
if value == "localhost" or _is_ip_address(value):
return None
if len(value) > 253:
return "Hostname must not exceed 253 characters."
if not _HOSTNAME_RE.match(value):
return "Hostname contains invalid characters."
return None


def _is_ip_address(value: str) -> bool:
"""Return True if *value* is a valid IPv4 or IPv6 address."""
try:
ipaddress.ip_address(value)
return True
except ValueError:
return False
99 changes: 99 additions & 0 deletions src/databao_cli/features/ui/components/datasource_form.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,107 @@
ConfigUnionPropertyDefinition,
)

from databao_cli.features.datasource.validation import validate_hostname, validate_port

SKIP_TOP_LEVEL_KEYS = {"type", "name"}

# Field keys that represent a network host.
_HOST_KEYS = {"host", "hostname"}

# Field keys that represent a network port.
_PORT_KEYS = {"port"}


def validate_config_fields(
config_fields: list[ConfigPropertyDefinition],
values: dict[str, Any],
) -> list[str]:
"""Return human-readable error strings for invalid or missing fields.

Checks are applied recursively but only to leaf
``ConfigSinglePropertyDefinition`` nodes. Checks performed:

* Required fields must not be empty.
* ``int``-typed fields must be valid integers.
* Fields named ``port`` must be in the 1-65535 range.
* Fields named ``host`` / ``hostname`` must be valid hostnames or IPs.
"""
return _validate_fields_recursive(config_fields, values)


def _validate_fields_recursive(
config_fields: list[ConfigPropertyDefinition],
values: dict[str, Any],
path_prefix: str = "",
) -> list[str]:
errors: list[str] = []
for prop in config_fields:
# Only skip special top-level keys like "type" / "name".
if not path_prefix and prop.property_key in SKIP_TOP_LEVEL_KEYS:
continue

full_key = f"{path_prefix}{prop.property_key}" if path_prefix else prop.property_key

if isinstance(prop, ConfigUnionPropertyDefinition):
union_vals = values.get(prop.property_key, {})
if not isinstance(union_vals, dict):
union_vals = {}
type_choices = {t.__name__: t for t in prop.types}
selected_name = union_vals.get("type") or _infer_union_type(union_vals, type_choices, prop.type_properties)
if selected_name and selected_name in type_choices:
selected_type = type_choices[selected_name]
nested_props = prop.type_properties.get(selected_type, [])
errors.extend(_validate_fields_recursive(nested_props, union_vals, f"{full_key}."))
elif len(prop.types) == 1:
sole_type = prop.types[0]
nested_props = prop.type_properties.get(sole_type, [])
errors.extend(_validate_fields_recursive(nested_props, union_vals, f"{full_key}."))
elif len(prop.types) > 1:
errors.append(
f"{full_key}: union type could not be determined; "
"select a configuration variant and provide required fields"
)
continue

if isinstance(prop, ConfigSinglePropertyDefinition):
if prop.nested_properties:
nested_vals = values.get(prop.property_key, {})
if not isinstance(nested_vals, dict):
nested_vals = {}
errors.extend(_validate_fields_recursive(prop.nested_properties, nested_vals, f"{full_key}."))
continue

raw = values.get(prop.property_key)
text = str(raw).strip() if raw is not None else ""

# Required check.
if prop.required and not text:
errors.append(f"{full_key}: required field is empty")
continue

if not text:
continue

# Type-specific checks.
if prop.property_key in _PORT_KEYS:
port_err = validate_port(text)
if port_err:
errors.append(f"{full_key}: {port_err}")
continue
elif prop.property_type is int:
try:
int(text)
except ValueError:
errors.append(f"{full_key}: must be a valid integer")
continue

if prop.property_key in _HOST_KEYS:
host_err = validate_hostname(text)
if host_err:
errors.append(f"{full_key}: {host_err}")

return errors


def render_datasource_config_form(
config_fields: list[ConfigPropertyDefinition],
Expand Down
62 changes: 47 additions & 15 deletions src/databao_cli/features/ui/components/datasource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,14 @@

import streamlit as st
from databao_context_engine import ConfiguredDatasource, DatasourceConnectionStatus
from databao_context_engine.pluginlib.config import ConfigPropertyDefinition

from databao_cli.features.datasource.validation import validate_datasource_name
from databao_cli.features.ui.app import invalidate_agent
from databao_cli.features.ui.components.datasource_form import render_datasource_config_form
from databao_cli.features.ui.components.datasource_form import (
render_datasource_config_form,
validate_config_fields,
)
from databao_cli.features.ui.services.dce_operations import (
add_datasource,
get_available_datasource_types,
Expand Down Expand Up @@ -56,6 +61,22 @@ def render_datasource_manager(project_dir: Path, *, read_only: bool = False) ->
_render_existing_datasource(project_dir, ds, idx, read_only=read_only)


def _validate_new_datasource_inputs(ds_name: str | None, selected_type: str | None) -> str | None:
"""Validate the datasource name and type, showing errors via ``st.error``.

Returns the stripped name on success, or ``None`` if validation failed.
"""
stripped_name = (ds_name or "").strip()
name_error = validate_datasource_name(stripped_name)
if name_error:
st.error(name_error)
return None
if not selected_type:
st.error("Please select a datasource type.")
return None
return stripped_name


def _get_form_version() -> int:
"""Get the current form version counter used to reset widget keys."""
if "_new_ds_form_version" not in st.session_state:
Expand Down Expand Up @@ -89,6 +110,7 @@ def _render_add_datasource_section(project_dir: Path) -> None:
)

config_values: dict[str, Any] = {}
config_fields: list[ConfigPropertyDefinition] = []
if selected_type:
try:
config_fields = get_datasource_config_fields(selected_type)
Expand All @@ -105,13 +127,15 @@ def _render_add_datasource_section(project_dir: Path) -> None:

with col_add:
if st.button("Add datasource", key="add_ds_btn", type="primary", use_container_width=True):
if not ds_name or not ds_name.strip():
st.error("Please provide a datasource name.")
elif not selected_type:
st.error("Please select a datasource type.")
validated = _validate_new_datasource_inputs(ds_name, selected_type)
if validated is None:
pass # errors already shown
elif config_fields and (field_errors := validate_config_fields(config_fields, config_values)):
for err in field_errors:
st.error(err)
else:
try:
add_datasource(project_dir, selected_type, ds_name.strip(), config_values)
add_datasource(project_dir, selected_type, validated, config_values)
_clear_new_datasource_form()
st.rerun()
except Exception as e:
Expand All @@ -120,11 +144,15 @@ def _render_add_datasource_section(project_dir: Path) -> None:

with col_verify_new:
if st.button("Verify connection", key="verify_new_ds_btn", use_container_width=True):
if not ds_name or not ds_name.strip() or not selected_type:
st.error("Please provide a datasource name and type first.")
validated = _validate_new_datasource_inputs(ds_name, selected_type)
if validated is None:
pass # errors already shown
elif config_fields and (field_errors := validate_config_fields(config_fields, config_values)):
for err in field_errors:
st.error(err)
else:
try:
result = verify_datasource_config(selected_type, ds_name.strip(), config_values)
result = verify_datasource_config(selected_type, validated, config_values)
if result.connection_status == DatasourceConnectionStatus.VALID:
st.success("Connection valid.")
elif result.connection_status == DatasourceConnectionStatus.UNKNOWN:
Expand Down Expand Up @@ -177,12 +205,16 @@ def _render_existing_datasource(project_dir: Path, ds: ConfiguredDatasource, idx
disabled=not has_changes,
use_container_width=True,
):
try:
save_datasource(project_dir, ds_type, ds_name, edited_values)
st.success("Saved.")
st.rerun()
except Exception as e:
st.error(f"Save failed: {e}")
if config_fields and (field_errors := validate_config_fields(config_fields, edited_values)):
for err in field_errors:
st.error(err)
else:
try:
save_datasource(project_dir, ds_type, ds_name, edited_values)
st.success("Saved.")
st.rerun()
except Exception as e:
st.error(f"Save failed: {e}")

with col_verify:
if st.button("Verify", key=f"verify_ds_{idx}", use_container_width=True):
Expand Down
9 changes: 8 additions & 1 deletion src/databao_cli/workflows/datasource/add.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
DatasourceType,
)

from databao_cli.features.datasource.validation import validate_datasource_name
from databao_cli.shared.context_engine_cli import ClickUserInputCallback
from databao_cli.shared.project.layout import ProjectLayout
from databao_cli.workflows.datasource.check import print_connection_check_results
Expand All @@ -24,7 +25,13 @@ def add_workflow(project_layout: ProjectLayout, domain: str) -> None:
click.echo(f"We will guide you to add a new datasource into {domain} domain, at {domain_dir.resolve()}")

datasource_type = _ask_for_datasource_type(plugin_loader.get_all_supported_datasource_types(exclude_file_plugins=True))
datasource_name = click.prompt("Datasource name?", type=str)

while True:
datasource_name = click.prompt("Datasource name?", type=str).strip()
name_error = validate_datasource_name(datasource_name)
if name_error is None:
break
click.secho(name_error, fg="red", err=True)

overwrite_existing = False
existing_id = datasource_config_exists(project_layout, domain, datasource_name)
Expand Down
Loading
Loading