Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 165 additions & 64 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,35 +6,127 @@
import duckdb
import logging
import os
import re
import time
from contextlib import asynccontextmanager
from typing import List, Optional, Any
from typing import Any, List, Literal, Optional

from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
import uvicorn

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


# --- Input validation -------------------------------------------------------
# Patterns tight enough to reject SQL-breaking characters (quotes, semicolons,
# whitespace, control chars) without blocking legitimate values. Applied at
# the API boundary by Pydantic validators so they run before any config value
# reaches DuckDB.

_ENDPOINT_RE = re.compile(r"^[A-Za-z0-9\-._:/@+%]+$")
_REGION_RE = re.compile(r"^[A-Za-z0-9\-]+$")
_SESSION_TOKEN_RE = re.compile(r"^[A-Za-z0-9+/=\-_.]+$")
_URL_RE = re.compile(r"^https?://[A-Za-z0-9\-._:/@]+$")
_SQL_IDENT_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
_S3_PATH_RE = re.compile(r"^s3://[A-Za-z0-9\-._/]+$")


def _require_match(value: str, pattern: "re.Pattern[str]", field: str) -> str:
if not pattern.fullmatch(value):
raise ValueError(f"{field} contains invalid characters")
return value


def _sql_string_literal(value: str) -> str:
"""Quote a pre-validated string as a SQL string literal.

Used only for statements where DuckDB's ``?`` parameter binding does not
apply (``CREATE SECRET``, ``ATTACH``). The caller must have validated
``value`` with one of the regexes above; this helper doubles any embedded
single-quote and rejects control characters as a last line of defence.
"""
if "\x00" in value or any(ord(c) < 0x20 and c != "\t" for c in value):
raise ValueError("Value contains control characters")
return "'" + value.replace("'", "''") + "'"


# Request/Response Models
class ConnectionConfig(BaseModel):
storageType: str = Field(..., description="Storage type: s3, r2, minio")
storageType: Literal["s3", "r2", "minio"] = Field(
..., description="Storage type: s3, r2, minio"
)
endpoint: str = Field(..., description="S3 endpoint or bucket path")
accessKey: str = Field(..., description="Access key ID")
secretKey: str = Field(..., description="Secret access key")
sessionToken: Optional[str] = Field(None, description="Session token for STS")
region: str = Field(default="us-east-1", description="Region")

# Iceberg-specific fields
catalogType: Optional[str] = Field(default="none", description="Catalog type: rest, none")
catalogType: Literal["none", "rest", "glue"] = Field(
default="none", description="Catalog type: none, rest, glue"
)
catalogEndpoint: Optional[str] = Field(None, description="REST catalog endpoint URL")
namespace: Optional[str] = Field(default="default", description="Iceberg namespace/database")
tablePath: Optional[str] = Field(None, description="Direct path to Iceberg table root")

@field_validator("endpoint")
@classmethod
def _validate_endpoint(cls, v: str) -> str:
if v == "":
return v
return _require_match(v, _ENDPOINT_RE, "endpoint")

@field_validator("region")
@classmethod
def _validate_region(cls, v: str) -> str:
return _require_match(v, _REGION_RE, "region")

@field_validator("accessKey", "secretKey")
@classmethod
def _validate_key(cls, v: str, info) -> str:
# AWS/MinIO credentials can contain characters that would need
# escaping, so we rely on parameter binding or _sql_string_literal for
# safe interpolation. This guards only against obvious smuggling.
if "\x00" in v or "\n" in v or "\r" in v:
raise ValueError(f"{info.field_name} contains invalid characters")
return v

@field_validator("sessionToken")
@classmethod
def _validate_session_token(cls, v: Optional[str]) -> Optional[str]:
if not v:
return v
return _require_match(v, _SESSION_TOKEN_RE, "sessionToken")

@field_validator("catalogEndpoint")
@classmethod
def _validate_catalog_endpoint(cls, v: Optional[str]) -> Optional[str]:
if not v:
return v
return _require_match(v, _URL_RE, "catalogEndpoint")

@field_validator("namespace")
@classmethod
def _validate_namespace(cls, v: Optional[str]) -> Optional[str]:
if not v:
return v
return _require_match(v, _SQL_IDENT_RE, "namespace")

