diff --git a/e2e-tests/src/databases/duckdb_utils.py b/e2e-tests/src/databases/duckdb_utils.py index 7d5c1e21..fe54c29a 100644 --- a/e2e-tests/src/databases/duckdb_utils.py +++ b/e2e-tests/src/databases/duckdb_utils.py @@ -6,7 +6,7 @@ @dataclass(frozen=True) class DuckdbDB: - datasource_name: str | None = "test duckdb conn" + datasource_name: str | None = "test_duckdb_conn" datasource_type: str = "duckdb" database_path: Path | None = None check_connection: bool = False diff --git a/e2e-tests/src/databases/sqlite_utils.py b/e2e-tests/src/databases/sqlite_utils.py index 3ebb2e2c..dc12997b 100644 --- a/e2e-tests/src/databases/sqlite_utils.py +++ b/e2e-tests/src/databases/sqlite_utils.py @@ -5,7 +5,7 @@ @dataclass(frozen=True) class SqliteDB: - datasource_name: str | None = "test sqlite conn" + datasource_name: str | None = "test_sqlite_conn" datasource_type: str = "sqlite" database_path: Path | None = None check_connection: bool = False diff --git a/e2e-tests/tests/resources/duckdb_introspections.yaml b/e2e-tests/tests/resources/duckdb_introspections.yaml index eca0a34d..23a95bdf 100644 --- a/e2e-tests/tests/resources/duckdb_introspections.yaml +++ b/e2e-tests/tests/resources/duckdb_introspections.yaml @@ -1,5 +1,5 @@ -# ===== test duckdb conn.yaml ===== -datasource_id: test duckdb conn.yaml +# ===== test_duckdb_conn.yaml ===== +datasource_id: test_duckdb_conn.yaml datasource_type: duckdb context_built_at: 2026-02-27 16:38:36.999625 context: diff --git a/e2e-tests/tests/resources/sqlite_introspections.yaml b/e2e-tests/tests/resources/sqlite_introspections.yaml index b1af0fe5..6c5da4db 100644 --- a/e2e-tests/tests/resources/sqlite_introspections.yaml +++ b/e2e-tests/tests/resources/sqlite_introspections.yaml @@ -1,5 +1,5 @@ -# ===== test sqlite conn.yaml ===== -datasource_id: test sqlite conn.yaml +# ===== test_sqlite_conn.yaml ===== +datasource_id: test_sqlite_conn.yaml datasource_type: sqlite context_built_at: 2026-02-27 16:15:31.010532 context: diff --git a/src/databao_cli/features/datasource/validation.py b/src/databao_cli/features/datasource/validation.py new file mode 100644 index 00000000..43448b1b --- /dev/null +++ b/src/databao_cli/features/datasource/validation.py @@ -0,0 +1,102 @@ +"""Validation rules for datasource fields. + +These rules are shared between the CLI workflow and the Streamlit UI so +that users get the same feedback regardless of how they create a +datasource. +""" + +import ipaddress +import re + +# The agent requires source names to match this pattern so they can be +# used unquoted in SQL queries. See databao-agent domain.py. +_SOURCE_NAME_RE = re.compile(r"^[A-Za-z][A-Za-z0-9_]*$") + +# Folder segments in the datasource path are more permissive — they only +# need to be valid filesystem names. +_FOLDER_SEGMENT_RE = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9._-]*[a-zA-Z0-9]$|^[a-zA-Z0-9]$") + +MAX_DATASOURCE_NAME_LENGTH = 255 + +MIN_PORT = 1 +MAX_PORT = 65535 + +# Hostname: RFC 952 / RFC 1123 — labels separated by dots. +_HOSTNAME_RE = re.compile(r"^(?!-)[A-Za-z0-9-]{1,63}(? str | None: + """Return an error message if *name* is invalid, or ``None`` if it is OK. + + Datasource names may contain forward-slash separators to organise + datasources into folders (e.g. ``resources/my_db``). Folder segments + are validated as filesystem-safe names. The final segment (the actual + source name) must match the agent's ``^[A-Za-z][A-Za-z0-9_]*$`` + pattern so it can be used unquoted in SQL queries. + """ + if not name or not name.strip(): + return "Datasource name must not be empty." + + if len(name) > MAX_DATASOURCE_NAME_LENGTH: + return f"Datasource name must be at most {MAX_DATASOURCE_NAME_LENGTH} characters." + + if re.search(r"\s", name): + return "Datasource name must not contain whitespace." + + segments = name.split("/") + for segment in segments: + if not segment: + return ( + "Datasource name must not contain empty path segments (for example, leading, trailing, or repeated slashes)." + ) + + # Validate folder segments (all but the last). + for segment in segments[:-1]: + if not _FOLDER_SEGMENT_RE.match(segment): + return ( + f"Folder segment '{segment}' may only contain letters, digits, " + "hyphens, underscores, and dots, and must start and end with " + "a letter or digit." + ) + + # Validate the source name (last segment) against the agent's pattern. + source_name = segments[-1] + if not _SOURCE_NAME_RE.match(source_name): + return "Datasource name must start with a letter and contain only letters, digits, and underscores." + + return None + + +def validate_port(value: str) -> str | None: + """Return an error message if *value* is not a valid port number.""" + try: + port = int(value) + except ValueError: + return "Port must be a number." + if port < MIN_PORT or port > MAX_PORT: + return f"Port must be between {MIN_PORT} and {MAX_PORT}." + return None + + +def validate_hostname(value: str) -> str | None: + """Return an error message if *value* is not a valid hostname or IP.""" + value = value.strip() + if not value: + return "Hostname must not be empty." + # Allow localhost and IP addresses as-is. + if value == "localhost" or _is_ip_address(value): + return None + if len(value) > 253: + return "Hostname must not exceed 253 characters." + if not _HOSTNAME_RE.match(value): + return "Hostname contains invalid characters." + return None + + +def _is_ip_address(value: str) -> bool: + """Return True if *value* is a valid IPv4 or IPv6 address.""" + try: + ipaddress.ip_address(value) + return True + except ValueError: + return False diff --git a/src/databao_cli/features/ui/components/datasource_form.py b/src/databao_cli/features/ui/components/datasource_form.py index a6662cc4..2fcbdd56 100644 --- a/src/databao_cli/features/ui/components/datasource_form.py +++ b/src/databao_cli/features/ui/components/datasource_form.py @@ -14,8 +14,107 @@ ConfigUnionPropertyDefinition, ) +from databao_cli.features.datasource.validation import validate_hostname, validate_port + SKIP_TOP_LEVEL_KEYS = {"type", "name"} +# Field keys that represent a network host. +_HOST_KEYS = {"host", "hostname"} + +# Field keys that represent a network port. +_PORT_KEYS = {"port"} + + +def validate_config_fields( + config_fields: list[ConfigPropertyDefinition], + values: dict[str, Any], +) -> list[str]: + """Return human-readable error strings for invalid or missing fields. + + Checks are applied recursively but only to leaf + ``ConfigSinglePropertyDefinition`` nodes. Checks performed: + + * Required fields must not be empty. + * ``int``-typed fields must be valid integers. + * Fields named ``port`` must be in the 1-65535 range. + * Fields named ``host`` / ``hostname`` must be valid hostnames or IPs. + """ + return _validate_fields_recursive(config_fields, values) + + +def _validate_fields_recursive( + config_fields: list[ConfigPropertyDefinition], + values: dict[str, Any], + path_prefix: str = "", +) -> list[str]: + errors: list[str] = [] + for prop in config_fields: + # Only skip special top-level keys like "type" / "name". + if not path_prefix and prop.property_key in SKIP_TOP_LEVEL_KEYS: + continue + + full_key = f"{path_prefix}{prop.property_key}" if path_prefix else prop.property_key + + if isinstance(prop, ConfigUnionPropertyDefinition): + union_vals = values.get(prop.property_key, {}) + if not isinstance(union_vals, dict): + union_vals = {} + type_choices = {t.__name__: t for t in prop.types} + selected_name = union_vals.get("type") or _infer_union_type(union_vals, type_choices, prop.type_properties) + if selected_name and selected_name in type_choices: + selected_type = type_choices[selected_name] + nested_props = prop.type_properties.get(selected_type, []) + errors.extend(_validate_fields_recursive(nested_props, union_vals, f"{full_key}.")) + elif len(prop.types) == 1: + sole_type = prop.types[0] + nested_props = prop.type_properties.get(sole_type, []) + errors.extend(_validate_fields_recursive(nested_props, union_vals, f"{full_key}.")) + elif len(prop.types) > 1: + errors.append( + f"{full_key}: union type could not be determined; " + "select a configuration variant and provide required fields" + ) + continue + + if isinstance(prop, ConfigSinglePropertyDefinition): + if prop.nested_properties: + nested_vals = values.get(prop.property_key, {}) + if not isinstance(nested_vals, dict): + nested_vals = {} + errors.extend(_validate_fields_recursive(prop.nested_properties, nested_vals, f"{full_key}.")) + continue + + raw = values.get(prop.property_key) + text = str(raw).strip() if raw is not None else "" + + # Required check. + if prop.required and not text: + errors.append(f"{full_key}: required field is empty") + continue + + if not text: + continue + + # Type-specific checks. + if prop.property_key in _PORT_KEYS: + port_err = validate_port(text) + if port_err: + errors.append(f"{full_key}: {port_err}") + continue + elif prop.property_type is int: + try: + int(text) + except ValueError: + errors.append(f"{full_key}: must be a valid integer") + continue + + if prop.property_key in _HOST_KEYS: + host_err = validate_hostname(text) + if host_err: + errors.append(f"{full_key}: {host_err}") + + return errors + def render_datasource_config_form( config_fields: list[ConfigPropertyDefinition], diff --git a/src/databao_cli/features/ui/components/datasource_manager.py b/src/databao_cli/features/ui/components/datasource_manager.py index 4da0f250..0a3f56a0 100644 --- a/src/databao_cli/features/ui/components/datasource_manager.py +++ b/src/databao_cli/features/ui/components/datasource_manager.py @@ -10,9 +10,14 @@ import streamlit as st from databao_context_engine import ConfiguredDatasource, DatasourceConnectionStatus +from databao_context_engine.pluginlib.config import ConfigPropertyDefinition +from databao_cli.features.datasource.validation import validate_datasource_name from databao_cli.features.ui.app import invalidate_agent -from databao_cli.features.ui.components.datasource_form import render_datasource_config_form +from databao_cli.features.ui.components.datasource_form import ( + render_datasource_config_form, + validate_config_fields, +) from databao_cli.features.ui.services.dce_operations import ( add_datasource, get_available_datasource_types, @@ -56,6 +61,22 @@ def render_datasource_manager(project_dir: Path, *, read_only: bool = False) -> _render_existing_datasource(project_dir, ds, idx, read_only=read_only) +def _validate_new_datasource_inputs(ds_name: str | None, selected_type: str | None) -> str | None: + """Validate the datasource name and type, showing errors via ``st.error``. + + Returns the stripped name on success, or ``None`` if validation failed. + """ + stripped_name = (ds_name or "").strip() + name_error = validate_datasource_name(stripped_name) + if name_error: + st.error(name_error) + return None + if not selected_type: + st.error("Please select a datasource type.") + return None + return stripped_name + + def _get_form_version() -> int: """Get the current form version counter used to reset widget keys.""" if "_new_ds_form_version" not in st.session_state: @@ -89,6 +110,7 @@ def _render_add_datasource_section(project_dir: Path) -> None: ) config_values: dict[str, Any] = {} + config_fields: list[ConfigPropertyDefinition] = [] if selected_type: try: config_fields = get_datasource_config_fields(selected_type) @@ -105,13 +127,15 @@ def _render_add_datasource_section(project_dir: Path) -> None: with col_add: if st.button("Add datasource", key="add_ds_btn", type="primary", use_container_width=True): - if not ds_name or not ds_name.strip(): - st.error("Please provide a datasource name.") - elif not selected_type: - st.error("Please select a datasource type.") + validated = _validate_new_datasource_inputs(ds_name, selected_type) + if validated is None: + pass # errors already shown + elif config_fields and (field_errors := validate_config_fields(config_fields, config_values)): + for err in field_errors: + st.error(err) else: try: - add_datasource(project_dir, selected_type, ds_name.strip(), config_values) + add_datasource(project_dir, selected_type, validated, config_values) _clear_new_datasource_form() st.rerun() except Exception as e: @@ -120,11 +144,15 @@ def _render_add_datasource_section(project_dir: Path) -> None: with col_verify_new: if st.button("Verify connection", key="verify_new_ds_btn", use_container_width=True): - if not ds_name or not ds_name.strip() or not selected_type: - st.error("Please provide a datasource name and type first.") + validated = _validate_new_datasource_inputs(ds_name, selected_type) + if validated is None: + pass # errors already shown + elif config_fields and (field_errors := validate_config_fields(config_fields, config_values)): + for err in field_errors: + st.error(err) else: try: - result = verify_datasource_config(selected_type, ds_name.strip(), config_values) + result = verify_datasource_config(selected_type, validated, config_values) if result.connection_status == DatasourceConnectionStatus.VALID: st.success("Connection valid.") elif result.connection_status == DatasourceConnectionStatus.UNKNOWN: @@ -177,12 +205,16 @@ def _render_existing_datasource(project_dir: Path, ds: ConfiguredDatasource, idx disabled=not has_changes, use_container_width=True, ): - try: - save_datasource(project_dir, ds_type, ds_name, edited_values) - st.success("Saved.") - st.rerun() - except Exception as e: - st.error(f"Save failed: {e}") + if config_fields and (field_errors := validate_config_fields(config_fields, edited_values)): + for err in field_errors: + st.error(err) + else: + try: + save_datasource(project_dir, ds_type, ds_name, edited_values) + st.success("Saved.") + st.rerun() + except Exception as e: + st.error(f"Save failed: {e}") with col_verify: if st.button("Verify", key=f"verify_ds_{idx}", use_container_width=True): diff --git a/src/databao_cli/workflows/datasource/add.py b/src/databao_cli/workflows/datasource/add.py index 6517998c..835bb78d 100644 --- a/src/databao_cli/workflows/datasource/add.py +++ b/src/databao_cli/workflows/datasource/add.py @@ -9,6 +9,7 @@ DatasourceType, ) +from databao_cli.features.datasource.validation import validate_datasource_name from databao_cli.shared.context_engine_cli import ClickUserInputCallback from databao_cli.shared.project.layout import ProjectLayout from databao_cli.workflows.datasource.check import print_connection_check_results @@ -24,7 +25,13 @@ def add_workflow(project_layout: ProjectLayout, domain: str) -> None: click.echo(f"We will guide you to add a new datasource into {domain} domain, at {domain_dir.resolve()}") datasource_type = _ask_for_datasource_type(plugin_loader.get_all_supported_datasource_types(exclude_file_plugins=True)) - datasource_name = click.prompt("Datasource name?", type=str) + + while True: + datasource_name = click.prompt("Datasource name?", type=str).strip() + name_error = validate_datasource_name(datasource_name) + if name_error is None: + break + click.secho(name_error, fg="red", err=True) overwrite_existing = False existing_id = datasource_config_exists(project_layout, domain, datasource_name) diff --git a/tests/test_datasource_form_validation.py b/tests/test_datasource_form_validation.py new file mode 100644 index 00000000..6c45772c --- /dev/null +++ b/tests/test_datasource_form_validation.py @@ -0,0 +1,162 @@ +"""Tests for validate_config_fields() in datasource_form.py.""" + +from databao_context_engine.pluginlib.config import ( + ConfigPropertyDefinition, + ConfigSinglePropertyDefinition, + ConfigUnionPropertyDefinition, +) + +from databao_cli.features.ui.components.datasource_form import validate_config_fields + + +def _single( + key: str, + *, + required: bool = False, + property_type: type = str, + nested: list[ConfigPropertyDefinition] | None = None, +) -> ConfigSinglePropertyDefinition: + return ConfigSinglePropertyDefinition( + property_key=key, + required=required, + property_type=property_type, + nested_properties=nested, + ) + + +def _fields(*props: ConfigPropertyDefinition) -> list[ConfigPropertyDefinition]: + """Helper to build a correctly-typed field list.""" + return list(props) + + +class TestIntVsPortValidation: + """Port range check only applies to port-named fields, not all int fields.""" + + def test_non_port_int_field_accepts_large_number(self) -> None: + fields = _fields(_single("timeout", property_type=int)) + errors = validate_config_fields(fields, {"timeout": "120000"}) + assert errors == [] + + def test_non_port_int_field_rejects_non_numeric(self) -> None: + fields = _fields(_single("retries", property_type=int)) + errors = validate_config_fields(fields, {"retries": "abc"}) + assert len(errors) == 1 + assert "integer" in errors[0].lower() + + def test_port_field_rejects_out_of_range(self) -> None: + fields = _fields(_single("port", property_type=int)) + errors = validate_config_fields(fields, {"port": "99999"}) + assert len(errors) == 1 + assert "between" in errors[0].lower() + + def test_port_field_accepts_valid_port(self) -> None: + fields = _fields(_single("port", property_type=int)) + errors = validate_config_fields(fields, {"port": "5432"}) + assert errors == [] + + +class TestSkipTopLevelKeysScope: + """SKIP_TOP_LEVEL_KEYS should only apply at the top level.""" + + def test_top_level_type_is_skipped(self) -> None: + fields = _fields(_single("type", required=True), _single("host", required=True)) + errors = validate_config_fields(fields, {"host": ""}) + # "type" should be skipped, only "host" error expected + assert len(errors) == 1 + assert "host" in errors[0] + + def test_nested_name_field_is_validated(self) -> None: + fields = _fields( + _single( + "connection", + nested=[_single("name", required=True)], + ), + ) + errors = validate_config_fields(fields, {"connection": {"name": ""}}) + assert len(errors) == 1 + assert "connection.name" in errors[0] + + +class TestUnionPropertyValidation: + """Union properties should be validated recursively.""" + + def test_union_required_field_missing(self) -> None: + class VariantA: + pass + + union = ConfigUnionPropertyDefinition( + property_key="auth", + types=(VariantA,), + type_properties={ + VariantA: [_single("username", required=True)], + }, + ) + fields = _fields(union) + errors = validate_config_fields(fields, {"auth": {"type": "VariantA", "username": ""}}) + assert len(errors) == 1 + assert "username" in errors[0] + + def test_union_valid_fields_pass(self) -> None: + class VariantA: + pass + + union = ConfigUnionPropertyDefinition( + property_key="auth", + types=(VariantA,), + type_properties={ + VariantA: [_single("username", required=True)], + }, + ) + fields = _fields(union) + errors = validate_config_fields(fields, {"auth": {"type": "VariantA", "username": "admin"}}) + assert errors == [] + + def test_single_variant_union_validates_with_empty_dict(self) -> None: + """When there's only one variant and no type discriminator, still validate.""" + + class OnlyVariant: + pass + + union = ConfigUnionPropertyDefinition( + property_key="auth", + types=(OnlyVariant,), + type_properties={ + OnlyVariant: [_single("token", required=True)], + }, + ) + fields = _fields(union) + errors = validate_config_fields(fields, {"auth": {}}) + assert len(errors) == 1 + assert "token" in errors[0] + + def test_multi_variant_union_errors_when_type_unknown(self) -> None: + """When multiple variants exist and type can't be inferred, emit an error.""" + + class VariantA: + pass + + class VariantB: + pass + + union = ConfigUnionPropertyDefinition( + property_key="auth", + types=(VariantA, VariantB), + type_properties={ + VariantA: [_single("username", required=True)], + VariantB: [_single("token", required=True)], + }, + ) + fields = _fields(union) + errors = validate_config_fields(fields, {"auth": {}}) + assert len(errors) == 1 + assert "could not be determined" in errors[0] + + +class TestHostValidation: + """Host field validation within config fields.""" + + def test_hostname_with_port_rejected(self) -> None: + fields = _fields(_single("host")) + errors = validate_config_fields(fields, {"host": "db.example.com:5432"}) + assert len(errors) == 1 + assert "host" in errors[0] diff --git a/tests/test_datasource_validation.py b/tests/test_datasource_validation.py new file mode 100644 index 00000000..e3169b6d --- /dev/null +++ b/tests/test_datasource_validation.py @@ -0,0 +1,147 @@ +import pytest + +from databao_cli.features.datasource.validation import ( + MAX_DATASOURCE_NAME_LENGTH, + validate_datasource_name, + validate_hostname, + validate_port, +) + + +class TestValidateDatasourceName: + """Tests for validate_datasource_name(). + + The last segment must match the agent's pattern: ^[A-Za-z][A-Za-z0-9_]*$ + Folder segments (preceding the last) are more permissive. + """ + + @pytest.mark.parametrize( + "name", + [ + "my_datasource", + "ds1", + "a", + "A", + "SnowflakeProd", + "test_db", + "ab", + "resources/my_db", + "folder/sub/name", + "my-folder/my_db", + "v2.data/source1", + ], + ) + def test_valid_names(self, name: str) -> None: + assert validate_datasource_name(name) is None + + @pytest.mark.parametrize( + "name", + [ + "my-datasource", + "my.datasource", + "Snowflake-Prod", + "test_db.v2", + "1startsWithDigit", + "_leading_underscore", + ], + ) + def test_invalid_source_name_segment(self, name: str) -> None: + """Last segment must match agent pattern: letter then [A-Za-z0-9_]*.""" + assert validate_datasource_name(name) is not None + + def test_empty_name(self) -> None: + assert validate_datasource_name("") is not None + + def test_whitespace_only(self) -> None: + assert validate_datasource_name(" ") is not None + + def test_name_with_spaces(self) -> None: + error = validate_datasource_name("my datasource") + assert error is not None + assert "whitespace" in error.lower() + + @pytest.mark.parametrize("name", ["my\tdatasource", "my\ndatasource", "my\u00a0datasource"]) + def test_name_with_non_space_whitespace(self, name: str) -> None: + """Tabs, newlines, and non-breaking spaces should also be rejected.""" + error = validate_datasource_name(name) + assert error is not None + assert "whitespace" in error.lower() + + @pytest.mark.parametrize("char", ["@", "#", "$", "%", "!", "?", "\\", ":", "*"]) + def test_forbidden_characters(self, char: str) -> None: + assert validate_datasource_name(f"ds{char}name") is not None + + def test_double_slash_rejected(self) -> None: + assert validate_datasource_name("a//b") is not None + + def test_leading_slash_rejected(self) -> None: + assert validate_datasource_name("/a") is not None + + def test_trailing_slash_rejected(self) -> None: + assert validate_datasource_name("a/") is not None + + def test_invalid_folder_segment(self) -> None: + assert validate_datasource_name(".hidden/my_db") is not None + + def test_name_too_long(self) -> None: + long_name = "a" * (MAX_DATASOURCE_NAME_LENGTH + 1) + error = validate_datasource_name(long_name) + assert error is not None + assert "255" in error + + def test_name_at_max_length(self) -> None: + name = "a" * MAX_DATASOURCE_NAME_LENGTH + assert validate_datasource_name(name) is None + + +class TestValidatePort: + """Tests for validate_port().""" + + @pytest.mark.parametrize("value", ["1", "80", "443", "5432", "65535"]) + def test_valid_ports(self, value: str) -> None: + assert validate_port(value) is None + + @pytest.mark.parametrize("value", ["0", "-1", "65536", "99999"]) + def test_out_of_range(self, value: str) -> None: + error = validate_port(value) + assert error is not None + assert "between" in error.lower() + + @pytest.mark.parametrize("value", ["abc", "12.5", "", " "]) + def test_non_numeric(self, value: str) -> None: + error = validate_port(value) + assert error is not None + assert "number" in error.lower() + + +class TestValidateHostname: + """Tests for validate_hostname().""" + + @pytest.mark.parametrize( + "value", + [ + "localhost", + "127.0.0.1", + "192.168.1.1", + "my-host", + "db.example.com", + "my-db.internal.corp.net", + "::1", + "2001:db8::1", + ], + ) + def test_valid_hostnames(self, value: str) -> None: + assert validate_hostname(value) is None + + def test_empty_hostname(self) -> None: + assert validate_hostname("") is not None + assert validate_hostname(" ") is not None + + @pytest.mark.parametrize("value", ["-leading-hyphen", "trailing-hyphen-"]) + def test_invalid_hostnames(self, value: str) -> None: + assert validate_hostname(value) is not None + + @pytest.mark.parametrize("value", ["not-an-ip:still-not", "db.example.com:5432"]) + def test_colon_strings_not_treated_as_ip(self, value: str) -> None: + """Strings with colons that are not valid IPv6 should not pass as IPs.""" + assert validate_hostname(value) is not None