@field_validator("tablePath")
@classmethod
def _validate_table_path(cls, v: Optional[str]) -> Optional[str]:
if not v:
return v
# Normalise first so downstream code can rely on a canonical value.
v = v.rstrip("/")
if v.endswith("/metadata"):
v = v[: -len("/metadata")]
return _require_match(v, _S3_PATH_RE, "tablePath")


class TestConnectionRequest(BaseModel):
connection: ConnectionConfig
Expand Down Expand Up @@ -118,40 +210,39 @@ def _apply_s3_config(self, config: ConnectionConfig):
self.connection.execute("SET memory_limit='2GB'")
self.connection.execute("SET threads=4")

# Apply new settings based on storage type
# Apply new settings based on storage type. All user-supplied
# values are sent through DuckDB parameter binding (?).
if config.storageType == "minio":
# MinIO configuration - handle both localhost and container endpoints
endpoint = config.endpoint
if "localhost" in endpoint:
# Replace localhost with container name for internal access
endpoint = endpoint.replace("localhost", "minio")
endpoint = endpoint.replace("http://", "").replace("https://", "")
logger.info(f"Final MinIO endpoint: {endpoint}")
self.connection.execute(f"SET s3_endpoint='{endpoint}'")
self.connection.execute("SET s3_endpoint=?", [endpoint])
self.connection.execute("SET s3_url_style='path'")
self.connection.execute("SET s3_use_ssl=false")
# MinIO requires AWS signature v4
self.connection.execute("SET s3_region='us-east-1'") # MinIO default
elif config.storageType == "r2":
# Cloudflare R2 configuration
endpoint = config.endpoint.replace("https://", "")
self.connection.execute(f"SET s3_endpoint='{endpoint}'")
self.connection.execute("SET s3_endpoint=?", [endpoint])
self.connection.execute("SET s3_url_style='path'")
self.connection.execute("SET s3_use_ssl=true")
else:
# AWS S3 configuration
logger.info(f"Setting S3 region: {config.region}")
self.connection.execute(f"SET s3_region='{config.region}'")
self.connection.execute("SET s3_region=?", [config.region])
self.connection.execute("SET s3_use_ssl=true")

# Set credentials
logger.info(f"Setting S3 credentials - Access Key starts with: {config.accessKey[:8] if config.accessKey else 'EMPTY'}...")
# Use parameter binding to handle special characters safely
# Set credentials (always via parameter binding).
logger.info(
f"Setting S3 credentials - Access Key starts with: {config.accessKey[:8] if config.accessKey else 'EMPTY'}..."
)
self.connection.execute("SET s3_access_key_id=?", [config.accessKey])
self.connection.execute("SET s3_secret_access_key=?", [config.secretKey])

if config.sessionToken:
self.connection.execute(f"SET s3_session_token='{config.sessionToken}'")
self.connection.execute("SET s3_session_token=?", [config.sessionToken])

logger.info(f"Applied {config.storageType} configuration")

Expand All @@ -164,34 +255,47 @@ def _apply_s3_config(self, config: ConnectionConfig):

def _attach_iceberg_catalog(self, config: ConnectionConfig):
"""Attach Iceberg REST catalog if configured."""
if config.catalogType == "rest":
if not config.catalogEndpoint:
raise HTTPException(
status_code=400,
detail="catalogEndpoint required for REST catalog"
)
if config.catalogType != "rest":
return

logger.info(f"Attaching Iceberg REST catalog: {config.catalogEndpoint}")
if not config.catalogEndpoint:
raise HTTPException(
status_code=400,
detail="catalogEndpoint required for REST catalog",
)
if not config.namespace:
raise HTTPException(
status_code=400,
detail="namespace required for REST catalog",
)

# Create secret for catalog authentication
# Note: For now using S3 credentials; extend for OAuth2 later
self.connection.execute(f"""
CREATE SECRET iceberg_catalog_secret (
TYPE iceberg,
TOKEN '{config.accessKey}:{config.secretKey}'
)
""")

# Attach catalog
self.connection.execute(f"""
ATTACH '{config.namespace}' AS iceberg_catalog (
TYPE iceberg,
SECRET iceberg_catalog_secret,
ENDPOINT '{config.catalogEndpoint}'
)
""")
logger.info(f"Attaching Iceberg REST catalog: {config.catalogEndpoint}")

# CREATE SECRET and ATTACH do not support prepared-statement
# placeholders for their option values, so we interpolate after
# (a) Pydantic-level regex validation on namespace/catalogEndpoint and
# (b) escaping through _sql_string_literal, which doubles quotes and
# rejects any control characters.
token_literal = _sql_string_literal(f"{config.accessKey}:{config.secretKey}")
namespace_literal = _sql_string_literal(config.namespace)
endpoint_literal = _sql_string_literal(config.catalogEndpoint)

self.connection.execute(f"""
CREATE SECRET iceberg_catalog_secret (
TYPE iceberg,
TOKEN {token_literal}
)
""")

logger.info("Iceberg catalog attached")
self.connection.execute(f"""
ATTACH {namespace_literal} AS iceberg_catalog (
TYPE iceberg,
SECRET iceberg_catalog_secret,
ENDPOINT {endpoint_literal}
)
""")

logger.info("Iceberg catalog attached")

def _validate_iceberg_table(self, table_path: str) -> dict:
"""
Expand All @@ -201,8 +305,9 @@ def _validate_iceberg_table(self, table_path: str) -> dict:
"""
try:
# Read metadata to check for delete files
metadata_query = f"SELECT * FROM iceberg_metadata('{table_path}')"
metadata = self.connection.execute(metadata_query).fetchdf()
metadata = self.connection.execute(
"SELECT * FROM iceberg_metadata(?)", [table_path]
).fetchdf()

# Check for delete files in manifests
has_deletes = any('DELETE' in str(v).upper() for v in metadata['manifest_content'].unique())
Expand Down Expand Up @@ -268,42 +373,38 @@ def replace_with_iceberg(match):
def test_connection(self, config: ConnectionConfig) -> bool:
"""Test if the connection configuration works."""
try:
# Clean up table path if needed
if config.tablePath:
# Remove trailing slash if present
if config.tablePath.endswith('/'):
config.tablePath = config.tablePath.rstrip('/')
logger.info(f"Cleaned table path (removed trailing slash): {config.tablePath}")

# Remove /metadata suffix if present
if config.tablePath.endswith('/metadata'):
config.tablePath = config.tablePath[:-9]
logger.info(f"Cleaned table path (removed /metadata): {config.tablePath}")

# tablePath is normalised (trailing / and /metadata stripped) and
# regex-validated by ConnectionConfig, so no further cleanup here.
self._apply_s3_config(config)

# Test Iceberg-specific connectivity
if config.catalogType == "rest":
# Test catalog access
test_query = f"SHOW TABLES FROM iceberg_catalog.{config.namespace}"
# namespace is validated at ingress against a SQL identifier
# pattern, so this interpolation is safe.
result = self.connection.execute(
f"SHOW TABLES FROM iceberg_catalog.{config.namespace}"
).fetchone()
elif config.tablePath:
# For direct table access, try to read version hint first to help DuckDB
# Read version hint first to help DuckDB
try:
version_hint = self.connection.execute(
f"SELECT * FROM read_text('{config.tablePath}/metadata/version-hint.text')"
"SELECT * FROM read_text(?)",
[f"{config.tablePath}/metadata/version-hint.text"],
).fetchone()
if version_hint:
logger.info(f"Found version hint: {version_hint[0]}")
except Exception as e:
logger.warning(f"Could not read version-hint.text: {e}")

# Test direct table access
test_query = f"SELECT COUNT(*) FROM iceberg_scan('{config.tablePath}') LIMIT 1"
result = self.connection.execute(
"SELECT COUNT(*) FROM iceberg_scan(?) LIMIT 1",
[config.tablePath],
).fetchone()
else:
# For demo MinIO setup, test with Iceberg scan
test_query = "SELECT COUNT(*) FROM iceberg_scan('s3://movies/warehouse/demo/movies') LIMIT 1"
# Demo MinIO setup (path is hardcoded, not user input)
result = self.connection.execute(
"SELECT COUNT(*) FROM iceberg_scan('s3://movies/warehouse/demo/movies') LIMIT 1"
).fetchone()

result = self.connection.execute(test_query).fetchone()
logger.info(f"Connection test successful: {result}")
return True

Expand Down
Loading