diff --git a/CHANGELOG.md b/CHANGELOG.md index 413e6e8..d7bf631 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,20 @@ ## [Unreleased] +### Added +- **Multi-Dialect SQL Support** + - Introduced `Dialect` interface for pluggable SQL generation (`dialect/dialect.go`) + - PostgreSQL dialect extracted from converter into `dialect/postgres/` (zero behavior change) + - MySQL dialect implementation (`dialect/mysql/`) + - SQLite dialect implementation (`dialect/sqlite/`) + - DuckDB dialect implementation (`dialect/duckdb/`) + - BigQuery dialect implementation (`dialect/bigquery/`) + - `WithDialect()` option for `Convert()` and `ConvertParameterized()` (defaults to PostgreSQL) + - Per-dialect type providers: `mysql/provider.go`, `sqlite/provider.go`, `duckdb/provider.go`, `bigquery/provider.go` + - Dialect-agnostic schema types in `schema/` package + - Shared test case infrastructure (`testcases/`, `testutil/`) with per-dialect expected SQL + - Dialect registry for name-based lookup (`dialect.Register()`, `dialect.Get()`) + ## [3.5.0] - 2026-01-08 ### Changed diff --git a/CLAUDE.md b/CLAUDE.md index 4549260..a77e51b 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -4,7 +4,7 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co ## Project Overview -cel2sql converts CEL (Common Expression Language) expressions to PostgreSQL SQL conditions. It specifically targets PostgreSQL standard SQL and was recently migrated from BigQuery. +cel2sql converts CEL (Common Expression Language) expressions to SQL conditions. It supports multiple SQL dialects: PostgreSQL (default), MySQL, SQLite, DuckDB, and BigQuery. **Module**: `github.com/spandigital/cel2sql/v3` **Go Version**: 1.24+ @@ -70,10 +70,31 @@ go test -v -run TestFunctionName ./... 6. **`pg/provider.go`** - PostgreSQL type provider for CEL type system - Maps PostgreSQL types to CEL types - - Supports dynamic schema loading from live databases + - Supports dynamic schema loading from live databases via `LoadTableSchema` - Handles composite types and arrays -7. **`sqltypes/types.go`** - Custom SQL type definitions for CEL (DATE, TIME, DATETIME, INTERVAL) +7. **`mysql/provider.go`** - MySQL type provider + - Maps MySQL types to CEL types + - `LoadTableSchema` uses `information_schema.columns` with `table_schema = DATABASE()` + - Accepts `*sql.DB` (caller owns connection) + +8. **`sqlite/provider.go`** - SQLite type provider + - Maps SQLite type affinity to CEL types + - `LoadTableSchema` uses `PRAGMA table_info` with table name validation + - Accepts `*sql.DB` (caller owns connection) + +9. **`duckdb/provider.go`** - DuckDB type provider + - Maps DuckDB types to CEL types, detects array types from `[]` suffix + - `LoadTableSchema` uses `information_schema.columns` + - Accepts `*sql.DB` (works with any DuckDB driver) + +10. **`bigquery/provider.go`** - BigQuery type provider + - Maps BigQuery types to CEL types + - `LoadTableSchema` uses BigQuery client API (`Table.Metadata`) + - Handles nested RECORD types recursively + - Accepts `*bigquery.Client` + dataset ID + +11. **`sqltypes/types.go`** - Custom SQL type definitions for CEL (DATE, TIME, DATETIME, INTERVAL) ### Type System Integration @@ -198,11 +219,14 @@ These validations prevent PostgreSQL syntax errors and ensure predictable behavi - Include package comments for main packages ### Testing Guidelines -- Use PostgreSQL schemas (`pg.Schema`) in tests, not BigQuery -- Use `pg.NewTypeProvider()` for schema definitions +- Use the dialect-specific schema/provider for each dialect's tests +- Use `pg.NewTypeProvider()` for PostgreSQL, `mysql.NewTypeProvider()` for MySQL, etc. - Include tests for nested types, arrays, and JSON fields -- Verify SQL output matches PostgreSQL syntax (single quotes, proper functions) -- Use testcontainers for integration tests with real PostgreSQL +- Verify SQL output matches the target dialect's syntax +- Use testcontainers for integration tests (PostgreSQL, MySQL, BigQuery) +- Use in-memory databases for SQLite integration tests (no Docker needed) +- DuckDB integration tests require CGO; use unit tests for type mapping validation +- Provider tests live in `{dialect}/provider_test.go` ### Performance Benchmarks @@ -334,8 +358,12 @@ benchstat bench-old.txt bench-new.txt ## Common Patterns -### Creating Type Providers +### Creating Type Providers (Pre-defined Schemas) + +All dialects support pre-defined schemas via `NewTypeProvider`: + ```go +// PostgreSQL schema := pg.NewSchema([]pg.FieldSchema{ {Name: "field_name", Type: "text", Repeated: false}, {Name: "array_field", Type: "text", Repeated: true}, @@ -343,19 +371,57 @@ schema := pg.NewSchema([]pg.FieldSchema{ {Name: "composite_field", Type: "composite", Schema: []pg.FieldSchema{...}}, }) provider := pg.NewTypeProvider(map[string]pg.Schema{"TableName": schema}) + +// MySQL (same schema types, dialect-specific type names) +schema := mysql.NewSchema([]mysql.FieldSchema{ + {Name: "name", Type: "varchar"}, + {Name: "metadata", Type: "json", IsJSON: true}, +}) +provider := mysql.NewTypeProvider(map[string]mysql.Schema{"TableName": schema}) + +// SQLite, DuckDB, BigQuery follow the same pattern with their own type names ``` ### Dynamic Schema Loading + +All dialects support runtime schema introspection from live databases via `LoadTableSchema`: + ```go +// PostgreSQL — accepts connection string, manages its own pool provider, err := pg.NewTypeProviderWithConnection(ctx, connectionString) if err != nil { return err } defer provider.Close() +err = provider.LoadTableSchema(ctx, "tableName") + +// MySQL — accepts *sql.DB, caller owns connection +db, _ := sql.Open("mysql", "user:pass@tcp(host:3306)/db?parseTime=true") +provider, err := mysql.NewTypeProviderWithConnection(ctx, db) +err = provider.LoadTableSchema(ctx, "tableName") + +// SQLite — accepts *sql.DB, uses PRAGMA table_info (validates table name) +db, _ := sql.Open("sqlite", "mydb.sqlite") +provider, err := sqlite.NewTypeProviderWithConnection(ctx, db) +err = provider.LoadTableSchema(ctx, "tableName") + +// DuckDB — accepts *sql.DB, works with any DuckDB driver +db, _ := sql.Open("duckdb", "mydb.duckdb") +provider, err := duckdb.NewTypeProviderWithConnection(ctx, db) +err = provider.LoadTableSchema(ctx, "tableName") +// BigQuery — accepts *bigquery.Client + dataset ID +client, _ := bigquery.NewClient(ctx, "project-id") +provider, err := bqprovider.NewTypeProviderWithClient(ctx, client, "dataset_id") err = provider.LoadTableSchema(ctx, "tableName") ``` +**Key differences per dialect:** +- **PostgreSQL**: `NewTypeProviderWithConnection(ctx, connString)` — owns its pgxpool, `Close()` releases it +- **MySQL/SQLite/DuckDB**: `NewTypeProviderWithConnection(ctx, *sql.DB)` — caller owns DB, `Close()` is no-op +- **BigQuery**: `NewTypeProviderWithClient(ctx, *bigquery.Client, datasetID)` — caller owns client, `Close()` is no-op +- **SQLite**: Table name validated via regex (`^[a-zA-Z_][a-zA-Z0-9_]*$`) since PRAGMA doesn't support parameterized queries + ### CEL Environment Setup ```go env, err := cel.NewEnv( @@ -377,7 +443,18 @@ sqlCondition, err := cel2sql.Convert(ast) ### Query Analysis and Index Recommendations -cel2sql can analyze CEL expressions and recommend database indexes to optimize performance. +cel2sql can analyze CEL expressions and recommend **dialect-specific** database indexes to optimize performance. + +#### Architecture + +Index analysis uses the **IndexAdvisor** interface (`dialect/index_advisor.go`): +- **Pattern detection** stays centralized in `analysis.go` (walks the CEL AST once) +- **DDL generation** is delegated to per-dialect `IndexAdvisor` implementations +- Each built-in dialect implements `IndexAdvisor` on its `*Dialect` struct +- Use `dialect.GetIndexAdvisor(d)` to type-assert a dialect to `IndexAdvisor` +- Unsupported patterns return `nil` (silently skipped) + +**PatternTypes** detected: `PatternComparison`, `PatternJSONAccess`, `PatternRegexMatch`, `PatternArrayMembership`, `PatternArrayComprehension`, `PatternJSONArrayComprehension`. #### Using AnalyzeQuery @@ -387,16 +464,18 @@ if issues != nil && issues.Err() != nil { return issues.Err() } +// PostgreSQL (default) sql, recommendations, err := cel2sql.AnalyzeQuery(ast, cel2sql.WithSchemas(schemas)) + +// Or with a specific dialect +sql, recommendations, err := cel2sql.AnalyzeQuery(ast, + cel2sql.WithSchemas(schemas), + cel2sql.WithDialect(mysql.New())) if err != nil { return err } -// Use the generated SQL -rows, err := db.Query("SELECT * FROM people WHERE " + sql) - -// Review and apply index recommendations for _, rec := range recommendations { fmt.Printf("Column: %s, Type: %s\n", rec.Column, rec.IndexType) fmt.Printf("Reason: %s\n", rec.Reason) @@ -404,42 +483,51 @@ for _, rec := range recommendations { } ``` -#### Index Recommendation Types - -AnalyzeQuery detects patterns and recommends appropriate index types: +#### Per-Dialect Index Types -- **B-tree indexes**: Comparison operations (`==, >, <, >=, <=`) - - Best for: Equality checks, range queries, sorting - - Example: `person.age > 18` → B-tree on `person.age` +| Pattern | PostgreSQL | MySQL | SQLite | DuckDB | BigQuery | +|---------|-----------|-------|--------|--------|----------| +| Comparison | BTREE | BTREE | BTREE | ART | CLUSTERING | +| JSON access | GIN | BTREE (functional) | _(nil)_ | ART | SEARCH_INDEX | +| Regex match | GIN + pg_trgm | FULLTEXT | _(nil)_ | _(nil)_ | _(nil)_ | +| Array membership | GIN | _(nil)_ | _(nil)_ | ART | _(nil)_ | +| Array comprehension | GIN | _(nil)_ | _(nil)_ | ART | _(nil)_ | +| JSON array comprehension | GIN | BTREE (functional) | _(nil)_ | ART | SEARCH_INDEX | -- **GIN indexes**: JSON/JSONB path operations, array operations - - Best for: JSON field access, array membership, containment - - Example: `person.metadata.verified == true` → GIN on `person.metadata` - - Example: `"premium" in person.tags` → GIN on `person.tags` - -- **GIN indexes with pg_trgm**: Regex pattern matching - - Best for: Text search, pattern matching, fuzzy matching - - Requires: PostgreSQL pg_trgm extension - - Example: `person.email.matches(r"@example\.com$")` → GIN on `person.email` +**Per-dialect DDL examples:** +- **PostgreSQL**: `CREATE INDEX idx_col_gin ON table_name USING GIN (col);` +- **MySQL**: `CREATE INDEX idx_col_btree ON table_name (col);` / `CREATE FULLTEXT INDEX ...` +- **SQLite**: `CREATE INDEX idx_col ON table_name (col);` +- **DuckDB**: `CREATE INDEX idx_col ON table_name (col);` (ART by default) +- **BigQuery**: `ALTER TABLE t SET OPTIONS (clustering_columns=['col']);` / `CREATE SEARCH INDEX ...` #### IndexRecommendation Structure ```go type IndexRecommendation struct { Column string // Full column name (e.g., "person.metadata") - IndexType string // "BTREE", "GIN", or "GIST" - Expression string // Complete CREATE INDEX statement + IndexType string // Dialect-specific: "BTREE", "GIN", "ART", "CLUSTERING", "SEARCH_INDEX", etc. + Expression string // Complete DDL statement for the target dialect Reason string // Explanation of why this index is recommended } ``` +#### Implementation Files + +- `dialect/index_advisor.go` — `IndexAdvisor` interface, `PatternType`, `IndexPattern`, `GetIndexAdvisor()` helper +- `dialect/postgres/index_advisor.go` — PostgreSQL: BTREE, GIN, GIN+pg_trgm +- `dialect/mysql/index_advisor.go` — MySQL: BTREE, FULLTEXT +- `dialect/sqlite/index_advisor.go` — SQLite: BTREE only +- `dialect/duckdb/index_advisor.go` — DuckDB: ART +- `dialect/bigquery/index_advisor.go` — BigQuery: CLUSTERING, SEARCH_INDEX + #### When to Use - **Development**: Discover which indexes your queries need - **Performance tuning**: Identify missing indexes causing slow queries - **Production monitoring**: Analyze user-generated filter expressions -See `examples/index_analysis/` for a complete working example. +See `examples/index_analysis/` for a complete working example with all 5 dialects. ### Logging and Observability @@ -731,23 +819,23 @@ For detailed security information, see the security documentation. ## Important Notes ### Migration Context -This project was migrated from BigQuery to PostgreSQL in v2.0: -- All `cloud.google.com/go/bigquery` dependencies removed -- `bq/` package removed entirely -- PostgreSQL-specific syntax (single quotes, POSITION(), ARRAY_LENGTH(,1), etc.) -- Comprehensive JSON/JSONB support added -- Dynamic schema loading added +This project was originally BigQuery-only, migrated to PostgreSQL in v2.0, and expanded to multi-dialect in v3.0: +- v2.0: All `cloud.google.com/go/bigquery` dependencies removed, `bq/` package removed +- v3.0: Multi-dialect support added (PostgreSQL, MySQL, SQLite, DuckDB, BigQuery) +- Each dialect has its own type provider with `LoadTableSchema` support +- BigQuery dependency re-added for BigQuery dialect support ### Things to Avoid -- Do NOT add BigQuery dependencies back - Do NOT remove protobuf dependencies (required by CEL) - Do NOT use direct SQL string concatenation (use proper escaping) - Do NOT ignore context cancellation in database operations +- Do NOT use `PRAGMA` with user-controlled table names without validation (SQLite) +- Do NOT assume a specific dialect — use the dialect interface for dialect-specific behavior ### When Adding Features -1. Consider PostgreSQL-specific SQL syntax -2. Add comprehensive tests with realistic schemas -3. Update type mappings in `pg/provider.go` if needed +1. Consider all supported SQL dialects, not just PostgreSQL +2. Add comprehensive tests with realistic schemas for each affected dialect +3. Update type mappings in the appropriate `{dialect}/provider.go` if needed 4. Document new CEL operators/functions in README.md 5. Ensure backward compatibility 6. Run `make ci` before committing @@ -756,18 +844,40 @@ This project was migrated from BigQuery to PostgreSQL in v2.0: ``` cel2sql/ ├── cel2sql.go # Main conversion engine +├── analysis.go # Query analysis and index recommendations (multi-dialect) ├── comprehensions.go # CEL comprehensions support ├── json.go # JSON/JSONB handling ├── operators.go # Operator conversion ├── timestamps.go # Timestamp/duration handling ├── utils.go # Utility functions +├── schema/ # Dialect-agnostic schema types +│ └── schema.go # FieldSchema, Schema with O(1) lookup ├── pg/ # PostgreSQL type provider -│ └── provider.go +│ └── provider.go # LoadTableSchema via information_schema + pgxpool +├── mysql/ # MySQL type provider +│ └── provider.go # LoadTableSchema via information_schema + *sql.DB +├── sqlite/ # SQLite type provider +│ └── provider.go # LoadTableSchema via PRAGMA table_info + *sql.DB +├── duckdb/ # DuckDB type provider +│ └── provider.go # LoadTableSchema via information_schema + *sql.DB +├── bigquery/ # BigQuery type provider +│ └── provider.go # LoadTableSchema via BigQuery client API +├── dialect/ # Dialect interface and implementations +│ ├── dialect.go # Core Dialect interface (~40 methods) +│ ├── index_advisor.go # IndexAdvisor interface, PatternType, IndexPattern +│ ├── postgres/ # PostgreSQL dialect + IndexAdvisor (BTREE, GIN, GIN+trgm) +│ ├── mysql/ # MySQL dialect + IndexAdvisor (BTREE, FULLTEXT) +│ ├── sqlite/ # SQLite dialect + IndexAdvisor (BTREE only) +│ ├── duckdb/ # DuckDB dialect + IndexAdvisor (ART) +│ └── bigquery/ # BigQuery dialect + IndexAdvisor (CLUSTERING, SEARCH_INDEX) ├── sqltypes/ # Custom SQL types for CEL │ └── types.go +├── testcases/ # Shared test cases with per-dialect expected SQL +├── testutil/ # Multi-dialect test runner + env factories └── examples/ # Usage examples ├── basic/ ├── comprehensions/ + ├── index_analysis/ # Multi-dialect index recommendation demo └── load_table_schema/ ``` diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index b4206b8..cb312b2 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -137,16 +137,44 @@ import ( 1. Add the function mapping in `cel2sql.go` 2. Add comprehensive tests in `cel2sql_test.go` 3. Update the README with documentation -4. Ensure PostgreSQL compatibility +4. Ensure the function works with dialect abstraction -### PostgreSQL Focus +### Multi-Dialect Architecture -This project targets PostgreSQL. When adding features: +cel2sql supports PostgreSQL (default), MySQL, SQLite, DuckDB, and BigQuery. When adding features: -- Use PostgreSQL-specific SQL syntax -- Test with realistic PostgreSQL schemas -- Use `pgx/v5` driver patterns -- Avoid BigQuery-specific features +- Call `con.dialect.*` methods for any SQL that differs between databases +- Standard SQL (AND, OR, =, !=, etc.) stays inline in the converter +- Add expected SQL for all dialects in `testcases/*.go` +- Run `make ci` to verify all dialects pass + +### Adding a New Dialect + +To add support for a new SQL dialect: + +1. **Create the dialect package**: `dialect//dialect.go` + - Implement the `dialect.Dialect` interface (~40 methods) + - Register with `dialect.Register()` in `init()` + +2. **Create regex conversion** (if applicable): `dialect//regex.go` + - Convert RE2 patterns to the dialect's regex format + - Include ReDoS protection (pattern length, nesting limits) + +3. **Create validation**: `dialect//validation.go` + - Field name validation, reserved keywords + +4. **Create type provider**: `/provider.go` + - Map native database types to CEL types + +5. **Add env factory**: `testutil/env.go` + - Add `EnvFactory()` function + - Update `DialectEnvFactory()` switch + +6. **Add test runner**: `testutil/runner__test.go` + +7. **Add expected SQL to all test case files** in `testcases/` + +8. **Update `dialect/dialect.go`** to add the dialect name constant ## Pull Request Process @@ -174,30 +202,36 @@ This project targets PostgreSQL. When adding features: ``` cel2sql/ -├── cel2sql.go # Main conversion engine -├── cel2sql_test.go # Main tests -├── pg/ # PostgreSQL type provider -│ ├── provider.go # Type provider implementation -│ └── provider_test.go # Type provider tests -├── sqltypes/ # Custom SQL types -│ └── types.go # CEL type definitions -├── examples/ # Usage examples -│ ├── basic/ # Basic usage example -│ │ ├── main.go -│ │ └── README.md -│ ├── load_table_schema/ # Dynamic schema loading example -│ │ ├── main.go -│ │ └── README.md -│ └── README.md # Examples overview -└── test/ # Test utilities - └── testdata.go # Test schemas +├── cel2sql.go # Main conversion engine (uses dialect interface) +├── cel2sql_test.go # Main tests +├── dialect/ # Dialect interface + implementations +│ ├── dialect.go # Interface definition + Name type +│ ├── registry.go # Name→Dialect lookup +│ ├── postgres/ # PostgreSQL dialect +│ ├── mysql/ # MySQL dialect +│ ├── sqlite/ # SQLite dialect +│ ├── duckdb/ # DuckDB dialect +│ └── bigquery/ # BigQuery dialect +├── pg/ # PostgreSQL type provider +├── mysql/ # MySQL type provider +├── sqlite/ # SQLite type provider +├── duckdb/ # DuckDB type provider +├── bigquery/ # BigQuery type provider +├── schema/ # Dialect-agnostic schema types +├── sqltypes/ # Custom SQL types for CEL +├── testcases/ # Shared test cases with per-dialect expected SQL +├── testutil/ # Test runner + env factories +└── examples/ # Usage examples ``` ### Key Components -- **cel2sql.go**: Core conversion logic from CEL AST to SQL -- **pg/provider.go**: PostgreSQL type system integration +- **cel2sql.go**: Core conversion logic from CEL AST to SQL (calls dialect methods) +- **dialect/dialect.go**: Dialect interface defining all SQL generation points +- **dialect/*/dialect.go**: Per-dialect SQL generation implementations +- **pg/provider.go**, **mysql/provider.go**, etc.: Type system integration per dialect - **sqltypes/types.go**: Custom SQL type definitions for CEL +- **testcases/*.go**: Shared test cases with expected SQL for all dialects ## Debugging diff --git a/README.md b/README.md index 93ae0b4..66671d3 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,17 @@ # cel2sql -> Convert [CEL (Common Expression Language)](https://cel.dev/) expressions to PostgreSQL SQL +> Convert [CEL (Common Expression Language)](https://cel.dev/) expressions to SQL for PostgreSQL, MySQL, SQLite, DuckDB, and BigQuery [![Go Version](https://img.shields.io/badge/Go-1.24%2B-blue)](https://golang.org) -[![PostgreSQL](https://img.shields.io/badge/PostgreSQL-17-blue)](https://www.postgresql.org) +[![PostgreSQL](https://img.shields.io/badge/PostgreSQL-17-336791)](https://www.postgresql.org) +[![MySQL](https://img.shields.io/badge/MySQL-8.0-4479A1)](https://www.mysql.com) +[![SQLite](https://img.shields.io/badge/SQLite-3-003B57)](https://www.sqlite.org) +[![DuckDB](https://img.shields.io/badge/DuckDB-1.x-FFF000)](https://duckdb.org) +[![BigQuery](https://img.shields.io/badge/BigQuery-GCP-4285F4)](https://cloud.google.com/bigquery) [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](LICENSE) [![Benchmarks](https://img.shields.io/badge/benchmarks-performance%20tracking-green)](https://spandigital.github.io/cel2sql/dev/bench/) -**cel2sql** makes it easy to build dynamic SQL queries using CEL expressions. Write type-safe, expressive filters in CEL and automatically convert them to PostgreSQL-compatible SQL. +**cel2sql** makes it easy to build dynamic SQL queries using CEL expressions. Write type-safe, expressive filters in CEL and automatically convert them to SQL for your database of choice. ## Quick Start @@ -61,10 +65,10 @@ func main() { ## Why cel2sql? +✅ **Multi-Dialect**: PostgreSQL, MySQL, SQLite, DuckDB, and BigQuery from a single API ✅ **Type-Safe**: Catch errors at compile time, not runtime -✅ **PostgreSQL 17**: Fully compatible with the latest PostgreSQL ✅ **Rich Features**: JSON/JSONB, arrays, regex, timestamps, and more -✅ **Well-Tested**: 100+ tests including integration tests with real PostgreSQL +✅ **Well-Tested**: 100+ tests including integration tests with real databases ✅ **Easy to Use**: Simple API, comprehensive documentation ✅ **Secure by Default**: Built-in protections against SQL injection and ReDoS attacks ✅ **Performance Tracked**: [Continuous benchmark monitoring](https://spandigital.github.io/cel2sql/dev/bench/) to prevent regressions @@ -117,29 +121,92 @@ sql, err := cel2sql.Convert(ast, ``` **Available Options:** +- `WithDialect(dialect.Dialect)` - Select target SQL dialect (default: PostgreSQL) - `WithSchemas(map[string]pg.Schema)` - Provide table schemas for JSON detection - `WithContext(context.Context)` - Enable cancellation and timeouts - `WithLogger(*slog.Logger)` - Enable structured logging - `WithMaxDepth(int)` - Set custom recursion depth limit (default: 100) +## Multi-Dialect Support + +cel2sql supports 5 SQL dialects. PostgreSQL is the default; select other dialects with `WithDialect()`: + +```go +import ( + "github.com/spandigital/cel2sql/v3" + "github.com/spandigital/cel2sql/v3/dialect/mysql" + "github.com/spandigital/cel2sql/v3/dialect/sqlite" + "github.com/spandigital/cel2sql/v3/dialect/duckdb" + "github.com/spandigital/cel2sql/v3/dialect/bigquery" +) + +// PostgreSQL (default - no option needed) +sql, err := cel2sql.Convert(ast) + +// MySQL +sql, err := cel2sql.Convert(ast, cel2sql.WithDialect(mysql.New())) + +// SQLite +sql, err := cel2sql.Convert(ast, cel2sql.WithDialect(sqlite.New())) + +// DuckDB +sql, err := cel2sql.Convert(ast, cel2sql.WithDialect(duckdb.New())) + +// BigQuery +sql, err := cel2sql.Convert(ast, cel2sql.WithDialect(bigquery.New())) +``` + +### Dialect Comparison + +| Feature | PostgreSQL | MySQL | SQLite | DuckDB | BigQuery | +|---------|-----------|-------|--------|--------|----------| +| String concat | `\|\|` | `CONCAT()` | `\|\|` | `\|\|` | `\|\|` | +| Regex | `~ / ~*` | `REGEXP` | unsupported | `~ / ~*` | `REGEXP_CONTAINS()` | +| JSON access | `->>'f'` | `->>'$.f'` | `json_extract()` | `->>'f'` | `JSON_VALUE()` | +| Arrays | `ARRAY[...]` | JSON arrays | JSON arrays | `[...]` | `[...]` | +| UNNEST | `UNNEST(x)` | `JSON_TABLE(...)` | `json_each(x)` | `UNNEST(x)` | `UNNEST(x)` | +| Param placeholder | `$1, $2` | `?, ?` | `?, ?` | `$1, $2` | `@p1, @p2` | +| Timestamp cast | `TIMESTAMP WITH TIME ZONE` | `DATETIME` | `datetime()` | `TIMESTAMPTZ` | `TIMESTAMP` | +| Contains | `POSITION()` | `LOCATE()` | `INSTR()` | `CONTAINS()` | `STRPOS()` | +| Index analysis | BTREE, GIN, GIN+trgm | BTREE, FULLTEXT | BTREE | ART | CLUSTERING, SEARCH_INDEX | + +### Per-Dialect Type Providers + +Each dialect has its own type provider for mapping database types to CEL types. All providers support both pre-defined schemas (`NewTypeProvider`) and dynamic schema loading (`LoadTableSchema`): + +```go +import "github.com/spandigital/cel2sql/v3/pg" // PostgreSQL (pgxpool connection string) +import "github.com/spandigital/cel2sql/v3/mysql" // MySQL (*sql.DB) +import "github.com/spandigital/cel2sql/v3/sqlite" // SQLite (*sql.DB) +import "github.com/spandigital/cel2sql/v3/duckdb" // DuckDB (*sql.DB) +import "github.com/spandigital/cel2sql/v3/bigquery" // BigQuery (*bigquery.Client) +``` + ## Query Analysis and Index Recommendations -cel2sql can analyze your CEL queries and recommend database indexes to optimize performance. The `AnalyzeQuery()` function returns both the converted SQL and actionable index recommendations. +cel2sql can analyze your CEL queries and recommend database indexes to optimize performance. The `AnalyzeQuery()` function returns both the converted SQL and **dialect-specific** index recommendations. ### How It Works -`AnalyzeQuery()` examines your CEL expression and detects patterns that would benefit from specific PostgreSQL index types: +`AnalyzeQuery()` examines your CEL expression and detects patterns that would benefit from indexing, then generates dialect-appropriate DDL: -- **JSON/JSONB path operations** (`->>, ?`) → GIN indexes -- **Array operations** (comprehensions, `IN` clauses) → GIN indexes -- **Regex matching** (`matches()`) → GIN indexes with `pg_trgm` extension -- **Comparison operations** (`==, >, <, >=, <=`) → B-tree indexes +- **Comparison operations** (`==, >, <, >=, <=`) → B-tree (PG/MySQL/SQLite), ART (DuckDB), Clustering (BigQuery) +- **JSON/JSONB path operations** (`->>, ?`) → GIN (PG), functional index (MySQL), Search Index (BigQuery), ART (DuckDB) +- **Regex matching** (`matches()`) → GIN with pg_trgm (PG), FULLTEXT (MySQL) +- **Array operations** (comprehensions, `IN` clauses) → GIN (PG), ART (DuckDB) ### Usage ```go +// PostgreSQL (default dialect) sql, recommendations, err := cel2sql.AnalyzeQuery(ast, cel2sql.WithSchemas(schemas)) + +// Or specify a dialect +sql, recommendations, err := cel2sql.AnalyzeQuery(ast, + cel2sql.WithSchemas(schemas), + cel2sql.WithDialect(mysql.New())) + if err != nil { log.Fatal(err) } @@ -153,38 +220,48 @@ for _, rec := range recommendations { fmt.Printf("Type: %s\n", rec.IndexType) fmt.Printf("Reason: %s\n", rec.Reason) fmt.Printf("Execute: %s\n\n", rec.Expression) - - // Apply the recommendation - // _, err := db.Exec(rec.Expression) } ``` +### Per-Dialect Index Types + +| Pattern | PostgreSQL | MySQL | SQLite | DuckDB | BigQuery | +|---------|-----------|-------|--------|--------|----------| +| Comparison | BTREE | BTREE | BTREE | ART | CLUSTERING | +| JSON access | GIN | BTREE (functional) | _(skip)_ | ART | SEARCH_INDEX | +| Regex | GIN + pg_trgm | FULLTEXT | _(skip)_ | _(skip)_ | _(skip)_ | +| Array membership | GIN | _(skip)_ | _(skip)_ | ART | _(skip)_ | +| Comprehension | GIN | _(skip)_ | _(skip)_ | ART | _(skip)_ | + +Unsupported patterns are silently skipped (no recommendation emitted). + ### Example ```go -// Query with multiple index-worthy patterns -celExpr := `person.age > 18 && - person.email.matches(r"@example\.com$") && - person.metadata.verified == true` - +celExpr := `person.age > 18 && person.metadata.verified == true` ast, _ := env.Compile(celExpr) -sql, recs, _ := cel2sql.AnalyzeQuery(ast, cel2sql.WithSchemas(schemas)) -// Generated SQL: -// person.age > 18 AND person.email ~ '@example\.com$' -// AND person.metadata->>'verified' = 'true' +// PostgreSQL recommendations +sql, recs, _ := cel2sql.AnalyzeQuery(ast, cel2sql.WithSchemas(schemas)) +// Recommendations: +// 1. CREATE INDEX idx_person_age_btree ON table_name (person.age); +// 2. CREATE INDEX idx_person_metadata_gin ON table_name USING GIN (person.metadata); +// MySQL recommendations +sql, recs, _ = cel2sql.AnalyzeQuery(ast, + cel2sql.WithSchemas(schemas), + cel2sql.WithDialect(mysql.New())) // Recommendations: // 1. CREATE INDEX idx_person_age_btree ON table_name (person.age); -// Reason: Comparison operations benefit from B-tree for range queries -// -// 2. CREATE INDEX idx_person_email_gin_trgm ON table_name -// USING GIN (person.email gin_trgm_ops); -// Reason: Regex matching benefits from GIN index with pg_trgm -// -// 3. CREATE INDEX idx_person_metadata_gin ON table_name -// USING GIN (person.metadata); -// Reason: JSON path operations benefit from GIN index +// 2. CREATE INDEX idx_person_metadata_json ON table_name ((CAST(person.metadata->>'$.path' AS CHAR(255)))); + +// BigQuery recommendations +sql, recs, _ = cel2sql.AnalyzeQuery(ast, + cel2sql.WithSchemas(schemas), + cel2sql.WithDialect(bigquery.New())) +// Recommendations: +// 1. ALTER TABLE table_name SET OPTIONS (clustering_columns=['person.age']); +// 2. CREATE SEARCH INDEX idx_person_metadata ON table_name (person.metadata); ``` ### When to Use @@ -193,7 +270,7 @@ sql, recs, _ := cel2sql.AnalyzeQuery(ast, cel2sql.WithSchemas(schemas)) - **Performance tuning**: Identify missing indexes causing slow queries - **Production monitoring**: Analyze user-generated filter expressions -See `examples/index_analysis/` for a complete working example. +See `examples/index_analysis/` for a complete working example with all 5 dialects. ## Parameterized Queries @@ -393,42 +470,159 @@ See [Regex Matching documentation](docs/regex-matching.md) for complete details, ## Type Mapping -| CEL Type | PostgreSQL Type | -|----------|-----------------| -| `int` | `bigint` | -| `double` | `double precision` | -| `bool` | `boolean` | -| `string` | `text` | -| `bytes` | `bytea` | -| `list` | `ARRAY` | -| `timestamp` | `timestamp with time zone` | -| `duration` | `INTERVAL` | +| CEL Type | PostgreSQL | MySQL | SQLite | DuckDB | BigQuery | +|----------|-----------|-------|--------|--------|----------| +| `int` | `bigint` | `SIGNED` | `INTEGER` | `BIGINT` | `INT64` | +| `double` | `double precision` | `DECIMAL` | `REAL` | `DOUBLE` | `FLOAT64` | +| `bool` | `boolean` | `UNSIGNED` | `INTEGER` | `BOOLEAN` | `BOOL` | +| `string` | `text` | `CHAR` | `TEXT` | `VARCHAR` | `STRING` | +| `bytes` | `bytea` | `BINARY` | `BLOB` | `BLOB` | `BYTES` | +| `list` | `ARRAY` | JSON array | JSON array | `LIST` | `ARRAY` | +| `timestamp` | `timestamptz` | `DATETIME` | `datetime()` | `TIMESTAMPTZ` | `TIMESTAMP` | +| `duration` | `INTERVAL` | `INTERVAL` | string modifier | `INTERVAL` | `INTERVAL` | ## Dynamic Schema Loading -Load table schemas directly from your PostgreSQL database: +Load table schemas directly from your database at runtime instead of defining them manually. Each dialect provider supports introspecting table schemas from a live database connection. + +### PostgreSQL ```go -// Connect to database and load schema +import "github.com/spandigital/cel2sql/v3/pg" + +// PostgreSQL accepts a connection string and manages its own connection pool provider, _ := pg.NewTypeProviderWithConnection(ctx, "postgres://user:pass@localhost/db") defer provider.Close() -// Load table schema dynamically provider.LoadTableSchema(ctx, "users") -// Use with CEL env, _ := cel.NewEnv( cel.CustomTypeProvider(provider), cel.Variable("user", cel.ObjectType("users")), ) ``` +### MySQL + +```go +import ( + "database/sql" + _ "github.com/go-sql-driver/mysql" + "github.com/spandigital/cel2sql/v3/mysql" +) + +// MySQL accepts a *sql.DB — you own the connection +db, _ := sql.Open("mysql", "user:pass@tcp(localhost:3306)/mydb?parseTime=true") +defer db.Close() + +provider, _ := mysql.NewTypeProviderWithConnection(ctx, db) +provider.LoadTableSchema(ctx, "users") + +env, _ := cel.NewEnv( + cel.CustomTypeProvider(provider), + cel.Variable("user", cel.ObjectType("users")), +) + +sql, _ := cel2sql.Convert(ast, cel2sql.WithDialect(mysqlDialect.New()), + cel2sql.WithSchemas(provider.GetSchemas())) +``` + +### SQLite + +```go +import ( + "database/sql" + _ "modernc.org/sqlite" + "github.com/spandigital/cel2sql/v3/sqlite" +) + +db, _ := sql.Open("sqlite", "mydb.sqlite") +defer db.Close() + +provider, _ := sqlite.NewTypeProviderWithConnection(ctx, db) +provider.LoadTableSchema(ctx, "users") + +env, _ := cel.NewEnv( + cel.CustomTypeProvider(provider), + cel.Variable("user", cel.ObjectType("users")), +) + +sql, _ := cel2sql.Convert(ast, cel2sql.WithDialect(sqliteDialect.New()), + cel2sql.WithSchemas(provider.GetSchemas())) +``` + +### DuckDB + +```go +import ( + "database/sql" + "github.com/spandigital/cel2sql/v3/duckdb" +) + +// DuckDB accepts *sql.DB — works with any DuckDB driver (requires CGO) +db, _ := sql.Open("duckdb", "mydb.duckdb") +defer db.Close() + +provider, _ := duckdb.NewTypeProviderWithConnection(ctx, db) +provider.LoadTableSchema(ctx, "users") + +env, _ := cel.NewEnv( + cel.CustomTypeProvider(provider), + cel.Variable("user", cel.ObjectType("users")), +) + +sql, _ := cel2sql.Convert(ast, cel2sql.WithDialect(duckdbDialect.New()), + cel2sql.WithSchemas(provider.GetSchemas())) +``` + +### BigQuery + +```go +import ( + "cloud.google.com/go/bigquery" + bqprovider "github.com/spandigital/cel2sql/v3/bigquery" +) + +// BigQuery uses the BigQuery client API (not database/sql) +client, _ := bigquery.NewClient(ctx, "my-project") +defer client.Close() + +provider, _ := bqprovider.NewTypeProviderWithClient(ctx, client, "my_dataset") +provider.LoadTableSchema(ctx, "users") + +env, _ := cel.NewEnv( + cel.CustomTypeProvider(provider), + cel.Variable("user", cel.ObjectType("users")), +) + +sql, _ := cel2sql.Convert(ast, cel2sql.WithDialect(bigqueryDialect.New()), + cel2sql.WithSchemas(provider.GetSchemas())) +``` + +### Notes + +- **PostgreSQL** manages its own connection pool via `pgxpool` — call `provider.Close()` when done. +- **MySQL, SQLite, DuckDB** accept a `*sql.DB` you provide — you own the connection lifecycle. `Close()` is a no-op. +- **BigQuery** accepts a `*bigquery.Client` + dataset ID — you own the client lifecycle. `Close()` is a no-op. +- All providers also support pre-defined schemas via `NewTypeProvider(schemas)` if you don't need runtime introspection. + See [Getting Started Guide](docs/getting-started.md) for more details. ## Requirements - Go 1.24 or higher -- PostgreSQL 17 (also compatible with PostgreSQL 15+) + +### CGO Requirement (DuckDB only) + +The DuckDB dialect's `LoadTableSchema` requires a DuckDB Go driver (e.g., `github.com/marcboeker/go-duckdb`) which depends on **CGO** and a C/C++ compiler. This means: + +- You must have `CGO_ENABLED=1` (the Go default on most platforms) +- A C/C++ compiler must be installed (GCC, Clang, or MSVC) +- Cross-compilation requires a C cross-compiler for the target platform + +**All other dialects (PostgreSQL, MySQL, SQLite, BigQuery) use pure Go drivers and do not require CGO.** + +If you only use DuckDB with pre-defined schemas via `duckdb.NewTypeProvider()` (no live database connection), CGO is **not** required. ## Contributing diff --git a/analysis.go b/analysis.go index e1bbd46..3dbbfeb 100644 --- a/analysis.go +++ b/analysis.go @@ -4,16 +4,18 @@ package cel2sql import ( "fmt" "log/slog" - "strings" "time" "github.com/google/cel-go/cel" "github.com/google/cel-go/common/operators" "github.com/google/cel-go/common/overloads" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + + "github.com/spandigital/cel2sql/v3/dialect" + "github.com/spandigital/cel2sql/v3/dialect/postgres" ) -// Index type constants for recommendations +// Index type constants for recommendations (kept for backward compatibility). const ( // IndexTypeBTree represents a B-tree index for efficient range queries and equality checks IndexTypeBTree = "BTREE" @@ -29,10 +31,10 @@ type IndexRecommendation struct { // Column is the database column that should be indexed Column string - // IndexType specifies the PostgreSQL index type (e.g., "BTREE", "GIN", "GIST") + // IndexType specifies the index type (e.g., "BTREE", "GIN", "ART", "CLUSTERING") IndexType string - // Expression is the complete CREATE INDEX statement that can be executed directly + // Expression is the complete DDL statement that can be executed directly Expression string // Reason explains why this index is recommended and what query patterns it optimizes @@ -44,21 +46,25 @@ type analysisConverter struct { *converter recommendations map[string]*IndexRecommendation // Key: column name, Value: recommendation visitedColumns map[string]bool // Track which columns have been accessed + advisor dialect.IndexAdvisor // Dialect-specific index advisor } -// AnalyzeQuery converts a CEL AST to PostgreSQL SQL and provides index recommendations. +// AnalyzeQuery converts a CEL AST to SQL and provides dialect-specific index recommendations. // It analyzes the query patterns to suggest indexes that would optimize performance. // // The function detects patterns that benefit from specific index types: -// - JSON/JSONB path operations (->>, ?) → GIN indexes -// - Array operations (UNNEST, comprehensions) → GIN indexes -// - Regex matching (matches()) → GIN indexes with pg_trgm extension -// - Frequently accessed fields in comparisons → B-tree indexes +// - JSON/JSONB path operations → GIN indexes (PostgreSQL), functional indexes (MySQL), search indexes (BigQuery) +// - Array operations → GIN indexes (PostgreSQL), ART indexes (DuckDB) +// - Regex matching → GIN indexes with pg_trgm (PostgreSQL), FULLTEXT indexes (MySQL) +// - Comparison operations → B-tree indexes (PostgreSQL/MySQL/SQLite), ART (DuckDB), clustering (BigQuery) +// +// Use WithDialect() to get dialect-specific index recommendations. Defaults to PostgreSQL. // // Example: // // sql, recommendations, err := cel2sql.AnalyzeQuery(ast, -// cel2sql.WithSchemas(schemas)) +// cel2sql.WithSchemas(schemas), +// cel2sql.WithDialect(mysql.New())) // if err != nil { // return err // } @@ -84,6 +90,19 @@ func AnalyzeQuery(ast *cel.Ast, opts ...ConvertOption) (string, []IndexRecommend opt(options) } + // Default to PostgreSQL dialect if none specified + d := options.dialect + if d == nil { + d = postgres.New() + } + + // Get the IndexAdvisor for the dialect (all built-in dialects implement it) + advisor, hasAdvisor := dialect.GetIndexAdvisor(d) + if !hasAdvisor { + // Fallback: use PostgreSQL advisor for backward compatibility + advisor = postgres.New() + } + // Convert AST to CheckedExpr checkedExpr, err := cel.AstToCheckedExpr(ast) if err != nil { @@ -104,6 +123,7 @@ func AnalyzeQuery(ast *cel.Ast, opts ...ConvertOption) (string, []IndexRecommend converter: baseConverter, recommendations: make(map[string]*IndexRecommendation), visitedColumns: make(map[string]bool), + advisor: advisor, } // Analyze the expression tree to collect index patterns @@ -121,6 +141,7 @@ func AnalyzeQuery(ast *cel.Ast, opts ...ConvertOption) (string, []IndexRecommend if options.logger != nil { options.logger.Debug("query analysis completed", "sql", sql, + "dialect", d.Name(), "recommendation_count", len(analyzer.recommendations), "duration", duration) } @@ -225,24 +246,18 @@ func (a *analysisConverter) analyzeCall(expr *exprpb.Expr) error { switch fun { case overloads.Matches: - // Regex matching benefits from GIN index with pg_trgm extension - if err := a.recommendRegexIndex(expr); err != nil { - return err - } + // Regex matching benefits from dialect-specific indexes + a.recommendRegexIndex(expr) case operators.Equals, operators.NotEquals, operators.Greater, operators.GreaterEquals, operators.Less, operators.LessEquals: - // Comparison operations benefit from B-tree indexes - if err := a.recommendComparisonIndex(expr); err != nil { - return err - } + // Comparison operations benefit from indexes + a.recommendComparisonIndex(expr) case operators.In: - // IN operations on arrays benefit from GIN indexes - if err := a.recommendArrayIndex(expr); err != nil { - return err - } + // IN operations on arrays benefit from indexes + a.recommendArrayIndex(expr) } return nil @@ -259,26 +274,13 @@ func (a *analysisConverter) analyzeComprehension(expr *exprpb.Expr) error { // Check if this is a JSON array comprehension if a.isJSONArrayField(iterRange) { - // Extract the column name from the iter range if column := a.extractColumnName(iterRange); column != "" { - a.addRecommendation(column, &IndexRecommendation{ - Column: column, - IndexType: IndexTypeGIN, - Expression: fmt.Sprintf("CREATE INDEX idx_%s_gin ON table_name USING GIN (%s);", - sanitizeIndexName(column), column), - Reason: fmt.Sprintf("JSONB array comprehension on '%s' benefits from GIN index for efficient array element access", column), - }) + a.recommendForPattern(column, dialect.PatternJSONArrayComprehension) } } else { // Regular array comprehension if column := a.extractColumnName(iterRange); column != "" { - a.addRecommendation(column, &IndexRecommendation{ - Column: column, - IndexType: IndexTypeGIN, - Expression: fmt.Sprintf("CREATE INDEX idx_%s_gin ON table_name USING GIN (%s);", - sanitizeIndexName(column), column), - Reason: fmt.Sprintf("Array comprehension on '%s' benefits from GIN index for efficient array operations", column), - }) + a.recommendForPattern(column, dialect.PatternArrayComprehension) } } @@ -299,13 +301,7 @@ func (a *analysisConverter) analyzeSelect(expr *exprpb.Expr) error { // Check if the parent field is JSON if a.isFieldJSON(tableName, operandField) { column := tableName + "." + operandField - a.addRecommendation(column, &IndexRecommendation{ - Column: column, - IndexType: IndexTypeGIN, - Expression: fmt.Sprintf("CREATE INDEX idx_%s_gin ON table_name USING GIN (%s);", - sanitizeIndexName(column), column), - Reason: fmt.Sprintf("JSON path operations on '%s' benefit from GIN index for efficient nested field access", column), - }) + a.recommendForPattern(column, dialect.PatternJSONAccess) } } } @@ -315,15 +311,9 @@ func (a *analysisConverter) analyzeSelect(expr *exprpb.Expr) error { tableName := identExpr.GetName() if a.isFieldJSON(tableName, fieldName) { column := tableName + "." + fieldName - a.addRecommendation(column, &IndexRecommendation{ - Column: column, - IndexType: IndexTypeGIN, - Expression: fmt.Sprintf("CREATE INDEX idx_%s_gin ON table_name USING GIN (%s);", - sanitizeIndexName(column), column), - Reason: fmt.Sprintf("JSON field '%s' benefits from GIN index for efficient access", column), - }) + a.recommendForPattern(column, dialect.PatternJSONAccess) } - // Track column access for potential B-tree indexes + // Track column access for potential indexes fullColumn := tableName + "." + fieldName a.visitedColumns[fullColumn] = true } @@ -331,8 +321,8 @@ func (a *analysisConverter) analyzeSelect(expr *exprpb.Expr) error { return nil } -// recommendRegexIndex recommends a GIN index with pg_trgm for regex operations -func (a *analysisConverter) recommendRegexIndex(expr *exprpb.Expr) error { +// recommendRegexIndex recommends an index for regex operations +func (a *analysisConverter) recommendRegexIndex(expr *exprpb.Expr) { c := expr.GetCallExpr() target := c.GetTarget() @@ -342,54 +332,38 @@ func (a *analysisConverter) recommendRegexIndex(expr *exprpb.Expr) error { if target != nil { if column := a.extractColumnName(target); column != "" { - a.addRecommendation(column, &IndexRecommendation{ - Column: column, - IndexType: IndexTypeGIN, - Expression: fmt.Sprintf("CREATE INDEX idx_%s_gin_trgm ON table_name USING GIN (%s gin_trgm_ops);", - sanitizeIndexName(column), column), - Reason: fmt.Sprintf("Regex matching on '%s' benefits from GIN index with pg_trgm extension for pattern matching", column), - }) + a.recommendForPattern(column, dialect.PatternRegexMatch) } } - - return nil } -// recommendComparisonIndex recommends a B-tree index for comparison operations -func (a *analysisConverter) recommendComparisonIndex(expr *exprpb.Expr) error { +// recommendComparisonIndex recommends an index for comparison operations +func (a *analysisConverter) recommendComparisonIndex(expr *exprpb.Expr) { c := expr.GetCallExpr() args := c.GetArgs() if len(args) < 2 { - return nil + return } lhs := args[0] // Extract column from left-hand side if column := a.extractColumnName(lhs); column != "" { - // Check if this is a JSON field (skip B-tree recommendation for JSON) + // Check if this is a JSON field (skip comparison recommendation for JSON) if !a.isJSONField(lhs) { - a.addRecommendation(column, &IndexRecommendation{ - Column: column, - IndexType: IndexTypeBTree, - Expression: fmt.Sprintf("CREATE INDEX idx_%s_btree ON table_name (%s);", - sanitizeIndexName(column), column), - Reason: fmt.Sprintf("Comparison operations on '%s' benefit from B-tree index for efficient range queries and equality checks", column), - }) + a.recommendForPattern(column, dialect.PatternComparison) } } - - return nil } -// recommendArrayIndex recommends a GIN index for array containment operations -func (a *analysisConverter) recommendArrayIndex(expr *exprpb.Expr) error { +// recommendArrayIndex recommends an index for array containment operations +func (a *analysisConverter) recommendArrayIndex(expr *exprpb.Expr) { c := expr.GetCallExpr() args := c.GetArgs() if len(args) < 2 { - return nil + return } rhs := args[1] @@ -397,17 +371,26 @@ func (a *analysisConverter) recommendArrayIndex(expr *exprpb.Expr) error { // Check if the right-hand side is an array field if a.isFieldArray(a.extractTableName(rhs), a.extractFieldName(rhs)) { if column := a.extractColumnName(rhs); column != "" { - a.addRecommendation(column, &IndexRecommendation{ - Column: column, - IndexType: IndexTypeGIN, - Expression: fmt.Sprintf("CREATE INDEX idx_%s_gin ON table_name USING GIN (%s);", - sanitizeIndexName(column), column), - Reason: fmt.Sprintf("Array membership tests on '%s' benefit from GIN index for efficient element lookups", column), - }) + a.recommendForPattern(column, dialect.PatternArrayMembership) } } +} - return nil +// recommendForPattern asks the dialect's IndexAdvisor for a recommendation and stores it. +func (a *analysisConverter) recommendForPattern(column string, pattern dialect.PatternType) { + rec := a.advisor.RecommendIndex(dialect.IndexPattern{ + Column: column, + Pattern: pattern, + }) + if rec == nil { + return + } + a.addRecommendation(column, &IndexRecommendation{ + Column: rec.Column, + IndexType: rec.IndexType, + Expression: rec.Expression, + Reason: rec.Reason, + }) } // extractColumnName extracts the full column name (table.column) from an expression @@ -458,34 +441,25 @@ func (a *analysisConverter) isJSONField(expr *exprpb.Expr) bool { return false } -// addRecommendation adds or updates an index recommendation +// addRecommendation adds or updates an index recommendation. +// When a more specialized recommendation exists for a column, it takes priority. func (a *analysisConverter) addRecommendation(column string, rec *IndexRecommendation) { // Only add if we don't already have a recommendation for this column - // or if the new recommendation is more specific (e.g., GIN over BTREE) + // or if the new recommendation is more specific existing, exists := a.recommendations[column] if !exists { a.recommendations[column] = rec return } - // GIN indexes are more versatile than BTREE for JSON/array operations - // If we already have a BTREE recommendation and we're suggesting GIN, upgrade it - if existing.IndexType == IndexTypeBTree && rec.IndexType == IndexTypeGIN { + // More specialized index types take priority over basic B-tree/comparison indexes + if isBasicIndexType(existing.IndexType) && !isBasicIndexType(rec.IndexType) { a.recommendations[column] = rec } } -// sanitizeIndexName creates a safe index name from a column name -func sanitizeIndexName(column string) string { - // Replace dots and special characters with underscores - sanitized := strings.ReplaceAll(column, ".", "_") - sanitized = strings.ReplaceAll(sanitized, " ", "_") - sanitized = strings.ReplaceAll(sanitized, "-", "_") - - // PostgreSQL index names are limited to 63 characters - if len(sanitized) > 50 { - sanitized = sanitized[:50] - } - - return sanitized +// isBasicIndexType returns true if the index type is a basic comparison index +// that should be upgraded when a more specialized recommendation is available. +func isBasicIndexType(indexType string) bool { + return indexType == IndexTypeBTree || indexType == "ART" || indexType == "CLUSTERING" } diff --git a/analysis_test.go b/analysis_test.go index 212ae7c..5081f47 100644 --- a/analysis_test.go +++ b/analysis_test.go @@ -6,9 +6,21 @@ import ( "testing" "github.com/google/cel-go/cel" + "github.com/spandigital/cel2sql/v3/dialect" + dialectbq "github.com/spandigital/cel2sql/v3/dialect/bigquery" + dialectduckdb "github.com/spandigital/cel2sql/v3/dialect/duckdb" + dialectmysql "github.com/spandigital/cel2sql/v3/dialect/mysql" + dialectpg "github.com/spandigital/cel2sql/v3/dialect/postgres" + dialectsqlite "github.com/spandigital/cel2sql/v3/dialect/sqlite" "github.com/spandigital/cel2sql/v3/pg" ) +// Test column name constants to avoid repetition. +const ( + colPersonEmail = "person.email" + colPersonMetadata = "person.metadata" +) + func TestAnalyzeQuery_JSONPathOperations(t *testing.T) { schema := pg.NewSchema([]pg.FieldSchema{ {Name: "id", Type: "text"}, @@ -34,14 +46,14 @@ func TestAnalyzeQuery_JSONPathOperations(t *testing.T) { { name: "simple JSON path access", expression: `person.metadata.name == "John"`, - expectedColumn: "person.metadata", + expectedColumn: colPersonMetadata, expectedType: "GIN", expectReason: "JSON path operations", }, { name: "nested JSON path access", expression: `person.metadata.profile.age > 18`, - expectedColumn: "person.metadata", + expectedColumn: colPersonMetadata, expectedType: "GIN", expectReason: "JSON path operations", }, @@ -120,7 +132,7 @@ func TestAnalyzeQuery_RegexOperations(t *testing.T) { // Check that we got a GIN index recommendation with pg_trgm found := false for _, rec := range recommendations { - if rec.Column == "person.email" && rec.IndexType == IndexTypeGIN { + if rec.Column == colPersonEmail && rec.IndexType == IndexTypeGIN { found = true if !strings.Contains(rec.Reason, "Regex matching") { t.Errorf("expected reason to mention regex matching, got %q", rec.Reason) @@ -370,9 +382,9 @@ func TestAnalyzeQuery_MultipleRecommendations(t *testing.T) { switch rec.Column { case "person.age": foundAge = rec.IndexType == IndexTypeBTree - case "person.email": + case colPersonEmail: foundEmail = rec.IndexType == IndexTypeGIN - case "person.metadata": + case colPersonMetadata: foundMetadata = rec.IndexType == IndexTypeGIN } } @@ -494,10 +506,330 @@ func TestAnalyzeQuery_IndexRecommendationPriority(t *testing.T) { // We should get a GIN recommendation for metadata, not BTREE for _, rec := range recommendations { - if rec.Column == "person.metadata" { + if rec.Column == colPersonMetadata { if rec.IndexType != IndexTypeGIN { t.Errorf("expected GIN index for JSON field, got %s", rec.IndexType) } } } } + +func TestAnalyzeQuery_WithDialect(t *testing.T) { + // Test that each dialect produces its own appropriate index types and DDL + schema := pg.NewSchema([]pg.FieldSchema{ + {Name: "id", Type: "bigint"}, + {Name: "age", Type: "integer"}, + {Name: "email", Type: "text"}, + {Name: "tags", Type: "text", Repeated: true}, + {Name: "metadata", Type: "jsonb", IsJSON: true, IsJSONB: true}, + }) + provider := pg.NewTypeProvider(map[string]pg.Schema{"person": schema}) + + env, err := cel.NewEnv( + cel.CustomTypeProvider(provider), + cel.Variable("person", cel.ObjectType("person")), + ) + if err != nil { + t.Fatalf("failed to create CEL environment: %v", err) + } + + type dialectTestCase struct { + name string + dialect dialect.Dialect + // Per-pattern expected results + comparisonType string // Expected IndexType for comparisons + comparisonContain string // Substring expected in Expression + jsonType string // Expected IndexType for JSON access + jsonContain string // Substring expected in Expression + } + + dialects := []dialectTestCase{ + { + name: "PostgreSQL", + dialect: dialectpg.New(), + comparisonType: "BTREE", + comparisonContain: "CREATE INDEX", + jsonType: "GIN", + jsonContain: "USING GIN", + }, + { + name: "MySQL", + dialect: dialectmysql.New(), + comparisonType: "BTREE", + comparisonContain: "CREATE INDEX", + jsonType: "BTREE", + jsonContain: "CAST", + }, + { + name: "SQLite", + dialect: dialectsqlite.New(), + comparisonType: "BTREE", + comparisonContain: "CREATE INDEX", + jsonType: "", // SQLite doesn't support JSON indexes + jsonContain: "", + }, + { + name: "DuckDB", + dialect: dialectduckdb.New(), + comparisonType: "ART", + comparisonContain: "CREATE INDEX", + jsonType: "ART", + jsonContain: "CREATE INDEX", + }, + { + name: "BigQuery", + dialect: dialectbq.New(), + comparisonType: "CLUSTERING", + comparisonContain: "clustering_columns", + jsonType: "SEARCH_INDEX", + jsonContain: "SEARCH INDEX", + }, + } + + for _, dt := range dialects { + t.Run(dt.name+"_comparison", func(t *testing.T) { + ast, issues := env.Compile(`person.age > 18`) + if issues != nil && issues.Err() != nil { + t.Fatalf("failed to compile expression: %v", issues.Err()) + } + + _, recommendations, err := AnalyzeQuery(ast, + WithSchemas(provider.GetSchemas()), + WithDialect(dt.dialect)) + if err != nil { + t.Fatalf("AnalyzeQuery failed: %v", err) + } + + found := false + for _, rec := range recommendations { + if rec.Column == "person.age" { + found = true + if rec.IndexType != dt.comparisonType { + t.Errorf("expected index type %q, got %q", dt.comparisonType, rec.IndexType) + } + if !strings.Contains(rec.Expression, dt.comparisonContain) { + t.Errorf("expected expression to contain %q, got %q", dt.comparisonContain, rec.Expression) + } + } + } + if !found { + t.Errorf("expected recommendation for person.age, got: %+v", recommendations) + } + }) + + t.Run(dt.name+"_json", func(t *testing.T) { + ast, issues := env.Compile(`person.metadata.verified == true`) + if issues != nil && issues.Err() != nil { + t.Fatalf("failed to compile expression: %v", issues.Err()) + } + + _, recommendations, err := AnalyzeQuery(ast, + WithSchemas(provider.GetSchemas()), + WithDialect(dt.dialect)) + if err != nil { + t.Fatalf("AnalyzeQuery failed: %v", err) + } + + if dt.jsonType == "" { + // This dialect doesn't recommend JSON indexes; verify none present for metadata + for _, rec := range recommendations { + if rec.Column == colPersonMetadata { + t.Errorf("expected no recommendation for JSON on %s, got: %+v", dt.name, rec) + } + } + return + } + + found := false + for _, rec := range recommendations { + if rec.Column == colPersonMetadata { + found = true + if rec.IndexType != dt.jsonType { + t.Errorf("expected index type %q, got %q", dt.jsonType, rec.IndexType) + } + if !strings.Contains(rec.Expression, dt.jsonContain) { + t.Errorf("expected expression to contain %q, got %q", dt.jsonContain, rec.Expression) + } + } + } + if !found { + t.Errorf("expected JSON recommendation for person.metadata on %s, got: %+v", dt.name, recommendations) + } + }) + } +} + +func TestAnalyzeQuery_UnsupportedPatternReturnsNil(t *testing.T) { + // SQLite should not produce recommendations for regex patterns + schema := pg.NewSchema([]pg.FieldSchema{ + {Name: "id", Type: "text"}, + {Name: "email", Type: "text"}, + }) + provider := pg.NewTypeProvider(map[string]pg.Schema{"person": schema}) + + env, err := cel.NewEnv( + cel.CustomTypeProvider(provider), + cel.Variable("person", cel.ObjectType("person")), + ) + if err != nil { + t.Fatalf("failed to create CEL environment: %v", err) + } + + // Note: We use person.email == "test" rather than matches() because SQLite + // doesn't support regex in SQL generation. We test the advisor directly instead. + advisor := dialectsqlite.New() + rec := advisor.RecommendIndex(dialect.IndexPattern{ + Column: colPersonEmail, + Pattern: dialect.PatternRegexMatch, + }) + if rec != nil { + t.Errorf("expected nil recommendation for regex on SQLite, got: %+v", rec) + } + + // Also verify SQLite returns nil for array patterns + rec = advisor.RecommendIndex(dialect.IndexPattern{ + Column: "person.tags", + Pattern: dialect.PatternArrayMembership, + }) + if rec != nil { + t.Errorf("expected nil recommendation for array membership on SQLite, got: %+v", rec) + } + + // But comparisons should still work + ast, issues := env.Compile(`person.email == "test@example.com"`) + if issues != nil && issues.Err() != nil { + t.Fatalf("failed to compile expression: %v", issues.Err()) + } + + _, recommendations, err := AnalyzeQuery(ast, + WithSchemas(provider.GetSchemas()), + WithDialect(dialectsqlite.New())) + if err != nil { + t.Fatalf("AnalyzeQuery failed: %v", err) + } + + found := false + for _, rec := range recommendations { + if rec.Column == colPersonEmail && rec.IndexType == "BTREE" { + found = true + } + } + if !found { + t.Errorf("expected BTREE recommendation for person.email on SQLite, got: %+v", recommendations) + } +} + +func TestAnalyzeQuery_AllDialectsSupportsIndexAnalysis(t *testing.T) { + // Verify that all built-in dialects report SupportsIndexAnalysis() = true + dialects := []dialect.Dialect{ + dialectpg.New(), + dialectmysql.New(), + dialectsqlite.New(), + dialectduckdb.New(), + dialectbq.New(), + } + + for _, d := range dialects { + t.Run(string(d.Name()), func(t *testing.T) { + if !d.SupportsIndexAnalysis() { + t.Errorf("%s should support index analysis", d.Name()) + } + + // Also verify the dialect implements IndexAdvisor + advisor, ok := dialect.GetIndexAdvisor(d) + if !ok { + t.Fatalf("%s does not implement IndexAdvisor", d.Name()) + } + + patterns := advisor.SupportedPatterns() + if len(patterns) == 0 { + t.Errorf("%s reports no supported patterns", d.Name()) + } + }) + } +} + +func TestAnalyzeQuery_IndexAdvisorSupportedPatterns(t *testing.T) { + tests := []struct { + name string + dialect dialect.Dialect + expectedPatterns []dialect.PatternType + }{ + { + name: "PostgreSQL supports all patterns", + dialect: dialectpg.New(), + expectedPatterns: []dialect.PatternType{ + dialect.PatternComparison, + dialect.PatternJSONAccess, + dialect.PatternRegexMatch, + dialect.PatternArrayMembership, + dialect.PatternArrayComprehension, + dialect.PatternJSONArrayComprehension, + }, + }, + { + name: "MySQL supports comparison, JSON, regex, JSON array", + dialect: dialectmysql.New(), + expectedPatterns: []dialect.PatternType{ + dialect.PatternComparison, + dialect.PatternJSONAccess, + dialect.PatternRegexMatch, + dialect.PatternJSONArrayComprehension, + }, + }, + { + name: "SQLite supports only comparison", + dialect: dialectsqlite.New(), + expectedPatterns: []dialect.PatternType{ + dialect.PatternComparison, + }, + }, + { + name: "DuckDB supports comparison, JSON, arrays", + dialect: dialectduckdb.New(), + expectedPatterns: []dialect.PatternType{ + dialect.PatternComparison, + dialect.PatternJSONAccess, + dialect.PatternArrayMembership, + dialect.PatternArrayComprehension, + dialect.PatternJSONArrayComprehension, + }, + }, + { + name: "BigQuery supports comparison, JSON, JSON array", + dialect: dialectbq.New(), + expectedPatterns: []dialect.PatternType{ + dialect.PatternComparison, + dialect.PatternJSONAccess, + dialect.PatternJSONArrayComprehension, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + advisor, ok := dialect.GetIndexAdvisor(tt.dialect) + if !ok { + t.Fatalf("dialect does not implement IndexAdvisor") + } + + patterns := advisor.SupportedPatterns() + if len(patterns) != len(tt.expectedPatterns) { + t.Errorf("expected %d patterns, got %d: %v", len(tt.expectedPatterns), len(patterns), patterns) + } + + for _, expected := range tt.expectedPatterns { + found := false + for _, actual := range patterns { + if actual == expected { + found = true + break + } + } + if !found { + t.Errorf("expected pattern %d not found in supported patterns", expected) + } + } + }) + } +} diff --git a/bigquery/provider.go b/bigquery/provider.go new file mode 100644 index 0000000..ce6dd57 --- /dev/null +++ b/bigquery/provider.go @@ -0,0 +1,228 @@ +// Package bigquery provides BigQuery type provider for CEL type system integration. +package bigquery + +import ( + "context" + "errors" + "fmt" + "strings" + + bq "cloud.google.com/go/bigquery" + "github.com/google/cel-go/checker/decls" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + + "github.com/spandigital/cel2sql/v3/schema" +) + +// Sentinel errors for the bigquery package. +var ( + // ErrInvalidSchema indicates a problem with the provided schema or database introspection. + ErrInvalidSchema = errors.New("invalid schema") +) + +// FieldSchema is an alias for schema.FieldSchema. +type FieldSchema = schema.FieldSchema + +// Schema is an alias for schema.Schema. +type Schema = schema.Schema + +// NewSchema creates a new Schema. This is an alias for schema.NewSchema. +var NewSchema = schema.NewSchema + +// TypeProvider interface for BigQuery type providers. +type TypeProvider interface { + types.Provider + LoadTableSchema(ctx context.Context, tableName string) error + GetSchemas() map[string]Schema + Close() +} + +type typeProvider struct { + schemas map[string]Schema + client *bq.Client + datasetID string +} + +// NewTypeProvider creates a new BigQuery type provider with pre-defined schemas. +func NewTypeProvider(schemas map[string]Schema) TypeProvider { + return &typeProvider{schemas: schemas} +} + +// NewTypeProviderWithClient creates a new BigQuery type provider that can introspect database schemas. +// The caller owns the *bigquery.Client and is responsible for closing it. +func NewTypeProviderWithClient(_ context.Context, client *bq.Client, datasetID string) (TypeProvider, error) { + if client == nil { + return nil, fmt.Errorf("%w: BigQuery client must not be nil", ErrInvalidSchema) + } + if datasetID == "" { + return nil, fmt.Errorf("%w: dataset ID must not be empty", ErrInvalidSchema) + } + + return &typeProvider{ + schemas: make(map[string]Schema), + client: client, + datasetID: datasetID, + }, nil +} + +// LoadTableSchema loads schema information for a table from BigQuery using the client API. +func (tp *typeProvider) LoadTableSchema(ctx context.Context, tableName string) error { + if tp.client == nil { + return fmt.Errorf("%w: no BigQuery client available", ErrInvalidSchema) + } + + meta, err := tp.client.Dataset(tp.datasetID).Table(tableName).Metadata(ctx) + if err != nil { + return fmt.Errorf("%w: failed to get table metadata", ErrInvalidSchema) + } + + fields := bigquerySchemaToFieldSchemas(meta.Schema) + if len(fields) == 0 { + return fmt.Errorf("%w: table %q has no columns", ErrInvalidSchema, tableName) + } + + tp.schemas[tableName] = NewSchema(fields) + return nil +} + +// bigquerySchemaToFieldSchemas converts a BigQuery schema to a slice of FieldSchemas. +func bigquerySchemaToFieldSchemas(bqSchema bq.Schema) []FieldSchema { + fields := make([]FieldSchema, 0, len(bqSchema)) + for _, f := range bqSchema { + fields = append(fields, bigqueryFieldToFieldSchema(f)) + } + return fields +} + +// bigqueryFieldToFieldSchema converts a BigQuery FieldSchema to our FieldSchema. +func bigqueryFieldToFieldSchema(f *bq.FieldSchema) FieldSchema { + typeName := bigqueryFieldTypeToString(f.Type) + isJSON := f.Type == bq.JSONFieldType + repeated := f.Repeated + + field := FieldSchema{ + Name: f.Name, + Type: typeName, + Repeated: repeated, + IsJSON: isJSON, + } + + // Handle nested RECORD types recursively + if f.Type == bq.RecordFieldType && len(f.Schema) > 0 { + field.Schema = bigquerySchemaToFieldSchemas(f.Schema) + } + + if repeated { + field.Dimensions = 1 + field.ElementType = typeName + } + + return field +} + +// bigqueryFieldTypeToString converts a BigQuery FieldType to a string type name. +func bigqueryFieldTypeToString(ft bq.FieldType) string { + return strings.ToLower(string(ft)) +} + +// Close is a no-op since we don't own the *bigquery.Client. +func (tp *typeProvider) Close() { + // No-op: caller owns the *bigquery.Client connection +} + +// GetSchemas returns the schemas known to this type provider. +func (tp *typeProvider) GetSchemas() map[string]Schema { + return tp.schemas +} + +// EnumValue implements types.Provider. +func (tp *typeProvider) EnumValue(_ string) ref.Val { + return types.NewErr("unknown enum value") +} + +// FindIdent implements types.Provider. +func (tp *typeProvider) FindIdent(_ string) (ref.Val, bool) { + return nil, false +} + +// FindStructType implements types.Provider. +func (tp *typeProvider) FindStructType(structType string) (*types.Type, bool) { + if _, ok := tp.schemas[structType]; ok { + return types.NewObjectType(structType), true + } + return nil, false +} + +// FindStructFieldNames implements types.Provider. +func (tp *typeProvider) FindStructFieldNames(structType string) ([]string, bool) { + s, ok := tp.schemas[structType] + if !ok { + return nil, false + } + fields := s.Fields() + names := make([]string, len(fields)) + for i, f := range fields { + names[i] = f.Name + } + return names, true +} + +// FindStructFieldType implements types.Provider. +func (tp *typeProvider) FindStructFieldType(structType, fieldName string) (*types.FieldType, bool) { + s, ok := tp.schemas[structType] + if !ok { + return nil, false + } + field, found := s.FindField(fieldName) + if !found { + return nil, false + } + + exprType := bigqueryTypeToCELExprType(field) + celType, err := types.ExprTypeToType(exprType) + if err != nil { + return nil, false + } + + return &types.FieldType{ + Type: celType, + }, true +} + +// NewValue implements types.Provider. +func (tp *typeProvider) NewValue(_ string, _ map[string]ref.Val) ref.Val { + return types.NewErr("unknown type in schema") +} + +// bigqueryTypeToCELExprType converts a BigQuery field schema to a CEL expression type. +func bigqueryTypeToCELExprType(field *schema.FieldSchema) *exprpb.Type { + baseType := bigqueryBaseTypeToCEL(field.Type) + if field.Repeated { + return decls.NewListType(baseType) + } + return baseType +} + +// bigqueryBaseTypeToCEL converts a BigQuery type name to a CEL expression type. +func bigqueryBaseTypeToCEL(typeName string) *exprpb.Type { + switch typeName { + case "STRING", "string": + return decls.String + case "INT64", "int64", "INTEGER", "integer": + return decls.Int + case "FLOAT64", "float64", "FLOAT", "float", "NUMERIC", "numeric": + return decls.Double + case "BOOL", "bool", "BOOLEAN", "boolean": + return decls.Bool + case "BYTES", "bytes": + return decls.Bytes + case "JSON", "json": + return decls.Dyn + case "TIMESTAMP", "timestamp": + return decls.Timestamp + default: + return decls.Dyn + } +} diff --git a/bigquery/provider_test.go b/bigquery/provider_test.go new file mode 100644 index 0000000..0583d24 --- /dev/null +++ b/bigquery/provider_test.go @@ -0,0 +1,274 @@ +package bigquery_test + +import ( + "bytes" + "context" + _ "embed" + "runtime" + "strings" + "testing" + + bq "cloud.google.com/go/bigquery" + "github.com/google/cel-go/common/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/testcontainers/testcontainers-go" + tcbigquery "github.com/testcontainers/testcontainers-go/modules/gcloud/bigquery" + "google.golang.org/api/option" + "google.golang.org/api/option/internaloption" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + + "github.com/spandigital/cel2sql/v3/bigquery" +) + +//go:embed testdata/provider_seed.yaml +var providerSeedYAML []byte + +const ( + testProjectID = "test-project" + testDataset = "testdataset" +) + +func TestNewTypeProvider(t *testing.T) { + schemas := map[string]bigquery.Schema{ + "users": bigquery.NewSchema([]bigquery.FieldSchema{ + {Name: "id", Type: "INTEGER"}, + {Name: "name", Type: "STRING"}, + }), + } + + provider := bigquery.NewTypeProvider(schemas) + require.NotNil(t, provider) + + got := provider.GetSchemas() + assert.Len(t, got, 1) + assert.Contains(t, got, "users") +} + +func TestNewTypeProviderWithClient_NilClient(t *testing.T) { + _, err := bigquery.NewTypeProviderWithClient(context.Background(), nil, "dataset") + require.Error(t, err) + assert.ErrorIs(t, err, bigquery.ErrInvalidSchema) +} + +func TestNewTypeProviderWithClient_EmptyDataset(t *testing.T) { + // We can't create a real client without credentials, so test the nil case + _, err := bigquery.NewTypeProviderWithClient(context.Background(), nil, "") + require.Error(t, err) + assert.ErrorIs(t, err, bigquery.ErrInvalidSchema) +} + +func TestLoadTableSchema_NoClient(t *testing.T) { + provider := bigquery.NewTypeProvider(nil) + err := provider.LoadTableSchema(context.Background(), "test") + require.Error(t, err) + assert.ErrorIs(t, err, bigquery.ErrInvalidSchema) +} + +func TestTypeProvider_FindStructType(t *testing.T) { + schemas := map[string]bigquery.Schema{ + "users": bigquery.NewSchema([]bigquery.FieldSchema{ + {Name: "id", Type: "INTEGER"}, + {Name: "name", Type: "STRING"}, + }), + } + provider := bigquery.NewTypeProvider(schemas) + + typ, found := provider.FindStructType("users") + assert.True(t, found) + assert.NotNil(t, typ) + + _, found = provider.FindStructType("nonexistent") + assert.False(t, found) +} + +func TestTypeProvider_FindStructFieldNames(t *testing.T) { + schemas := map[string]bigquery.Schema{ + "users": bigquery.NewSchema([]bigquery.FieldSchema{ + {Name: "id", Type: "INTEGER"}, + {Name: "name", Type: "STRING"}, + {Name: "email", Type: "STRING"}, + }), + } + provider := bigquery.NewTypeProvider(schemas) + + names, found := provider.FindStructFieldNames("users") + assert.True(t, found) + assert.ElementsMatch(t, []string{"id", "name", "email"}, names) + + _, found = provider.FindStructFieldNames("nonexistent") + assert.False(t, found) +} + +func TestTypeProvider_FindStructFieldType(t *testing.T) { + schemas := map[string]bigquery.Schema{ + "test_table": bigquery.NewSchema([]bigquery.FieldSchema{ + {Name: "str_field", Type: "STRING"}, + {Name: "int_field", Type: "INTEGER"}, + {Name: "int64_field", Type: "INT64"}, + {Name: "float_field", Type: "FLOAT64"}, + {Name: "bool_field", Type: "BOOL"}, + {Name: "bytes_field", Type: "BYTES"}, + {Name: "json_field", Type: "JSON"}, + {Name: "ts_field", Type: "TIMESTAMP"}, + {Name: "str_lower", Type: "string"}, + {Name: "int_lower", Type: "integer"}, + }), + } + provider := bigquery.NewTypeProvider(schemas) + + tests := []struct { + fieldName string + wantType *types.Type + wantFound bool + }{ + {"str_field", types.StringType, true}, + {"int_field", types.IntType, true}, + {"int64_field", types.IntType, true}, + {"float_field", types.DoubleType, true}, + {"bool_field", types.BoolType, true}, + {"bytes_field", types.BytesType, true}, + {"json_field", types.DynType, true}, + {"ts_field", types.TimestampType, true}, + {"str_lower", types.StringType, true}, + {"int_lower", types.IntType, true}, + {"nonexistent", nil, false}, + } + + for _, tt := range tests { + t.Run(tt.fieldName, func(t *testing.T) { + got, found := provider.FindStructFieldType("test_table", tt.fieldName) + assert.Equal(t, tt.wantFound, found) + if found { + assert.Equal(t, tt.wantType, got.Type) + } + }) + } +} + +func TestTypeProvider_Close(_ *testing.T) { + provider := bigquery.NewTypeProvider(nil) + // Close should not panic + provider.Close() +} + +// setupBigQueryContainer starts a BigQuery emulator container and returns a client. +func setupBigQueryContainer(ctx context.Context, t *testing.T) (*tcbigquery.Container, *bq.Client) { + t.Helper() + + container, err := tcbigquery.Run(ctx, + "ghcr.io/goccy/bigquery-emulator:0.6.6", + tcbigquery.WithProjectID(testProjectID), + tcbigquery.WithDataYAML(bytes.NewReader(providerSeedYAML)), + testcontainers.WithImagePlatform("linux/amd64"), + ) + if err != nil { + if runtime.GOARCH == "arm64" || strings.Contains(err.Error(), "no image found") || strings.Contains(err.Error(), "container exited") { + t.Skipf("Skipping BigQuery integration test: emulator not available on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err) + } + t.Fatalf("Failed to start BigQuery emulator container: %v", err) + } + + opts := []option.ClientOption{ + option.WithEndpoint(container.URI()), + option.WithGRPCDialOption(grpc.WithTransportCredentials(insecure.NewCredentials())), + option.WithoutAuthentication(), + internaloption.SkipDialSettingsValidation(), + } + + client, err := bq.NewClient(ctx, container.ProjectID(), opts...) + if err != nil { + if termErr := container.Terminate(ctx); termErr != nil { + t.Logf("failed to terminate container: %v", termErr) + } + t.Fatalf("Failed to create BigQuery client: %v", err) + } + + return container, client +} + +func TestLoadTableSchema_Integration(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + ctx := context.Background() + container, client := setupBigQueryContainer(ctx, t) + defer func() { + if closeErr := client.Close(); closeErr != nil { + t.Logf("failed to close client: %v", closeErr) + } + if err := container.Terminate(ctx); err != nil { + t.Logf("failed to terminate container: %v", err) + } + }() + + provider, err := bigquery.NewTypeProviderWithClient(ctx, client, testDataset) + require.NoError(t, err) + + err = provider.LoadTableSchema(ctx, "test_data") + require.NoError(t, err) + + // Verify schema was loaded + schemas := provider.GetSchemas() + assert.Contains(t, schemas, "test_data") + + // Verify FindStructType + typ, found := provider.FindStructType("test_data") + assert.True(t, found) + assert.NotNil(t, typ) + + // Verify FindStructFieldNames + names, found := provider.FindStructFieldNames("test_data") + assert.True(t, found) + assert.Contains(t, names, "id") + assert.Contains(t, names, "text_val") + assert.Contains(t, names, "int_val") + + // Verify type mappings from loaded schema + tests := []struct { + fieldName string + wantType *types.Type + }{ + {"id", types.IntType}, + {"text_val", types.StringType}, + {"int_val", types.IntType}, + {"float_val", types.DoubleType}, + {"bool_val", types.BoolType}, + } + + for _, tt := range tests { + t.Run("type_"+tt.fieldName, func(t *testing.T) { + got, found := provider.FindStructFieldType("test_data", tt.fieldName) + assert.True(t, found, "field %q should be found", tt.fieldName) + if found { + assert.Equal(t, tt.wantType, got.Type, "field %q type mismatch", tt.fieldName) + } + }) + } +} + +func TestLoadTableSchema_NonexistentTable(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + ctx := context.Background() + container, client := setupBigQueryContainer(ctx, t) + defer func() { + if closeErr := client.Close(); closeErr != nil { + t.Logf("failed to close client: %v", closeErr) + } + if err := container.Terminate(ctx); err != nil { + t.Logf("failed to terminate container: %v", err) + } + }() + + provider, err := bigquery.NewTypeProviderWithClient(ctx, client, testDataset) + require.NoError(t, err) + + err = provider.LoadTableSchema(ctx, "nonexistent_table") + require.Error(t, err) + assert.ErrorIs(t, err, bigquery.ErrInvalidSchema) +} diff --git a/bigquery/testdata/provider_seed.yaml b/bigquery/testdata/provider_seed.yaml new file mode 100644 index 0000000..8afab3f --- /dev/null +++ b/bigquery/testdata/provider_seed.yaml @@ -0,0 +1,41 @@ +projects: + - id: test-project + datasets: + - id: testdataset + tables: + - id: test_data + columns: + - name: id + type: INT64 + - name: text_val + type: STRING + - name: int_val + type: INT64 + - name: float_val + type: FLOAT64 + - name: bool_val + type: BOOL + - name: nullable_text + type: STRING + - name: nullable_int + type: INT64 + data: + - id: 1 + text_val: "hello" + int_val: 10 + float_val: 10.5 + bool_val: true + nullable_text: "present" + nullable_int: 100 + - id: 2 + text_val: "world" + int_val: 20 + float_val: 20.5 + bool_val: false + - id: 3 + text_val: "test" + int_val: 30 + float_val: 30.5 + bool_val: true + nullable_text: "here" + nullable_int: 200 diff --git a/bigquery_integration_test.go b/bigquery_integration_test.go new file mode 100644 index 0000000..1c87b12 --- /dev/null +++ b/bigquery_integration_test.go @@ -0,0 +1,441 @@ +package cel2sql_test + +import ( + "bytes" + "context" + _ "embed" + "fmt" + "runtime" + "strings" + "testing" + + "cloud.google.com/go/bigquery" + "github.com/google/cel-go/cel" + "github.com/stretchr/testify/require" + "github.com/testcontainers/testcontainers-go" + tcbigquery "github.com/testcontainers/testcontainers-go/modules/gcloud/bigquery" + "google.golang.org/api/option" + "google.golang.org/api/option/internaloption" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + + "github.com/spandigital/cel2sql/v3" + bigqueryDialect "github.com/spandigital/cel2sql/v3/dialect/bigquery" + "github.com/spandigital/cel2sql/v3/pg" +) + +//go:embed testdata/bigquery_seed.yaml +var bigQuerySeedYAML []byte + +const ( + bigQueryProjectID = "test-project" + bigQueryDataset = "testdataset" +) + +// setupBigQueryContainer starts a BigQuery emulator container and returns a client. +// Returns nil container and client if the emulator cannot start (e.g., on arm64). +func setupBigQueryContainer(ctx context.Context, t *testing.T) (*tcbigquery.Container, *bigquery.Client) { + t.Helper() + + container, err := tcbigquery.Run(ctx, + "ghcr.io/goccy/bigquery-emulator:0.6.6", + tcbigquery.WithProjectID(bigQueryProjectID), + tcbigquery.WithDataYAML(bytes.NewReader(bigQuerySeedYAML)), + testcontainers.WithImagePlatform("linux/amd64"), + ) + if err != nil { + // The BigQuery emulator only provides amd64 images. On arm64 (Apple Silicon), + // it crashes under QEMU emulation due to Go runtime lfstack.push issues. + if runtime.GOARCH == "arm64" || strings.Contains(err.Error(), "no image found") || strings.Contains(err.Error(), "container exited") { + t.Skipf("Skipping BigQuery integration test: emulator not available on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err) + } + t.Fatalf("Failed to start BigQuery emulator container: %v", err) + } + + opts := []option.ClientOption{ + option.WithEndpoint(container.URI()), + option.WithGRPCDialOption(grpc.WithTransportCredentials(insecure.NewCredentials())), + option.WithoutAuthentication(), + internaloption.SkipDialSettingsValidation(), + } + + client, err := bigquery.NewClient(ctx, container.ProjectID(), opts...) + if err != nil { + if termErr := container.Terminate(ctx); termErr != nil { + t.Logf("failed to terminate container: %v", termErr) + } + t.Fatalf("Failed to create BigQuery client: %v", err) + } + + return container, client +} + +// bigQueryCount executes a count query and returns the result. +func bigQueryCount(ctx context.Context, t *testing.T, client *bigquery.Client, query string) int { + t.Helper() + + q := client.Query(query) + it, err := q.Read(ctx) + require.NoError(t, err, "Failed to execute query: %s", query) + + var row []bigquery.Value + err = it.Next(&row) + require.NoError(t, err, "Failed to read query result: %s", query) + require.Len(t, row, 1, "Expected exactly one column in COUNT(*) result") + + switch v := row[0].(type) { + case int64: + return int(v) + case float64: + return int(v) + default: + t.Fatalf("Unexpected type %T for COUNT(*) result: %v", row[0], row[0]) + return 0 + } +} + +// TestBigQueryOperatorsIntegration validates operator conversions against a BigQuery emulator. +func TestBigQueryOperatorsIntegration(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + ctx := context.Background() + container, client := setupBigQueryContainer(ctx, t) + defer func() { + if closeErr := client.Close(); closeErr != nil { + t.Logf("failed to close client: %v", closeErr) + } + if err := container.Terminate(ctx); err != nil { + t.Logf("failed to terminate container: %v", err) + } + }() + + env, err := cel.NewEnv( + cel.Variable("id", cel.IntType), + cel.Variable("text_val", cel.StringType), + cel.Variable("int_val", cel.IntType), + cel.Variable("float_val", cel.DoubleType), + cel.Variable("bool_val", cel.BoolType), + cel.Variable("nullable_text", cel.StringType), + cel.Variable("nullable_int", cel.IntType), + ) + require.NoError(t, err) + + dialectOpt := cel2sql.WithDialect(bigqueryDialect.New()) + + tests := []struct { + name string + celExpr string + expectedRows int + description string + }{ + // Comparison operators + { + name: "Equality string", + celExpr: `text_val == "hello"`, + expectedRows: 1, + description: "String equality comparison", + }, + { + name: "Equality integer", + celExpr: `int_val == 20`, + expectedRows: 1, + description: "Integer equality comparison", + }, + { + name: "Equality float", + celExpr: `float_val == 10.5`, + expectedRows: 1, + description: "Float equality comparison", + }, + { + name: "Equality boolean", + celExpr: `bool_val == true`, + expectedRows: 3, + description: "Boolean equality comparison", + }, + { + name: "Not equal", + celExpr: `text_val != "hello"`, + expectedRows: 4, + description: "Not equal comparison", + }, + { + name: "Less than", + celExpr: `int_val < 15`, + expectedRows: 2, // 10, 5 + description: "Less than comparison", + }, + { + name: "Less than or equal", + celExpr: `int_val <= 15`, + expectedRows: 3, // 10, 5, 15 + description: "Less than or equal comparison", + }, + { + name: "Greater than", + celExpr: `int_val > 15`, + expectedRows: 2, // 20, 30 + description: "Greater than comparison", + }, + { + name: "Greater than or equal", + celExpr: `int_val >= 15`, + expectedRows: 3, // 20, 30, 15 + description: "Greater than or equal comparison", + }, + + // Logical operators + { + name: "Logical AND", + celExpr: `int_val > 10 && bool_val == true`, + expectedRows: 2, // rows 3 (30,true) and 5 (15,true) + description: "Logical AND operator", + }, + { + name: "Logical OR", + celExpr: `int_val < 10 || bool_val == false`, + expectedRows: 2, // rows 2 (20,false) and 4 (5,false) + description: "Logical OR operator", + }, + { + name: "Logical NOT", + celExpr: `!bool_val`, + expectedRows: 2, // rows 2 and 4 + description: "Logical NOT operator", + }, + { + name: "Complex logical expression", + celExpr: `(int_val > 10 && bool_val) || int_val < 10`, + expectedRows: 3, // rows 3, 5, 4 + description: "Complex nested logical operators", + }, + + // Arithmetic operators + { + name: "Addition", + celExpr: `int_val + 10 == 20`, + expectedRows: 1, // 10 + 10 = 20 + description: "Addition operator", + }, + { + name: "Subtraction", + celExpr: `int_val - 5 == 15`, + expectedRows: 1, // 20 - 5 = 15 + description: "Subtraction operator", + }, + { + name: "Multiplication", + celExpr: `int_val * 2 == 20`, + expectedRows: 1, // 10 * 2 = 20 + description: "Multiplication operator", + }, + { + name: "Division", + celExpr: `int_val / 2 == 10`, + expectedRows: 1, // 20 / 2 = 10 + description: "Division operator", + }, + { + name: "Modulo", + celExpr: `int_val % 10 == 0`, + expectedRows: 3, // 10, 20, 30 + description: "Modulo operator", + }, + { + name: "Complex arithmetic", + celExpr: `(int_val * 2) + 5 > 30`, + expectedRows: 3, // (20*2)+5=45, (30*2)+5=65, (15*2)+5=35 + description: "Complex arithmetic expression", + }, + + // String operators + { + name: "String concatenation", + celExpr: `text_val + "!" == "hello!"`, + expectedRows: 1, + description: "String concatenation (||)", + }, + { + name: "String contains", + celExpr: `text_val.contains("world")`, + expectedRows: 2, // "world", "hello world" + description: "String contains function (STRPOS)", + }, + { + name: "String startsWith", + celExpr: `text_val.startsWith("hello")`, + expectedRows: 2, // "hello", "hello world" + description: "String startsWith function (LIKE)", + }, + { + name: "String endsWith", + celExpr: `text_val.endsWith("world")`, + expectedRows: 2, // "world", "hello world" + description: "String endsWith function (LIKE)", + }, + + // Regex (BigQuery uses REGEXP_CONTAINS with RE2) + { + name: "Regex match", + celExpr: `text_val.matches(r"^hello")`, + expectedRows: 2, // "hello", "hello world" + description: "Regex match (REGEXP_CONTAINS)", + }, + { + name: "Regex simple pattern", + celExpr: `text_val.matches(r"test")`, + expectedRows: 2, // "test", "testing" + description: "Regex simple pattern", + }, + + // Complex combined operators + { + name: "Complex multi-operator expression", + celExpr: `int_val > 10 && bool_val && text_val.contains("test")`, + expectedRows: 2, // rows 3 and 5 + description: "Complex expression with multiple operator types", + }, + { + name: "Nested parenthesized operators", + celExpr: `((int_val + 5) * 2 > 30) && (text_val.contains("test") || bool_val)`, + expectedRows: 2, // rows 3 and 5 + description: "Deeply nested operators with parentheses", + }, + { + name: "Triple negation", + celExpr: `!!!bool_val`, + expectedRows: 2, + description: "Multiple NOT operators", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ast, issues := env.Compile(tt.celExpr) + if issues != nil && issues.Err() != nil { + t.Fatalf("CEL compilation failed: %v", issues.Err()) + } + + sqlCondition, err := cel2sql.Convert(ast, dialectOpt) + require.NoError(t, err, "Conversion should succeed for: %s", tt.description) + + t.Logf("CEL Expression: %s", tt.celExpr) + t.Logf("Generated SQL WHERE clause: %s", sqlCondition) + + // #nosec G202 - This is a test validating SQL generation, not a security risk + query := fmt.Sprintf("SELECT COUNT(*) FROM `%s.test_data` WHERE %s", bigQueryDataset, sqlCondition) + t.Logf("Full SQL Query: %s", query) + + actualRows := bigQueryCount(ctx, t, client, query) + + require.Equal(t, tt.expectedRows, actualRows, + "Query should return expected number of rows. %s\nCEL: %s\nSQL: %s", + tt.description, tt.celExpr, sqlCondition) + + t.Logf("OK: %s (expected %d rows, got %d rows)", + tt.description, tt.expectedRows, actualRows) + }) + } +} + +// TestBigQueryJSONIntegration validates JSON operations against a BigQuery emulator. +func TestBigQueryJSONIntegration(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + ctx := context.Background() + container, client := setupBigQueryContainer(ctx, t) + defer func() { + if closeErr := client.Close(); closeErr != nil { + t.Logf("failed to close client: %v", closeErr) + } + if err := container.Terminate(ctx); err != nil { + t.Logf("failed to terminate container: %v", err) + } + }() + + // Set up CEL environment with schema for JSON detection + productSchema := pg.NewSchema([]pg.FieldSchema{ + {Name: "id", Type: "integer"}, + {Name: "name", Type: "text"}, + {Name: "price", Type: "double precision"}, + {Name: "metadata", Type: "json", IsJSON: true}, + }) + + schemas := map[string]pg.Schema{ + "product": productSchema, + } + + provider := pg.NewTypeProvider(schemas) + + env, err := cel.NewEnv( + cel.CustomTypeProvider(provider), + cel.Variable("product", cel.ObjectType("product")), + ) + require.NoError(t, err) + + dialectOpt := cel2sql.WithDialect(bigqueryDialect.New()) + schemaOpt := cel2sql.WithSchemas(schemas) + + tests := []struct { + name string + celExpr string + expectedRows int + description string + }{ + { + name: "JSON field access", + celExpr: `product.metadata.brand == "Acme"`, + expectedRows: 2, + description: "JSON field access with JSON_VALUE", + }, + { + name: "JSON field access different value", + celExpr: `product.metadata.color == "blue"`, + expectedRows: 1, + description: "JSON field access with different value", + }, + { + name: "JSON with regular field", + celExpr: `product.metadata.brand == "Acme" && product.price > 30.0`, + expectedRows: 1, // Doohickey (Acme, 39.99) + description: "JSON field combined with regular field comparison", + }, + { + name: "JSON field existence", + celExpr: `has(product.metadata.brand)`, + expectedRows: 3, // All rows have 'brand' + description: "JSON field existence check", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ast, issues := env.Compile(tt.celExpr) + if issues != nil && issues.Err() != nil { + t.Fatalf("CEL compilation failed: %v", issues.Err()) + } + + sqlCondition, err := cel2sql.Convert(ast, dialectOpt, schemaOpt) + require.NoError(t, err, "Conversion should succeed for: %s", tt.description) + + t.Logf("CEL Expression: %s", tt.celExpr) + t.Logf("Generated SQL WHERE clause: %s", sqlCondition) + + // #nosec G202 - This is a test validating SQL generation, not a security risk + query := fmt.Sprintf("SELECT COUNT(*) FROM `%s.products` product WHERE %s", bigQueryDataset, sqlCondition) + t.Logf("Full SQL Query: %s", query) + + actualRows := bigQueryCount(ctx, t, client, query) + + require.Equal(t, tt.expectedRows, actualRows, + "Query should return expected number of rows. %s\nCEL: %s\nSQL: %s", + tt.description, tt.celExpr, sqlCondition) + + t.Logf("OK: %s", tt.description) + }) + } +} diff --git a/cel2sql.go b/cel2sql.go index 5827061..9c70b0e 100644 --- a/cel2sql.go +++ b/cel2sql.go @@ -1,13 +1,12 @@ -// Package cel2sql converts CEL (Common Expression Language) expressions to PostgreSQL SQL conditions. +// Package cel2sql converts CEL (Common Expression Language) expressions to SQL conditions. +// It supports multiple SQL dialects through the dialect interface, with PostgreSQL as the default. package cel2sql import ( "context" - "encoding/hex" "fmt" "log/slog" "math" - "regexp" "slices" "strconv" "strings" @@ -18,26 +17,16 @@ import ( "github.com/google/cel-go/common/overloads" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" - "github.com/spandigital/cel2sql/v3/pg" + "github.com/spandigital/cel2sql/v3/dialect" + "github.com/spandigital/cel2sql/v3/dialect/postgres" + "github.com/spandigital/cel2sql/v3/schema" ) // Implementations based on `google/cel-go`'s unparser // https://github.com/google/cel-go/blob/master/parser/unparser.go -// Regex pattern complexity limits to prevent ReDoS attacks (CWE-1333). +// Resource limit constants. const ( - // maxRegexPatternLength is the maximum allowed length for regex patterns - // to prevent processing extremely long patterns that could cause DoS. - maxRegexPatternLength = 500 - - // maxRegexGroups is the maximum number of capture groups allowed in a pattern - // to prevent memory exhaustion and slow matching. - maxRegexGroups = 20 - - // maxRegexNestingDepth is the maximum nesting depth for groups and quantifiers - // to prevent catastrophic backtracking. - maxRegexNestingDepth = 10 - // defaultMaxRecursionDepth is the default maximum recursion depth for visit() // to prevent stack overflow from deeply nested expressions (CWE-674: Uncontrolled Recursion). defaultMaxRecursionDepth = 100 @@ -62,11 +51,26 @@ type ConvertOption func(*convertOptions) // convertOptions holds configuration options for the Convert function. type convertOptions struct { - schemas map[string]pg.Schema + schemas map[string]schema.Schema ctx context.Context logger *slog.Logger - maxDepth int // Maximum recursion depth (0 = use default) - maxOutputLen int // Maximum SQL output length (0 = use default) + maxDepth int // Maximum recursion depth (0 = use default) + maxOutputLen int // Maximum SQL output length (0 = use default) + dialect dialect.Dialect // SQL dialect (nil = PostgreSQL default) +} + +// WithDialect sets the SQL dialect for conversion. +// If not provided, PostgreSQL is used as the default dialect. +// +// Example: +// +// import "github.com/spandigital/cel2sql/v3/dialect/mysql" +// +// sql, err := cel2sql.Convert(ast, cel2sql.WithDialect(mysql.New())) +func WithDialect(d dialect.Dialect) ConvertOption { + return func(o *convertOptions) { + o.dialect = d + } } // WithSchemas provides schema information for proper JSON/JSONB field handling. @@ -76,7 +80,7 @@ type convertOptions struct { // // schemas := provider.GetSchemas() // sql, err := cel2sql.Convert(ast, cel2sql.WithSchemas(schemas)) -func WithSchemas(schemas map[string]pg.Schema) ConvertOption { +func WithSchemas(schemas map[string]schema.Schema) ConvertOption { return func(o *convertOptions) { o.schemas = schemas } @@ -177,16 +181,20 @@ type Result struct { Parameters []any // Parameter values in order ($1, $2, etc.) } -// Convert converts a CEL AST to a PostgreSQL SQL WHERE clause condition. -// Options can be provided to configure the conversion behavior. +// Convert converts a CEL AST to a SQL WHERE clause condition. +// By default, PostgreSQL SQL is generated. Use WithDialect to select a different dialect. // -// Example without options: +// Example without options (PostgreSQL): // // sql, err := cel2sql.Convert(ast) // // Example with schema information for JSON/JSONB support: // // sql, err := cel2sql.Convert(ast, cel2sql.WithSchemas(schemas)) +// +// Example with a different dialect: +// +// sql, err := cel2sql.Convert(ast, cel2sql.WithDialect(mysql.New())) func Convert(ast *cel.Ast, opts ...ConvertOption) (string, error) { start := time.Now() @@ -199,6 +207,11 @@ func Convert(ast *cel.Ast, opts ...ConvertOption) (string, error) { opt(options) } + // Default to PostgreSQL dialect if none specified + if options.dialect == nil { + options.dialect = postgres.New() + } + options.logger.Debug("starting CEL to SQL conversion") checkedExpr, err := cel.AstToCheckedExpr(ast) @@ -212,6 +225,7 @@ func Convert(ast *cel.Ast, opts ...ConvertOption) (string, error) { schemas: options.schemas, ctx: options.ctx, logger: options.logger, + dialect: options.dialect, maxDepth: options.maxDepth, maxOutputLen: options.maxOutputLen, } @@ -227,15 +241,16 @@ func Convert(ast *cel.Ast, opts ...ConvertOption) (string, error) { options.logger.LogAttrs(context.Background(), slog.LevelDebug, "conversion completed", slog.String("sql", result), + slog.String("dialect", string(options.dialect.Name())), slog.Duration("duration", duration), ) return result, nil } -// ConvertParameterized converts a CEL AST to a parameterized PostgreSQL SQL WHERE clause. -// Returns both the SQL string with placeholders ($1, $2, etc.) and the parameter values. -// This enables query plan caching and provides additional SQL injection protection. +// ConvertParameterized converts a CEL AST to a parameterized SQL WHERE clause. +// Returns both the SQL string with placeholders and the parameter values. +// By default uses PostgreSQL ($1, $2). Use WithDialect for other placeholder styles. // // Constants that are parameterized: // - String literals: 'John' → $1 @@ -268,6 +283,11 @@ func ConvertParameterized(ast *cel.Ast, opts ...ConvertOption) (*Result, error) opt(options) } + // Default to PostgreSQL dialect if none specified + if options.dialect == nil { + options.dialect = postgres.New() + } + options.logger.Debug("starting parameterized CEL to SQL conversion") checkedExpr, err := cel.AstToCheckedExpr(ast) @@ -281,6 +301,7 @@ func ConvertParameterized(ast *cel.Ast, opts ...ConvertOption) (*Result, error) schemas: options.schemas, ctx: options.ctx, logger: options.logger, + dialect: options.dialect, maxDepth: options.maxDepth, maxOutputLen: options.maxOutputLen, parameterize: true, // Enable parameterization @@ -310,16 +331,17 @@ func ConvertParameterized(ast *cel.Ast, opts ...ConvertOption) (*Result, error) type converter struct { str strings.Builder typeMap map[int64]*exprpb.Type - schemas map[string]pg.Schema + schemas map[string]schema.Schema ctx context.Context logger *slog.Logger + dialect dialect.Dialect depth int // Current recursion depth maxDepth int // Maximum allowed recursion depth maxOutputLen int // Maximum allowed SQL output length comprehensionDepth int // Current comprehension nesting depth parameterize bool // Enable parameterized output parameters []any // Collected parameters for parameterized queries - paramCount int // Parameter counter for placeholders ($1, $2, etc.) + paramCount int // Parameter counter for placeholders } // checkContext checks if the context has been cancelled or expired. @@ -599,6 +621,30 @@ func (con *converter) visitCallBinary(expr *exprpb.Expr) error { rhsParen = isSamePrecedence(fun, rhs) } + // Handle string concatenation via dialect before writing LHS. + // This allows MySQL to use CONCAT() instead of ||. + if fun == operators.Add && + ((lhsType.GetPrimitive() == exprpb.Type_STRING && rhsType.GetPrimitive() == exprpb.Type_STRING) || + (isStringLiteral(lhs) || isStringLiteral(rhs))) { + return con.dialect.WriteStringConcat(&con.str, + func() error { return con.visitMaybeNested(lhs, lhsParen) }, + func() error { return con.visitMaybeNested(rhs, rhsParen) }, + ) + } + + // Handle array membership (IN operator with list) via dialect before writing LHS. + // This allows dialects like SQLite to use a fundamentally different pattern + // (e.g., "elem IN (SELECT value FROM json_each(array))") instead of "elem = ANY(array)". + if fun == operators.In && isListType(rhsType) { + // Non-JSON list membership + if !isFieldAccessExpression(rhs) || !con.isJSONArrayField(rhs) { + return con.dialect.WriteArrayMembership(&con.str, + func() error { return con.visitMaybeNested(lhs, lhsParen) }, + func() error { return con.visitMaybeNested(rhs, rhsParen) }, + ) + } + } + // Check if we need numeric casting for JSON text extraction needsNumericCasting := false if con.isJSONTextExtraction(lhs) && isNumericComparison(fun) && isNumericType(rhsType) { @@ -611,7 +657,8 @@ func (con *converter) visitCallBinary(expr *exprpb.Expr) error { } if needsNumericCasting { - con.str.WriteString(")::numeric") + con.str.WriteString(")") + con.dialect.WriteCastToNumeric(&con.str) } var operator string if fun == operators.Add && (lhsType.GetPrimitive() == exprpb.Type_STRING && rhsType.GetPrimitive() == exprpb.Type_STRING) { @@ -655,28 +702,25 @@ func (con *converter) visitCallBinary(expr *exprpb.Expr) error { if fun == operators.In && (isListType(rhsType) || isFieldAccessExpression(rhs)) { // Check if we're dealing with a JSON array if isFieldAccessExpression(rhs) && con.isJSONArrayField(rhs) { - // For JSON arrays, use jsonb_array_elements with ANY + // For JSON arrays, use dialect-specific JSON array membership jsonFunc := con.getJSONArrayFunction(rhs) - con.str.WriteString("ANY(ARRAY(SELECT ") // For nested JSON access like settings.permissions, we need to handle differently if con.isNestedJSONAccess(rhs) { - // Use text extraction for the array elements - con.str.WriteString("jsonb_array_elements_text(") - // Generate the JSON path with -> instead of ->> to preserve JSONB type - if err := con.visitNestedJSONForArray(rhs); err != nil { + // Use dialect-specific nested JSON array membership + if err := con.dialect.WriteNestedJSONArrayMembership(&con.str, func() error { + return con.visitNestedJSONForArray(rhs) + }); err != nil { return err } - con.str.WriteString(")))") return nil } // For direct JSON array access - con.str.WriteString(jsonFunc) - con.str.WriteString("(") - if err := con.visitMaybeNested(rhs, rhsParen); err != nil { + if err := con.dialect.WriteJSONArrayMembership(&con.str, jsonFunc, func() error { + return con.visitMaybeNested(rhs, rhsParen) + }); err != nil { return err } - con.str.WriteString(")))") return nil } con.str.WriteString("ANY(") @@ -728,27 +772,27 @@ func (con *converter) callContains(target *exprpb.Expr, args []*exprpb.Expr) err return nil } - // For regular strings, use POSITION - con.str.WriteString("POSITION(") - for i, arg := range args { - err := con.visit(arg) - if err != nil { - return err - } - if i < len(args)-1 { - con.str.WriteString(" IN ") - } - } - if target != nil { - con.str.WriteString(" IN ") - nested := isBinaryOrTernaryOperator(target) - err := con.visitMaybeNested(target, nested) - if err != nil { - return err - } - } - con.str.WriteString(") > 0") - return nil + // For regular strings, use dialect-specific contains + return con.dialect.WriteContains(&con.str, + func() error { + if target != nil { + nested := isBinaryOrTernaryOperator(target) + return con.visitMaybeNested(target, nested) + } + return nil + }, + func() error { + for i, arg := range args { + if err := con.visit(arg); err != nil { + return err + } + if i < len(args)-1 { + con.str.WriteString(", ") + } + } + return nil + }, + ) } func (con *converter) callStartsWith(target *exprpb.Expr, args []*exprpb.Expr) error { @@ -780,14 +824,16 @@ func (con *converter) callStartsWith(target *exprpb.Expr, args []*exprpb.Expr) e escaped := escapeLikePattern(prefix) con.str.WriteString("'") con.str.WriteString(escaped) - con.str.WriteString("%' ESCAPE E'\\\\'") + con.str.WriteString("%'") + con.dialect.WriteLikeEscape(&con.str) } else { // For non-literal patterns, escape special characters at runtime and concatenate with % con.str.WriteString("REPLACE(REPLACE(REPLACE(") if err := con.visit(args[0]); err != nil { return err } - con.str.WriteString(", '\\\\', '\\\\\\\\'), '%', '\\%'), '_', '\\_') || '%' ESCAPE E'\\\\'") + con.str.WriteString(", '\\\\', '\\\\\\\\'), '%', '\\%'), '_', '\\_') || '%'") + con.dialect.WriteLikeEscape(&con.str) } return nil @@ -795,8 +841,7 @@ func (con *converter) callStartsWith(target *exprpb.Expr, args []*exprpb.Expr) e func (con *converter) callEndsWith(target *exprpb.Expr, args []*exprpb.Expr) error { // CEL endsWith function: string.endsWith(suffix) - // Convert to PostgreSQL: string LIKE '%suffix' - // or for more robust handling: RIGHT(string, LENGTH(suffix)) = suffix + // Convert to SQL: string LIKE '%suffix' if target == nil || len(args) == 0 { return fmt.Errorf("%w: endsWith function requires both string and suffix arguments", ErrInvalidArguments) @@ -822,14 +867,16 @@ func (con *converter) callEndsWith(target *exprpb.Expr, args []*exprpb.Expr) err escaped := escapeLikePattern(suffix) con.str.WriteString("'%") con.str.WriteString(escaped) - con.str.WriteString("' ESCAPE E'\\\\'") + con.str.WriteString("'") + con.dialect.WriteLikeEscape(&con.str) } else { // For non-literal patterns, escape special characters at runtime and concatenate with % con.str.WriteString("'%' || REPLACE(REPLACE(REPLACE(") if err := con.visit(args[0]); err != nil { return err } - con.str.WriteString(", '\\\\', '\\\\\\\\'), '%', '\\%'), '_', '\\_') ESCAPE E'\\\\'") + con.str.WriteString(", '\\\\', '\\\\\\\\'), '%', '\\%'), '_', '\\_')") + con.dialect.WriteLikeEscape(&con.str) } return nil @@ -841,40 +888,44 @@ func (con *converter) callCasting(function string, _ *exprpb.Expr, args []*exprp } arg := args[0] if function == overloads.TypeConvertInt && isTimestampType(con.getType(arg)) { - con.str.WriteString("EXTRACT(EPOCH FROM ") - if err := con.visit(arg); err != nil { - return err - } - con.str.WriteString(")::bigint") - return nil + return con.dialect.WriteEpochExtract(&con.str, func() error { + return con.visit(arg) + }) } con.str.WriteString("CAST(") if err := con.visit(arg); err != nil { return err } con.str.WriteString(" AS ") + // Map CEL type conversion function to dialect-specific type name + var celTypeName string switch function { case overloads.TypeConvertBool: - con.str.WriteString("BOOLEAN") + celTypeName = "bool" case overloads.TypeConvertBytes: - con.str.WriteString("BYTEA") + celTypeName = "bytes" case overloads.TypeConvertDouble: - con.str.WriteString("DOUBLE PRECISION") + celTypeName = "double" case overloads.TypeConvertInt: - con.str.WriteString("BIGINT") + celTypeName = "int" case overloads.TypeConvertString: - con.str.WriteString("TEXT") + celTypeName = "string" case overloads.TypeConvertUint: - con.str.WriteString("BIGINT") + celTypeName = "uint" } + con.dialect.WriteTypeName(&con.str, celTypeName) con.str.WriteString(")") return nil } -// callMatches handles CEL matches() function with RE2 to POSIX regex conversion +// callMatches handles CEL matches() function with regex conversion func (con *converter) callMatches(target *exprpb.Expr, args []*exprpb.Expr) error { // CEL matches function: string.matches(pattern) or matches(string, pattern) - // Convert to PostgreSQL: string ~ 'posix_pattern' + + // Check if the dialect supports regex + if !con.dialect.SupportsRegex() { + return fmt.Errorf("%w: regex matching is not supported by %s dialect", ErrUnsupportedDialectFeature, con.dialect.Name()) + } // Get the string to match against var stringExpr *exprpb.Expr @@ -896,22 +947,16 @@ func (con *converter) callMatches(target *exprpb.Expr, args []*exprpb.Expr) erro return fmt.Errorf("%w: matches function requires both string and pattern arguments", ErrInvalidArguments) } - // Visit the string expression - if err := con.visit(stringExpr); err != nil { - return err - } - - // Visit the pattern expression and convert from RE2 to POSIX if it's a string literal + // Visit the pattern expression and convert if it's a string literal if constExpr := patternExpr.GetConstExpr(); constExpr != nil && constExpr.GetStringValue() != "" { - // Convert RE2 pattern to POSIX re2Pattern := constExpr.GetStringValue() // Reject patterns containing null bytes if strings.Contains(re2Pattern, "\x00") { return fmt.Errorf("%w: regex patterns cannot contain null bytes", ErrInvalidRegexPattern) } - // Convert RE2 to POSIX with security validation - posixPattern, caseInsensitive, err := convertRE2ToPOSIX(re2Pattern) + // Convert RE2 to dialect-native format with security validation + convertedPattern, caseInsensitive, err := con.dialect.ConvertRegex(re2Pattern) if err != nil { return fmt.Errorf("%w: %w", ErrInvalidRegexPattern, err) } @@ -919,32 +964,23 @@ func (con *converter) callMatches(target *exprpb.Expr, args []*exprpb.Expr) erro con.logger.LogAttrs(context.Background(), slog.LevelDebug, "regex pattern conversion", slog.String("original_pattern", re2Pattern), - slog.String("converted_pattern", posixPattern), + slog.String("converted_pattern", convertedPattern), slog.Bool("case_insensitive", caseInsensitive), + slog.String("dialect", string(con.dialect.Name())), ) - // Use ~* for case-insensitive matching, ~ for case-sensitive - if caseInsensitive { - con.str.WriteString(" ~* ") - } else { - con.str.WriteString(" ~ ") - } - - // Write the converted pattern as a string literal - escaped := strings.ReplaceAll(posixPattern, "'", "''") - con.str.WriteString("'") - con.str.WriteString(escaped) - con.str.WriteString("'") - } else { - // For non-literal patterns, we can't convert at compile time - // Just use the pattern as-is with case-sensitive operator - con.str.WriteString(" ~ ") - if err := con.visit(patternExpr); err != nil { - return err - } + // Use dialect-specific regex match writing + return con.dialect.WriteRegexMatch(&con.str, func() error { + return con.visit(stringExpr) + }, convertedPattern, caseInsensitive) } - - return nil + // For non-literal patterns, we can't convert at compile time + // Visit the string, then write regex operator, then visit the pattern + if err := con.visit(stringExpr); err != nil { + return err + } + con.str.WriteString(" ~ ") + return con.visit(patternExpr) } // callLowerASCII handles CEL lowerAscii() string function @@ -1469,53 +1505,36 @@ func (con *converter) callSplit(target *exprpb.Expr, args []*exprpb.Expr) error } // Generate SQL based on limit value + writeStr := func() error { + nested := isBinaryOrTernaryOperator(stringExpr) + return con.visitMaybeNested(stringExpr, nested) + } + writeDelim := func() error { + return con.visit(delimiterExpr) + } + switch { case limit == 0: // Empty array - con.str.WriteString("ARRAY[]::text[]") + con.dialect.WriteEmptyTypedArray(&con.str, "text") return nil case limit == 1: // Return original string as single-element array - con.str.WriteString("ARRAY[") - nested := isBinaryOrTernaryOperator(stringExpr) - if err := con.visitMaybeNested(stringExpr, nested); err != nil { + con.dialect.WriteArrayLiteralOpen(&con.str) + if err := writeStr(); err != nil { return err } - con.str.WriteString("]") + con.dialect.WriteArrayLiteralClose(&con.str) return nil case limit == -1: - // Unlimited splits (default PostgreSQL behavior) - con.str.WriteString("STRING_TO_ARRAY(") - nested := isBinaryOrTernaryOperator(stringExpr) - if err := con.visitMaybeNested(stringExpr, nested); err != nil { - return err - } - con.str.WriteString(", ") - if err := con.visit(delimiterExpr); err != nil { - return err - } - con.str.WriteString(")") - return nil + // Unlimited splits + return con.dialect.WriteSplit(&con.str, writeStr, writeDelim) case limit > 1: - // Arbitrary positive limit - use array slicing with REGEXP_SPLIT_TO_ARRAY - // REGEXP_SPLIT_TO_ARRAY is more powerful and allows us to limit splits - // Result: (REGEXP_SPLIT_TO_ARRAY(string, delimiter))[1:limit] - con.str.WriteString("(STRING_TO_ARRAY(") - nested := isBinaryOrTernaryOperator(stringExpr) - if err := con.visitMaybeNested(stringExpr, nested); err != nil { - return err - } - con.str.WriteString(", ") - if err := con.visit(delimiterExpr); err != nil { - return err - } - con.str.WriteString("))[1:") - con.str.WriteString(strconv.FormatInt(limit, 10)) - con.str.WriteString("]") - return nil + // Positive limit - use dialect-specific split with limit + return con.dialect.WriteSplitWithLimit(&con.str, writeStr, writeDelim, limit) default: // Negative limits other than -1 are not supported @@ -1559,26 +1578,18 @@ func (con *converter) callJoin(target *exprpb.Expr, args []*exprpb.Expr) error { } } - // Generate SQL - con.str.WriteString("ARRAY_TO_STRING(") - nested := isBinaryOrTernaryOperator(arrayExpr) - if err := con.visitMaybeNested(arrayExpr, nested); err != nil { - return err + // Generate SQL using dialect-specific join + writeArray := func() error { + nested := isBinaryOrTernaryOperator(arrayExpr) + return con.visitMaybeNested(arrayExpr, nested) } - con.str.WriteString(", ") - - // Use provided delimiter or empty string default + var writeDelim func() error if delimiterExpr != nil { - if err := con.visit(delimiterExpr); err != nil { - return err + writeDelim = func() error { + return con.visit(delimiterExpr) } - } else { - con.str.WriteString("''") } - - // Third parameter: null_string (use empty string to replace nulls) - con.str.WriteString(", '')") - return nil + return con.dialect.WriteJoin(&con.str, writeArray, writeDelim) } // callFormat handles CEL format() function @@ -1810,28 +1821,17 @@ func (con *converter) visitCallFunc(expr *exprpb.Expr) error { case isListType(argType): // Check if this is a JSON array field if con.isJSONArrayField(argExpr) { - // For JSON arrays, use jsonb_array_length wrapped in COALESCE - con.str.WriteString("COALESCE(jsonb_array_length(") - err := con.visit(argExpr) - if err != nil { - return err - } - con.str.WriteString("), 0)") - return nil + // For JSON arrays, use dialect-specific JSON array length + return con.dialect.WriteJSONArrayLength(&con.str, func() error { + return con.visit(argExpr) + }) } - // For PostgreSQL, we need to specify the array dimension - // Detect the dimension from schema if available, otherwise default to 1 + // For native arrays, use dialect-specific array length dimension := con.getArrayDimension(argExpr) - - // Wrap in COALESCE to handle NULL arrays (ARRAY_LENGTH returns NULL for NULL input) - con.str.WriteString("COALESCE(ARRAY_LENGTH(") - nested := isBinaryOrTernaryOperator(argExpr) - err := con.visitMaybeNested(argExpr, nested) - if err != nil { - return err - } - fmt.Fprintf(&con.str, ", %d), 0)", dimension) - return nil + return con.dialect.WriteArrayLength(&con.str, dimension, func() error { + nested := isBinaryOrTernaryOperator(argExpr) + return con.visitMaybeNested(argExpr, nested) + }) default: return newConversionErrorf(errMsgUnsupportedType, "size() argument type: %s", argType.String()) } @@ -1900,13 +1900,9 @@ func (con *converter) visitCallListIndex(expr *exprpb.Expr) error { return fmt.Errorf("%w: list index operator requires list and index arguments", ErrInvalidArguments) } l := args[0] - nested := isBinaryOrTernaryOperator(l) - if err := con.visitMaybeNested(l, nested); err != nil { - return err - } - con.str.WriteString("[") index := args[1] - // PostgreSQL arrays are 1-indexed, CEL is 0-indexed, so add 1 + + // Check for constant index if constExpr := index.GetConstExpr(); constExpr != nil { idx := constExpr.GetInt64Value() if idx == math.MaxInt64 { @@ -1915,15 +1911,19 @@ func (con *converter) visitCallListIndex(expr *exprpb.Expr) error { if idx < 0 { return fmt.Errorf("%w: negative array index %d is not supported", ErrInvalidArguments, idx) } - con.str.WriteString(strconv.FormatInt(idx+1, 10)) - } else { - if err := con.visit(index); err != nil { - return err - } - con.str.WriteString(" + 1") + return con.dialect.WriteListIndexConst(&con.str, func() error { + nested := isBinaryOrTernaryOperator(l) + return con.visitMaybeNested(l, nested) + }, idx) } - con.str.WriteString("]") - return nil + + // Dynamic index + return con.dialect.WriteListIndex(&con.str, func() error { + nested := isBinaryOrTernaryOperator(l) + return con.visitMaybeNested(l, nested) + }, func() error { + return con.visit(index) + }) } func (con *converter) visitCallUnary(expr *exprpb.Expr) error { @@ -1992,36 +1992,17 @@ func (con *converter) visitComprehension(expr *exprpb.Expr) error { // Comprehension visit functions - Phase 1 placeholder implementations func (con *converter) visitAllComprehension(expr *exprpb.Expr, info *ComprehensionInfo) error { - // Generate SQL for ALL comprehension: all elements must satisfy the predicate - // Pattern: NOT EXISTS (SELECT 1 FROM UNNEST(array) AS item WHERE NOT predicate) - // For JSON arrays: NOT EXISTS (SELECT 1 FROM jsonb_array_elements(json_field) AS item WHERE NOT predicate) - comprehension := expr.GetComprehensionExpr() if comprehension == nil { return newConversionError(errMsgUnsupportedComprehension, "expression is not a comprehension (ALL)") } iterRange := comprehension.GetIterRange() - isJSONArray := con.isJSONArrayField(iterRange) con.str.WriteString("NOT EXISTS (SELECT 1 FROM ") - - if isJSONArray { - jsonFunc := con.getJSONArrayFunction(iterRange) - con.str.WriteString(jsonFunc) - con.str.WriteString("(") - if err := con.visit(iterRange); err != nil { - return wrapConversionError(err, "visiting iter range in ALL comprehension") - } - con.str.WriteString(")") - } else { - con.str.WriteString("UNNEST(") - if err := con.visit(iterRange); err != nil { - return wrapConversionError(err, "visiting iter range in ALL comprehension") - } - con.str.WriteString(")") + if err := con.writeComprehensionSource(iterRange); err != nil { + return wrapConversionError(err, "visiting iter range in ALL comprehension") } - con.str.WriteString(" AS ") con.str.WriteString(info.IterVar) @@ -2038,36 +2019,17 @@ func (con *converter) visitAllComprehension(expr *exprpb.Expr, info *Comprehensi } func (con *converter) visitExistsComprehension(expr *exprpb.Expr, info *ComprehensionInfo) error { - // Generate SQL for EXISTS comprehension: at least one element satisfies the predicate - // Pattern: EXISTS (SELECT 1 FROM UNNEST(array) AS item WHERE predicate) - // For JSON arrays: EXISTS (SELECT 1 FROM jsonb_array_elements(json_field) AS item WHERE predicate) - comprehension := expr.GetComprehensionExpr() if comprehension == nil { return newConversionError(errMsgUnsupportedComprehension, "expression is not a comprehension (EXISTS)") } iterRange := comprehension.GetIterRange() - isJSONArray := con.isJSONArrayField(iterRange) con.str.WriteString("EXISTS (SELECT 1 FROM ") - - if isJSONArray { - jsonFunc := con.getJSONArrayFunction(iterRange) - con.str.WriteString(jsonFunc) - con.str.WriteString("(") - if err := con.visit(iterRange); err != nil { - return wrapConversionError(err, "visiting iter range in EXISTS comprehension") - } - con.str.WriteString(")") - } else { - con.str.WriteString("UNNEST(") - if err := con.visit(iterRange); err != nil { - return wrapConversionError(err, "visiting iter range in EXISTS comprehension") - } - con.str.WriteString(")") + if err := con.writeComprehensionSource(iterRange); err != nil { + return wrapConversionError(err, "visiting iter range in EXISTS comprehension") } - con.str.WriteString(" AS ") con.str.WriteString(info.IterVar) @@ -2083,36 +2045,17 @@ func (con *converter) visitExistsComprehension(expr *exprpb.Expr, info *Comprehe } func (con *converter) visitExistsOneComprehension(expr *exprpb.Expr, info *ComprehensionInfo) error { - // Generate SQL for EXISTS_ONE comprehension: exactly one element satisfies the predicate - // Pattern: (SELECT COUNT(*) FROM UNNEST(array) AS item WHERE predicate) = 1 - // For JSON arrays: (SELECT COUNT(*) FROM jsonb_array_elements(json_field) AS item WHERE predicate) = 1 - comprehension := expr.GetComprehensionExpr() if comprehension == nil { return newConversionError(errMsgUnsupportedComprehension, "expression is not a comprehension (EXISTS_ONE)") } iterRange := comprehension.GetIterRange() - isJSONArray := con.isJSONArrayField(iterRange) con.str.WriteString("(SELECT COUNT(*) FROM ") - - if isJSONArray { - jsonFunc := con.getJSONArrayFunction(iterRange) - con.str.WriteString(jsonFunc) - con.str.WriteString("(") - if err := con.visit(iterRange); err != nil { - return wrapConversionError(err, "visiting iter range in EXISTS_ONE comprehension") - } - con.str.WriteString(")") - } else { - con.str.WriteString("UNNEST(") - if err := con.visit(iterRange); err != nil { - return wrapConversionError(err, "visiting iter range in EXISTS_ONE comprehension") - } - con.str.WriteString(")") + if err := con.writeComprehensionSource(iterRange); err != nil { + return wrapConversionError(err, "visiting iter range in EXISTS_ONE comprehension") } - con.str.WriteString(" AS ") con.str.WriteString(info.IterVar) @@ -2128,52 +2071,29 @@ func (con *converter) visitExistsOneComprehension(expr *exprpb.Expr, info *Compr } func (con *converter) visitMapComprehension(expr *exprpb.Expr, info *ComprehensionInfo) error { - // Generate SQL for MAP comprehension: transform elements using the transform expression - // Pattern: ARRAY(SELECT transform FROM UNNEST(array) AS item [WHERE filter]) - // For JSON arrays: ARRAY(SELECT transform FROM jsonb_array_elements(json_field) AS item [WHERE filter]) - comprehension := expr.GetComprehensionExpr() if comprehension == nil { return newConversionError(errMsgUnsupportedComprehension, "expression is not a comprehension (MAP)") } iterRange := comprehension.GetIterRange() - isJSONArray := con.isJSONArrayField(iterRange) - - con.str.WriteString("ARRAY(SELECT ") - // Visit the transform expression + con.dialect.WriteArraySubqueryOpen(&con.str) if info.Transform != nil { if err := con.visit(info.Transform); err != nil { return wrapConversionError(err, "visiting transform in MAP comprehension") } } else { - // If no transform, just return the variable itself con.str.WriteString(info.IterVar) } - + con.dialect.WriteArraySubqueryExprClose(&con.str) con.str.WriteString(" FROM ") - - if isJSONArray { - jsonFunc := con.getJSONArrayFunction(iterRange) - con.str.WriteString(jsonFunc) - con.str.WriteString("(") - if err := con.visit(iterRange); err != nil { - return wrapConversionError(err, "visiting iter range in MAP comprehension") - } - con.str.WriteString(")") - } else { - con.str.WriteString("UNNEST(") - if err := con.visit(iterRange); err != nil { - return wrapConversionError(err, "visiting iter range in MAP comprehension") - } - con.str.WriteString(")") + if err := con.writeComprehensionSource(iterRange); err != nil { + return wrapConversionError(err, "visiting iter range in MAP comprehension") } - con.str.WriteString(" AS ") con.str.WriteString(info.IterVar) - // Add filter condition if present (for map with filter) if info.Filter != nil { con.str.WriteString(" WHERE ") if err := con.visit(info.Filter); err != nil { @@ -2186,38 +2106,20 @@ func (con *converter) visitMapComprehension(expr *exprpb.Expr, info *Comprehensi } func (con *converter) visitFilterComprehension(expr *exprpb.Expr, info *ComprehensionInfo) error { - // Generate SQL for FILTER comprehension: return elements that satisfy the predicate - // Pattern: ARRAY(SELECT item FROM UNNEST(array) AS item WHERE predicate) - // For JSON arrays: ARRAY(SELECT item FROM jsonb_array_elements(json_field) AS item WHERE predicate) - comprehension := expr.GetComprehensionExpr() if comprehension == nil { return newConversionError(errMsgUnsupportedComprehension, "expression is not a comprehension (FILTER)") } iterRange := comprehension.GetIterRange() - isJSONArray := con.isJSONArrayField(iterRange) - con.str.WriteString("ARRAY(SELECT ") + con.dialect.WriteArraySubqueryOpen(&con.str) con.str.WriteString(info.IterVar) + con.dialect.WriteArraySubqueryExprClose(&con.str) con.str.WriteString(" FROM ") - - if isJSONArray { - jsonFunc := con.getJSONArrayFunction(iterRange) - con.str.WriteString(jsonFunc) - con.str.WriteString("(") - if err := con.visit(iterRange); err != nil { - return wrapConversionError(err, "visiting iter range in FILTER comprehension") - } - con.str.WriteString(")") - } else { - con.str.WriteString("UNNEST(") - if err := con.visit(iterRange); err != nil { - return wrapConversionError(err, "visiting iter range in FILTER comprehension") - } - con.str.WriteString(")") + if err := con.writeComprehensionSource(iterRange); err != nil { + return wrapConversionError(err, "visiting iter range in FILTER comprehension") } - con.str.WriteString(" AS ") con.str.WriteString(info.IterVar) @@ -2233,37 +2135,27 @@ func (con *converter) visitFilterComprehension(expr *exprpb.Expr, info *Comprehe } func (con *converter) visitTransformListComprehension(expr *exprpb.Expr, info *ComprehensionInfo) error { - // Generate SQL for TRANSFORM_LIST comprehension: similar to MAP but may have different semantics - // Pattern: ARRAY(SELECT transform FROM UNNEST(array) AS item [WHERE filter]) - comprehension := expr.GetComprehensionExpr() if comprehension == nil { return newConversionError(errMsgUnsupportedComprehension, "expression is not a comprehension (TRANSFORM_LIST)") } - con.str.WriteString("ARRAY(SELECT ") - - // Visit the transform expression + con.dialect.WriteArraySubqueryOpen(&con.str) if info.Transform != nil { if err := con.visit(info.Transform); err != nil { return wrapConversionError(err, "visiting transform in TRANSFORM_LIST comprehension") } } else { - // If no transform, just return the variable itself con.str.WriteString(info.IterVar) } - - con.str.WriteString(" FROM UNNEST(") - - // Visit the iterable range (the array/list being comprehended over) - if err := con.visit(comprehension.GetIterRange()); err != nil { + con.dialect.WriteArraySubqueryExprClose(&con.str) + con.str.WriteString(" FROM ") + if err := con.writeComprehensionSource(comprehension.GetIterRange()); err != nil { return wrapConversionError(err, "visiting iter range in TRANSFORM_LIST comprehension") } - - con.str.WriteString(") AS ") + con.str.WriteString(" AS ") con.str.WriteString(info.IterVar) - // Add filter condition if present if info.Filter != nil { con.str.WriteString(" WHERE ") if err := con.visit(info.Filter); err != nil { @@ -2305,7 +2197,7 @@ func (con *converter) visitConst(expr *exprpb.Expr) error { case *exprpb.Constant_Int64Value: if con.parameterize { con.paramCount++ - fmt.Fprintf(&con.str, "$%d", con.paramCount) + con.dialect.WriteParamPlaceholder(&con.str, con.paramCount) con.parameters = append(con.parameters, c.GetInt64Value()) } else { i := strconv.FormatInt(c.GetInt64Value(), 10) @@ -2314,7 +2206,7 @@ func (con *converter) visitConst(expr *exprpb.Expr) error { case *exprpb.Constant_Uint64Value: if con.parameterize { con.paramCount++ - fmt.Fprintf(&con.str, "$%d", con.paramCount) + con.dialect.WriteParamPlaceholder(&con.str, con.paramCount) con.parameters = append(con.parameters, c.GetUint64Value()) } else { ui := strconv.FormatUint(c.GetUint64Value(), 10) @@ -2323,7 +2215,7 @@ func (con *converter) visitConst(expr *exprpb.Expr) error { case *exprpb.Constant_DoubleValue: if con.parameterize { con.paramCount++ - fmt.Fprintf(&con.str, "$%d", con.paramCount) + con.dialect.WriteParamPlaceholder(&con.str, con.paramCount) con.parameters = append(con.parameters, c.GetDoubleValue()) } else { d := strconv.FormatFloat(c.GetDoubleValue(), 'g', -1, 64) @@ -2338,31 +2230,26 @@ func (con *converter) visitConst(expr *exprpb.Expr) error { if con.parameterize { con.paramCount++ - fmt.Fprintf(&con.str, "$%d", con.paramCount) + con.dialect.WriteParamPlaceholder(&con.str, con.paramCount) con.parameters = append(con.parameters, str) } else { - // Use single quotes for PostgreSQL string literals - // Escape single quotes by doubling them - escaped := strings.ReplaceAll(str, "'", "''") - con.str.WriteString("'") - con.str.WriteString(escaped) - con.str.WriteString("'") + con.dialect.WriteStringLiteral(&con.str, str) } case *exprpb.Constant_BytesValue: b := c.GetBytesValue() if con.parameterize { con.paramCount++ - fmt.Fprintf(&con.str, "$%d", con.paramCount) + con.dialect.WriteParamPlaceholder(&con.str, con.paramCount) con.parameters = append(con.parameters, b) } else { // Validate byte array length to prevent resource exhaustion (CWE-400) if len(b) > maxByteArrayLength { return fmt.Errorf("%w: %d bytes exceeds limit of %d bytes", ErrInvalidByteArrayLength, len(b), maxByteArrayLength) } - con.str.WriteString("'\\x") - con.str.WriteString(hex.EncodeToString(b)) - con.str.WriteString("'") + if err := con.dialect.WriteBytesLiteral(&con.str, b); err != nil { + return err + } } default: return newConversionErrorf(errMsgUnsupportedExpression, "constant type: %T", c.ConstantKind) @@ -2374,7 +2261,7 @@ func (con *converter) visitIdent(expr *exprpb.Expr) error { identName := expr.GetIdentExpr().GetName() // Validate identifier name for security (prevent SQL injection) - if err := validateFieldName(identName); err != nil { + if err := con.dialect.ValidateFieldName(identName); err != nil { return fmt.Errorf("%w: %w", ErrInvalidFieldName, err) } @@ -2382,7 +2269,8 @@ func (con *converter) visitIdent(expr *exprpb.Expr) error { if con.needsNumericCasting(identName) { con.str.WriteString("(") con.str.WriteString(identName) - con.str.WriteString(")::numeric") + con.str.WriteString(")") + con.dialect.WriteCastToNumeric(&con.str) } else { con.str.WriteString(identName) } @@ -2392,7 +2280,7 @@ func (con *converter) visitIdent(expr *exprpb.Expr) error { func (con *converter) visitList(expr *exprpb.Expr) error { l := expr.GetListExpr() elems := l.GetElements() - con.str.WriteString("ARRAY[") + con.dialect.WriteArrayLiteralOpen(&con.str) for i, elem := range elems { err := con.visit(elem) if err != nil { @@ -2402,7 +2290,7 @@ func (con *converter) visitList(expr *exprpb.Expr) error { con.str.WriteString(", ") } } - con.str.WriteString("]") + con.dialect.WriteArrayLiteralClose(&con.str) return nil } @@ -2411,7 +2299,7 @@ func (con *converter) visitSelect(expr *exprpb.Expr) error { // Validate field name for security (prevent SQL injection) fieldName := sel.GetField() - if err := validateFieldName(fieldName); err != nil { + if err := con.dialect.ValidateFieldName(fieldName); err != nil { return fmt.Errorf("%w: %w", ErrInvalidFieldName, err) } @@ -2433,34 +2321,35 @@ func (con *converter) visitSelect(expr *exprpb.Expr) error { nested := !sel.GetTestOnly() && isBinaryOrTernaryOperator(sel.GetOperand()) - if useJSONObjectAccess && con.isNumericJSONField(fieldName) { - // For numeric JSON fields, wrap in parentheses for casting - con.str.WriteString("(") - } - - err := con.visitMaybeNested(sel.GetOperand(), nested) - if err != nil { - return err + writeBase := func() error { + return con.visitMaybeNested(sel.GetOperand(), nested) } switch { case useJSONPath: - // Use ->> for text extraction - con.str.WriteString("->>") - con.str.WriteString("'") - con.str.WriteString(escapeJSONFieldName(fieldName)) - con.str.WriteString("'") + // Use dialect-specific JSON field access (text extraction) + if err := con.dialect.WriteJSONFieldAccess(&con.str, writeBase, fieldName, true); err != nil { + return err + } case useJSONObjectAccess: - // Use -> for JSON object field access in comprehensions - con.str.WriteString("->>'") - con.str.WriteString(escapeJSONFieldName(fieldName)) - con.str.WriteString("'") - if con.isNumericJSONField(fieldName) { + // Use dialect-specific JSON object field access in comprehensions + isNumeric := con.isNumericJSONField(fieldName) + if isNumeric { + con.str.WriteString("(") + } + if err := con.dialect.WriteJSONFieldAccess(&con.str, writeBase, fieldName, true); err != nil { + return err + } + if isNumeric { // Close parentheses and add numeric cast - con.str.WriteString(")::numeric") + con.str.WriteString(")") + con.dialect.WriteCastToNumeric(&con.str) } default: // Regular field selection + if err := writeBase(); err != nil { + return err + } con.str.WriteString(".") con.str.WriteString(fieldName) } @@ -2476,25 +2365,10 @@ func (con *converter) visitHasFunction(expr *exprpb.Expr) error { // Check if this is a direct JSON field access (e.g., table.json_column.key) if con.isDirectJSONFieldAccess(operand, field) { - // For direct JSON field access, use the appropriate existence operator - err := con.visitMaybeNested(operand, isBinaryOrTernaryOperator(operand)) - if err != nil { - return err - } - - // Check if this is a JSONB field - if con.isJSONBField(operand) { - // Use JSONB's ? operator for existence check - con.str.WriteString(" ? '") - con.str.WriteString(escapeJSONFieldName(field)) - con.str.WriteString("'") - } else { - // For JSON fields, check if the field is not null - con.str.WriteString("->'") - con.str.WriteString(escapeJSONFieldName(field)) - con.str.WriteString("' IS NOT NULL") - } - return nil + isJSONB := con.isJSONBField(operand) + return con.dialect.WriteJSONExistence(&con.str, isJSONB, field, func() error { + return con.visitMaybeNested(operand, isBinaryOrTernaryOperator(operand)) + }) } // Check if this is a nested JSON path (e.g., table.json_column.key.subkey) @@ -2532,27 +2406,12 @@ func (con *converter) isDirectJSONFieldAccess(operand *exprpb.Expr, _ string) bo // visitNestedJSONHas handles has() for deeply nested JSON paths func (con *converter) visitNestedJSONHas(expr *exprpb.Expr) error { - // For nested JSON paths, we use jsonb_extract_path_text and check for NOT NULL - // This is more reliable than trying to use ? operator on nested paths - con.str.WriteString("jsonb_extract_path_text(") - // Get the root JSON column and remaining path segments rootColumn, pathSegments := con.getJSONRootAndPath(expr) - // Visit the root column without adding JSON access operators - if err := con.visitJSONColumnReference(rootColumn); err != nil { - return err - } - - // Add path segments as arguments - for _, segment := range pathSegments { - con.str.WriteString(", '") - con.str.WriteString(escapeJSONFieldName(segment)) - con.str.WriteString("'") - } - - con.str.WriteString(") IS NOT NULL") - return nil + return con.dialect.WriteJSONExtractPath(&con.str, pathSegments, func() error { + return con.visitJSONColumnReference(rootColumn) + }) } // visitJSONColumnReference visits a JSON column reference without adding JSON access operators @@ -2676,7 +2535,7 @@ func (con *converter) visitStructMsg(expr *exprpb.Expr) error { func (con *converter) visitStructMap(expr *exprpb.Expr) error { m := expr.GetStructExpr() entries := m.GetEntries() - con.str.WriteString("ROW(") + con.dialect.WriteStructOpen(&con.str) for i, entry := range entries { v := entry.GetValue() if err := con.visit(v); err != nil { @@ -2686,10 +2545,27 @@ func (con *converter) visitStructMap(expr *exprpb.Expr) error { con.str.WriteString(", ") } } - con.str.WriteString(")") + con.dialect.WriteStructClose(&con.str) return nil } +// writeComprehensionSource writes the source expression for a comprehension (UNNEST or JSON function). +func (con *converter) writeComprehensionSource(iterRange *exprpb.Expr) error { + isJSONArray := con.isJSONArrayField(iterRange) + if isJSONArray { + jsonFunc := con.getJSONArrayFunction(iterRange) + isJSONB := con.isJSONBField(iterRange) + // Determine if we need text extraction or object extraction + asText := strings.HasSuffix(jsonFunc, "_text") + return con.dialect.WriteJSONArrayElements(&con.str, isJSONB, asText, func() error { + return con.visit(iterRange) + }) + } + return con.dialect.WriteUnnest(&con.str, func() error { + return con.visit(iterRange) + }) +} + func (con *converter) visitMaybeNested(expr *exprpb.Expr, nested bool) error { if nested { con.str.WriteString("(") @@ -2767,181 +2643,3 @@ func isBinaryOrTernaryOperator(expr *exprpb.Expr) bool { _, isBinaryOp := operators.FindReverseBinaryOperator(expr.GetCallExpr().GetFunction()) return isBinaryOp || isSamePrecedence(operators.Conditional, expr) } - -// convertRE2ToPOSIX converts an RE2 regex pattern to POSIX ERE format for PostgreSQL. -// It performs security validation to prevent ReDoS attacks (CWE-1333). -// Returns: (posixPattern, caseInsensitive, error) -// Note: This is a basic conversion for common patterns. Full RE2 to POSIX conversion is complex. -func convertRE2ToPOSIX(re2Pattern string) (string, bool, error) { - // 1. Check pattern length to prevent processing extremely long patterns - if len(re2Pattern) > maxRegexPatternLength { - return "", false, fmt.Errorf("%w: pattern length %d exceeds limit of %d characters", ErrInvalidRegexPattern, len(re2Pattern), maxRegexPatternLength) - } - - // 2. Extract case-insensitive flag if present - caseInsensitive := false - if strings.HasPrefix(re2Pattern, "(?i)") { - caseInsensitive = true - re2Pattern = strings.TrimPrefix(re2Pattern, "(?i)") - } - - // 3. Detect unsupported RE2 features and return errors - // Lookahead assertions - if strings.Contains(re2Pattern, "(?=") || strings.Contains(re2Pattern, "(?!") { - return "", false, fmt.Errorf("%w: lookahead assertions (?=...), (?!...) are not supported in PostgreSQL POSIX regex", ErrInvalidRegexPattern) - } - // Lookbehind assertions - if strings.Contains(re2Pattern, "(?<=") || strings.Contains(re2Pattern, "(?...) are not supported in PostgreSQL POSIX regex", ErrInvalidRegexPattern) - } - // Other inline flags (after we've already handled (?i)) - if strings.Contains(re2Pattern, "(?m") || strings.Contains(re2Pattern, "(?s") || strings.Contains(re2Pattern, "(?-") { - return "", false, fmt.Errorf("%w: inline flags other than (?i) are not supported in PostgreSQL POSIX regex", ErrInvalidRegexPattern) - } - - // 4. Detect catastrophic nested quantifiers that cause exponential backtracking - // Patterns like (a+)+, (a*)*, (x+x+)+, ((a)+b)+, etc. are extremely dangerous - - // Check for doubled quantifiers - if matched, _ := regexp.MatchString(`[*+][*+]`, re2Pattern); matched { - return "", false, fmt.Errorf("%w: regex contains catastrophic nested quantifiers that could cause ReDoS", ErrInvalidRegexPattern) - } - - // Check for groups that contain quantifiers and are themselves quantified - // This catches patterns like (a+)+, ((a)+b)+, (a*b*)*, etc. - // We need to check if any opening paren eventually leads to a closing paren followed by a quantifier, - // and if there are quantifiers between those parens. - depth := 0 - groupHasQuantifier := make([]bool, 0) - - for i := 0; i < len(re2Pattern); i++ { - char := re2Pattern[i] - - // Skip escaped characters - if i > 0 && re2Pattern[i-1] == '\\' { - continue - } - - switch char { - case '(': - depth++ - groupHasQuantifier = append(groupHasQuantifier, false) - case ')': - if depth > 0 { - depth-- - // Check if the closing paren is followed by a quantifier - if i+1 < len(re2Pattern) { - nextChar := re2Pattern[i+1] - if nextChar == '*' || nextChar == '+' || nextChar == '?' || nextChar == '{' { - // This group is quantified. Check if it contains quantifiers - if len(groupHasQuantifier) > 0 && groupHasQuantifier[len(groupHasQuantifier)-1] { - return "", false, fmt.Errorf("%w: regex contains catastrophic nested quantifiers that could cause ReDoS", ErrInvalidRegexPattern) - } - } - } - if len(groupHasQuantifier) > 0 { - // Pop the last group - if len(groupHasQuantifier) > 1 { - // If inner group had quantifier, mark outer group as having quantifier too - if groupHasQuantifier[len(groupHasQuantifier)-1] { - groupHasQuantifier[len(groupHasQuantifier)-2] = true - } - } - groupHasQuantifier = groupHasQuantifier[:len(groupHasQuantifier)-1] - } - } - case '*', '+', '?': - // Mark that current group contains a quantifier - if len(groupHasQuantifier) > 0 { - groupHasQuantifier[len(groupHasQuantifier)-1] = true - } - case '{': - // Brace quantifier {n,m} - if len(groupHasQuantifier) > 0 { - groupHasQuantifier[len(groupHasQuantifier)-1] = true - } - } - } - - // 5. Count and limit capture groups to prevent memory exhaustion - groupCount := strings.Count(re2Pattern, "(") - strings.Count(re2Pattern, `\(`) - if groupCount > maxRegexGroups { - return "", false, fmt.Errorf("%w: regex contains %d capture groups, exceeds limit of %d", ErrInvalidRegexPattern, groupCount, maxRegexGroups) - } - - // 6. Detect exponential alternation patterns like (a|a)*b or (a|ab)* - alternationPattern := regexp.MustCompile(`\([^)]*\|[^)]*\)[*+]`) - if alternationPattern.MatchString(re2Pattern) { - // Check if alternation has overlapping branches (more dangerous) - // This is a simple heuristic - full analysis would be more complex - return "", false, fmt.Errorf("%w: regex contains quantified alternation that could cause ReDoS", ErrInvalidRegexPattern) - } - - // 7. Check nesting depth to prevent deeply nested patterns - maxDepth := 0 - currentDepth := 0 - for _, char := range re2Pattern { - if char == '(' && !strings.HasSuffix(re2Pattern[:strings.LastIndex(re2Pattern, string(char))], `\`) { - currentDepth++ - if currentDepth > maxDepth { - maxDepth = currentDepth - } - } else if char == ')' && !strings.HasSuffix(re2Pattern[:strings.LastIndex(re2Pattern, string(char))], `\`) { - currentDepth-- - } - } - if maxDepth > maxRegexNestingDepth { - return "", false, fmt.Errorf("%w: nesting depth %d exceeds limit of %d", ErrInvalidRegexPattern, maxDepth, maxRegexNestingDepth) - } - - // Passed all security checks - proceed with conversion - posixPattern := re2Pattern - - // Basic conversions for common differences between RE2 and POSIX: - - // 1. Word boundaries: \b -> [[:<:]] and [[:<:]] (PostgreSQL extension) - // Note: PostgreSQL supports \y for word boundaries in some contexts - posixPattern = strings.ReplaceAll(posixPattern, `\b`, `\y`) - - // 2. Non-word boundaries: \B -> [^[:alnum:]_] (approximate) - // This is a simplification; exact conversion is complex - posixPattern = strings.ReplaceAll(posixPattern, `\B`, `[^[:alnum:]_]`) - - // 3. Digit shortcuts: \d -> [[:digit:]] or [0-9] - posixPattern = strings.ReplaceAll(posixPattern, `\d`, `[[:digit:]]`) - - // 4. Non-digit shortcuts: \D -> [^[:digit:]] or [^0-9] - posixPattern = strings.ReplaceAll(posixPattern, `\D`, `[^[:digit:]]`) - - // 5. Word character shortcuts: \w -> [[:alnum:]_] - posixPattern = strings.ReplaceAll(posixPattern, `\w`, `[[:alnum:]_]`) - - // 6. Non-word character shortcuts: \W -> [^[:alnum:]_] - posixPattern = strings.ReplaceAll(posixPattern, `\W`, `[^[:alnum:]_]`) - - // 7. Whitespace shortcuts: \s -> [[:space:]] - posixPattern = strings.ReplaceAll(posixPattern, `\s`, `[[:space:]]`) - - // 8. Non-whitespace shortcuts: \S -> [^[:space:]] - posixPattern = strings.ReplaceAll(posixPattern, `\S`, `[^[:space:]]`) - - // 9. Non-capturing groups: (?:...) -> (...) - // POSIX ERE doesn't have non-capturing groups, so convert to regular groups - posixPattern = strings.ReplaceAll(posixPattern, `(?:`, `(`) - - // Note: Unsupported RE2 features that are now validated and return errors: - // - Lookahead/lookbehind assertions (?=...), (?!...), (?<=...), (?...) - ERROR - // - Case-insensitive flag (?i) - CONVERTED (returned as separate boolean) - // - Other inline flags (?m), (?s) - ERROR - // - // Converted features: - // - Non-capturing groups (?:...) - Converted to regular groups (...) - // - Character class shortcuts (\d, \w, \s, etc.) - Converted to POSIX equivalents - - return posixPattern, caseInsensitive, nil -} diff --git a/dialect/bigquery/dialect.go b/dialect/bigquery/dialect.go new file mode 100644 index 0000000..5314237 --- /dev/null +++ b/dialect/bigquery/dialect.go @@ -0,0 +1,504 @@ +// Package bigquery implements the BigQuery SQL dialect for cel2sql. +package bigquery + +import ( + "fmt" + "strings" + + "github.com/spandigital/cel2sql/v3/dialect" +) + +// Dialect implements dialect.Dialect for BigQuery. +type Dialect struct{} + +// New creates a new BigQuery dialect. +func New() *Dialect { + return &Dialect{} +} + +func init() { + dialect.Register(dialect.BigQuery, func() dialect.Dialect { return New() }) +} + +// Ensure Dialect implements dialect.Dialect at compile time. +var _ dialect.Dialect = (*Dialect)(nil) + +// Name returns the dialect name. +func (d *Dialect) Name() dialect.Name { return dialect.BigQuery } + +// --- Literals --- + +// WriteStringLiteral writes a BigQuery string literal with ” escaping. +func (d *Dialect) WriteStringLiteral(w *strings.Builder, value string) { + escaped := strings.ReplaceAll(value, "'", "\\'") + w.WriteString("'") + w.WriteString(escaped) + w.WriteString("'") +} + +// WriteBytesLiteral writes a BigQuery octal-encoded byte literal (b"..."). +func (d *Dialect) WriteBytesLiteral(w *strings.Builder, value []byte) error { + w.WriteString("b\"") + for _, b := range value { + fmt.Fprintf(w, "\\%03o", b) + } + w.WriteString("\"") + return nil +} + +// WriteParamPlaceholder writes a BigQuery named parameter (@p1, @p2, ...). +func (d *Dialect) WriteParamPlaceholder(w *strings.Builder, paramIndex int) { + fmt.Fprintf(w, "@p%d", paramIndex) +} + +// --- Operators --- + +// WriteStringConcat writes BigQuery string concatenation using ||. +func (d *Dialect) WriteStringConcat(w *strings.Builder, writeLHS, writeRHS func() error) error { + if err := writeLHS(); err != nil { + return err + } + w.WriteString(" || ") + return writeRHS() +} + +// WriteRegexMatch writes a BigQuery regex match using REGEXP_CONTAINS. +func (d *Dialect) WriteRegexMatch(w *strings.Builder, writeTarget func() error, pattern string, _ bool) error { + w.WriteString("REGEXP_CONTAINS(") + if err := writeTarget(); err != nil { + return err + } + w.WriteString(", '") + escaped := strings.ReplaceAll(pattern, "'", "\\'") + w.WriteString(escaped) + w.WriteString("')") + return nil +} + +// WriteLikeEscape is a no-op for BigQuery. +// BigQuery uses backslash as the default escape character in LIKE patterns +// and does not support the ESCAPE keyword. +func (d *Dialect) WriteLikeEscape(_ *strings.Builder) { +} + +// WriteArrayMembership writes a BigQuery array membership test using IN UNNEST(). +func (d *Dialect) WriteArrayMembership(w *strings.Builder, writeElem func() error, writeArray func() error) error { + if err := writeElem(); err != nil { + return err + } + w.WriteString(" IN UNNEST(") + if err := writeArray(); err != nil { + return err + } + w.WriteString(")") + return nil +} + +// --- Type Casting --- + +// WriteCastToNumeric writes a BigQuery numeric cast (CAST(... AS FLOAT64)). +func (d *Dialect) WriteCastToNumeric(w *strings.Builder) { + // BigQuery doesn't have a ::type cast syntax; this is used after expressions. + // For BigQuery, the converter should use CAST(expr AS FLOAT64) instead. + w.WriteString("::FLOAT64") +} + +// WriteTypeName writes a BigQuery type name for CAST expressions. +func (d *Dialect) WriteTypeName(w *strings.Builder, celTypeName string) { + switch celTypeName { + case "bool": + w.WriteString("BOOL") + case "bytes": + w.WriteString("BYTES") + case "double": + w.WriteString("FLOAT64") + case "int": + w.WriteString("INT64") + case "string": + w.WriteString("STRING") + case "uint": + w.WriteString("INT64") + default: + w.WriteString(strings.ToUpper(celTypeName)) + } +} + +// WriteEpochExtract writes UNIX_SECONDS(expr) for BigQuery. +func (d *Dialect) WriteEpochExtract(w *strings.Builder, writeExpr func() error) error { + w.WriteString("UNIX_SECONDS(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString(")") + return nil +} + +// WriteTimestampCast writes a BigQuery CAST to TIMESTAMP. +func (d *Dialect) WriteTimestampCast(w *strings.Builder, writeExpr func() error) error { + w.WriteString("CAST(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString(" AS TIMESTAMP)") + return nil +} + +// --- Arrays --- + +// WriteArrayLiteralOpen writes the BigQuery array literal opening ([). +func (d *Dialect) WriteArrayLiteralOpen(w *strings.Builder) { + w.WriteString("[") +} + +// WriteArrayLiteralClose writes the BigQuery array literal closing (]). +func (d *Dialect) WriteArrayLiteralClose(w *strings.Builder) { + w.WriteString("]") +} + +// WriteArrayLength writes ARRAY_LENGTH(expr) for BigQuery. +func (d *Dialect) WriteArrayLength(w *strings.Builder, _ int, writeExpr func() error) error { + w.WriteString("ARRAY_LENGTH(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString(")") + return nil +} + +// WriteListIndex writes BigQuery 0-indexed array access using OFFSET. +func (d *Dialect) WriteListIndex(w *strings.Builder, writeArray func() error, writeIndex func() error) error { + if err := writeArray(); err != nil { + return err + } + w.WriteString("[OFFSET(") + if err := writeIndex(); err != nil { + return err + } + w.WriteString(")]") + return nil +} + +// WriteListIndexConst writes BigQuery constant array index access using OFFSET. +func (d *Dialect) WriteListIndexConst(w *strings.Builder, writeArray func() error, index int64) error { + if err := writeArray(); err != nil { + return err + } + fmt.Fprintf(w, "[OFFSET(%d)]", index) + return nil +} + +// WriteEmptyTypedArray writes an empty BigQuery typed array. +func (d *Dialect) WriteEmptyTypedArray(w *strings.Builder, typeName string) { + w.WriteString("ARRAY<") + w.WriteString(bigqueryTypeName(typeName)) + w.WriteString(">[]") +} + +// --- JSON --- + +// WriteJSONFieldAccess writes BigQuery JSON field access using JSON_VALUE. +func (d *Dialect) WriteJSONFieldAccess(w *strings.Builder, writeBase func() error, fieldName string, isFinal bool) error { + escaped := escapeJSONFieldName(fieldName) + if isFinal { + w.WriteString("JSON_VALUE(") + } else { + w.WriteString("JSON_QUERY(") + } + if err := writeBase(); err != nil { + return err + } + w.WriteString(", '$.") + w.WriteString(escaped) + w.WriteString("')") + return nil +} + +// WriteJSONExistence writes a BigQuery JSON key existence check. +func (d *Dialect) WriteJSONExistence(w *strings.Builder, _ bool, fieldName string, writeBase func() error) error { + escaped := escapeJSONFieldName(fieldName) + w.WriteString("JSON_VALUE(") + if err := writeBase(); err != nil { + return err + } + w.WriteString(", '$.") + w.WriteString(escaped) + w.WriteString("') IS NOT NULL") + return nil +} + +// WriteJSONArrayElements writes BigQuery JSON array expansion. +func (d *Dialect) WriteJSONArrayElements(w *strings.Builder, _ bool, _ bool, writeExpr func() error) error { + w.WriteString("UNNEST(JSON_QUERY_ARRAY(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString("))") + return nil +} + +// WriteJSONArrayLength writes ARRAY_LENGTH(JSON_QUERY_ARRAY(expr)) for BigQuery. +func (d *Dialect) WriteJSONArrayLength(w *strings.Builder, writeExpr func() error) error { + w.WriteString("ARRAY_LENGTH(JSON_QUERY_ARRAY(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString("))") + return nil +} + +// WriteJSONExtractPath writes BigQuery JSON path existence using JSON_VALUE. +func (d *Dialect) WriteJSONExtractPath(w *strings.Builder, pathSegments []string, writeRoot func() error) error { + w.WriteString("JSON_VALUE(") + if err := writeRoot(); err != nil { + return err + } + w.WriteString(", '$") + for _, segment := range pathSegments { + w.WriteString(".") + w.WriteString(escapeJSONFieldName(segment)) + } + w.WriteString("') IS NOT NULL") + return nil +} + +// WriteJSONArrayMembership writes BigQuery JSON array membership. +func (d *Dialect) WriteJSONArrayMembership(w *strings.Builder, _ string, writeExpr func() error) error { + w.WriteString("UNNEST(JSON_VALUE_ARRAY(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString("))") + return nil +} + +// WriteNestedJSONArrayMembership writes BigQuery nested JSON array membership. +func (d *Dialect) WriteNestedJSONArrayMembership(w *strings.Builder, writeExpr func() error) error { + w.WriteString("UNNEST(JSON_VALUE_ARRAY(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString("))") + return nil +} + +// --- Timestamps --- + +// WriteDuration writes a BigQuery INTERVAL literal. +func (d *Dialect) WriteDuration(w *strings.Builder, value int64, unit string) { + fmt.Fprintf(w, "INTERVAL %d %s", value, unit) +} + +// WriteInterval writes a BigQuery INTERVAL expression. +func (d *Dialect) WriteInterval(w *strings.Builder, writeValue func() error, unit string) error { + w.WriteString("INTERVAL ") + if err := writeValue(); err != nil { + return err + } + w.WriteString(" ") + w.WriteString(unit) + return nil +} + +// WriteExtract writes a BigQuery EXTRACT expression. +// BigQuery uses DAYOFWEEK (1=Sunday) instead of DOW (0=Sunday). +func (d *Dialect) WriteExtract(w *strings.Builder, part string, writeExpr func() error, writeTZ func() error) error { + isDOW := part == "DOW" + bqPart := part + if isDOW { + bqPart = "DAYOFWEEK" + w.WriteString("(") + } + w.WriteString("EXTRACT(") + w.WriteString(bqPart) + w.WriteString(" FROM ") + if err := writeExpr(); err != nil { + return err + } + if writeTZ != nil { + w.WriteString(" AT TIME ZONE ") + if err := writeTZ(); err != nil { + return err + } + } + w.WriteString(")") + if isDOW { + // BigQuery DAYOFWEEK: 1=Sunday, 2=Monday, ..., 7=Saturday + // CEL getDayOfWeek: 0=Sunday, 1=Monday, ..., 6=Saturday + w.WriteString(" - 1)") + } + return nil +} + +// WriteTimestampArithmetic writes BigQuery timestamp arithmetic using functions. +// BigQuery uses TIMESTAMP_ADD/TIMESTAMP_SUB instead of + / - operators. +func (d *Dialect) WriteTimestampArithmetic(w *strings.Builder, op string, writeTS, writeDur func() error) error { + if op == "+" { + w.WriteString("TIMESTAMP_ADD(") + } else { + w.WriteString("TIMESTAMP_SUB(") + } + if err := writeTS(); err != nil { + return err + } + w.WriteString(", ") + if err := writeDur(); err != nil { + return err + } + w.WriteString(")") + return nil +} + +// --- String Functions --- + +// WriteContains writes STRPOS(haystack, needle) > 0 for BigQuery. +func (d *Dialect) WriteContains(w *strings.Builder, writeHaystack, writeNeedle func() error) error { + w.WriteString("STRPOS(") + if err := writeHaystack(); err != nil { + return err + } + w.WriteString(", ") + if err := writeNeedle(); err != nil { + return err + } + w.WriteString(") > 0") + return nil +} + +// WriteSplit writes BigQuery string split using SPLIT. +func (d *Dialect) WriteSplit(w *strings.Builder, writeStr, writeDelim func() error) error { + w.WriteString("SPLIT(") + if err := writeStr(); err != nil { + return err + } + w.WriteString(", ") + if err := writeDelim(); err != nil { + return err + } + w.WriteString(")") + return nil +} + +// WriteSplitWithLimit writes BigQuery SPLIT with array slice. +func (d *Dialect) WriteSplitWithLimit(w *strings.Builder, writeStr, writeDelim func() error, limit int64) error { + w.WriteString("ARRAY(SELECT x FROM UNNEST(SPLIT(") + if err := writeStr(); err != nil { + return err + } + w.WriteString(", ") + if err := writeDelim(); err != nil { + return err + } + fmt.Fprintf(w, ")) AS x WITH OFFSET WHERE OFFSET < %d)", limit) + return nil +} + +// WriteJoin writes BigQuery array join using ARRAY_TO_STRING. +func (d *Dialect) WriteJoin(w *strings.Builder, writeArray, writeDelim func() error) error { + w.WriteString("ARRAY_TO_STRING(") + if err := writeArray(); err != nil { + return err + } + w.WriteString(", ") + if err := writeDelim(); err != nil { + return err + } + w.WriteString(")") + return nil +} + +// --- Comprehensions --- + +// WriteUnnest writes BigQuery UNNEST for array unnesting. +func (d *Dialect) WriteUnnest(w *strings.Builder, writeSource func() error) error { + w.WriteString("UNNEST(") + if err := writeSource(); err != nil { + return err + } + w.WriteString(")") + return nil +} + +// WriteArraySubqueryOpen writes ARRAY(SELECT for BigQuery. +func (d *Dialect) WriteArraySubqueryOpen(w *strings.Builder) { + w.WriteString("ARRAY(SELECT ") +} + +// WriteArraySubqueryExprClose is a no-op for BigQuery. +func (d *Dialect) WriteArraySubqueryExprClose(_ *strings.Builder) { +} + +// --- Struct --- + +// WriteStructOpen writes the BigQuery struct literal opening. +func (d *Dialect) WriteStructOpen(w *strings.Builder) { + w.WriteString("STRUCT(") +} + +// WriteStructClose writes the BigQuery struct literal closing. +func (d *Dialect) WriteStructClose(w *strings.Builder) { + w.WriteString(")") +} + +// --- Validation --- + +// MaxIdentifierLength returns 300 for BigQuery. +func (d *Dialect) MaxIdentifierLength() int { + return 300 +} + +// ValidateFieldName validates a field name against BigQuery naming rules. +func (d *Dialect) ValidateFieldName(name string) error { + return validateFieldName(name) +} + +// ReservedKeywords returns the set of reserved SQL keywords for BigQuery. +func (d *Dialect) ReservedKeywords() map[string]bool { + return reservedSQLKeywords +} + +// --- Regex --- + +// ConvertRegex converts an RE2 regex pattern for BigQuery. +// BigQuery uses RE2 natively, so minimal conversion is needed. +func (d *Dialect) ConvertRegex(re2Pattern string) (string, bool, error) { + return convertRE2ToBigQuery(re2Pattern) +} + +// SupportsRegex returns true as BigQuery supports RE2 regex natively. +func (d *Dialect) SupportsRegex() bool { return true } + +// --- Capabilities --- + +// SupportsNativeArrays returns true as BigQuery has native array types. +func (d *Dialect) SupportsNativeArrays() bool { return true } + +// SupportsJSONB returns false as BigQuery has a single JSON type. +func (d *Dialect) SupportsJSONB() bool { return false } + +// SupportsIndexAnalysis returns true as BigQuery index analysis is supported. +func (d *Dialect) SupportsIndexAnalysis() bool { return true } + +// --- Internal helpers --- + +// escapeJSONFieldName escapes special characters in JSON field names for BigQuery. +func escapeJSONFieldName(fieldName string) string { + return strings.ReplaceAll(fieldName, "'", "\\'") +} + +// bigqueryTypeName converts a CEL/common type name to a BigQuery type name. +func bigqueryTypeName(typeName string) string { + switch strings.ToLower(typeName) { + case "text", "string", "varchar": + return "STRING" + case "int", "integer", "bigint", "int64": + return "INT64" + case "double", "float", "real", "float64": + return "FLOAT64" + case "boolean", "bool": + return "BOOL" + case "bytes", "bytea", "blob": + return "BYTES" + default: + return strings.ToUpper(typeName) + } +} diff --git a/dialect/bigquery/index_advisor.go b/dialect/bigquery/index_advisor.go new file mode 100644 index 0000000..bbc9bce --- /dev/null +++ b/dialect/bigquery/index_advisor.go @@ -0,0 +1,86 @@ +package bigquery + +import ( + "fmt" + "strings" + + "github.com/spandigital/cel2sql/v3/dialect" +) + +// BigQuery index type constants. +const ( + IndexTypeClustering = "CLUSTERING" + IndexTypeSearchIndex = "SEARCH_INDEX" +) + +// RecommendIndex generates a BigQuery-specific index recommendation for the given pattern. +// BigQuery uses clustering keys and search indexes. Returns nil for unsupported patterns. +func (d *Dialect) RecommendIndex(pattern dialect.IndexPattern) *dialect.IndexRecommendation { + table := pattern.TableHint + if table == "" { + table = "table_name" + } + col := pattern.Column + safeName := sanitizeIndexName(col) + + switch pattern.Pattern { + case dialect.PatternComparison: + return &dialect.IndexRecommendation{ + Column: col, + IndexType: IndexTypeClustering, + Expression: fmt.Sprintf("ALTER TABLE %s SET OPTIONS (clustering_columns=['%s']);", + table, col), + Reason: fmt.Sprintf("Comparison operations on '%s' benefit from clustering for efficient partition pruning and range scans", col), + } + + case dialect.PatternJSONAccess: + return &dialect.IndexRecommendation{ + Column: col, + IndexType: IndexTypeSearchIndex, + Expression: fmt.Sprintf("CREATE SEARCH INDEX idx_%s ON %s (%s);", + safeName, table, col), + Reason: fmt.Sprintf("JSON field access on '%s' benefits from a search index for efficient nested field lookups", col), + } + + case dialect.PatternRegexMatch: + // BigQuery does not have specialized regex indexes + return nil + + case dialect.PatternArrayMembership, dialect.PatternArrayComprehension: + // BigQuery arrays do not benefit from standalone indexes + return nil + + case dialect.PatternJSONArrayComprehension: + return &dialect.IndexRecommendation{ + Column: col, + IndexType: IndexTypeSearchIndex, + Expression: fmt.Sprintf("CREATE SEARCH INDEX idx_%s ON %s (%s);", + safeName, table, col), + Reason: fmt.Sprintf("JSON array operations on '%s' may benefit from a search index", col), + } + } + + return nil +} + +// SupportedPatterns returns the pattern types supported by BigQuery. +func (d *Dialect) SupportedPatterns() []dialect.PatternType { + return []dialect.PatternType{ + dialect.PatternComparison, + dialect.PatternJSONAccess, + dialect.PatternJSONArrayComprehension, + } +} + +// sanitizeIndexName creates a safe index name from a column name. +func sanitizeIndexName(column string) string { + sanitized := strings.ReplaceAll(column, ".", "_") + sanitized = strings.ReplaceAll(sanitized, " ", "_") + sanitized = strings.ReplaceAll(sanitized, "-", "_") + + if len(sanitized) > 50 { + sanitized = sanitized[:50] + } + + return sanitized +} diff --git a/dialect/bigquery/regex.go b/dialect/bigquery/regex.go new file mode 100644 index 0000000..2fcfbf7 --- /dev/null +++ b/dialect/bigquery/regex.go @@ -0,0 +1,137 @@ +package bigquery + +import ( + "errors" + "fmt" + "regexp" + "strings" +) + +// Regex pattern complexity limits to prevent ReDoS attacks (CWE-1333). +const ( + maxRegexPatternLength = 500 + maxRegexGroups = 20 + maxRegexNestingDepth = 10 +) + +// convertRE2ToBigQuery converts an RE2 regex pattern to BigQuery-compatible format. +// BigQuery uses RE2 natively, so most patterns pass through unchanged. +// Returns the converted pattern, whether it's case-insensitive, and any error. +func convertRE2ToBigQuery(re2Pattern string) (string, bool, error) { + // 1. Pattern length validation + if len(re2Pattern) > maxRegexPatternLength { + return "", false, fmt.Errorf("pattern length %d exceeds limit of %d characters", len(re2Pattern), maxRegexPatternLength) + } + + // 2. Validate pattern compiles + if _, err := regexp.Compile(re2Pattern); err != nil { + return "", false, fmt.Errorf("invalid regex pattern: %w", err) + } + + // 3. Detect unsupported features (lookahead/lookbehind not in RE2 anyway) + if strings.Contains(re2Pattern, "(?=") || strings.Contains(re2Pattern, "(?!") { + return "", false, errors.New("lookahead assertions (?=...), (?!...) are not supported") + } + if strings.Contains(re2Pattern, "(?<=") || strings.Contains(re2Pattern, "(?...) are not supported") + } + + // 4. Detect catastrophic nested quantifiers + if matched, _ := regexp.MatchString(`[*+][*+]`, re2Pattern); matched { + return "", false, errors.New("regex contains catastrophic nested quantifiers that could cause ReDoS") + } + + // 5. Check for nested quantifiers in groups + depth := 0 + groupHasQuantifier := make([]bool, 0) + for i := 0; i < len(re2Pattern); i++ { + char := re2Pattern[i] + if i > 0 && re2Pattern[i-1] == '\\' { + continue + } + switch char { + case '(': + depth++ + groupHasQuantifier = append(groupHasQuantifier, false) + case ')': + if depth > 0 { + depth-- + if i+1 < len(re2Pattern) { + nextChar := re2Pattern[i+1] + if nextChar == '*' || nextChar == '+' || nextChar == '?' || nextChar == '{' { + if len(groupHasQuantifier) > 0 && groupHasQuantifier[len(groupHasQuantifier)-1] { + return "", false, errors.New("regex contains catastrophic nested quantifiers that could cause ReDoS") + } + } + } + if len(groupHasQuantifier) > 0 { + groupHasQuantifier = groupHasQuantifier[:len(groupHasQuantifier)-1] + } + } + case '*', '+', '?', '{': + for j := range groupHasQuantifier { + groupHasQuantifier[j] = true + } + } + } + + // 6. Check group count limit + groupCount := strings.Count(re2Pattern, "(") - strings.Count(re2Pattern, "\\(") + if groupCount > maxRegexGroups { + return "", false, fmt.Errorf("regex contains %d capture groups, exceeds limit of %d", groupCount, maxRegexGroups) + } + + // 7. Check for quantified alternation + quantifiedAlternation := regexp.MustCompile(`\([^)]*\|[^)]*\)[*+]`) + if quantifiedAlternation.MatchString(re2Pattern) { + return "", false, errors.New("regex contains quantified alternation that could cause ReDoS") + } + + // 8. Check nesting depth + maxDepthVal := 0 + currentDepth := 0 + for i := 0; i < len(re2Pattern); i++ { + if i > 0 && re2Pattern[i-1] == '\\' { + continue + } + switch re2Pattern[i] { + case '(': + currentDepth++ + if currentDepth > maxDepthVal { + maxDepthVal = currentDepth + } + case ')': + if currentDepth > 0 { + currentDepth-- + } + } + } + if maxDepthVal > maxRegexNestingDepth { + return "", false, fmt.Errorf("nesting depth %d exceeds limit of %d", maxDepthVal, maxRegexNestingDepth) + } + + // Process pattern: BigQuery uses RE2 natively, so minimal conversion needed + caseInsensitive := false + pattern := re2Pattern + + // Handle (?i) flag - BigQuery REGEXP_CONTAINS embeds the flag in the pattern + if strings.HasPrefix(pattern, "(?i)") { + caseInsensitive = true + pattern = pattern[4:] + } + + // Handle inline flags other than (?i) at start + if strings.Contains(pattern, "(?m") || strings.Contains(pattern, "(?s") || strings.Contains(pattern, "(?-") { + return "", false, errors.New("inline flags other than (?i) are not supported in BigQuery regex") + } + + // Convert non-capturing groups (?:...) to regular groups (...) + pattern = strings.ReplaceAll(pattern, "(?:", "(") + + // BigQuery RE2 supports \d, \w, \s, \b natively - no conversion needed + + return pattern, caseInsensitive, nil +} diff --git a/dialect/bigquery/validation.go b/dialect/bigquery/validation.go new file mode 100644 index 0000000..0ae982d --- /dev/null +++ b/dialect/bigquery/validation.go @@ -0,0 +1,56 @@ +package bigquery + +import ( + "errors" + "fmt" + "regexp" + "strings" +) + +var ( + // fieldNameRegexp validates BigQuery identifier format. + fieldNameRegexp = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) + + // reservedSQLKeywords contains BigQuery reserved keywords. + reservedSQLKeywords = map[string]bool{ + "all": true, "and": true, "any": true, "array": true, "as": true, + "asc": true, "assert_rows_modified": true, "at": true, "between": true, + "by": true, "case": true, "cast": true, "collate": true, "contains": true, + "create": true, "cross": true, "cube": true, "current": true, + "default": true, "define": true, "desc": true, "distinct": true, + "else": true, "end": true, "enum": true, "escape": true, "except": true, + "exclude": true, "exists": true, "extract": true, "false": true, + "fetch": true, "following": true, "for": true, "from": true, "full": true, + "group": true, "grouping": true, "groups": true, "hash": true, + "having": true, "if": true, "ignore": true, "in": true, "inner": true, + "intersect": true, "interval": true, "into": true, "is": true, + "join": true, "lateral": true, "left": true, "like": true, "limit": true, + "lookup": true, "merge": true, "natural": true, "new": true, "no": true, + "not": true, "null": true, "nulls": true, "of": true, "on": true, + "or": true, "order": true, "outer": true, "over": true, + "partition": true, "preceding": true, "proto": true, "range": true, + "recursive": true, "respect": true, "right": true, "rollup": true, + "rows": true, "select": true, "set": true, "some": true, "struct": true, + "tablesample": true, "then": true, "to": true, "treat": true, + "true": true, "unbounded": true, "union": true, "unnest": true, + "using": true, "when": true, "where": true, "window": true, + "with": true, "within": true, + } +) + +// validateFieldName validates that a field name follows BigQuery naming conventions. +func validateFieldName(name string) error { + if len(name) == 0 { + return errors.New("field name cannot be empty") + } + + if !fieldNameRegexp.MatchString(name) { + return fmt.Errorf("field name %q must start with a letter or underscore and contain only alphanumeric characters and underscores", name) + } + + if reservedSQLKeywords[strings.ToLower(name)] { + return fmt.Errorf("field name %q is a reserved SQL keyword and cannot be used without quoting", name) + } + + return nil +} diff --git a/dialect/dialect.go b/dialect/dialect.go new file mode 100644 index 0000000..7b0e30b --- /dev/null +++ b/dialect/dialect.go @@ -0,0 +1,234 @@ +// Package dialect defines the interface for SQL dialect-specific code generation. +// Each supported database implements this interface to produce correct SQL syntax. +package dialect + +import ( + "errors" + "strings" +) + +// Name represents a SQL dialect name. +type Name string + +// Supported SQL dialect names. +const ( + PostgreSQL Name = "postgresql" + MySQL Name = "mysql" + SQLite Name = "sqlite" + DuckDB Name = "duckdb" + BigQuery Name = "bigquery" +) + +// ErrUnsupportedFeature indicates that the requested feature is not supported by this dialect. +var ErrUnsupportedFeature = errors.New("unsupported dialect feature") + +// Dialect defines the interface for SQL dialect-specific code generation. +// The converter calls these methods at every point where SQL syntax diverges +// between databases. Methods receive a *strings.Builder that shares the +// converter's output buffer, and callback functions for writing sub-expressions. +type Dialect interface { + // Name returns the dialect name. + Name() Name + + // --- Literals --- + + // WriteStringLiteral writes a string literal in the dialect's syntax. + // For PostgreSQL: 'value' with '' escaping. + WriteStringLiteral(w *strings.Builder, value string) + + // WriteBytesLiteral writes a byte array literal in the dialect's syntax. + // For PostgreSQL: '\xDEADBEEF'. + WriteBytesLiteral(w *strings.Builder, value []byte) error + + // WriteParamPlaceholder writes a parameter placeholder. + // For PostgreSQL: $1, $2. For MySQL: ?, ?. For BigQuery: @p1, @p2. + WriteParamPlaceholder(w *strings.Builder, paramIndex int) + + // --- Operators --- + + // WriteStringConcat writes a string concatenation expression. + // For PostgreSQL: lhs || rhs. For MySQL: CONCAT(lhs, rhs). + WriteStringConcat(w *strings.Builder, writeLHS, writeRHS func() error) error + + // WriteRegexMatch writes a regex match expression. + // For PostgreSQL: expr ~ 'pattern' or expr ~* 'pattern'. + WriteRegexMatch(w *strings.Builder, writeTarget func() error, pattern string, caseInsensitive bool) error + + // WriteLikeEscape writes the LIKE escape clause. + // For PostgreSQL: ESCAPE E'\\'. For MySQL: ESCAPE '\\'. + WriteLikeEscape(w *strings.Builder) + + // WriteArrayMembership writes an array membership test. + // For PostgreSQL: elem = ANY(array). For MySQL: JSON_CONTAINS(array, elem). + WriteArrayMembership(w *strings.Builder, writeElem func() error, writeArray func() error) error + + // --- Type Casting --- + + // WriteCastToNumeric writes a cast to numeric type. + // For PostgreSQL: ::numeric. For MySQL: CAST(... AS DECIMAL). + WriteCastToNumeric(w *strings.Builder) + + // WriteTypeName writes a type name for CAST expressions. + // For PostgreSQL: BOOLEAN, BYTEA, DOUBLE PRECISION, BIGINT, TEXT. + WriteTypeName(w *strings.Builder, celTypeName string) + + // WriteEpochExtract writes extraction of epoch from a timestamp. + // For PostgreSQL: EXTRACT(EPOCH FROM expr)::bigint. + WriteEpochExtract(w *strings.Builder, writeExpr func() error) error + + // WriteTimestampCast writes a cast to timestamp type. + // For PostgreSQL: CAST(expr AS TIMESTAMP WITH TIME ZONE). + WriteTimestampCast(w *strings.Builder, writeExpr func() error) error + + // --- Arrays --- + + // WriteArrayLiteralOpen writes the opening of an array literal. + // For PostgreSQL: ARRAY[. For DuckDB: [. + WriteArrayLiteralOpen(w *strings.Builder) + + // WriteArrayLiteralClose writes the closing of an array literal. + // For PostgreSQL: ]. For DuckDB: ]. + WriteArrayLiteralClose(w *strings.Builder) + + // WriteArrayLength writes an array length expression. + // For PostgreSQL: COALESCE(ARRAY_LENGTH(expr, dimension), 0). + WriteArrayLength(w *strings.Builder, dimension int, writeExpr func() error) error + + // WriteListIndex writes a list index expression. + // For PostgreSQL: array[index + 1] (1-indexed). + WriteListIndex(w *strings.Builder, writeArray func() error, writeIndex func() error) error + + // WriteListIndexConst writes a constant list index. + // For PostgreSQL: array[idx+1] (converts 0-indexed to 1-indexed). + WriteListIndexConst(w *strings.Builder, writeArray func() error, index int64) error + + // WriteEmptyTypedArray writes an empty typed array literal. + // For PostgreSQL: ARRAY[]::text[]. + WriteEmptyTypedArray(w *strings.Builder, typeName string) + + // --- JSON --- + + // WriteJSONFieldAccess writes JSON field access. + // For PostgreSQL: base->>'field' (text) or base->'field' (json). + // For SQLite: json_extract(base, '$.field'). + // writeBase writes the base expression; the dialect wraps or appends as needed. + WriteJSONFieldAccess(w *strings.Builder, writeBase func() error, fieldName string, isFinal bool) error + + // WriteJSONExistence writes a JSON key existence check. + // For PostgreSQL (JSONB): ? 'key'. For PostgreSQL (JSON): ->'key' IS NOT NULL. + WriteJSONExistence(w *strings.Builder, isJSONB bool, fieldName string, writeBase func() error) error + + // WriteJSONArrayElements writes a call to extract JSON array elements. + // For PostgreSQL: jsonb_array_elements(expr) or json_array_elements(expr). + WriteJSONArrayElements(w *strings.Builder, isJSONB bool, asText bool, writeExpr func() error) error + + // WriteJSONArrayLength writes a JSON array length expression. + // For PostgreSQL: COALESCE(jsonb_array_length(expr), 0). + WriteJSONArrayLength(w *strings.Builder, writeExpr func() error) error + + // WriteJSONExtractPath writes a JSON path extraction function. + // For PostgreSQL: jsonb_extract_path_text(root, 'seg1', 'seg2'). + WriteJSONExtractPath(w *strings.Builder, pathSegments []string, writeRoot func() error) error + + // WriteJSONArrayMembership writes a JSON array membership test for the IN operator. + // For PostgreSQL: ANY(ARRAY(SELECT jsonb_func(expr))). + WriteJSONArrayMembership(w *strings.Builder, jsonFunc string, writeExpr func() error) error + + // WriteNestedJSONArrayMembership writes a nested JSON array membership test. + // For PostgreSQL: ANY(ARRAY(SELECT jsonb_array_elements_text(expr))). + WriteNestedJSONArrayMembership(w *strings.Builder, writeExpr func() error) error + + // --- Timestamps --- + + // WriteDuration writes a duration/interval literal. + // For PostgreSQL: INTERVAL N UNIT. + WriteDuration(w *strings.Builder, value int64, unit string) + + // WriteInterval writes an INTERVAL expression from a variable. + // For PostgreSQL: INTERVAL expr UNIT. + WriteInterval(w *strings.Builder, writeValue func() error, unit string) error + + // WriteExtract writes a timestamp field extraction expression. + // Handles DOW conversion, Month/DOY adjustment, and timezone support. + WriteExtract(w *strings.Builder, part string, writeExpr func() error, writeTZ func() error) error + + // WriteTimestampArithmetic writes timestamp arithmetic. + // For PostgreSQL: timestamp +/- interval. + WriteTimestampArithmetic(w *strings.Builder, op string, writeTS, writeDur func() error) error + + // --- String Functions --- + + // WriteContains writes a string contains expression. + // For PostgreSQL: POSITION(needle IN haystack) > 0. + WriteContains(w *strings.Builder, writeHaystack, writeNeedle func() error) error + + // WriteSplit writes a string split expression. + // For PostgreSQL: STRING_TO_ARRAY(string, delimiter). + WriteSplit(w *strings.Builder, writeStr, writeDelim func() error) error + + // WriteSplitWithLimit writes a string split expression with a limit. + // For PostgreSQL: (STRING_TO_ARRAY(string, delimiter))[1:limit]. + WriteSplitWithLimit(w *strings.Builder, writeStr, writeDelim func() error, limit int64) error + + // WriteJoin writes an array join expression. + // For PostgreSQL: ARRAY_TO_STRING(array, delimiter, ''). + WriteJoin(w *strings.Builder, writeArray, writeDelim func() error) error + + // --- Comprehensions --- + + // WriteUnnest writes the UNNEST source for comprehensions. + // For PostgreSQL: UNNEST(array). For MySQL: JSON_TABLE(...). + WriteUnnest(w *strings.Builder, writeSource func() error) error + + // WriteArraySubqueryOpen writes the prefix before the transform expression + // in an array-building subquery. + // For PostgreSQL: "ARRAY(SELECT ". For SQLite: "(SELECT json_group_array(". + WriteArraySubqueryOpen(w *strings.Builder) + + // WriteArraySubqueryExprClose writes the suffix after the transform expression + // and before FROM in an array-building subquery. + // For PostgreSQL: "" (nothing). For SQLite: ")". + WriteArraySubqueryExprClose(w *strings.Builder) + + // --- Struct --- + + // WriteStructOpen writes the opening of a struct/row literal. + // For PostgreSQL: ROW(. For BigQuery: STRUCT(. + WriteStructOpen(w *strings.Builder) + + // WriteStructClose writes the closing of a struct/row literal. + // For PostgreSQL: ). For BigQuery: ). + WriteStructClose(w *strings.Builder) + + // --- Validation --- + + // MaxIdentifierLength returns the maximum identifier length for this dialect. + // For PostgreSQL: 63. For MySQL: 64. For SQLite: unlimited (0). + MaxIdentifierLength() int + + // ValidateFieldName validates a field name for this dialect. + ValidateFieldName(name string) error + + // ReservedKeywords returns the set of reserved SQL keywords for this dialect. + ReservedKeywords() map[string]bool + + // --- Regex --- + + // ConvertRegex converts an RE2 regex pattern to the dialect's native format. + // Returns: (convertedPattern, caseInsensitive, error). + ConvertRegex(re2Pattern string) (pattern string, caseInsensitive bool, err error) + + // SupportsRegex indicates whether this dialect supports regex matching. + SupportsRegex() bool + + // --- Capabilities --- + + // SupportsNativeArrays indicates whether this dialect has native array types. + SupportsNativeArrays() bool + + // SupportsJSONB indicates whether this dialect has a distinct JSONB type. + SupportsJSONB() bool + + // SupportsIndexAnalysis indicates whether index analysis is supported. + SupportsIndexAnalysis() bool +} diff --git a/dialect/duckdb/dialect.go b/dialect/duckdb/dialect.go new file mode 100644 index 0000000..5a22235 --- /dev/null +++ b/dialect/duckdb/dialect.go @@ -0,0 +1,473 @@ +// Package duckdb implements the DuckDB SQL dialect for cel2sql. +package duckdb + +import ( + "fmt" + "strings" + + "github.com/spandigital/cel2sql/v3/dialect" +) + +// Dialect implements dialect.Dialect for DuckDB. +type Dialect struct{} + +// New creates a new DuckDB dialect. +func New() *Dialect { + return &Dialect{} +} + +func init() { + dialect.Register(dialect.DuckDB, func() dialect.Dialect { return New() }) +} + +// Ensure Dialect implements dialect.Dialect at compile time. +var _ dialect.Dialect = (*Dialect)(nil) + +// Name returns the dialect name. +func (d *Dialect) Name() dialect.Name { return dialect.DuckDB } + +// --- Literals --- + +// WriteStringLiteral writes a DuckDB string literal with ” escaping. +func (d *Dialect) WriteStringLiteral(w *strings.Builder, value string) { + escaped := strings.ReplaceAll(value, "'", "''") + w.WriteString("'") + w.WriteString(escaped) + w.WriteString("'") +} + +// WriteBytesLiteral writes a DuckDB hex-encoded byte literal ('\x...'). +func (d *Dialect) WriteBytesLiteral(w *strings.Builder, value []byte) error { + w.WriteString("'\\x") + for _, b := range value { + fmt.Fprintf(w, "%02x", b) + } + w.WriteString("'") + return nil +} + +// WriteParamPlaceholder writes a DuckDB positional parameter ($1, $2, ...). +func (d *Dialect) WriteParamPlaceholder(w *strings.Builder, paramIndex int) { + fmt.Fprintf(w, "$%d", paramIndex) +} + +// --- Operators --- + +// WriteStringConcat writes DuckDB string concatenation using ||. +func (d *Dialect) WriteStringConcat(w *strings.Builder, writeLHS, writeRHS func() error) error { + if err := writeLHS(); err != nil { + return err + } + w.WriteString(" || ") + return writeRHS() +} + +// WriteRegexMatch writes a DuckDB regex match expression using ~ or ~*. +func (d *Dialect) WriteRegexMatch(w *strings.Builder, writeTarget func() error, pattern string, caseInsensitive bool) error { + if err := writeTarget(); err != nil { + return err + } + if caseInsensitive { + w.WriteString(" ~* ") + } else { + w.WriteString(" ~ ") + } + escaped := strings.ReplaceAll(pattern, "'", "''") + w.WriteString("'") + w.WriteString(escaped) + w.WriteString("'") + return nil +} + +// WriteLikeEscape writes the DuckDB LIKE escape clause. +func (d *Dialect) WriteLikeEscape(w *strings.Builder) { + w.WriteString(" ESCAPE '\\\\'") +} + +// WriteArrayMembership writes a DuckDB array membership test using = ANY(). +func (d *Dialect) WriteArrayMembership(w *strings.Builder, writeElem func() error, writeArray func() error) error { + if err := writeElem(); err != nil { + return err + } + w.WriteString(" = ANY(") + if err := writeArray(); err != nil { + return err + } + w.WriteString(")") + return nil +} + +// --- Type Casting --- + +// WriteCastToNumeric writes a DuckDB numeric cast (::DOUBLE). +func (d *Dialect) WriteCastToNumeric(w *strings.Builder) { + w.WriteString("::DOUBLE") +} + +// WriteTypeName writes a DuckDB type name for CAST expressions. +func (d *Dialect) WriteTypeName(w *strings.Builder, celTypeName string) { + switch celTypeName { + case "bool": + w.WriteString("BOOLEAN") + case "bytes": + w.WriteString("BLOB") + case "double": + w.WriteString("DOUBLE") + case "int": + w.WriteString("BIGINT") + case "string": + w.WriteString("VARCHAR") + case "uint": + w.WriteString("UBIGINT") + default: + w.WriteString(strings.ToUpper(celTypeName)) + } +} + +// WriteEpochExtract writes EXTRACT(EPOCH FROM expr)::BIGINT for DuckDB. +func (d *Dialect) WriteEpochExtract(w *strings.Builder, writeExpr func() error) error { + w.WriteString("EXTRACT(EPOCH FROM ") + if err := writeExpr(); err != nil { + return err + } + w.WriteString(")::BIGINT") + return nil +} + +// WriteTimestampCast writes a DuckDB CAST to TIMESTAMPTZ. +func (d *Dialect) WriteTimestampCast(w *strings.Builder, writeExpr func() error) error { + w.WriteString("CAST(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString(" AS TIMESTAMPTZ)") + return nil +} + +// --- Arrays --- + +// WriteArrayLiteralOpen writes the DuckDB array literal opening ([). +func (d *Dialect) WriteArrayLiteralOpen(w *strings.Builder) { + w.WriteString("[") +} + +// WriteArrayLiteralClose writes the DuckDB array literal closing (]). +func (d *Dialect) WriteArrayLiteralClose(w *strings.Builder) { + w.WriteString("]") +} + +// WriteArrayLength writes COALESCE(array_length(expr), 0) for DuckDB. +func (d *Dialect) WriteArrayLength(w *strings.Builder, _ int, writeExpr func() error) error { + w.WriteString("COALESCE(array_length(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString("), 0)") + return nil +} + +// WriteListIndex writes DuckDB 1-indexed array access. +func (d *Dialect) WriteListIndex(w *strings.Builder, writeArray func() error, writeIndex func() error) error { + if err := writeArray(); err != nil { + return err + } + w.WriteString("[") + if err := writeIndex(); err != nil { + return err + } + w.WriteString(" + 1]") + return nil +} + +// WriteListIndexConst writes DuckDB constant array index access (0-indexed to 1-indexed). +func (d *Dialect) WriteListIndexConst(w *strings.Builder, writeArray func() error, index int64) error { + if err := writeArray(); err != nil { + return err + } + fmt.Fprintf(w, "[%d]", index+1) + return nil +} + +// WriteEmptyTypedArray writes an empty DuckDB typed array. +func (d *Dialect) WriteEmptyTypedArray(w *strings.Builder, typeName string) { + w.WriteString("[]::") //nolint:gocritic + w.WriteString(typeName) + w.WriteString("[]") +} + +// --- JSON --- + +// WriteJSONFieldAccess writes DuckDB JSON field access using -> or ->> operators. +func (d *Dialect) WriteJSONFieldAccess(w *strings.Builder, writeBase func() error, fieldName string, isFinal bool) error { + if err := writeBase(); err != nil { + return err + } + escaped := escapeJSONFieldName(fieldName) + if isFinal { + w.WriteString("->>'") + } else { + w.WriteString("->'") + } + w.WriteString(escaped) + w.WriteString("'") + return nil +} + +// WriteJSONExistence writes a DuckDB JSON key existence check using json_exists. +func (d *Dialect) WriteJSONExistence(w *strings.Builder, _ bool, fieldName string, writeBase func() error) error { + w.WriteString("json_exists(") + if err := writeBase(); err != nil { + return err + } + escaped := escapeJSONFieldName(fieldName) + w.WriteString(", '$.") + w.WriteString(escaped) + w.WriteString("')") + return nil +} + +// WriteJSONArrayElements writes DuckDB JSON array expansion using json_each. +func (d *Dialect) WriteJSONArrayElements(w *strings.Builder, _ bool, _ bool, writeExpr func() error) error { + w.WriteString("json_each(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString(")") + return nil +} + +// WriteJSONArrayLength writes COALESCE(json_array_length(expr), 0) for DuckDB. +func (d *Dialect) WriteJSONArrayLength(w *strings.Builder, writeExpr func() error) error { + w.WriteString("COALESCE(json_array_length(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString("), 0)") + return nil +} + +// WriteJSONExtractPath writes DuckDB JSON path existence using json_exists. +func (d *Dialect) WriteJSONExtractPath(w *strings.Builder, pathSegments []string, writeRoot func() error) error { + w.WriteString("json_exists(") + if err := writeRoot(); err != nil { + return err + } + w.WriteString(", '$") + for _, segment := range pathSegments { + w.WriteString(".") + w.WriteString(escapeJSONFieldName(segment)) + } + w.WriteString("')") + return nil +} + +// WriteJSONArrayMembership writes DuckDB JSON array membership using json_each. +func (d *Dialect) WriteJSONArrayMembership(w *strings.Builder, _ string, writeExpr func() error) error { + w.WriteString("(SELECT value FROM json_each(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString("))") + return nil +} + +// WriteNestedJSONArrayMembership writes DuckDB nested JSON array membership. +func (d *Dialect) WriteNestedJSONArrayMembership(w *strings.Builder, writeExpr func() error) error { + w.WriteString("(SELECT value FROM json_each(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString("))") + return nil +} + +// --- Timestamps --- + +// WriteDuration writes a DuckDB INTERVAL literal. +func (d *Dialect) WriteDuration(w *strings.Builder, value int64, unit string) { + fmt.Fprintf(w, "INTERVAL %d %s", value, unit) +} + +// WriteInterval writes a DuckDB INTERVAL expression. +func (d *Dialect) WriteInterval(w *strings.Builder, writeValue func() error, unit string) error { + w.WriteString("INTERVAL ") + if err := writeValue(); err != nil { + return err + } + w.WriteString(" ") + w.WriteString(unit) + return nil +} + +// WriteExtract writes a DuckDB EXTRACT expression with DOW conversion. +func (d *Dialect) WriteExtract(w *strings.Builder, part string, writeExpr func() error, writeTZ func() error) error { + isDOW := part == "DOW" + if isDOW { + w.WriteString("(") + } + w.WriteString("EXTRACT(") + w.WriteString(part) + w.WriteString(" FROM ") + if err := writeExpr(); err != nil { + return err + } + if writeTZ != nil { + w.WriteString(" AT TIME ZONE ") + if err := writeTZ(); err != nil { + return err + } + } + w.WriteString(")") + if isDOW { + w.WriteString(" + 6) % 7") + } + return nil +} + +// WriteTimestampArithmetic writes DuckDB timestamp arithmetic. +func (d *Dialect) WriteTimestampArithmetic(w *strings.Builder, op string, writeTS, writeDur func() error) error { + if err := writeTS(); err != nil { + return err + } + w.WriteString(" ") + w.WriteString(op) + w.WriteString(" ") + return writeDur() +} + +// --- String Functions --- + +// WriteContains writes CONTAINS(haystack, needle) for DuckDB. +func (d *Dialect) WriteContains(w *strings.Builder, writeHaystack, writeNeedle func() error) error { + w.WriteString("CONTAINS(") + if err := writeHaystack(); err != nil { + return err + } + w.WriteString(", ") + if err := writeNeedle(); err != nil { + return err + } + w.WriteString(")") + return nil +} + +// WriteSplit writes DuckDB string split using STRING_SPLIT. +func (d *Dialect) WriteSplit(w *strings.Builder, writeStr, writeDelim func() error) error { + w.WriteString("STRING_SPLIT(") + if err := writeStr(); err != nil { + return err + } + w.WriteString(", ") + if err := writeDelim(); err != nil { + return err + } + w.WriteString(")") + return nil +} + +// WriteSplitWithLimit writes DuckDB string split with array slice. +func (d *Dialect) WriteSplitWithLimit(w *strings.Builder, writeStr, writeDelim func() error, limit int64) error { + w.WriteString("STRING_SPLIT(") + if err := writeStr(); err != nil { + return err + } + w.WriteString(", ") + if err := writeDelim(); err != nil { + return err + } + fmt.Fprintf(w, ")[1:%d]", limit) + return nil +} + +// WriteJoin writes DuckDB array join using ARRAY_TO_STRING. +func (d *Dialect) WriteJoin(w *strings.Builder, writeArray, writeDelim func() error) error { + w.WriteString("ARRAY_TO_STRING(") + if err := writeArray(); err != nil { + return err + } + w.WriteString(", ") + if err := writeDelim(); err != nil { + return err + } + w.WriteString(")") + return nil +} + +// --- Comprehensions --- + +// WriteUnnest writes DuckDB UNNEST for array unnesting. +func (d *Dialect) WriteUnnest(w *strings.Builder, writeSource func() error) error { + w.WriteString("UNNEST(") + if err := writeSource(); err != nil { + return err + } + w.WriteString(")") + return nil +} + +// WriteArraySubqueryOpen writes ARRAY(SELECT for DuckDB. +func (d *Dialect) WriteArraySubqueryOpen(w *strings.Builder) { + w.WriteString("ARRAY(SELECT ") +} + +// WriteArraySubqueryExprClose is a no-op for DuckDB (no wrapper around the expression). +func (d *Dialect) WriteArraySubqueryExprClose(_ *strings.Builder) { +} + +// --- Struct --- + +// WriteStructOpen writes the DuckDB struct literal opening. +func (d *Dialect) WriteStructOpen(w *strings.Builder) { + w.WriteString("ROW(") +} + +// WriteStructClose writes the DuckDB struct literal closing. +func (d *Dialect) WriteStructClose(w *strings.Builder) { + w.WriteString(")") +} + +// --- Validation --- + +// MaxIdentifierLength returns 0 as DuckDB has no hard identifier length limit. +func (d *Dialect) MaxIdentifierLength() int { + return 0 +} + +// ValidateFieldName validates a field name against DuckDB naming rules. +func (d *Dialect) ValidateFieldName(name string) error { + return validateFieldName(name) +} + +// ReservedKeywords returns the set of reserved SQL keywords for DuckDB. +func (d *Dialect) ReservedKeywords() map[string]bool { + return reservedSQLKeywords +} + +// --- Regex --- + +// ConvertRegex converts an RE2 regex pattern to DuckDB-compatible format. +// DuckDB uses RE2 natively, so minimal conversion is needed. +func (d *Dialect) ConvertRegex(re2Pattern string) (string, bool, error) { + return convertRE2ToDuckDB(re2Pattern) +} + +// SupportsRegex returns true as DuckDB supports RE2 regex natively. +func (d *Dialect) SupportsRegex() bool { return true } + +// --- Capabilities --- + +// SupportsNativeArrays returns true as DuckDB has native array (LIST) types. +func (d *Dialect) SupportsNativeArrays() bool { return true } + +// SupportsJSONB returns false as DuckDB has a single JSON type. +func (d *Dialect) SupportsJSONB() bool { return false } + +// SupportsIndexAnalysis returns true as DuckDB index analysis is supported. +func (d *Dialect) SupportsIndexAnalysis() bool { return true } + +// --- Internal helpers --- + +// escapeJSONFieldName escapes special characters in JSON field names for DuckDB. +func escapeJSONFieldName(fieldName string) string { + return strings.ReplaceAll(fieldName, "'", "''") +} diff --git a/dialect/duckdb/index_advisor.go b/dialect/duckdb/index_advisor.go new file mode 100644 index 0000000..9906b32 --- /dev/null +++ b/dialect/duckdb/index_advisor.go @@ -0,0 +1,92 @@ +package duckdb + +import ( + "fmt" + "strings" + + "github.com/spandigital/cel2sql/v3/dialect" +) + +// DuckDB index type constants. +const ( + IndexTypeART = "ART" +) + +// RecommendIndex generates a DuckDB-specific index recommendation for the given pattern. +// DuckDB uses ART (Adaptive Radix Tree) indexes by default. Returns nil for unsupported patterns. +func (d *Dialect) RecommendIndex(pattern dialect.IndexPattern) *dialect.IndexRecommendation { + table := pattern.TableHint + if table == "" { + table = "table_name" + } + col := pattern.Column + safeName := sanitizeIndexName(col) + + switch pattern.Pattern { + case dialect.PatternComparison: + return &dialect.IndexRecommendation{ + Column: col, + IndexType: IndexTypeART, + Expression: fmt.Sprintf("CREATE INDEX idx_%s ON %s (%s);", + safeName, table, col), + Reason: fmt.Sprintf("Comparison operations on '%s' benefit from an ART index for efficient range queries and equality checks", col), + } + + case dialect.PatternJSONAccess: + return &dialect.IndexRecommendation{ + Column: col, + IndexType: IndexTypeART, + Expression: fmt.Sprintf("CREATE INDEX idx_%s_json ON %s (%s);", + safeName, table, col), + Reason: fmt.Sprintf("JSON field access on '%s' may benefit from an ART index", col), + } + + case dialect.PatternRegexMatch: + // DuckDB does not have specialized regex indexes + return nil + + case dialect.PatternArrayMembership, dialect.PatternArrayComprehension: + return &dialect.IndexRecommendation{ + Column: col, + IndexType: IndexTypeART, + Expression: fmt.Sprintf("CREATE INDEX idx_%s ON %s (%s);", + safeName, table, col), + Reason: fmt.Sprintf("Array operations on '%s' may benefit from an ART index", col), + } + + case dialect.PatternJSONArrayComprehension: + return &dialect.IndexRecommendation{ + Column: col, + IndexType: IndexTypeART, + Expression: fmt.Sprintf("CREATE INDEX idx_%s_json ON %s (%s);", + safeName, table, col), + Reason: fmt.Sprintf("JSON array comprehension on '%s' may benefit from an ART index", col), + } + } + + return nil +} + +// SupportedPatterns returns the pattern types supported by DuckDB. +func (d *Dialect) SupportedPatterns() []dialect.PatternType { + return []dialect.PatternType{ + dialect.PatternComparison, + dialect.PatternJSONAccess, + dialect.PatternArrayMembership, + dialect.PatternArrayComprehension, + dialect.PatternJSONArrayComprehension, + } +} + +// sanitizeIndexName creates a safe index name from a column name. +func sanitizeIndexName(column string) string { + sanitized := strings.ReplaceAll(column, ".", "_") + sanitized = strings.ReplaceAll(sanitized, " ", "_") + sanitized = strings.ReplaceAll(sanitized, "-", "_") + + if len(sanitized) > 50 { + sanitized = sanitized[:50] + } + + return sanitized +} diff --git a/dialect/duckdb/regex.go b/dialect/duckdb/regex.go new file mode 100644 index 0000000..582f83d --- /dev/null +++ b/dialect/duckdb/regex.go @@ -0,0 +1,137 @@ +package duckdb + +import ( + "errors" + "fmt" + "regexp" + "strings" +) + +// Regex pattern complexity limits to prevent ReDoS attacks (CWE-1333). +const ( + maxRegexPatternLength = 500 + maxRegexGroups = 20 + maxRegexNestingDepth = 10 +) + +// convertRE2ToDuckDB converts an RE2 regex pattern to DuckDB-compatible format. +// DuckDB uses RE2 natively, so most patterns pass through unchanged. +// Returns the converted pattern, whether it's case-insensitive, and any error. +func convertRE2ToDuckDB(re2Pattern string) (string, bool, error) { + // 1. Pattern length validation + if len(re2Pattern) > maxRegexPatternLength { + return "", false, fmt.Errorf("pattern length %d exceeds limit of %d characters", len(re2Pattern), maxRegexPatternLength) + } + + // 2. Validate pattern compiles + if _, err := regexp.Compile(re2Pattern); err != nil { + return "", false, fmt.Errorf("invalid regex pattern: %w", err) + } + + // 3. Detect unsupported features (lookahead/lookbehind not in RE2 anyway) + if strings.Contains(re2Pattern, "(?=") || strings.Contains(re2Pattern, "(?!") { + return "", false, errors.New("lookahead assertions (?=...), (?!...) are not supported") + } + if strings.Contains(re2Pattern, "(?<=") || strings.Contains(re2Pattern, "(?...) are not supported") + } + + // 4. Detect catastrophic nested quantifiers + if matched, _ := regexp.MatchString(`[*+][*+]`, re2Pattern); matched { + return "", false, errors.New("regex contains catastrophic nested quantifiers that could cause ReDoS") + } + + // 5. Check for nested quantifiers in groups + depth := 0 + groupHasQuantifier := make([]bool, 0) + for i := 0; i < len(re2Pattern); i++ { + char := re2Pattern[i] + if i > 0 && re2Pattern[i-1] == '\\' { + continue + } + switch char { + case '(': + depth++ + groupHasQuantifier = append(groupHasQuantifier, false) + case ')': + if depth > 0 { + depth-- + if i+1 < len(re2Pattern) { + nextChar := re2Pattern[i+1] + if nextChar == '*' || nextChar == '+' || nextChar == '?' || nextChar == '{' { + if len(groupHasQuantifier) > 0 && groupHasQuantifier[len(groupHasQuantifier)-1] { + return "", false, errors.New("regex contains catastrophic nested quantifiers that could cause ReDoS") + } + } + } + if len(groupHasQuantifier) > 0 { + groupHasQuantifier = groupHasQuantifier[:len(groupHasQuantifier)-1] + } + } + case '*', '+', '?', '{': + for j := range groupHasQuantifier { + groupHasQuantifier[j] = true + } + } + } + + // 6. Check group count limit + groupCount := strings.Count(re2Pattern, "(") - strings.Count(re2Pattern, "\\(") + if groupCount > maxRegexGroups { + return "", false, fmt.Errorf("regex contains %d capture groups, exceeds limit of %d", groupCount, maxRegexGroups) + } + + // 7. Check for quantified alternation + quantifiedAlternation := regexp.MustCompile(`\([^)]*\|[^)]*\)[*+]`) + if quantifiedAlternation.MatchString(re2Pattern) { + return "", false, errors.New("regex contains quantified alternation that could cause ReDoS") + } + + // 8. Check nesting depth + maxDepthVal := 0 + currentDepth := 0 + for i := 0; i < len(re2Pattern); i++ { + if i > 0 && re2Pattern[i-1] == '\\' { + continue + } + switch re2Pattern[i] { + case '(': + currentDepth++ + if currentDepth > maxDepthVal { + maxDepthVal = currentDepth + } + case ')': + if currentDepth > 0 { + currentDepth-- + } + } + } + if maxDepthVal > maxRegexNestingDepth { + return "", false, fmt.Errorf("nesting depth %d exceeds limit of %d", maxDepthVal, maxRegexNestingDepth) + } + + // Process pattern: DuckDB uses RE2 natively, so minimal conversion needed + caseInsensitive := false + pattern := re2Pattern + + // Handle (?i) flag + if strings.HasPrefix(pattern, "(?i)") { + caseInsensitive = true + pattern = pattern[4:] + } + + // Handle inline flags other than (?i) at start + if strings.Contains(pattern, "(?m") || strings.Contains(pattern, "(?s") || strings.Contains(pattern, "(?-") { + return "", false, errors.New("inline flags other than (?i) are not supported in DuckDB regex") + } + + // Convert non-capturing groups (?:...) to regular groups (...) + pattern = strings.ReplaceAll(pattern, "(?:", "(") + + // DuckDB RE2 supports \d, \w, \s, \b natively - no conversion needed + + return pattern, caseInsensitive, nil +} diff --git a/dialect/duckdb/validation.go b/dialect/duckdb/validation.go new file mode 100644 index 0000000..976e304 --- /dev/null +++ b/dialect/duckdb/validation.go @@ -0,0 +1,55 @@ +package duckdb + +import ( + "errors" + "fmt" + "regexp" + "strings" +) + +var ( + // fieldNameRegexp validates DuckDB identifier format. + fieldNameRegexp = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) + + // reservedSQLKeywords contains DuckDB reserved keywords. + reservedSQLKeywords = map[string]bool{ + "all": true, "alter": true, "analyse": true, "analyze": true, "and": true, + "any": true, "array": true, "as": true, "asc": true, "asymmetric": true, + "between": true, "both": true, "case": true, "cast": true, "check": true, + "collate": true, "column": true, "constraint": true, "create": true, + "cross": true, "current_catalog": true, "current_date": true, + "current_role": true, "current_schema": true, "current_time": true, + "current_timestamp": true, "current_user": true, "default": true, + "deferrable": true, "desc": true, "distinct": true, "do": true, + "else": true, "end": true, "except": true, "exists": true, "false": true, + "fetch": true, "for": true, "foreign": true, "from": true, "full": true, + "grant": true, "group": true, "having": true, "in": true, "initially": true, + "inner": true, "intersect": true, "into": true, "is": true, "isnull": true, + "join": true, "lateral": true, "leading": true, "left": true, "like": true, + "limit": true, "localtime": true, "localtimestamp": true, "natural": true, + "not": true, "notnull": true, "null": true, "offset": true, "on": true, + "only": true, "or": true, "order": true, "outer": true, "overlaps": true, + "placing": true, "primary": true, "references": true, "returning": true, + "right": true, "select": true, "session_user": true, "similar": true, + "some": true, "symmetric": true, "table": true, "then": true, "to": true, + "trailing": true, "true": true, "union": true, "unique": true, "using": true, + "variadic": true, "when": true, "where": true, "window": true, "with": true, + } +) + +// validateFieldName validates that a field name follows DuckDB naming conventions. +func validateFieldName(name string) error { + if len(name) == 0 { + return errors.New("field name cannot be empty") + } + + if !fieldNameRegexp.MatchString(name) { + return fmt.Errorf("field name %q must start with a letter or underscore and contain only alphanumeric characters and underscores", name) + } + + if reservedSQLKeywords[strings.ToLower(name)] { + return fmt.Errorf("field name %q is a reserved SQL keyword and cannot be used without quoting", name) + } + + return nil +} diff --git a/dialect/index_advisor.go b/dialect/index_advisor.go new file mode 100644 index 0000000..210e62f --- /dev/null +++ b/dialect/index_advisor.go @@ -0,0 +1,61 @@ +// Package dialect defines the IndexAdvisor interface for dialect-specific index recommendations. +package dialect + +// PatternType enumerates detected index-worthy query patterns. +type PatternType int + +// Index-worthy pattern types detected during query analysis. +const ( + PatternComparison PatternType = iota // Equality/range comparisons (==, >, <, >=, <=) + PatternJSONAccess // JSON/JSONB field access + PatternRegexMatch // Regex pattern matching + PatternArrayMembership // Array IN/containment + PatternArrayComprehension // Array comprehension (all, exists, filter, map) + PatternJSONArrayComprehension // JSON array comprehension +) + +// IndexPattern describes a detected query pattern that could benefit from indexing. +type IndexPattern struct { + // Column is the full column name (e.g., "person.metadata"). + Column string + + // Pattern is the type of query pattern detected. + Pattern PatternType + + // TableHint is an optional table name hint for generating CREATE INDEX statements. + // If empty, "table_name" is used as the default placeholder. + TableHint string +} + +// IndexRecommendation represents a database index recommendation. +// It provides actionable guidance for optimizing query performance. +type IndexRecommendation struct { + // Column is the database column that should be indexed. + Column string + + // IndexType specifies the index type (e.g., "BTREE", "GIN", "ART", "CLUSTERING"). + IndexType string + + // Expression is the complete DDL statement that can be executed directly. + Expression string + + // Reason explains why this index is recommended and what query patterns it optimizes. + Reason string +} + +// IndexAdvisor generates dialect-specific index recommendations. +// Dialects that support index analysis implement this interface on their Dialect struct. +type IndexAdvisor interface { + // RecommendIndex generates an IndexRecommendation for the given pattern, + // or returns nil if the dialect has no applicable index for this pattern. + RecommendIndex(pattern IndexPattern) *IndexRecommendation + + // SupportedPatterns returns which PatternTypes this advisor can handle. + SupportedPatterns() []PatternType +} + +// GetIndexAdvisor returns the IndexAdvisor for a dialect, if it implements the interface. +func GetIndexAdvisor(d Dialect) (IndexAdvisor, bool) { + advisor, ok := d.(IndexAdvisor) + return advisor, ok +} diff --git a/dialect/mysql/dialect.go b/dialect/mysql/dialect.go new file mode 100644 index 0000000..01cd97d --- /dev/null +++ b/dialect/mysql/dialect.go @@ -0,0 +1,475 @@ +// Package mysql implements the MySQL SQL dialect for cel2sql. +package mysql + +import ( + "fmt" + "strings" + + "github.com/spandigital/cel2sql/v3/dialect" +) + +// Dialect implements dialect.Dialect for MySQL 8.0+. +type Dialect struct{} + +// New creates a new MySQL dialect. +func New() *Dialect { + return &Dialect{} +} + +func init() { + dialect.Register(dialect.MySQL, func() dialect.Dialect { return New() }) +} + +// Ensure Dialect implements dialect.Dialect at compile time. +var _ dialect.Dialect = (*Dialect)(nil) + +// Name returns the dialect name. +func (d *Dialect) Name() dialect.Name { return dialect.MySQL } + +// --- Literals --- + +// WriteStringLiteral writes a MySQL string literal with ” escaping. +func (d *Dialect) WriteStringLiteral(w *strings.Builder, value string) { + escaped := strings.ReplaceAll(value, "'", "''") + w.WriteString("'") + w.WriteString(escaped) + w.WriteString("'") +} + +// WriteBytesLiteral writes a MySQL hex-encoded byte literal (X'...'). +func (d *Dialect) WriteBytesLiteral(w *strings.Builder, value []byte) error { + w.WriteString("X'") + for _, b := range value { + fmt.Fprintf(w, "%02x", b) + } + w.WriteString("'") + return nil +} + +// WriteParamPlaceholder writes a MySQL positional parameter (?). +func (d *Dialect) WriteParamPlaceholder(w *strings.Builder, _ int) { + w.WriteString("?") +} + +// --- Operators --- + +// WriteStringConcat writes MySQL string concatenation using CONCAT(). +func (d *Dialect) WriteStringConcat(w *strings.Builder, writeLHS, writeRHS func() error) error { + w.WriteString("CONCAT(") + if err := writeLHS(); err != nil { + return err + } + w.WriteString(", ") + if err := writeRHS(); err != nil { + return err + } + w.WriteString(")") + return nil +} + +// WriteRegexMatch writes a MySQL REGEXP match expression. +func (d *Dialect) WriteRegexMatch(w *strings.Builder, writeTarget func() error, pattern string, _ bool) error { + if err := writeTarget(); err != nil { + return err + } + w.WriteString(" REGEXP ") + escaped := strings.ReplaceAll(pattern, "'", "''") + w.WriteString("'") + w.WriteString(escaped) + w.WriteString("'") + return nil +} + +// WriteLikeEscape writes the MySQL LIKE escape clause. +func (d *Dialect) WriteLikeEscape(w *strings.Builder) { + w.WriteString(" ESCAPE '\\\\'") +} + +// WriteArrayMembership writes a MySQL array membership test using JSON_CONTAINS. +func (d *Dialect) WriteArrayMembership(w *strings.Builder, writeElem func() error, writeArray func() error) error { + w.WriteString("JSON_CONTAINS(") + if err := writeArray(); err != nil { + return err + } + w.WriteString(", CAST(") + if err := writeElem(); err != nil { + return err + } + w.WriteString(" AS JSON))") + return nil +} + +// --- Type Casting --- + +// WriteCastToNumeric writes a MySQL numeric cast (CAST(... AS DECIMAL)). +func (d *Dialect) WriteCastToNumeric(w *strings.Builder) { + w.WriteString(" + 0") +} + +// WriteTypeName writes a MySQL type name for CAST expressions. +func (d *Dialect) WriteTypeName(w *strings.Builder, celTypeName string) { + switch celTypeName { + case "bool": + w.WriteString("UNSIGNED") + case "bytes": + w.WriteString("BINARY") + case "double": + w.WriteString("DECIMAL") + case "int": + w.WriteString("SIGNED") + case "string": + w.WriteString("CHAR") + case "uint": + w.WriteString("UNSIGNED") + default: + w.WriteString(strings.ToUpper(celTypeName)) + } +} + +// WriteEpochExtract writes UNIX_TIMESTAMP(expr) for MySQL. +func (d *Dialect) WriteEpochExtract(w *strings.Builder, writeExpr func() error) error { + w.WriteString("UNIX_TIMESTAMP(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString(")") + return nil +} + +// WriteTimestampCast writes a MySQL CAST to DATETIME. +func (d *Dialect) WriteTimestampCast(w *strings.Builder, writeExpr func() error) error { + w.WriteString("CAST(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString(" AS DATETIME)") + return nil +} + +// --- Arrays --- + +// WriteArrayLiteralOpen writes the MySQL JSON array literal opening. +func (d *Dialect) WriteArrayLiteralOpen(w *strings.Builder) { + w.WriteString("JSON_ARRAY(") +} + +// WriteArrayLiteralClose writes the MySQL JSON array literal closing. +func (d *Dialect) WriteArrayLiteralClose(w *strings.Builder) { + w.WriteString(")") +} + +// WriteArrayLength writes JSON_LENGTH(expr) for MySQL. +func (d *Dialect) WriteArrayLength(w *strings.Builder, _ int, writeExpr func() error) error { + w.WriteString("COALESCE(JSON_LENGTH(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString("), 0)") + return nil +} + +// WriteListIndex writes MySQL JSON array index access. +func (d *Dialect) WriteListIndex(w *strings.Builder, writeArray func() error, writeIndex func() error) error { + w.WriteString("JSON_EXTRACT(") + if err := writeArray(); err != nil { + return err + } + w.WriteString(", CONCAT('$[', ") + if err := writeIndex(); err != nil { + return err + } + w.WriteString(", ']'))") + return nil +} + +// WriteListIndexConst writes MySQL JSON constant array index access. +func (d *Dialect) WriteListIndexConst(w *strings.Builder, writeArray func() error, index int64) error { + w.WriteString("JSON_EXTRACT(") + if err := writeArray(); err != nil { + return err + } + fmt.Fprintf(w, ", '$[%d]')", index) + return nil +} + +// WriteEmptyTypedArray writes an empty MySQL JSON array. +func (d *Dialect) WriteEmptyTypedArray(w *strings.Builder, _ string) { + w.WriteString("JSON_ARRAY()") +} + +// --- JSON --- + +// WriteJSONFieldAccess writes MySQL JSON field access using JSON_EXTRACT/JSON_UNQUOTE. +func (d *Dialect) WriteJSONFieldAccess(w *strings.Builder, writeBase func() error, fieldName string, isFinal bool) error { + if err := writeBase(); err != nil { + return err + } + escaped := escapeJSONFieldName(fieldName) + if isFinal { + // For final access, we need text: use ->> which is JSON_UNQUOTE(JSON_EXTRACT(...)) + w.WriteString("->>'$.") + w.WriteString(escaped) + w.WriteString("'") + } else { + w.WriteString("->'$.") + w.WriteString(escaped) + w.WriteString("'") + } + return nil +} + +// WriteJSONExistence writes a MySQL JSON key existence check. +func (d *Dialect) WriteJSONExistence(w *strings.Builder, _ bool, fieldName string, writeBase func() error) error { + w.WriteString("JSON_CONTAINS_PATH(") + if err := writeBase(); err != nil { + return err + } + escaped := escapeJSONFieldName(fieldName) + w.WriteString(", 'one', '$.") + w.WriteString(escaped) + w.WriteString("')") + return nil +} + +// WriteJSONArrayElements writes MySQL JSON array expansion using JSON_TABLE. +func (d *Dialect) WriteJSONArrayElements(w *strings.Builder, _ bool, _ bool, writeExpr func() error) error { + w.WriteString("JSON_TABLE(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString(", '$[*]' COLUMNS(value TEXT PATH '$'))") + return nil +} + +// WriteJSONArrayLength writes COALESCE(JSON_LENGTH(expr), 0) for MySQL. +func (d *Dialect) WriteJSONArrayLength(w *strings.Builder, writeExpr func() error) error { + w.WriteString("COALESCE(JSON_LENGTH(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString("), 0)") + return nil +} + +// WriteJSONExtractPath writes MySQL JSON path extraction using JSON_CONTAINS_PATH. +func (d *Dialect) WriteJSONExtractPath(w *strings.Builder, pathSegments []string, writeRoot func() error) error { + w.WriteString("JSON_CONTAINS_PATH(") + if err := writeRoot(); err != nil { + return err + } + w.WriteString(", 'one', '$") + for _, segment := range pathSegments { + w.WriteString(".") + w.WriteString(escapeJSONFieldName(segment)) + } + w.WriteString("')") + return nil +} + +// WriteJSONArrayMembership writes MySQL JSON array membership using JSON_CONTAINS. +func (d *Dialect) WriteJSONArrayMembership(w *strings.Builder, _ string, writeExpr func() error) error { + w.WriteString("JSON_CONTAINS(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString(", CAST(? AS JSON))") + return nil +} + +// WriteNestedJSONArrayMembership writes MySQL nested JSON array membership. +func (d *Dialect) WriteNestedJSONArrayMembership(w *strings.Builder, writeExpr func() error) error { + w.WriteString("JSON_CONTAINS(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString(", CAST(? AS JSON))") + return nil +} + +// --- Timestamps --- + +// WriteDuration writes a MySQL INTERVAL literal. +func (d *Dialect) WriteDuration(w *strings.Builder, value int64, unit string) { + fmt.Fprintf(w, "INTERVAL %d %s", value, unit) +} + +// WriteInterval writes a MySQL INTERVAL expression. +func (d *Dialect) WriteInterval(w *strings.Builder, writeValue func() error, unit string) error { + w.WriteString("INTERVAL ") + if err := writeValue(); err != nil { + return err + } + w.WriteString(" ") + w.WriteString(unit) + return nil +} + +// WriteExtract writes a MySQL EXTRACT expression with DOW conversion. +func (d *Dialect) WriteExtract(w *strings.Builder, part string, writeExpr func() error, writeTZ func() error) error { + isDOW := part == "DOW" + if isDOW { + // MySQL DAYOFWEEK: 1=Sunday, 2=Monday, ..., 7=Saturday + // CEL getDayOfWeek: 0=Monday, 1=Tuesday, ..., 6=Sunday (ISO 8601) + // Convert: (DAYOFWEEK(x) + 5) % 7 + w.WriteString("(DAYOFWEEK(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString(") + 5) % 7") + return nil + } + + w.WriteString("EXTRACT(") + w.WriteString(part) + w.WriteString(" FROM ") + if err := writeExpr(); err != nil { + return err + } + if writeTZ != nil { + w.WriteString(" AT TIME ZONE ") + if err := writeTZ(); err != nil { + return err + } + } + w.WriteString(")") + return nil +} + +// WriteTimestampArithmetic writes MySQL timestamp arithmetic. +func (d *Dialect) WriteTimestampArithmetic(w *strings.Builder, op string, writeTS, writeDur func() error) error { + if err := writeTS(); err != nil { + return err + } + w.WriteString(" ") + w.WriteString(op) + w.WriteString(" ") + return writeDur() +} + +// --- String Functions --- + +// WriteContains writes LOCATE(needle, haystack) > 0 for MySQL. +func (d *Dialect) WriteContains(w *strings.Builder, writeHaystack, writeNeedle func() error) error { + w.WriteString("LOCATE(") + if err := writeNeedle(); err != nil { + return err + } + w.WriteString(", ") + if err := writeHaystack(); err != nil { + return err + } + w.WriteString(") > 0") + return nil +} + +// WriteSplit writes MySQL string split using SUBSTRING_INDEX pattern. +func (d *Dialect) WriteSplit(w *strings.Builder, writeStr, writeDelim func() error) error { + // MySQL doesn't have a direct STRING_TO_ARRAY equivalent. + // Use a JSON approach: convert to JSON array. + w.WriteString("JSON_ARRAY(") + if err := writeStr(); err != nil { + return err + } + w.WriteString(")") + // Note: A full MySQL split implementation would require a more complex approach. + // This is a simplified version. + _ = writeDelim + return nil +} + +// WriteSplitWithLimit writes MySQL string split with limit. +func (d *Dialect) WriteSplitWithLimit(w *strings.Builder, writeStr, writeDelim func() error, _ int64) error { + // Simplified: delegate to WriteSplit + return d.WriteSplit(w, writeStr, writeDelim) +} + +// WriteJoin writes MySQL array join using JSON_UNQUOTE/GROUP_CONCAT pattern. +func (d *Dialect) WriteJoin(w *strings.Builder, writeArray, writeDelim func() error) error { + // MySQL doesn't have ARRAY_TO_STRING; simplified approach + w.WriteString("JSON_UNQUOTE(") + if err := writeArray(); err != nil { + return err + } + w.WriteString(")") + _ = writeDelim + return nil +} + +// --- Comprehensions --- + +// WriteUnnest writes MySQL JSON_TABLE for array unnesting. +func (d *Dialect) WriteUnnest(w *strings.Builder, writeSource func() error) error { + w.WriteString("JSON_TABLE(") + if err := writeSource(); err != nil { + return err + } + w.WriteString(", '$[*]' COLUMNS(value TEXT PATH '$'))") + return nil +} + +// WriteArraySubqueryOpen writes (SELECT JSON_ARRAYAGG( for MySQL array subqueries. +func (d *Dialect) WriteArraySubqueryOpen(w *strings.Builder) { + w.WriteString("(SELECT JSON_ARRAYAGG(") +} + +// WriteArraySubqueryExprClose closes the JSON_ARRAYAGG aggregate function for MySQL. +func (d *Dialect) WriteArraySubqueryExprClose(w *strings.Builder) { + w.WriteString(")") +} + +// --- Struct --- + +// WriteStructOpen writes the MySQL struct literal opening. +func (d *Dialect) WriteStructOpen(w *strings.Builder) { + w.WriteString("ROW(") +} + +// WriteStructClose writes the MySQL struct literal closing. +func (d *Dialect) WriteStructClose(w *strings.Builder) { + w.WriteString(")") +} + +// --- Validation --- + +// MaxIdentifierLength returns the MySQL maximum identifier length (64). +func (d *Dialect) MaxIdentifierLength() int { + return maxMySQLIdentifierLength +} + +// ValidateFieldName validates a field name against MySQL naming rules. +func (d *Dialect) ValidateFieldName(name string) error { + return validateFieldName(name) +} + +// ReservedKeywords returns the set of reserved SQL keywords for MySQL. +func (d *Dialect) ReservedKeywords() map[string]bool { + return reservedSQLKeywords +} + +// --- Regex --- + +// ConvertRegex converts an RE2 regex pattern to MySQL-compatible format. +func (d *Dialect) ConvertRegex(re2Pattern string) (string, bool, error) { + return convertRE2ToMySQL(re2Pattern) +} + +// SupportsRegex returns true as MySQL 8.0+ supports ICU regex. +func (d *Dialect) SupportsRegex() bool { return true } + +// --- Capabilities --- + +// SupportsNativeArrays returns false as MySQL uses JSON arrays. +func (d *Dialect) SupportsNativeArrays() bool { return false } + +// SupportsJSONB returns false as MySQL has a single JSON type. +func (d *Dialect) SupportsJSONB() bool { return false } + +// SupportsIndexAnalysis returns true as MySQL index analysis is supported. +func (d *Dialect) SupportsIndexAnalysis() bool { return true } + +// --- Internal helpers --- + +// escapeJSONFieldName escapes special characters in JSON field names for MySQL. +func escapeJSONFieldName(fieldName string) string { + return strings.ReplaceAll(fieldName, "'", "''") +} diff --git a/dialect/mysql/index_advisor.go b/dialect/mysql/index_advisor.go new file mode 100644 index 0000000..4ffe9df --- /dev/null +++ b/dialect/mysql/index_advisor.go @@ -0,0 +1,93 @@ +package mysql + +import ( + "fmt" + "strings" + + "github.com/spandigital/cel2sql/v3/dialect" +) + +// MySQL index type constants. +const ( + IndexTypeBTree = "BTREE" + IndexTypeFullText = "FULLTEXT" +) + +// RecommendIndex generates a MySQL-specific index recommendation for the given pattern. +// Returns nil if no applicable index exists for this pattern. +func (d *Dialect) RecommendIndex(pattern dialect.IndexPattern) *dialect.IndexRecommendation { + table := pattern.TableHint + if table == "" { + table = "table_name" + } + col := pattern.Column + safeName := sanitizeIndexName(col) + + switch pattern.Pattern { + case dialect.PatternComparison: + return &dialect.IndexRecommendation{ + Column: col, + IndexType: IndexTypeBTree, + Expression: fmt.Sprintf("CREATE INDEX idx_%s_btree ON %s (%s);", + safeName, table, col), + Reason: fmt.Sprintf("Comparison operations on '%s' benefit from B-tree index for efficient range queries and equality checks", col), + } + + case dialect.PatternJSONAccess: + return &dialect.IndexRecommendation{ + Column: col, + IndexType: IndexTypeBTree, + Expression: fmt.Sprintf("CREATE INDEX idx_%s_json ON %s ((CAST(%s->>'$.path' AS CHAR(255))));", + safeName, table, col), + Reason: fmt.Sprintf("JSON field access on '%s' benefits from a functional B-tree index on extracted JSON paths", col), + } + + case dialect.PatternRegexMatch: + return &dialect.IndexRecommendation{ + Column: col, + IndexType: IndexTypeFullText, + Expression: fmt.Sprintf("CREATE FULLTEXT INDEX idx_%s_fulltext ON %s (%s);", + safeName, table, col), + Reason: fmt.Sprintf("Regex matching on '%s' may benefit from FULLTEXT index for text search patterns", col), + } + + case dialect.PatternArrayMembership, dialect.PatternArrayComprehension: + // MySQL does not have native array types; skip + return nil + + case dialect.PatternJSONArrayComprehension: + return &dialect.IndexRecommendation{ + Column: col, + IndexType: IndexTypeBTree, + Expression: fmt.Sprintf("CREATE INDEX idx_%s_json ON %s ((CAST(%s->>'$.path' AS CHAR(255))));", + safeName, table, col), + Reason: fmt.Sprintf("JSON array operations on '%s' may benefit from a functional index on extracted JSON values", col), + } + } + + return nil +} + +// SupportedPatterns returns the pattern types supported by MySQL. +func (d *Dialect) SupportedPatterns() []dialect.PatternType { + return []dialect.PatternType{ + dialect.PatternComparison, + dialect.PatternJSONAccess, + dialect.PatternRegexMatch, + dialect.PatternJSONArrayComprehension, + } +} + +// sanitizeIndexName creates a safe index name from a column name. +func sanitizeIndexName(column string) string { + sanitized := strings.ReplaceAll(column, ".", "_") + sanitized = strings.ReplaceAll(sanitized, " ", "_") + sanitized = strings.ReplaceAll(sanitized, "-", "_") + + // MySQL index names are limited to 64 characters + if len(sanitized) > 50 { + sanitized = sanitized[:50] + } + + return sanitized +} diff --git a/dialect/mysql/regex.go b/dialect/mysql/regex.go new file mode 100644 index 0000000..7965097 --- /dev/null +++ b/dialect/mysql/regex.go @@ -0,0 +1,139 @@ +package mysql + +import ( + "errors" + "fmt" + "regexp" + "strings" +) + +// Regex pattern complexity limits to prevent ReDoS attacks (CWE-1333). +const ( + maxRegexPatternLength = 500 + maxRegexGroups = 20 + maxRegexNestingDepth = 10 +) + +// convertRE2ToMySQL converts an RE2 regex pattern to MySQL-compatible format. +// MySQL 8.0+ uses ICU regex which supports most RE2 features. +// Returns the converted pattern, whether it's case-insensitive, and any error. +func convertRE2ToMySQL(re2Pattern string) (string, bool, error) { + // 1. Pattern length validation + if len(re2Pattern) > maxRegexPatternLength { + return "", false, fmt.Errorf("pattern length %d exceeds limit of %d characters", len(re2Pattern), maxRegexPatternLength) + } + + // 2. Validate pattern compiles + if _, err := regexp.Compile(re2Pattern); err != nil { + return "", false, fmt.Errorf("invalid regex pattern: %w", err) + } + + // 3. Detect unsupported RE2 features + if strings.Contains(re2Pattern, "(?=") || strings.Contains(re2Pattern, "(?!") { + return "", false, errors.New("lookahead assertions (?=...), (?!...) are not supported in MySQL regex") + } + if strings.Contains(re2Pattern, "(?<=") || strings.Contains(re2Pattern, "(?...) are not supported in MySQL regex") + } + + // 4. Detect catastrophic nested quantifiers + if matched, _ := regexp.MatchString(`[*+][*+]`, re2Pattern); matched { + return "", false, errors.New("regex contains catastrophic nested quantifiers that could cause ReDoS") + } + + // 5. Check for nested quantifiers in groups + depth := 0 + groupHasQuantifier := make([]bool, 0) + for i := 0; i < len(re2Pattern); i++ { + char := re2Pattern[i] + if i > 0 && re2Pattern[i-1] == '\\' { + continue + } + switch char { + case '(': + depth++ + groupHasQuantifier = append(groupHasQuantifier, false) + case ')': + if depth > 0 { + depth-- + if i+1 < len(re2Pattern) { + nextChar := re2Pattern[i+1] + if nextChar == '*' || nextChar == '+' || nextChar == '?' || nextChar == '{' { + if len(groupHasQuantifier) > 0 && groupHasQuantifier[len(groupHasQuantifier)-1] { + return "", false, errors.New("regex contains catastrophic nested quantifiers that could cause ReDoS") + } + } + } + if len(groupHasQuantifier) > 0 { + groupHasQuantifier = groupHasQuantifier[:len(groupHasQuantifier)-1] + } + } + case '*', '+', '?', '{': + for j := range groupHasQuantifier { + groupHasQuantifier[j] = true + } + } + } + + // 6. Check group count limit + groupCount := strings.Count(re2Pattern, "(") - strings.Count(re2Pattern, "\\(") + if groupCount > maxRegexGroups { + return "", false, fmt.Errorf("regex contains %d capture groups, exceeds limit of %d", groupCount, maxRegexGroups) + } + + // 7. Check for quantified alternation + quantifiedAlternation := regexp.MustCompile(`\([^)]*\|[^)]*\)[*+]`) + if quantifiedAlternation.MatchString(re2Pattern) { + return "", false, errors.New("regex contains quantified alternation that could cause ReDoS") + } + + // 8. Check nesting depth + maxDepthVal := 0 + currentDepth := 0 + for i := 0; i < len(re2Pattern); i++ { + if i > 0 && re2Pattern[i-1] == '\\' { + continue + } + switch re2Pattern[i] { + case '(': + currentDepth++ + if currentDepth > maxDepthVal { + maxDepthVal = currentDepth + } + case ')': + if currentDepth > 0 { + currentDepth-- + } + } + } + if maxDepthVal > maxRegexNestingDepth { + return "", false, fmt.Errorf("nesting depth %d exceeds limit of %d", maxDepthVal, maxRegexNestingDepth) + } + + // Process pattern: extract case-insensitivity, convert features + caseInsensitive := false + pattern := re2Pattern + + // Handle (?i) flag + if strings.HasPrefix(pattern, "(?i)") { + caseInsensitive = true + pattern = pattern[4:] + } + + // Handle inline flags other than (?i) at start + if strings.Contains(pattern, "(?m") || strings.Contains(pattern, "(?s") || strings.Contains(pattern, "(?-") { + return "", false, errors.New("inline flags other than (?i) are not supported in MySQL regex") + } + + // Convert non-capturing groups (?:...) to regular groups (...) + pattern = strings.ReplaceAll(pattern, "(?:", "(") + + // MySQL ICU regex supports \d, \w, \s natively - no conversion needed + // Convert \b word boundary to MySQL's \b (same syntax in ICU) + // No conversion needed for MySQL 8.0+ + + return pattern, caseInsensitive, nil +} diff --git a/dialect/mysql/validation.go b/dialect/mysql/validation.go new file mode 100644 index 0000000..15a20d4 --- /dev/null +++ b/dialect/mysql/validation.go @@ -0,0 +1,91 @@ +package mysql + +import ( + "errors" + "fmt" + "regexp" + "strings" +) + +const ( + // maxMySQLIdentifierLength is the maximum length for MySQL identifiers. + maxMySQLIdentifierLength = 64 +) + +var ( + // fieldNameRegexp validates MySQL identifier format. + fieldNameRegexp = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) + + // reservedSQLKeywords contains MySQL reserved keywords. + reservedSQLKeywords = map[string]bool{ + "accessible": true, "add": true, "all": true, "alter": true, "analyze": true, + "and": true, "as": true, "asc": true, "asensitive": true, "before": true, + "between": true, "bigint": true, "binary": true, "blob": true, "both": true, + "by": true, "call": true, "cascade": true, "case": true, "change": true, + "char": true, "character": true, "check": true, "collate": true, "column": true, + "condition": true, "constraint": true, "continue": true, "convert": true, + "create": true, "cross": true, "current_date": true, "current_time": true, + "current_timestamp": true, "current_user": true, "cursor": true, "database": true, + "databases": true, "day_hour": true, "day_microsecond": true, "day_minute": true, + "day_second": true, "dec": true, "decimal": true, "declare": true, "default": true, + "delayed": true, "delete": true, "desc": true, "describe": true, "deterministic": true, + "distinct": true, "distinctrow": true, "div": true, "double": true, "drop": true, + "dual": true, "each": true, "else": true, "elseif": true, "enclosed": true, + "escaped": true, "exists": true, "exit": true, "explain": true, "false": true, + "fetch": true, "float": true, "float4": true, "float8": true, "for": true, + "force": true, "foreign": true, "from": true, "fulltext": true, "grant": true, + "group": true, "having": true, "high_priority": true, "hour_microsecond": true, + "hour_minute": true, "hour_second": true, "if": true, "ignore": true, "in": true, + "index": true, "infile": true, "inner": true, "inout": true, "insensitive": true, + "insert": true, "int": true, "int1": true, "int2": true, "int3": true, + "int4": true, "int8": true, "integer": true, "interval": true, "into": true, + "is": true, "iterate": true, "join": true, "key": true, "keys": true, + "kill": true, "leading": true, "leave": true, "left": true, "like": true, + "limit": true, "linear": true, "lines": true, "load": true, "localtime": true, + "localtimestamp": true, "lock": true, "long": true, "longblob": true, + "longtext": true, "loop": true, "low_priority": true, "match": true, + "mediumblob": true, "mediumint": true, "mediumtext": true, "middleint": true, + "minute_microsecond": true, "minute_second": true, "mod": true, "modifies": true, + "natural": true, "not": true, "null": true, "numeric": true, "on": true, + "optimize": true, "option": true, "optionally": true, "or": true, "order": true, + "out": true, "outer": true, "outfile": true, "precision": true, "primary": true, + "procedure": true, "purge": true, "range": true, "read": true, "reads": true, + "real": true, "references": true, "regexp": true, "release": true, "rename": true, + "repeat": true, "replace": true, "require": true, "restrict": true, "return": true, + "revoke": true, "right": true, "rlike": true, "schema": true, "schemas": true, + "second_microsecond": true, "select": true, "sensitive": true, "separator": true, + "set": true, "show": true, "signal": true, "smallint": true, "spatial": true, + "specific": true, "sql": true, "sqlexception": true, "sqlstate": true, + "sqlwarning": true, "sql_big_result": true, "sql_calc_found_rows": true, + "sql_small_result": true, "ssl": true, "starting": true, "straight_join": true, + "table": true, "terminated": true, "then": true, "tinyblob": true, "tinyint": true, + "tinytext": true, "to": true, "trailing": true, "trigger": true, "true": true, + "undo": true, "union": true, "unique": true, "unlock": true, "unsigned": true, + "update": true, "usage": true, "use": true, "using": true, "utc_date": true, + "utc_time": true, "utc_timestamp": true, "values": true, "varbinary": true, + "varchar": true, "varcharacter": true, "varying": true, "when": true, + "where": true, "while": true, "with": true, "write": true, "xor": true, + "year_month": true, "zerofill": true, + } +) + +// validateFieldName validates that a field name follows MySQL naming conventions. +func validateFieldName(name string) error { + if len(name) == 0 { + return errors.New("field name cannot be empty") + } + + if len(name) > maxMySQLIdentifierLength { + return fmt.Errorf("field name %q exceeds MySQL maximum identifier length of %d characters", name, maxMySQLIdentifierLength) + } + + if !fieldNameRegexp.MatchString(name) { + return fmt.Errorf("field name %q must start with a letter or underscore and contain only alphanumeric characters and underscores", name) + } + + if reservedSQLKeywords[strings.ToLower(name)] { + return fmt.Errorf("field name %q is a reserved SQL keyword and cannot be used without quoting", name) + } + + return nil +} diff --git a/dialect/postgres/dialect.go b/dialect/postgres/dialect.go new file mode 100644 index 0000000..8aee559 --- /dev/null +++ b/dialect/postgres/dialect.go @@ -0,0 +1,496 @@ +// Package postgres implements the PostgreSQL SQL dialect for cel2sql. +package postgres + +import ( + "encoding/hex" + "fmt" + "strings" + + "github.com/spandigital/cel2sql/v3/dialect" +) + +// Dialect implements dialect.Dialect for PostgreSQL. +type Dialect struct{} + +// New creates a new PostgreSQL dialect. +func New() *Dialect { + return &Dialect{} +} + +func init() { + dialect.Register(dialect.PostgreSQL, func() dialect.Dialect { return New() }) +} + +// Ensure Dialect implements dialect.Dialect at compile time. +var _ dialect.Dialect = (*Dialect)(nil) + +// Name returns the dialect name. +func (d *Dialect) Name() dialect.Name { return dialect.PostgreSQL } + +// --- Literals --- + +// WriteStringLiteral writes a PostgreSQL string literal with ” escaping. +func (d *Dialect) WriteStringLiteral(w *strings.Builder, value string) { + escaped := strings.ReplaceAll(value, "'", "''") + w.WriteString("'") + w.WriteString(escaped) + w.WriteString("'") +} + +// WriteBytesLiteral writes a PostgreSQL hex-encoded byte literal. +func (d *Dialect) WriteBytesLiteral(w *strings.Builder, value []byte) error { + w.WriteString("'\\x") + w.WriteString(hex.EncodeToString(value)) + w.WriteString("'") + return nil +} + +// WriteParamPlaceholder writes a PostgreSQL positional parameter ($1, $2, ...). +func (d *Dialect) WriteParamPlaceholder(w *strings.Builder, paramIndex int) { + fmt.Fprintf(w, "$%d", paramIndex) +} + +// --- Operators --- + +// WriteStringConcat writes a PostgreSQL string concatenation using ||. +func (d *Dialect) WriteStringConcat(w *strings.Builder, writeLHS, writeRHS func() error) error { + if err := writeLHS(); err != nil { + return err + } + w.WriteString(" || ") + return writeRHS() +} + +// WriteRegexMatch writes a PostgreSQL regex match using ~ or ~* operators. +func (d *Dialect) WriteRegexMatch(w *strings.Builder, writeTarget func() error, pattern string, caseInsensitive bool) error { + if err := writeTarget(); err != nil { + return err + } + if caseInsensitive { + w.WriteString(" ~* ") + } else { + w.WriteString(" ~ ") + } + escaped := strings.ReplaceAll(pattern, "'", "''") + w.WriteString("'") + w.WriteString(escaped) + w.WriteString("'") + return nil +} + +// WriteLikeEscape writes the PostgreSQL LIKE escape clause. +func (d *Dialect) WriteLikeEscape(w *strings.Builder) { + w.WriteString(" ESCAPE E'\\\\'") +} + +// WriteArrayMembership writes a PostgreSQL array membership test using = ANY(). +func (d *Dialect) WriteArrayMembership(w *strings.Builder, writeElem func() error, writeArray func() error) error { + if err := writeElem(); err != nil { + return err + } + w.WriteString(" = ANY(") + if err := writeArray(); err != nil { + return err + } + w.WriteString(")") + return nil +} + +// --- Type Casting --- + +// WriteCastToNumeric writes a PostgreSQL numeric cast suffix (::numeric). +func (d *Dialect) WriteCastToNumeric(w *strings.Builder) { + w.WriteString("::numeric") +} + +// WriteTypeName writes a PostgreSQL type name for CAST expressions. +func (d *Dialect) WriteTypeName(w *strings.Builder, celTypeName string) { + switch celTypeName { + case "bool": + w.WriteString("BOOLEAN") + case "bytes": + w.WriteString("BYTEA") + case "double": + w.WriteString("DOUBLE PRECISION") + case "int": + w.WriteString("BIGINT") + case "string": + w.WriteString("TEXT") + case "uint": + w.WriteString("BIGINT") + default: + w.WriteString(strings.ToUpper(celTypeName)) + } +} + +// WriteEpochExtract writes EXTRACT(EPOCH FROM expr)::bigint for PostgreSQL. +func (d *Dialect) WriteEpochExtract(w *strings.Builder, writeExpr func() error) error { + w.WriteString("EXTRACT(EPOCH FROM ") + if err := writeExpr(); err != nil { + return err + } + w.WriteString(")::bigint") + return nil +} + +// WriteTimestampCast writes CAST(expr AS TIMESTAMP WITH TIME ZONE) for PostgreSQL. +func (d *Dialect) WriteTimestampCast(w *strings.Builder, writeExpr func() error) error { + w.WriteString("CAST(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString(" AS TIMESTAMP WITH TIME ZONE)") + return nil +} + +// --- Arrays --- + +// WriteArrayLiteralOpen writes the PostgreSQL array literal opening (ARRAY[). +func (d *Dialect) WriteArrayLiteralOpen(w *strings.Builder) { + w.WriteString("ARRAY[") +} + +// WriteArrayLiteralClose writes the PostgreSQL array literal closing (]). +func (d *Dialect) WriteArrayLiteralClose(w *strings.Builder) { + w.WriteString("]") +} + +// WriteArrayLength writes COALESCE(ARRAY_LENGTH(expr, dimension), 0) for PostgreSQL. +func (d *Dialect) WriteArrayLength(w *strings.Builder, dimension int, writeExpr func() error) error { + w.WriteString("COALESCE(ARRAY_LENGTH(") + if err := writeExpr(); err != nil { + return err + } + fmt.Fprintf(w, ", %d), 0)", dimension) + return nil +} + +// WriteListIndex writes a PostgreSQL 1-indexed array access (array[index + 1]). +func (d *Dialect) WriteListIndex(w *strings.Builder, writeArray func() error, writeIndex func() error) error { + if err := writeArray(); err != nil { + return err + } + w.WriteString("[") + if err := writeIndex(); err != nil { + return err + } + w.WriteString(" + 1]") + return nil +} + +// WriteListIndexConst writes a PostgreSQL constant array index (0-indexed to 1-indexed). +func (d *Dialect) WriteListIndexConst(w *strings.Builder, writeArray func() error, index int64) error { + if err := writeArray(); err != nil { + return err + } + fmt.Fprintf(w, "[%d]", index+1) + return nil +} + +// WriteEmptyTypedArray writes an empty PostgreSQL typed array (ARRAY[]::type[]). +func (d *Dialect) WriteEmptyTypedArray(w *strings.Builder, typeName string) { + fmt.Fprintf(w, "ARRAY[]::%s[]", typeName) +} + +// --- JSON --- + +// WriteJSONFieldAccess writes PostgreSQL JSON field access using -> or ->> operators. +func (d *Dialect) WriteJSONFieldAccess(w *strings.Builder, writeBase func() error, fieldName string, isFinal bool) error { + if err := writeBase(); err != nil { + return err + } + escapedField := escapeJSONFieldName(fieldName) + if isFinal { + w.WriteString("->>'") + } else { + w.WriteString("->'") + } + w.WriteString(escapedField) + w.WriteString("'") + return nil +} + +// WriteJSONExistence writes a PostgreSQL JSON key existence check (? or IS NOT NULL). +func (d *Dialect) WriteJSONExistence(w *strings.Builder, isJSONB bool, fieldName string, writeBase func() error) error { + if err := writeBase(); err != nil { + return err + } + escapedField := escapeJSONFieldName(fieldName) + if isJSONB { + w.WriteString(" ? '") + w.WriteString(escapedField) + w.WriteString("'") + } else { + w.WriteString("->'") + w.WriteString(escapedField) + w.WriteString("' IS NOT NULL") + } + return nil +} + +// WriteJSONArrayElements writes a PostgreSQL JSON array expansion function. +func (d *Dialect) WriteJSONArrayElements(w *strings.Builder, isJSONB bool, asText bool, writeExpr func() error) error { + if isJSONB { + if asText { + w.WriteString("jsonb_array_elements_text(") + } else { + w.WriteString("jsonb_array_elements(") + } + } else { + if asText { + w.WriteString("json_array_elements_text(") + } else { + w.WriteString("json_array_elements(") + } + } + if err := writeExpr(); err != nil { + return err + } + w.WriteString(")") + return nil +} + +// WriteJSONArrayLength writes COALESCE(jsonb_array_length(expr), 0) for PostgreSQL. +func (d *Dialect) WriteJSONArrayLength(w *strings.Builder, writeExpr func() error) error { + w.WriteString("COALESCE(jsonb_array_length(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString("), 0)") + return nil +} + +// WriteJSONExtractPath writes jsonb_extract_path_text() IS NOT NULL for PostgreSQL. +func (d *Dialect) WriteJSONExtractPath(w *strings.Builder, pathSegments []string, writeRoot func() error) error { + w.WriteString("jsonb_extract_path_text(") + if err := writeRoot(); err != nil { + return err + } + for _, segment := range pathSegments { + w.WriteString(", '") + w.WriteString(escapeJSONFieldName(segment)) + w.WriteString("'") + } + w.WriteString(") IS NOT NULL") + return nil +} + +// WriteJSONArrayMembership writes ANY(ARRAY(SELECT json_func(expr))) for PostgreSQL. +func (d *Dialect) WriteJSONArrayMembership(w *strings.Builder, jsonFunc string, writeExpr func() error) error { + w.WriteString("ANY(ARRAY(SELECT ") + w.WriteString(jsonFunc) + w.WriteString("(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString(")))") + return nil +} + +// WriteNestedJSONArrayMembership writes ANY(ARRAY(SELECT jsonb_array_elements_text(expr))) for PostgreSQL. +func (d *Dialect) WriteNestedJSONArrayMembership(w *strings.Builder, writeExpr func() error) error { + w.WriteString("ANY(ARRAY(SELECT jsonb_array_elements_text(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString(")))") + return nil +} + +// --- Timestamps --- + +// WriteDuration writes a PostgreSQL INTERVAL literal (INTERVAL N UNIT). +func (d *Dialect) WriteDuration(w *strings.Builder, value int64, unit string) { + fmt.Fprintf(w, "INTERVAL %d %s", value, unit) +} + +// WriteInterval writes a PostgreSQL INTERVAL expression (INTERVAL expr UNIT). +func (d *Dialect) WriteInterval(w *strings.Builder, writeValue func() error, unit string) error { + w.WriteString("INTERVAL ") + if err := writeValue(); err != nil { + return err + } + w.WriteString(" ") + w.WriteString(unit) + return nil +} + +// WriteExtract writes a PostgreSQL EXTRACT expression with DOW conversion. +func (d *Dialect) WriteExtract(w *strings.Builder, part string, writeExpr func() error, writeTZ func() error) error { + // For getDayOfWeek, we need to wrap the entire EXTRACT for modulo operation + isDOW := part == "DOW" + if isDOW { + w.WriteString("(") + } + w.WriteString("EXTRACT(") + w.WriteString(part) + w.WriteString(" FROM ") + if err := writeExpr(); err != nil { + return err + } + if writeTZ != nil { + w.WriteString(" AT TIME ZONE ") + if err := writeTZ(); err != nil { + return err + } + } + w.WriteString(")") + if isDOW { + // PostgreSQL DOW: 0=Sunday, 1=Monday, ..., 6=Saturday + // CEL getDayOfWeek: 0=Monday, 1=Tuesday, ..., 6=Sunday (ISO 8601) + // Convert: (DOW + 6) % 7 + w.WriteString(" + 6) % 7") + } + return nil +} + +// WriteTimestampArithmetic writes PostgreSQL timestamp arithmetic (timestamp +/- interval). +func (d *Dialect) WriteTimestampArithmetic(w *strings.Builder, op string, writeTS, writeDur func() error) error { + if err := writeTS(); err != nil { + return err + } + w.WriteString(" ") + w.WriteString(op) + w.WriteString(" ") + return writeDur() +} + +// --- String Functions --- + +// WriteContains writes POSITION(needle IN haystack) > 0 for PostgreSQL. +func (d *Dialect) WriteContains(w *strings.Builder, writeHaystack, writeNeedle func() error) error { + w.WriteString("POSITION(") + if err := writeNeedle(); err != nil { + return err + } + w.WriteString(" IN ") + if err := writeHaystack(); err != nil { + return err + } + w.WriteString(") > 0") + return nil +} + +// WriteSplit writes STRING_TO_ARRAY(string, delimiter) for PostgreSQL. +func (d *Dialect) WriteSplit(w *strings.Builder, writeStr, writeDelim func() error) error { + w.WriteString("STRING_TO_ARRAY(") + if err := writeStr(); err != nil { + return err + } + w.WriteString(", ") + if err := writeDelim(); err != nil { + return err + } + w.WriteString(")") + return nil +} + +// WriteSplitWithLimit writes (STRING_TO_ARRAY(string, delimiter))[1:limit] for PostgreSQL. +func (d *Dialect) WriteSplitWithLimit(w *strings.Builder, writeStr, writeDelim func() error, limit int64) error { + w.WriteString("(STRING_TO_ARRAY(") + if err := writeStr(); err != nil { + return err + } + w.WriteString(", ") + if err := writeDelim(); err != nil { + return err + } + fmt.Fprintf(w, "))[1:%d]", limit) + return nil +} + +// WriteJoin writes ARRAY_TO_STRING(array, delimiter, ”) for PostgreSQL. +func (d *Dialect) WriteJoin(w *strings.Builder, writeArray, writeDelim func() error) error { + w.WriteString("ARRAY_TO_STRING(") + if err := writeArray(); err != nil { + return err + } + w.WriteString(", ") + if writeDelim != nil { + if err := writeDelim(); err != nil { + return err + } + } else { + w.WriteString("''") + } + w.WriteString(", '')") + return nil +} + +// --- Comprehensions --- + +// WriteUnnest writes UNNEST(source) for PostgreSQL comprehensions. +func (d *Dialect) WriteUnnest(w *strings.Builder, writeSource func() error) error { + w.WriteString("UNNEST(") + if err := writeSource(); err != nil { + return err + } + w.WriteString(")") + return nil +} + +// WriteArraySubqueryOpen writes ARRAY(SELECT for PostgreSQL. +func (d *Dialect) WriteArraySubqueryOpen(w *strings.Builder) { + w.WriteString("ARRAY(SELECT ") +} + +// WriteArraySubqueryExprClose is a no-op for PostgreSQL (no wrapper around the expression). +func (d *Dialect) WriteArraySubqueryExprClose(_ *strings.Builder) { +} + +// --- Struct --- + +// WriteStructOpen writes the PostgreSQL struct/row literal opening (ROW(). +func (d *Dialect) WriteStructOpen(w *strings.Builder) { + w.WriteString("ROW(") +} + +// WriteStructClose writes the PostgreSQL struct/row literal closing ()). +func (d *Dialect) WriteStructClose(w *strings.Builder) { + w.WriteString(")") +} + +// --- Validation --- + +// MaxIdentifierLength returns the PostgreSQL maximum identifier length (63). +func (d *Dialect) MaxIdentifierLength() int { + return maxPostgreSQLIdentifierLength +} + +// ValidateFieldName validates a field name against PostgreSQL naming rules. +func (d *Dialect) ValidateFieldName(name string) error { + return validateFieldName(name) +} + +// ReservedKeywords returns the set of reserved SQL keywords for PostgreSQL. +func (d *Dialect) ReservedKeywords() map[string]bool { + return reservedSQLKeywords +} + +// --- Regex --- + +// ConvertRegex converts an RE2 regex pattern to PostgreSQL POSIX format. +func (d *Dialect) ConvertRegex(re2Pattern string) (string, bool, error) { + return convertRE2ToPOSIX(re2Pattern) +} + +// SupportsRegex returns true as PostgreSQL supports POSIX regex matching. +func (d *Dialect) SupportsRegex() bool { return true } + +// --- Capabilities --- + +// SupportsNativeArrays returns true as PostgreSQL has native array types. +func (d *Dialect) SupportsNativeArrays() bool { return true } + +// SupportsJSONB returns true as PostgreSQL has a distinct JSONB type. +func (d *Dialect) SupportsJSONB() bool { return true } + +// SupportsIndexAnalysis returns true as PostgreSQL index analysis is supported. +func (d *Dialect) SupportsIndexAnalysis() bool { return true } + +// --- Internal helpers --- + +// escapeJSONFieldName escapes single quotes in JSON field names for safe use in PostgreSQL JSON path operators. +func escapeJSONFieldName(fieldName string) string { + return strings.ReplaceAll(fieldName, "'", "''") +} diff --git a/dialect/postgres/index_advisor.go b/dialect/postgres/index_advisor.go new file mode 100644 index 0000000..286d24b --- /dev/null +++ b/dialect/postgres/index_advisor.go @@ -0,0 +1,110 @@ +package postgres + +import ( + "fmt" + "strings" + + "github.com/spandigital/cel2sql/v3/dialect" +) + +// PostgreSQL index type constants. +const ( + IndexTypeBTree = "BTREE" + IndexTypeGIN = "GIN" + IndexTypeGIST = "GIST" +) + +// RecommendIndex generates a PostgreSQL-specific index recommendation for the given pattern. +// Returns nil if no applicable index exists for this pattern. +func (d *Dialect) RecommendIndex(pattern dialect.IndexPattern) *dialect.IndexRecommendation { + table := pattern.TableHint + if table == "" { + table = "table_name" + } + col := pattern.Column + safeName := sanitizeIndexName(col) + + switch pattern.Pattern { + case dialect.PatternComparison: + return &dialect.IndexRecommendation{ + Column: col, + IndexType: IndexTypeBTree, + Expression: fmt.Sprintf("CREATE INDEX idx_%s_btree ON %s (%s);", + safeName, table, col), + Reason: fmt.Sprintf("Comparison operations on '%s' benefit from B-tree index for efficient range queries and equality checks", col), + } + + case dialect.PatternJSONAccess: + return &dialect.IndexRecommendation{ + Column: col, + IndexType: IndexTypeGIN, + Expression: fmt.Sprintf("CREATE INDEX idx_%s_gin ON %s USING GIN (%s);", + safeName, table, col), + Reason: fmt.Sprintf("JSON path operations on '%s' benefit from GIN index for efficient nested field access", col), + } + + case dialect.PatternRegexMatch: + return &dialect.IndexRecommendation{ + Column: col, + IndexType: IndexTypeGIN, + Expression: fmt.Sprintf("CREATE INDEX idx_%s_gin_trgm ON %s USING GIN (%s gin_trgm_ops);", + safeName, table, col), + Reason: fmt.Sprintf("Regex matching on '%s' benefits from GIN index with pg_trgm extension for pattern matching", col), + } + + case dialect.PatternArrayMembership: + return &dialect.IndexRecommendation{ + Column: col, + IndexType: IndexTypeGIN, + Expression: fmt.Sprintf("CREATE INDEX idx_%s_gin ON %s USING GIN (%s);", + safeName, table, col), + Reason: fmt.Sprintf("Array membership tests on '%s' benefit from GIN index for efficient element lookups", col), + } + + case dialect.PatternArrayComprehension: + return &dialect.IndexRecommendation{ + Column: col, + IndexType: IndexTypeGIN, + Expression: fmt.Sprintf("CREATE INDEX idx_%s_gin ON %s USING GIN (%s);", + safeName, table, col), + Reason: fmt.Sprintf("Array comprehension on '%s' benefits from GIN index for efficient array operations", col), + } + + case dialect.PatternJSONArrayComprehension: + return &dialect.IndexRecommendation{ + Column: col, + IndexType: IndexTypeGIN, + Expression: fmt.Sprintf("CREATE INDEX idx_%s_gin ON %s USING GIN (%s);", + safeName, table, col), + Reason: fmt.Sprintf("JSONB array comprehension on '%s' benefits from GIN index for efficient array element access", col), + } + } + + return nil +} + +// SupportedPatterns returns all pattern types supported by PostgreSQL. +func (d *Dialect) SupportedPatterns() []dialect.PatternType { + return []dialect.PatternType{ + dialect.PatternComparison, + dialect.PatternJSONAccess, + dialect.PatternRegexMatch, + dialect.PatternArrayMembership, + dialect.PatternArrayComprehension, + dialect.PatternJSONArrayComprehension, + } +} + +// sanitizeIndexName creates a safe index name from a column name. +func sanitizeIndexName(column string) string { + sanitized := strings.ReplaceAll(column, ".", "_") + sanitized = strings.ReplaceAll(sanitized, " ", "_") + sanitized = strings.ReplaceAll(sanitized, "-", "_") + + // PostgreSQL index names are limited to 63 characters + if len(sanitized) > 50 { + sanitized = sanitized[:50] + } + + return sanitized +} diff --git a/dialect/postgres/regex.go b/dialect/postgres/regex.go new file mode 100644 index 0000000..1cc4a4b --- /dev/null +++ b/dialect/postgres/regex.go @@ -0,0 +1,143 @@ +package postgres + +import ( + "errors" + "fmt" + "regexp" + "strings" +) + +// Regex pattern complexity limits to prevent ReDoS attacks (CWE-1333). +const ( + maxRegexPatternLength = 500 + maxRegexGroups = 20 + maxRegexNestingDepth = 10 +) + +// convertRE2ToPOSIX converts an RE2 regex pattern to POSIX ERE format for PostgreSQL. +// It performs security validation to prevent ReDoS attacks (CWE-1333). +// Returns: (posixPattern, caseInsensitive, error) +func convertRE2ToPOSIX(re2Pattern string) (string, bool, error) { + // 1. Check pattern length to prevent processing extremely long patterns + if len(re2Pattern) > maxRegexPatternLength { + return "", false, fmt.Errorf("pattern length %d exceeds limit of %d characters", len(re2Pattern), maxRegexPatternLength) + } + + // 2. Extract case-insensitive flag if present + caseInsensitive := false + if strings.HasPrefix(re2Pattern, "(?i)") { + caseInsensitive = true + re2Pattern = strings.TrimPrefix(re2Pattern, "(?i)") + } + + // 3. Detect unsupported RE2 features and return errors + if strings.Contains(re2Pattern, "(?=") || strings.Contains(re2Pattern, "(?!") { + return "", false, errors.New("lookahead assertions (?=...), (?!...) are not supported in PostgreSQL POSIX regex") + } + if strings.Contains(re2Pattern, "(?<=") || strings.Contains(re2Pattern, "(?...) are not supported in PostgreSQL POSIX regex") + } + if strings.Contains(re2Pattern, "(?m") || strings.Contains(re2Pattern, "(?s") || strings.Contains(re2Pattern, "(?-") { + return "", false, errors.New("inline flags other than (?i) are not supported in PostgreSQL POSIX regex") + } + + // 4. Detect catastrophic nested quantifiers + if matched, _ := regexp.MatchString(`[*+][*+]`, re2Pattern); matched { + return "", false, errors.New("regex contains catastrophic nested quantifiers that could cause ReDoS") + } + + // Check for groups that contain quantifiers and are themselves quantified + depth := 0 + groupHasQuantifier := make([]bool, 0) + + for i := 0; i < len(re2Pattern); i++ { + char := re2Pattern[i] + + // Skip escaped characters + if i > 0 && re2Pattern[i-1] == '\\' { + continue + } + + switch char { + case '(': + depth++ + groupHasQuantifier = append(groupHasQuantifier, false) + case ')': + if depth > 0 { + depth-- + if i+1 < len(re2Pattern) { + nextChar := re2Pattern[i+1] + if nextChar == '*' || nextChar == '+' || nextChar == '?' || nextChar == '{' { + if len(groupHasQuantifier) > 0 && groupHasQuantifier[len(groupHasQuantifier)-1] { + return "", false, errors.New("regex contains catastrophic nested quantifiers that could cause ReDoS") + } + } + } + if len(groupHasQuantifier) > 0 { + if len(groupHasQuantifier) > 1 { + if groupHasQuantifier[len(groupHasQuantifier)-1] { + groupHasQuantifier[len(groupHasQuantifier)-2] = true + } + } + groupHasQuantifier = groupHasQuantifier[:len(groupHasQuantifier)-1] + } + } + case '*', '+', '?': + if len(groupHasQuantifier) > 0 { + groupHasQuantifier[len(groupHasQuantifier)-1] = true + } + case '{': + if len(groupHasQuantifier) > 0 { + groupHasQuantifier[len(groupHasQuantifier)-1] = true + } + } + } + + // 5. Count and limit capture groups + groupCount := strings.Count(re2Pattern, "(") - strings.Count(re2Pattern, `\(`) + if groupCount > maxRegexGroups { + return "", false, fmt.Errorf("regex contains %d capture groups, exceeds limit of %d", groupCount, maxRegexGroups) + } + + // 6. Detect exponential alternation patterns + alternationPattern := regexp.MustCompile(`\([^)]*\|[^)]*\)[*+]`) + if alternationPattern.MatchString(re2Pattern) { + return "", false, errors.New("regex contains quantified alternation that could cause ReDoS") + } + + // 7. Check nesting depth + maxDepthVal := 0 + currentDepth := 0 + for _, char := range re2Pattern { + if char == '(' && !strings.HasSuffix(re2Pattern[:strings.LastIndex(re2Pattern, string(char))], `\`) { + currentDepth++ + if currentDepth > maxDepthVal { + maxDepthVal = currentDepth + } + } else if char == ')' && !strings.HasSuffix(re2Pattern[:strings.LastIndex(re2Pattern, string(char))], `\`) { + currentDepth-- + } + } + if maxDepthVal > maxRegexNestingDepth { + return "", false, fmt.Errorf("nesting depth %d exceeds limit of %d", maxDepthVal, maxRegexNestingDepth) + } + + // Passed all security checks - proceed with conversion + posixPattern := re2Pattern + + // Convert RE2 patterns to POSIX equivalents + posixPattern = strings.ReplaceAll(posixPattern, `\b`, `\y`) + posixPattern = strings.ReplaceAll(posixPattern, `\B`, `[^[:alnum:]_]`) + posixPattern = strings.ReplaceAll(posixPattern, `\d`, `[[:digit:]]`) + posixPattern = strings.ReplaceAll(posixPattern, `\D`, `[^[:digit:]]`) + posixPattern = strings.ReplaceAll(posixPattern, `\w`, `[[:alnum:]_]`) + posixPattern = strings.ReplaceAll(posixPattern, `\W`, `[^[:alnum:]_]`) + posixPattern = strings.ReplaceAll(posixPattern, `\s`, `[[:space:]]`) + posixPattern = strings.ReplaceAll(posixPattern, `\S`, `[^[:space:]]`) + posixPattern = strings.ReplaceAll(posixPattern, `(?:`, `(`) + + return posixPattern, caseInsensitive, nil +} diff --git a/dialect/postgres/validation.go b/dialect/postgres/validation.go new file mode 100644 index 0000000..162da68 --- /dev/null +++ b/dialect/postgres/validation.go @@ -0,0 +1,66 @@ +package postgres + +import ( + "errors" + "fmt" + "regexp" + "strings" +) + +const ( + // maxPostgreSQLIdentifierLength is the maximum length for PostgreSQL identifiers + // PostgreSQL's NAMEDATALEN is 64 bytes (including null terminator), so max usable length is 63 + maxPostgreSQLIdentifierLength = 63 +) + +var ( + // fieldNameRegexp validates PostgreSQL identifier format + fieldNameRegexp = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) + + // reservedSQLKeywords contains SQL keywords that should not be used as unquoted identifiers + reservedSQLKeywords = map[string]bool{ + "all": true, "analyse": true, "analyze": true, "and": true, "any": true, + "array": true, "as": true, "asc": true, "asymmetric": true, "both": true, + "case": true, "cast": true, "check": true, "collate": true, "column": true, + "constraint": true, "create": true, "cross": true, "current_catalog": true, + "current_date": true, "current_role": true, "current_time": true, + "current_timestamp": true, "current_user": true, "default": true, + "deferrable": true, "desc": true, "distinct": true, "do": true, "else": true, + "end": true, "except": true, "false": true, "fetch": true, "for": true, + "foreign": true, "from": true, "grant": true, "group": true, "having": true, + "in": true, "initially": true, "inner": true, "intersect": true, "into": true, + "is": true, "join": true, "leading": true, "left": true, "like": true, + "limit": true, "localtime": true, "localtimestamp": true, "natural": true, + "not": true, "null": true, "offset": true, "on": true, "only": true, + "or": true, "order": true, "outer": true, "overlaps": true, "placing": true, + "primary": true, "references": true, "returning": true, "right": true, + "select": true, "session_user": true, "similar": true, "some": true, + "symmetric": true, "table": true, "then": true, "to": true, "trailing": true, + "true": true, "union": true, "unique": true, "user": true, "using": true, + "variadic": true, "when": true, "where": true, "window": true, "with": true, + // Additional keywords that commonly cause issues + "alter": true, "delete": true, "drop": true, "insert": true, "update": true, + } +) + +// validateFieldName validates that a field name follows PostgreSQL naming conventions +// and is safe to use in SQL queries without quoting. +func validateFieldName(name string) error { + if len(name) == 0 { + return errors.New("field name cannot be empty") + } + + if len(name) > maxPostgreSQLIdentifierLength { + return fmt.Errorf("field name %q exceeds PostgreSQL maximum identifier length of %d characters", name, maxPostgreSQLIdentifierLength) + } + + if !fieldNameRegexp.MatchString(name) { + return fmt.Errorf("field name %q must start with a letter or underscore and contain only alphanumeric characters and underscores", name) + } + + if reservedSQLKeywords[strings.ToLower(name)] { + return fmt.Errorf("field name %q is a reserved SQL keyword and cannot be used without quoting", name) + } + + return nil +} diff --git a/dialect/registry.go b/dialect/registry.go new file mode 100644 index 0000000..095af2d --- /dev/null +++ b/dialect/registry.go @@ -0,0 +1,42 @@ +package dialect + +import ( + "fmt" + "sync" +) + +var ( + registryMu sync.RWMutex + registry = make(map[Name]func() Dialect) +) + +// Register registers a dialect factory function by name. +// This is typically called in an init() function by each dialect package. +func Register(name Name, factory func() Dialect) { + registryMu.Lock() + defer registryMu.Unlock() + registry[name] = factory +} + +// Get returns a new dialect instance by name. +// Returns an error if the dialect is not registered. +func Get(name Name) (Dialect, error) { + registryMu.RLock() + defer registryMu.RUnlock() + factory, ok := registry[name] + if !ok { + return nil, fmt.Errorf("%w: dialect %q is not registered", ErrUnsupportedFeature, name) + } + return factory(), nil +} + +// Registered returns the names of all registered dialects. +func Registered() []Name { + registryMu.RLock() + defer registryMu.RUnlock() + names := make([]Name, 0, len(registry)) + for name := range registry { + names = append(names, name) + } + return names +} diff --git a/dialect/sqlite/dialect.go b/dialect/sqlite/dialect.go new file mode 100644 index 0000000..5e97d1f --- /dev/null +++ b/dialect/sqlite/dialect.go @@ -0,0 +1,462 @@ +// Package sqlite implements the SQLite SQL dialect for cel2sql. +package sqlite + +import ( + "fmt" + "strings" + + "github.com/spandigital/cel2sql/v3/dialect" +) + +// Dialect implements dialect.Dialect for SQLite. +type Dialect struct{} + +// New creates a new SQLite dialect. +func New() *Dialect { + return &Dialect{} +} + +func init() { + dialect.Register(dialect.SQLite, func() dialect.Dialect { return New() }) +} + +// Ensure Dialect implements dialect.Dialect at compile time. +var _ dialect.Dialect = (*Dialect)(nil) + +// Name returns the dialect name. +func (d *Dialect) Name() dialect.Name { return dialect.SQLite } + +// --- Literals --- + +// WriteStringLiteral writes a SQLite string literal with ” escaping. +func (d *Dialect) WriteStringLiteral(w *strings.Builder, value string) { + escaped := strings.ReplaceAll(value, "'", "''") + w.WriteString("'") + w.WriteString(escaped) + w.WriteString("'") +} + +// WriteBytesLiteral writes a SQLite hex-encoded byte literal (X'...'). +func (d *Dialect) WriteBytesLiteral(w *strings.Builder, value []byte) error { + w.WriteString("X'") + for _, b := range value { + fmt.Fprintf(w, "%02x", b) + } + w.WriteString("'") + return nil +} + +// WriteParamPlaceholder writes a SQLite positional parameter (?). +func (d *Dialect) WriteParamPlaceholder(w *strings.Builder, _ int) { + w.WriteString("?") +} + +// --- Operators --- + +// WriteStringConcat writes SQLite string concatenation using ||. +func (d *Dialect) WriteStringConcat(w *strings.Builder, writeLHS, writeRHS func() error) error { + if err := writeLHS(); err != nil { + return err + } + w.WriteString(" || ") + return writeRHS() +} + +// WriteRegexMatch returns an error as SQLite does not natively support regex. +func (d *Dialect) WriteRegexMatch(_ *strings.Builder, _ func() error, _ string, _ bool) error { + return fmt.Errorf("%w: regex matching", dialect.ErrUnsupportedFeature) +} + +// WriteLikeEscape writes the SQLite LIKE escape clause. +// SQLite does not use backslash escaping in string literals, so '\' is a single character. +func (d *Dialect) WriteLikeEscape(w *strings.Builder) { + w.WriteString(" ESCAPE '\\'") +} + +// WriteArrayMembership writes a SQLite array membership test using json_each. +func (d *Dialect) WriteArrayMembership(w *strings.Builder, writeElem func() error, writeArray func() error) error { + if err := writeElem(); err != nil { + return err + } + w.WriteString(" IN (SELECT value FROM json_each(") + if err := writeArray(); err != nil { + return err + } + w.WriteString("))") + return nil +} + +// --- Type Casting --- + +// WriteCastToNumeric writes a SQLite numeric cast (CAST(... AS REAL)). +func (d *Dialect) WriteCastToNumeric(w *strings.Builder) { + w.WriteString(" + 0") +} + +// WriteTypeName writes a SQLite type name for CAST expressions. +func (d *Dialect) WriteTypeName(w *strings.Builder, celTypeName string) { + switch celTypeName { + case "bool": + w.WriteString("INTEGER") + case "bytes": + w.WriteString("BLOB") + case "double": + w.WriteString("REAL") + case "int": + w.WriteString("INTEGER") + case "string": + w.WriteString("TEXT") + case "uint": + w.WriteString("INTEGER") + default: + w.WriteString(strings.ToUpper(celTypeName)) + } +} + +// WriteEpochExtract writes strftime('%s', expr) for SQLite. +func (d *Dialect) WriteEpochExtract(w *strings.Builder, writeExpr func() error) error { + w.WriteString("CAST(strftime('%s', ") + if err := writeExpr(); err != nil { + return err + } + w.WriteString(") AS INTEGER)") + return nil +} + +// WriteTimestampCast writes a SQLite datetime cast. +func (d *Dialect) WriteTimestampCast(w *strings.Builder, writeExpr func() error) error { + w.WriteString("datetime(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString(")") + return nil +} + +// --- Arrays --- + +// WriteArrayLiteralOpen writes the SQLite JSON array literal opening. +func (d *Dialect) WriteArrayLiteralOpen(w *strings.Builder) { + w.WriteString("json_array(") +} + +// WriteArrayLiteralClose writes the SQLite JSON array literal closing. +func (d *Dialect) WriteArrayLiteralClose(w *strings.Builder) { + w.WriteString(")") +} + +// WriteArrayLength writes json_array_length(expr) for SQLite. +func (d *Dialect) WriteArrayLength(w *strings.Builder, _ int, writeExpr func() error) error { + w.WriteString("COALESCE(json_array_length(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString("), 0)") + return nil +} + +// WriteListIndex writes SQLite JSON array index access. +func (d *Dialect) WriteListIndex(w *strings.Builder, writeArray func() error, writeIndex func() error) error { + w.WriteString("json_extract(") + if err := writeArray(); err != nil { + return err + } + w.WriteString(", '$[' || ") + if err := writeIndex(); err != nil { + return err + } + w.WriteString(" || ']')") + return nil +} + +// WriteListIndexConst writes SQLite JSON constant array index access. +func (d *Dialect) WriteListIndexConst(w *strings.Builder, writeArray func() error, index int64) error { + w.WriteString("json_extract(") + if err := writeArray(); err != nil { + return err + } + fmt.Fprintf(w, ", '$[%d]')", index) + return nil +} + +// WriteEmptyTypedArray writes an empty SQLite JSON array. +func (d *Dialect) WriteEmptyTypedArray(w *strings.Builder, _ string) { + w.WriteString("json_array()") +} + +// --- JSON --- + +// WriteJSONFieldAccess writes SQLite JSON field access using json_extract. +func (d *Dialect) WriteJSONFieldAccess(w *strings.Builder, writeBase func() error, fieldName string, _ bool) error { + w.WriteString("json_extract(") + if err := writeBase(); err != nil { + return err + } + escaped := escapeJSONFieldName(fieldName) + w.WriteString(", '$.") + w.WriteString(escaped) + w.WriteString("')") + return nil +} + +// WriteJSONExistence writes a SQLite JSON key existence check. +func (d *Dialect) WriteJSONExistence(w *strings.Builder, _ bool, fieldName string, writeBase func() error) error { + w.WriteString("json_type(") + if err := writeBase(); err != nil { + return err + } + escaped := escapeJSONFieldName(fieldName) + w.WriteString(", '$.") + w.WriteString(escaped) + w.WriteString("') IS NOT NULL") + return nil +} + +// WriteJSONArrayElements writes SQLite JSON array expansion using json_each. +func (d *Dialect) WriteJSONArrayElements(w *strings.Builder, _ bool, _ bool, writeExpr func() error) error { + w.WriteString("json_each(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString(")") + return nil +} + +// WriteJSONArrayLength writes COALESCE(json_array_length(expr), 0) for SQLite. +func (d *Dialect) WriteJSONArrayLength(w *strings.Builder, writeExpr func() error) error { + w.WriteString("COALESCE(json_array_length(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString("), 0)") + return nil +} + +// WriteJSONExtractPath writes SQLite JSON path extraction. +func (d *Dialect) WriteJSONExtractPath(w *strings.Builder, pathSegments []string, writeRoot func() error) error { + w.WriteString("json_type(") + if err := writeRoot(); err != nil { + return err + } + w.WriteString(", '$") + for _, segment := range pathSegments { + w.WriteString(".") + w.WriteString(escapeJSONFieldName(segment)) + } + w.WriteString("') IS NOT NULL") + return nil +} + +// WriteJSONArrayMembership writes SQLite JSON array membership using json_each. +func (d *Dialect) WriteJSONArrayMembership(w *strings.Builder, _ string, writeExpr func() error) error { + w.WriteString("(SELECT value FROM json_each(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString("))") + return nil +} + +// WriteNestedJSONArrayMembership writes SQLite nested JSON array membership. +func (d *Dialect) WriteNestedJSONArrayMembership(w *strings.Builder, writeExpr func() error) error { + w.WriteString("(SELECT value FROM json_each(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString("))") + return nil +} + +// --- Timestamps --- + +// WriteDuration writes a SQLite duration as a string modifier. +func (d *Dialect) WriteDuration(w *strings.Builder, value int64, unit string) { + // SQLite uses datetime modifiers like '+N seconds', '+N minutes', etc. + fmt.Fprintf(w, "'%+d %s'", value, strings.ToLower(unit)+"s") +} + +// WriteInterval writes a SQLite interval expression. +func (d *Dialect) WriteInterval(w *strings.Builder, writeValue func() error, unit string) error { + w.WriteString("'+'||") + if err := writeValue(); err != nil { + return err + } + fmt.Fprintf(w, "||' %s'", strings.ToLower(unit)+"s") + return nil +} + +// WriteExtract writes a SQLite strftime extraction expression. +func (d *Dialect) WriteExtract(w *strings.Builder, part string, writeExpr func() error, _ func() error) error { + format := sqliteExtractFormat(part) + w.WriteString("CAST(strftime('") + w.WriteString(format) + w.WriteString("', ") + if err := writeExpr(); err != nil { + return err + } + w.WriteString(") AS INTEGER)") + return nil +} + +// WriteTimestampArithmetic writes SQLite timestamp arithmetic using datetime(). +func (d *Dialect) WriteTimestampArithmetic(w *strings.Builder, op string, writeTS, writeDur func() error) error { + if op == "-" { + // For subtraction, negate the duration + w.WriteString("datetime(") + if err := writeTS(); err != nil { + return err + } + w.WriteString(", '-'||") + if err := writeDur(); err != nil { + return err + } + w.WriteString(")") + } else { + w.WriteString("datetime(") + if err := writeTS(); err != nil { + return err + } + w.WriteString(", ") + if err := writeDur(); err != nil { + return err + } + w.WriteString(")") + } + return nil +} + +// --- String Functions --- + +// WriteContains writes INSTR(haystack, needle) > 0 for SQLite. +func (d *Dialect) WriteContains(w *strings.Builder, writeHaystack, writeNeedle func() error) error { + w.WriteString("INSTR(") + if err := writeHaystack(); err != nil { + return err + } + w.WriteString(", ") + if err := writeNeedle(); err != nil { + return err + } + w.WriteString(") > 0") + return nil +} + +// WriteSplit returns an error as SQLite does not have a native string split. +func (d *Dialect) WriteSplit(_ *strings.Builder, _, _ func() error) error { + return fmt.Errorf("%w: string split", dialect.ErrUnsupportedFeature) +} + +// WriteSplitWithLimit returns an error as SQLite does not have a native string split. +func (d *Dialect) WriteSplitWithLimit(_ *strings.Builder, _, _ func() error, _ int64) error { + return fmt.Errorf("%w: string split with limit", dialect.ErrUnsupportedFeature) +} + +// WriteJoin returns an error as SQLite does not have a native array join. +func (d *Dialect) WriteJoin(_ *strings.Builder, _, _ func() error) error { + return fmt.Errorf("%w: array join", dialect.ErrUnsupportedFeature) +} + +// --- Comprehensions --- + +// WriteUnnest writes SQLite json_each for array unnesting. +func (d *Dialect) WriteUnnest(w *strings.Builder, writeSource func() error) error { + w.WriteString("json_each(") + if err := writeSource(); err != nil { + return err + } + w.WriteString(")") + return nil +} + +// WriteArraySubqueryOpen writes (SELECT json_group_array( for SQLite array subqueries. +func (d *Dialect) WriteArraySubqueryOpen(w *strings.Builder) { + w.WriteString("(SELECT json_group_array(") +} + +// WriteArraySubqueryExprClose closes the json_group_array aggregate function for SQLite. +func (d *Dialect) WriteArraySubqueryExprClose(w *strings.Builder) { + w.WriteString(")") +} + +// --- Struct --- + +// WriteStructOpen writes the SQLite struct literal opening. +func (d *Dialect) WriteStructOpen(w *strings.Builder) { + w.WriteString("json_object(") +} + +// WriteStructClose writes the SQLite struct literal closing. +func (d *Dialect) WriteStructClose(w *strings.Builder) { + w.WriteString(")") +} + +// --- Validation --- + +// MaxIdentifierLength returns 0 as SQLite has no hard identifier length limit. +func (d *Dialect) MaxIdentifierLength() int { + return 0 +} + +// ValidateFieldName validates a field name against SQLite naming rules. +func (d *Dialect) ValidateFieldName(name string) error { + return validateFieldName(name) +} + +// ReservedKeywords returns the set of reserved SQL keywords for SQLite. +func (d *Dialect) ReservedKeywords() map[string]bool { + return reservedSQLKeywords +} + +// --- Regex --- + +// ConvertRegex returns an error as SQLite does not natively support regex. +func (d *Dialect) ConvertRegex(_ string) (string, bool, error) { + return "", false, fmt.Errorf("%w: regex matching", dialect.ErrUnsupportedFeature) +} + +// SupportsRegex returns false as SQLite does not natively support regex. +func (d *Dialect) SupportsRegex() bool { return false } + +// --- Capabilities --- + +// SupportsNativeArrays returns false as SQLite uses JSON arrays. +func (d *Dialect) SupportsNativeArrays() bool { return false } + +// SupportsJSONB returns false as SQLite has a single JSON type. +func (d *Dialect) SupportsJSONB() bool { return false } + +// SupportsIndexAnalysis returns true as SQLite index analysis is supported. +func (d *Dialect) SupportsIndexAnalysis() bool { return true } + +// --- Internal helpers --- + +// escapeJSONFieldName escapes special characters in JSON field names for SQLite. +func escapeJSONFieldName(fieldName string) string { + return strings.ReplaceAll(fieldName, "'", "''") +} + +// sqliteExtractFormat maps SQL EXTRACT parts to SQLite strftime format strings. +func sqliteExtractFormat(part string) string { + switch part { + case "YEAR": + return "%Y" + case "MONTH": + return "%m" + case "DAY": + return "%d" + case "HOUR": + return "%H" + case "MINUTE": + return "%M" + case "SECOND": + return "%S" + case "DOY": + return "%j" + case "DOW": + return "%w" + case "MILLISECONDS": + return "%f" + default: + return "%Y" + } +} diff --git a/dialect/sqlite/index_advisor.go b/dialect/sqlite/index_advisor.go new file mode 100644 index 0000000..107ce11 --- /dev/null +++ b/dialect/sqlite/index_advisor.go @@ -0,0 +1,73 @@ +package sqlite + +import ( + "fmt" + "strings" + + "github.com/spandigital/cel2sql/v3/dialect" +) + +// SQLite index type constants. +const ( + IndexTypeBTree = "BTREE" +) + +// RecommendIndex generates a SQLite-specific index recommendation for the given pattern. +// SQLite only supports standard B-tree indexes. Returns nil for unsupported patterns. +func (d *Dialect) RecommendIndex(pattern dialect.IndexPattern) *dialect.IndexRecommendation { + table := pattern.TableHint + if table == "" { + table = "table_name" + } + col := pattern.Column + safeName := sanitizeIndexName(col) + + switch pattern.Pattern { + case dialect.PatternComparison: + return &dialect.IndexRecommendation{ + Column: col, + IndexType: IndexTypeBTree, + Expression: fmt.Sprintf("CREATE INDEX idx_%s ON %s (%s);", + safeName, table, col), + Reason: fmt.Sprintf("Comparison operations on '%s' benefit from an index for efficient range queries and equality checks", col), + } + + case dialect.PatternJSONAccess: + // SQLite does not support indexes on JSON expressions directly + return nil + + case dialect.PatternRegexMatch: + // SQLite does not support native regex; no index recommendation + return nil + + case dialect.PatternArrayMembership, dialect.PatternArrayComprehension: + // SQLite does not have native array types + return nil + + case dialect.PatternJSONArrayComprehension: + // SQLite does not support indexes on JSON array operations + return nil + } + + return nil +} + +// SupportedPatterns returns the pattern types supported by SQLite. +func (d *Dialect) SupportedPatterns() []dialect.PatternType { + return []dialect.PatternType{ + dialect.PatternComparison, + } +} + +// sanitizeIndexName creates a safe index name from a column name. +func sanitizeIndexName(column string) string { + sanitized := strings.ReplaceAll(column, ".", "_") + sanitized = strings.ReplaceAll(sanitized, " ", "_") + sanitized = strings.ReplaceAll(sanitized, "-", "_") + + if len(sanitized) > 50 { + sanitized = sanitized[:50] + } + + return sanitized +} diff --git a/dialect/sqlite/validation.go b/dialect/sqlite/validation.go new file mode 100644 index 0000000..805c06c --- /dev/null +++ b/dialect/sqlite/validation.go @@ -0,0 +1,68 @@ +package sqlite + +import ( + "errors" + "fmt" + "regexp" + "strings" +) + +var ( + // fieldNameRegexp validates SQLite identifier format. + fieldNameRegexp = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) + + // reservedSQLKeywords contains SQLite reserved keywords. + reservedSQLKeywords = map[string]bool{ + "abort": true, "action": true, "add": true, "after": true, "all": true, + "alter": true, "always": true, "analyze": true, "and": true, "as": true, + "asc": true, "attach": true, "autoincrement": true, "before": true, + "begin": true, "between": true, "by": true, "cascade": true, "case": true, + "cast": true, "check": true, "collate": true, "column": true, "commit": true, + "conflict": true, "constraint": true, "create": true, "cross": true, + "current": true, "current_date": true, "current_time": true, + "current_timestamp": true, "database": true, "default": true, + "deferrable": true, "deferred": true, "delete": true, "desc": true, + "detach": true, "distinct": true, "do": true, "drop": true, "each": true, + "else": true, "end": true, "escape": true, "except": true, "exclude": true, + "exclusive": true, "exists": true, "explain": true, "fail": true, + "filter": true, "first": true, "following": true, "for": true, + "foreign": true, "from": true, "full": true, "glob": true, "group": true, + "groups": true, "having": true, "if": true, "ignore": true, "immediate": true, + "in": true, "index": true, "indexed": true, "initially": true, "inner": true, + "insert": true, "instead": true, "intersect": true, "into": true, "is": true, + "isnull": true, "join": true, "key": true, "last": true, "left": true, + "like": true, "limit": true, "match": true, "materialized": true, + "natural": true, "no": true, "not": true, "nothing": true, "notnull": true, + "null": true, "nulls": true, "of": true, "offset": true, "on": true, + "or": true, "order": true, "others": true, "outer": true, "over": true, + "partition": true, "plan": true, "pragma": true, "preceding": true, + "primary": true, "query": true, "raise": true, "range": true, + "recursive": true, "references": true, "regexp": true, "reindex": true, + "release": true, "rename": true, "replace": true, "restrict": true, + "returning": true, "right": true, "rollback": true, "row": true, + "rows": true, "savepoint": true, "select": true, "set": true, "table": true, + "temp": true, "temporary": true, "then": true, "ties": true, "to": true, + "transaction": true, "trigger": true, "unbounded": true, "union": true, + "unique": true, "update": true, "using": true, "vacuum": true, "values": true, + "view": true, "virtual": true, "when": true, "where": true, "window": true, + "with": true, "without": true, + } +) + +// validateFieldName validates that a field name follows SQLite naming conventions. +func validateFieldName(name string) error { + if len(name) == 0 { + return errors.New("field name cannot be empty") + } + + // SQLite has no hard limit on identifier length but we use a reasonable limit + if !fieldNameRegexp.MatchString(name) { + return fmt.Errorf("field name %q must start with a letter or underscore and contain only alphanumeric characters and underscores", name) + } + + if reservedSQLKeywords[strings.ToLower(name)] { + return fmt.Errorf("field name %q is a reserved SQL keyword and cannot be used without quoting", name) + } + + return nil +} diff --git a/duckdb/provider.go b/duckdb/provider.go new file mode 100644 index 0000000..e1b73b9 --- /dev/null +++ b/duckdb/provider.go @@ -0,0 +1,251 @@ +// Package duckdb provides DuckDB type provider for CEL type system integration. +package duckdb + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" + + "github.com/google/cel-go/checker/decls" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + + "github.com/spandigital/cel2sql/v3/schema" +) + +// Sentinel errors for the duckdb package. +var ( + // ErrInvalidSchema indicates a problem with the provided schema or database introspection. + ErrInvalidSchema = errors.New("invalid schema") +) + +// FieldSchema is an alias for schema.FieldSchema. +type FieldSchema = schema.FieldSchema + +// Schema is an alias for schema.Schema. +type Schema = schema.Schema + +// NewSchema creates a new Schema. This is an alias for schema.NewSchema. +var NewSchema = schema.NewSchema + +// TypeProvider interface for DuckDB type providers. +type TypeProvider interface { + types.Provider + LoadTableSchema(ctx context.Context, tableName string) error + GetSchemas() map[string]Schema + Close() +} + +type typeProvider struct { + schemas map[string]Schema + db *sql.DB +} + +// NewTypeProvider creates a new DuckDB type provider with pre-defined schemas. +func NewTypeProvider(schemas map[string]Schema) TypeProvider { + return &typeProvider{schemas: schemas} +} + +// NewTypeProviderWithConnection creates a new DuckDB type provider that can introspect database schemas. +// The caller owns the *sql.DB and is responsible for closing it. +// This works with any DuckDB driver that implements database/sql (e.g., github.com/marcboeker/go-duckdb). +func NewTypeProviderWithConnection(_ context.Context, db *sql.DB) (TypeProvider, error) { + if db == nil { + return nil, fmt.Errorf("%w: db connection must not be nil", ErrInvalidSchema) + } + + return &typeProvider{ + schemas: make(map[string]Schema), + db: db, + }, nil +} + +// LoadTableSchema loads schema information for a table from the database. +func (tp *typeProvider) LoadTableSchema(ctx context.Context, tableName string) error { + if tp.db == nil { + return fmt.Errorf("%w: no database connection available", ErrInvalidSchema) + } + + query := ` + SELECT column_name, data_type, is_nullable + FROM information_schema.columns + WHERE table_name = ? + ORDER BY ordinal_position + ` + + rows, err := tp.db.QueryContext(ctx, query, tableName) + if err != nil { + return fmt.Errorf("%w: failed to query table schema", ErrInvalidSchema) + } + defer func() { _ = rows.Close() }() + + var fields []FieldSchema + for rows.Next() { + var columnName, dataType, isNullable string + + if err := rows.Scan(&columnName, &dataType, &isNullable); err != nil { + return fmt.Errorf("%w: failed to scan row", ErrInvalidSchema) + } + + field := duckdbColumnToFieldSchema(columnName, dataType) + fields = append(fields, field) + } + + if err := rows.Err(); err != nil { + return fmt.Errorf("%w: error iterating rows", ErrInvalidSchema) + } + + if len(fields) == 0 { + return fmt.Errorf("%w: table %q has no columns or does not exist", ErrInvalidSchema, tableName) + } + + tp.schemas[tableName] = NewSchema(fields) + return nil +} + +// duckdbColumnToFieldSchema converts DuckDB column metadata to a FieldSchema. +func duckdbColumnToFieldSchema(columnName, dataType string) FieldSchema { + // DuckDB array types appear as "INTEGER[]", "VARCHAR[]", etc. + isArray, elementType, dimensions := detectDuckDBArray(dataType) + isJSON := strings.EqualFold(dataType, "json") + + if isArray { + return FieldSchema{ + Name: columnName, + Type: strings.ToLower(elementType), + Repeated: true, + Dimensions: dimensions, + ElementType: strings.ToLower(elementType), + } + } + + return FieldSchema{ + Name: columnName, + Type: normalizeDuckDBType(dataType), + IsJSON: isJSON, + } +} + +// detectDuckDBArray detects if a DuckDB data type is an array and returns element type and dimensions. +func detectDuckDBArray(dataType string) (isArray bool, elementType string, dimensions int) { + // Count trailing [] pairs + remaining := dataType + dims := 0 + for strings.HasSuffix(remaining, "[]") { + dims++ + remaining = strings.TrimSuffix(remaining, "[]") + } + + if dims > 0 { + return true, remaining, dims + } + return false, "", 0 +} + +// normalizeDuckDBType normalizes a DuckDB type name to lowercase. +func normalizeDuckDBType(dataType string) string { + return strings.ToLower(dataType) +} + +// Close is a no-op since we don't own the *sql.DB. +func (tp *typeProvider) Close() { + // No-op: caller owns the *sql.DB connection +} + +// GetSchemas returns the schemas known to this type provider. +func (tp *typeProvider) GetSchemas() map[string]Schema { + return tp.schemas +} + +// EnumValue implements types.Provider. +func (tp *typeProvider) EnumValue(_ string) ref.Val { + return types.NewErr("unknown enum value") +} + +// FindIdent implements types.Provider. +func (tp *typeProvider) FindIdent(_ string) (ref.Val, bool) { + return nil, false +} + +// FindStructType implements types.Provider. +func (tp *typeProvider) FindStructType(structType string) (*types.Type, bool) { + if _, ok := tp.schemas[structType]; ok { + return types.NewObjectType(structType), true + } + return nil, false +} + +// FindStructFieldNames implements types.Provider. +func (tp *typeProvider) FindStructFieldNames(structType string) ([]string, bool) { + s, ok := tp.schemas[structType] + if !ok { + return nil, false + } + fields := s.Fields() + names := make([]string, len(fields)) + for i, f := range fields { + names[i] = f.Name + } + return names, true +} + +// FindStructFieldType implements types.Provider. +func (tp *typeProvider) FindStructFieldType(structType, fieldName string) (*types.FieldType, bool) { + s, ok := tp.schemas[structType] + if !ok { + return nil, false + } + field, found := s.FindField(fieldName) + if !found { + return nil, false + } + + exprType := duckdbTypeToCELExprType(field) + celType, err := types.ExprTypeToType(exprType) + if err != nil { + return nil, false + } + + return &types.FieldType{ + Type: celType, + }, true +} + +// NewValue implements types.Provider. +func (tp *typeProvider) NewValue(_ string, _ map[string]ref.Val) ref.Val { + return types.NewErr("unknown type in schema") +} + +// duckdbTypeToCELExprType converts a DuckDB field schema to a CEL expression type. +func duckdbTypeToCELExprType(field *schema.FieldSchema) *exprpb.Type { + baseType := duckdbBaseTypeToCEL(field.Type) + if field.Repeated { + return decls.NewListType(baseType) + } + return baseType +} + +// duckdbBaseTypeToCEL converts a DuckDB type name to a CEL expression type. +func duckdbBaseTypeToCEL(typeName string) *exprpb.Type { + switch typeName { + case "varchar", "text", "char", "bpchar", "name": + return decls.String + case "bigint", "integer", "int", "int4", "int8", "smallint", "int2", "tinyint", "hugeint": + return decls.Int + case "double", "float", "real", "float4", "float8", "numeric", "decimal": + return decls.Double + case "boolean", "bool": + return decls.Bool + case "blob", "bytea": + return decls.Bytes + case "json": + return decls.Dyn + case "timestamp", "timestamptz", "timestamp with time zone", "timestamp without time zone": + return decls.Timestamp + default: + return decls.Dyn + } +} diff --git a/duckdb/provider_test.go b/duckdb/provider_test.go new file mode 100644 index 0000000..98e47f0 --- /dev/null +++ b/duckdb/provider_test.go @@ -0,0 +1,157 @@ +package duckdb_test + +import ( + "context" + "testing" + + "github.com/google/cel-go/common/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/spandigital/cel2sql/v3/duckdb" +) + +func TestNewTypeProvider(t *testing.T) { + schemas := map[string]duckdb.Schema{ + "users": duckdb.NewSchema([]duckdb.FieldSchema{ + {Name: "id", Type: "integer"}, + {Name: "name", Type: "varchar"}, + }), + } + + provider := duckdb.NewTypeProvider(schemas) + require.NotNil(t, provider) + + got := provider.GetSchemas() + assert.Len(t, got, 1) + assert.Contains(t, got, "users") +} + +func TestNewTypeProviderWithConnection_NilDB(t *testing.T) { + _, err := duckdb.NewTypeProviderWithConnection(context.Background(), nil) + require.Error(t, err) + assert.ErrorIs(t, err, duckdb.ErrInvalidSchema) +} + +func TestLoadTableSchema_NoDB(t *testing.T) { + provider := duckdb.NewTypeProvider(nil) + err := provider.LoadTableSchema(context.Background(), "test") + require.Error(t, err) + assert.ErrorIs(t, err, duckdb.ErrInvalidSchema) +} + +func TestTypeProvider_FindStructType(t *testing.T) { + schemas := map[string]duckdb.Schema{ + "users": duckdb.NewSchema([]duckdb.FieldSchema{ + {Name: "id", Type: "integer"}, + {Name: "name", Type: "varchar"}, + }), + } + provider := duckdb.NewTypeProvider(schemas) + + typ, found := provider.FindStructType("users") + assert.True(t, found) + assert.NotNil(t, typ) + + _, found = provider.FindStructType("nonexistent") + assert.False(t, found) +} + +func TestTypeProvider_FindStructFieldNames(t *testing.T) { + schemas := map[string]duckdb.Schema{ + "users": duckdb.NewSchema([]duckdb.FieldSchema{ + {Name: "id", Type: "integer"}, + {Name: "name", Type: "varchar"}, + {Name: "email", Type: "text"}, + }), + } + provider := duckdb.NewTypeProvider(schemas) + + names, found := provider.FindStructFieldNames("users") + assert.True(t, found) + assert.ElementsMatch(t, []string{"id", "name", "email"}, names) + + _, found = provider.FindStructFieldNames("nonexistent") + assert.False(t, found) +} + +func TestTypeProvider_FindStructFieldType(t *testing.T) { + schemas := map[string]duckdb.Schema{ + "test_table": duckdb.NewSchema([]duckdb.FieldSchema{ + {Name: "str_field", Type: "varchar"}, + {Name: "text_field", Type: "text"}, + {Name: "int_field", Type: "integer"}, + {Name: "bigint_field", Type: "bigint"}, + {Name: "smallint_field", Type: "smallint"}, + {Name: "tinyint_field", Type: "tinyint"}, + {Name: "hugeint_field", Type: "hugeint"}, + {Name: "double_field", Type: "double"}, + {Name: "float_field", Type: "float"}, + {Name: "bool_field", Type: "boolean"}, + {Name: "blob_field", Type: "blob"}, + {Name: "json_field", Type: "json"}, + {Name: "ts_field", Type: "timestamp"}, + {Name: "array_field", Type: "integer", Repeated: true}, + }), + } + provider := duckdb.NewTypeProvider(schemas) + + tests := []struct { + fieldName string + wantType *types.Type + wantFound bool + }{ + {"str_field", types.StringType, true}, + {"text_field", types.StringType, true}, + {"int_field", types.IntType, true}, + {"bigint_field", types.IntType, true}, + {"smallint_field", types.IntType, true}, + {"tinyint_field", types.IntType, true}, + {"hugeint_field", types.IntType, true}, + {"double_field", types.DoubleType, true}, + {"float_field", types.DoubleType, true}, + {"bool_field", types.BoolType, true}, + {"blob_field", types.BytesType, true}, + {"json_field", types.DynType, true}, + {"ts_field", types.TimestampType, true}, + {"array_field", types.NewListType(types.IntType), true}, + {"nonexistent", nil, false}, + } + + for _, tt := range tests { + t.Run(tt.fieldName, func(t *testing.T) { + got, found := provider.FindStructFieldType("test_table", tt.fieldName) + assert.Equal(t, tt.wantFound, found) + if found { + assert.Equal(t, tt.wantType, got.Type) + } + }) + } +} + +func TestTypeProvider_Close(_ *testing.T) { + provider := duckdb.NewTypeProvider(nil) + // Close should not panic + provider.Close() +} + +func TestTypeProvider_ArrayDetection(t *testing.T) { + // Test that arrays defined manually with Repeated=true work correctly + schemas := map[string]duckdb.Schema{ + "test_table": duckdb.NewSchema([]duckdb.FieldSchema{ + {Name: "tags", Type: "varchar", Repeated: true, Dimensions: 1}, + {Name: "matrix", Type: "integer", Repeated: true, Dimensions: 2}, + }), + } + provider := duckdb.NewTypeProvider(schemas) + + // tags should be list of strings + got, found := provider.FindStructFieldType("test_table", "tags") + assert.True(t, found) + assert.Equal(t, types.NewListType(types.StringType), got.Type) + + // matrix should be list of integers (CEL sees all array dims as list) + got, found = provider.FindStructFieldType("test_table", "matrix") + assert.True(t, found) + assert.Equal(t, types.NewListType(types.IntType), got.Type) +} diff --git a/errors.go b/errors.go index 92115f3..4f6b585 100644 --- a/errors.go +++ b/errors.go @@ -58,6 +58,9 @@ var ( // ErrInvalidByteArrayLength indicates byte array exceeds maximum length ErrInvalidByteArrayLength = errors.New("byte array exceeds maximum length") + + // ErrUnsupportedDialectFeature indicates a feature not supported by the selected dialect + ErrUnsupportedDialectFeature = errors.New("unsupported dialect feature") ) // ConversionError represents an error that occurred during CEL to SQL conversion. diff --git a/examples/index_analysis/main.go b/examples/index_analysis/main.go index 73f744b..8c4be94 100644 --- a/examples/index_analysis/main.go +++ b/examples/index_analysis/main.go @@ -7,6 +7,12 @@ import ( "github.com/google/cel-go/cel" "github.com/spandigital/cel2sql/v3" + "github.com/spandigital/cel2sql/v3/dialect" + dialectbq "github.com/spandigital/cel2sql/v3/dialect/bigquery" + dialectduckdb "github.com/spandigital/cel2sql/v3/dialect/duckdb" + dialectmysql "github.com/spandigital/cel2sql/v3/dialect/mysql" + dialectpg "github.com/spandigital/cel2sql/v3/dialect/postgres" + dialectsqlite "github.com/spandigital/cel2sql/v3/dialect/sqlite" "github.com/spandigital/cel2sql/v3/pg" ) @@ -77,59 +83,114 @@ func main() { }, } - // Analyze each query and display recommendations + // Analyze each query and display recommendations (PostgreSQL default) + fmt.Println("\n--- PostgreSQL (default) ---") for i, ex := range examples { - fmt.Printf("%d. %s\n", i+1, ex.name) - fmt.Printf(" Description: %s\n", ex.description) - fmt.Printf(" CEL Expression: %s\n\n", ex.expression) - - // Compile the CEL expression - ast, issues := env.Compile(ex.expression) - if issues != nil && issues.Err() != nil { - log.Printf(" ERROR: Failed to compile: %v\n\n", issues.Err()) - continue - } + analyzeExample(env, provider, i, ex.name, ex.description, ex.expression) + } + + // Multi-dialect examples + fmt.Println("\n===================================") + fmt.Println("Multi-Dialect Index Recommendations") + fmt.Println("===================================") + + // Use a simple comparison query to show dialect differences + comparisonExpr := `users.age > 21 && users.metadata.verified == true` + + dialectExamples := []struct { + name string + dialect dialect.Dialect + }{ + {"PostgreSQL", dialectpg.New()}, + {"MySQL", dialectmysql.New()}, + {"SQLite", dialectsqlite.New()}, + {"DuckDB", dialectduckdb.New()}, + {"BigQuery", dialectbq.New()}, + } + + ast, issues := env.Compile(comparisonExpr) + if issues != nil && issues.Err() != nil { + log.Fatalf("Failed to compile expression: %v", issues.Err()) + } - // Analyze the query - sql, recommendations, err := cel2sql.AnalyzeQuery(ast, - cel2sql.WithSchemas(provider.GetSchemas())) + for _, de := range dialectExamples { + fmt.Printf("\n--- %s ---\n", de.name) + fmt.Printf(" CEL Expression: %s\n\n", comparisonExpr) + + _, recommendations, err := cel2sql.AnalyzeQuery(ast, + cel2sql.WithSchemas(provider.GetSchemas()), + cel2sql.WithDialect(de.dialect)) if err != nil { log.Printf(" ERROR: Failed to analyze: %v\n\n", err) continue } - // Display the generated SQL - fmt.Printf(" Generated SQL:\n %s\n\n", sql) - - // Display index recommendations if len(recommendations) == 0 { - fmt.Printf(" No index recommendations (query uses constants or simple conditions)\n\n") + fmt.Printf(" No index recommendations\n") } else { - fmt.Printf(" Index Recommendations (%d):\n", len(recommendations)) for j, rec := range recommendations { fmt.Printf(" [%d] Column: %s\n", j+1, rec.Column) fmt.Printf(" Type: %s\n", rec.IndexType) fmt.Printf(" Reason: %s\n", rec.Reason) - fmt.Printf(" SQL: %s\n", rec.Expression) + fmt.Printf(" DDL: %s\n", rec.Expression) fmt.Println() } } - - fmt.Println(" " + string(make([]byte, 60))) - fmt.Println() } // Summary fmt.Println("\nSummary") fmt.Println("=======") - fmt.Println("Index recommendations help optimize query performance by:") - fmt.Println(" • B-tree indexes: Fast equality and range queries on scalar columns") - fmt.Println(" • GIN indexes: Efficient JSON path access and array operations") - fmt.Println(" • GIN with pg_trgm: Fast regex pattern matching on text columns") + fmt.Println("Index recommendations are dialect-aware:") + fmt.Println(" PostgreSQL: B-tree, GIN, GIN with pg_trgm") + fmt.Println(" MySQL: B-tree, FULLTEXT, functional JSON indexes") + fmt.Println(" SQLite: B-tree (limited index types)") + fmt.Println(" DuckDB: ART (Adaptive Radix Tree)") + fmt.Println(" BigQuery: Clustering keys, Search indexes") + fmt.Println() + fmt.Println("Use WithDialect() to get dialect-specific recommendations:") + fmt.Println(" sql, recs, err := cel2sql.AnalyzeQuery(ast,") + fmt.Println(" cel2sql.WithDialect(mysql.New()),") + fmt.Println(" cel2sql.WithSchemas(schemas))") +} + +func analyzeExample(env *cel.Env, provider pg.TypeProvider, idx int, name, description, expression string) { + fmt.Printf("%d. %s\n", idx+1, name) + fmt.Printf(" Description: %s\n", description) + fmt.Printf(" CEL Expression: %s\n\n", expression) + + // Compile the CEL expression + ast, issues := env.Compile(expression) + if issues != nil && issues.Err() != nil { + log.Printf(" ERROR: Failed to compile: %v\n\n", issues.Err()) + return + } + + // Analyze the query + sql, recommendations, err := cel2sql.AnalyzeQuery(ast, + cel2sql.WithSchemas(provider.GetSchemas())) + if err != nil { + log.Printf(" ERROR: Failed to analyze: %v\n\n", err) + return + } + + // Display the generated SQL + fmt.Printf(" Generated SQL:\n %s\n\n", sql) + + // Display index recommendations + if len(recommendations) == 0 { + fmt.Printf(" No index recommendations (query uses constants or simple conditions)\n\n") + } else { + fmt.Printf(" Index Recommendations (%d):\n", len(recommendations)) + for j, rec := range recommendations { + fmt.Printf(" [%d] Column: %s\n", j+1, rec.Column) + fmt.Printf(" Type: %s\n", rec.IndexType) + fmt.Printf(" Reason: %s\n", rec.Reason) + fmt.Printf(" SQL: %s\n", rec.Expression) + fmt.Println() + } + } + + fmt.Println(" " + string(make([]byte, 60))) fmt.Println() - fmt.Println("To apply recommendations:") - fmt.Println(" 1. Review each recommendation and its reason") - fmt.Println(" 2. Adjust table_name to your actual table name") - fmt.Println(" 3. Execute the CREATE INDEX statements on your database") - fmt.Println(" 4. Monitor query performance improvements") } diff --git a/go.mod b/go.mod index b383cc0..58118e2 100644 --- a/go.mod +++ b/go.mod @@ -3,21 +3,35 @@ module github.com/spandigital/cel2sql/v3 go 1.25.7 require ( + cloud.google.com/go/bigquery v1.73.1 + github.com/go-sql-driver/mysql v1.9.3 github.com/google/cel-go v0.27.0 github.com/jackc/pgx/v5 v5.8.0 github.com/lib/pq v1.11.2 github.com/stretchr/testify v1.11.1 github.com/testcontainers/testcontainers-go v0.40.0 + github.com/testcontainers/testcontainers-go/modules/gcloud v0.40.0 + github.com/testcontainers/testcontainers-go/modules/mysql v0.40.0 github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0 - google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822 + google.golang.org/api v0.268.0 + google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409 + google.golang.org/grpc v1.79.1 + modernc.org/sqlite v1.46.1 ) require ( cel.dev/expr v0.25.1 // indirect + cloud.google.com/go v0.123.0 // indirect + cloud.google.com/go/auth v0.18.1 // indirect + cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect + cloud.google.com/go/compute/metadata v0.9.0 // indirect + cloud.google.com/go/iam v1.5.3 // indirect dario.cat/mergo v1.0.2 // indirect + filippo.io/edwards25519 v1.1.0 // indirect github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect github.com/antlr4-go/antlr/v4 v4.13.1 // indirect + github.com/apache/arrow/go/v15 v15.0.2 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/containerd/errdefs v1.0.0 // indirect @@ -30,19 +44,27 @@ require ( github.com/docker/docker v28.5.2+incompatible // indirect github.com/docker/go-connections v0.6.0 // indirect github.com/docker/go-units v0.5.0 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect github.com/ebitengine/purego v0.8.4 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-ole/go-ole v1.2.6 // indirect + github.com/goccy/go-json v0.10.2 // indirect + github.com/google/flatbuffers v23.5.26+incompatible // indirect + github.com/google/s2a-go v0.1.9 // indirect github.com/google/uuid v1.6.0 // indirect + github.com/googleapis/enterprise-certificate-proxy v0.3.11 // indirect + github.com/googleapis/gax-go/v2 v2.17.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.1 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/klauspost/compress v1.18.0 // indirect + github.com/klauspost/cpuid/v2 v2.2.5 // indirect github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect github.com/magiconair/properties v1.8.10 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect github.com/moby/go-archive v0.1.0 // indirect github.com/moby/patternmatcher v0.6.0 // indirect @@ -51,28 +73,43 @@ require ( github.com/moby/sys/userns v0.1.0 // indirect github.com/moby/term v0.5.0 // indirect github.com/morikuni/aec v1.0.0 // indirect + github.com/ncruces/go-strftime v1.0.0 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.1 // indirect + github.com/pierrec/lz4/v4 v4.1.18 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/shirou/gopsutil/v4 v4.25.6 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/tklauser/go-sysconf v0.3.12 // indirect github.com/tklauser/numcpus v0.6.1 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect + github.com/zeebo/xxh3 v1.0.2 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect + go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.63.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.65.0 // indirect go.opentelemetry.io/otel v1.40.0 // indirect go.opentelemetry.io/otel/metric v1.40.0 // indirect go.opentelemetry.io/otel/trace v1.40.0 // indirect - golang.org/x/crypto v0.46.0 // indirect - golang.org/x/exp v0.0.0-20240823005443-9b4947da3948 // indirect + golang.org/x/crypto v0.47.0 // indirect + golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect + golang.org/x/mod v0.31.0 // indirect + golang.org/x/net v0.49.0 // indirect + golang.org/x/oauth2 v0.35.0 // indirect golang.org/x/sync v0.19.0 // indirect golang.org/x/sys v0.40.0 // indirect - golang.org/x/text v0.32.0 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 // indirect - google.golang.org/grpc v1.73.0 // indirect - google.golang.org/protobuf v1.36.10 // indirect + golang.org/x/telemetry v0.0.0-20251203150158-8fff8a5912fc // indirect + golang.org/x/text v0.33.0 // indirect + golang.org/x/time v0.14.0 // indirect + golang.org/x/tools v0.40.0 // indirect + golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect + google.golang.org/genproto v0.0.0-20260128011058-8636f8732409 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20260203192932-546029d2fa20 // indirect + google.golang.org/protobuf v1.36.11 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect + modernc.org/libc v1.67.6 // indirect + modernc.org/mathutil v1.7.1 // indirect + modernc.org/memory v1.11.0 // indirect ) diff --git a/go.sum b/go.sum index 095dcf6..1d392a9 100644 --- a/go.sum +++ b/go.sum @@ -1,19 +1,51 @@ cel.dev/expr v0.25.1 h1:1KrZg61W6TWSxuNZ37Xy49ps13NUovb66QLprthtwi4= cel.dev/expr v0.25.1/go.mod h1:hrXvqGP6G6gyx8UAHSHJ5RGk//1Oj5nXQ2NI02Nrsg4= +cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaJeYoNhuNEM9kM8SjgRgdeOJUSE= +cloud.google.com/go v0.123.0/go.mod h1:xBoMV08QcqUGuPW65Qfm1o9Y4zKZBpGS+7bImXLTAZU= +cloud.google.com/go/auth v0.18.1 h1:IwTEx92GFUo2pJ6Qea0EU3zYvKnTAeRCODxfA/G5UWs= +cloud.google.com/go/auth v0.18.1/go.mod h1:GfTYoS9G3CWpRA3Va9doKN9mjPGRS+v41jmZAhBzbrA= +cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc= +cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c= +cloud.google.com/go/bigquery v1.73.1 h1:v//GZwdhtmCbZ87rOnxz7pectOGFS1GNRvrGTvLzka4= +cloud.google.com/go/bigquery v1.73.1/go.mod h1:KSLx1mKP/yGiA8U+ohSrqZM1WknUnjZAxHAQZ51/b1k= +cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs= +cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10= +cloud.google.com/go/datacatalog v1.26.1 h1:bCRKA8uSQN8wGW3Tw0gwko4E9a64GRmbW1nCblhgC2k= +cloud.google.com/go/datacatalog v1.26.1/go.mod h1:2Qcq8vsHNxMDgjgadRFmFG47Y+uuIVsyEGUrlrKEdrg= +cloud.google.com/go/iam v1.5.3 h1:+vMINPiDF2ognBJ97ABAYYwRgsaqxPbQDlMnbHMjolc= +cloud.google.com/go/iam v1.5.3/go.mod h1:MR3v9oLkZCTlaqljW6Eb2d3HGDGK5/bDv93jhfISFvU= +cloud.google.com/go/longrunning v0.8.0 h1:LiKK77J3bx5gDLi4SMViHixjD2ohlkwBi+mKA7EhfW8= +cloud.google.com/go/longrunning v0.8.0/go.mod h1:UmErU2Onzi+fKDg2gR7dusz11Pe26aknR4kHmJJqIfk= +cloud.google.com/go/monitoring v1.24.3 h1:dde+gMNc0UhPZD1Azu6at2e79bfdztVDS5lvhOdsgaE= +cloud.google.com/go/monitoring v1.24.3/go.mod h1:nYP6W0tm3N9H/bOw8am7t62YTzZY+zUeQ+Bi6+2eonI= +cloud.google.com/go/storage v1.59.0 h1:9p3yDzEN9Vet4JnbN90FECIw6n4FCXcKBK1scxtQnw8= +cloud.google.com/go/storage v1.59.0/go.mod h1:cMWbtM+anpC74gn6qjLh+exqYcfmB9Hqe5z6adx+CLI= dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8= dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA= +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6 h1:He8afgbRMd7mFxO99hRNu+6tazq8nFF9lIwo9JFroBk= github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8= github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8= github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.30.0 h1:sBEjpZlNHzK1voKq9695PJSX2o5NEXl7/OL3coiIY0c= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.30.0/go.mod h1:P4WPRUkOhJC13W//jWpyfJNDAIpvRbAUIYLX/4jtlE0= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.54.0 h1:lhhYARPUu3LmHysQ/igznQphfzynnqI3D75oUyw1HXk= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.54.0/go.mod h1:l9rva3ApbBpEJxSNYnwT9N4CDLrWgtq3u8736C5hyJw= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.54.0 h1:s0WlVbf9qpvkh1c/uDAPElam0WrL7fHRIidgZJ7UqZI= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.54.0/go.mod h1:Mf6O40IAyB9zR/1J8nGDDPirZQQPbYJni8Yisy7NTMc= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/antlr4-go/antlr/v4 v4.13.1 h1:SqQKkuVZ+zWkMMNkjy5FZe5mr5WURWnlpmOuzYWrPrQ= github.com/antlr4-go/antlr/v4 v4.13.1/go.mod h1:GKmUxMtwp6ZgGwZSva4eWPC5mS6vUAmOABFgjdkM7Nw= +github.com/apache/arrow/go/v15 v15.0.2 h1:60IliRbiyTWCWjERBCkO1W4Qun9svcYoZrSLcyOsMLE= +github.com/apache/arrow/go/v15 v15.0.2/go.mod h1:DGXsR3ajT524njufqf95822i+KTh+yea1jass9YXgjA= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cncf/xds/go v0.0.0-20251210132809-ee656c7534f5 h1:6xNmx7iTtyBRev0+D/Tv1FZd4SCg8axKApyNyRsAt/w= +github.com/cncf/xds/go v0.0.0-20251210132809-ee656c7534f5/go.mod h1:KdCmV+x/BuvyMxRnYBlmVaq4OLiKW6iRQfvC62cvdkI= github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI= github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M= github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE= @@ -37,10 +69,19 @@ github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pM github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0omw= github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ= +github.com/envoyproxy/go-control-plane v0.14.0 h1:hbG2kr4RuFj222B6+7T83thSPqLjwBIfQawTkC++2HA= +github.com/envoyproxy/go-control-plane/envoy v1.36.0 h1:yg/JjO5E7ubRyKX3m07GF3reDNEnfOboJ0QySbH736g= +github.com/envoyproxy/go-control-plane/envoy v1.36.0/go.mod h1:ty89S1YCCVruQAm9OtKeEkQLTb+Lkz0k8v9W0Oxsv98= +github.com/envoyproxy/protoc-gen-validate v1.3.0 h1:TvGH1wof4H33rezVKWSpqKz5NXWg5VPuZ0uONDT6eb4= +github.com/envoyproxy/protoc-gen-validate v1.3.0/go.mod h1:HvYl7zwPa5mffgyeTUHA9zHIH36nmrm7oCbo4YKoSWA= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs= +github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= @@ -48,15 +89,35 @@ github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= +github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo= +github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= +github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= +github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/cel-go v0.27.0 h1:e7ih85+4qVrBuqQWTW4FKSqZYokVuc3HnhH5keboFTo= github.com/google/cel-go v0.27.0/go.mod h1:tTJ11FWqnhw5KKpnWpvW9CJC3Y9GK4EIS0WXnBbebzw= +github.com/google/flatbuffers v23.5.26+incompatible h1:M9dgRyhJemaM4Sw8+66GHBu8ioaQmyPLg1b8VwK5WJg= +github.com/google/flatbuffers v23.5.26+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/martian/v3 v3.3.3 h1:DIhPTQrbPkgs2yJYdXU/eNACCG5DVQjySNRNlflZ9Fc= +github.com/google/martian/v3 v3.3.3/go.mod h1:iEPrYcgCF7jA9OtScMFQyAlZZ4YXTKEtJ1E6RWzmBA0= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= +github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0= +github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/googleapis/enterprise-certificate-proxy v0.3.11 h1:vAe81Msw+8tKUxi2Dqh/NZMz7475yUvmRIkXr4oN2ao= +github.com/googleapis/enterprise-certificate-proxy v0.3.11/go.mod h1:RFV7MUdlb7AgEq2v7FmMCfeSMCllAzWxFgRdusoGks8= +github.com/googleapis/gax-go/v2 v2.17.0 h1:RksgfBpxqff0EZkDWYuz9q/uWsTVz+kf43LsZ1J6SMc= +github.com/googleapis/gax-go/v2 v2.17.0/go.mod h1:mzaqghpQp4JDh3HvADwrat+6M3MOIDp5YKHhb9PAgDY= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.1 h1:X5VWvz21y3gzm9Nw/kaUeku/1+uBhcekkmy4IkffJww= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.1/go.mod h1:Zanoh4+gvIgluNqcfMVTJueD4wSS5hT7zTt4Mrutd90= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= @@ -67,6 +128,8 @@ github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/klauspost/cpuid/v2 v2.2.5 h1:0E5MSMDEoAulmXNFquVs//DdoomxaoTY1kUhbc/qbZg= +github.com/klauspost/cpuid/v2 v2.2.5/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= @@ -77,6 +140,8 @@ github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE= github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI= github.com/mdelapenya/tlscert v0.2.0/go.mod h1:O4njj3ELLnJjGdkN7M/vIVCpZ+Cf0L6muqOG4tLSl8o= github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= @@ -97,22 +162,32 @@ github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0= github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y= github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= +github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= +github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M= +github.com/pierrec/lz4/v4 v4.1.18 h1:xaKrnTkyoqfh1YItXl56+6KJNVYWlEEPuAQW9xsplYQ= +github.com/pierrec/lz4/v4 v4.1.18/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo= +github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/shirou/gopsutil/v4 v4.25.6 h1:kLysI2JsKorfaFPcYmcJqbzROzsBWEOAtw6A7dIfqXs= github.com/shirou/gopsutil/v4 v4.25.6/go.mod h1:PfybzyydfZcN+JMMjkF6Zb8Mq1A/VcogFFg7hj50W9c= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/spiffe/go-spiffe/v2 v2.6.0 h1:l+DolpxNWYgruGQVV0xsfeya3CsC7m8iBzDnMpsbLuo= +github.com/spiffe/go-spiffe/v2 v2.6.0/go.mod h1:gm2SeUoMZEtpnzPNs2Csc0D/gX33k1xIx7lEzqblHEs= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= @@ -122,6 +197,10 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/testcontainers/testcontainers-go v0.40.0 h1:pSdJYLOVgLE8YdUY2FHQ1Fxu+aMnb6JfVz1mxk7OeMU= github.com/testcontainers/testcontainers-go v0.40.0/go.mod h1:FSXV5KQtX2HAMlm7U3APNyLkkap35zNLxukw9oBi/MY= +github.com/testcontainers/testcontainers-go/modules/gcloud v0.40.0 h1:9Q7AnMCHmLArYtWe0i06hHnmVylJw2FNkJX/Sm0Rpf0= +github.com/testcontainers/testcontainers-go/modules/gcloud v0.40.0/go.mod h1:SyFaMHm4IaOBL8DoNUZ2ov4vlQuU7qBRAcJuUNYw2OA= +github.com/testcontainers/testcontainers-go/modules/mysql v0.40.0 h1:P9Txfy5Jothx2wFdcus0QoSmX/PKSIXZxrTbZPVJswA= +github.com/testcontainers/testcontainers-go/modules/mysql v0.40.0/go.mod h1:oZPHHqJqXG7FD8OB/yWH7gLnDvZUlFHAVJNrGftL+eg= github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0 h1:s2bIayFXlbDFexo96y+htn7FzuhpXLYJNnIuglNKqOk= github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0/go.mod h1:h+u/2KoREGTnTl9UwrQ/g+XhasAT8E6dClclAADeXoQ= github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU= @@ -130,8 +209,16 @@ github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+F github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY= github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= +github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ= +github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= +github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= +github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= +go.opentelemetry.io/contrib/detectors/gcp v1.39.0 h1:kWRNZMsfBHZ+uHjiH4y7Etn2FK26LAGkNFw7RHv1DhE= +go.opentelemetry.io/contrib/detectors/gcp v1.39.0/go.mod h1:t/OGqzHBa5v6RHZwrDBJ2OirWc+4q/w2fTbLZwAKjTk= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.63.0 h1:YH4g8lQroajqUwWbq/tr2QX1JFmEXaDLgG+ew9bLMWo= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.63.0/go.mod h1:fvPi2qXDqFs8M4B4fmJhE92TyQs9Ydjlg3RvfUp+NbQ= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.65.0 h1:7iP2uCb7sGddAr30RRS6xjKy7AZ2JtTOPA3oolgVSw8= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.65.0/go.mod h1:c7hN3ddxs/z6q9xwvfLPk+UHlWRQyaeR1LdgfL/66l0= go.opentelemetry.io/otel v1.40.0 h1:oA5YeOcpRTXq6NN7frwmwFR0Cn3RhTVZvXsP4duvCms= @@ -152,37 +239,55 @@ go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lI go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM= go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= -golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= -golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= -golang.org/x/exp v0.0.0-20240823005443-9b4947da3948 h1:kx6Ds3MlpiUHKj7syVnbp57++8WpuKPcR5yjLBjvLEA= -golang.org/x/exp v0.0.0-20240823005443-9b4947da3948/go.mod h1:akd2r19cwCdwSwWeIdzYQGa/EZZyqcOdwWiwj5L5eKQ= -golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= -golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= +golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= +golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= +golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY= +golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= +golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI= +golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg= +golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= +golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= +golang.org/x/oauth2 v0.35.0 h1:Mv2mzuHuZuY2+bkyWXIHMfhNdJAdwW3FuWeCPYN5GVQ= +golang.org/x/oauth2 v0.35.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q= -golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg= -golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= -golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= -golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 h1:vVKdlvoWBphwdxWKrFZEuM0kGgGLxUOYcY4U/2Vjg44= -golang.org/x/time v0.0.0-20220210224613-90d013bbcef8/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/telemetry v0.0.0-20251203150158-8fff8a5912fc h1:bH6xUXay0AIFMElXG2rQ4uiE+7ncwtiOdPfYK1NK2XA= +golang.org/x/telemetry v0.0.0-20251203150158-8fff8a5912fc/go.mod h1:hKdjCMrbv9skySur+Nek8Hd0uJ0GuxJIoIX2payrIdQ= +golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY= +golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww= +golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= +golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= +golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= +golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= +golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA= +golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822 h1:oWVWY3NzT7KJppx2UKhKmzPq4SRe0LdCijVRwvGeikY= -google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822/go.mod h1:h3c4v36UTKzUiuaOKQ6gr3S+0hovBtUrXzTG/i3+XEc= -google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 h1:fc6jSaCT0vBduLYZHYrBBNY4dsWuvgyff9noRNDdBeE= -google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= -google.golang.org/grpc v1.73.0 h1:VIWSmpI2MegBtTuFt5/JWy2oXxtjJ/e89Z70ImfD2ok= -google.golang.org/grpc v1.73.0/go.mod h1:50sbHOUqWoCQGI8V2HQLJM0B+LMlIUjNSZmow7EVBQc= -google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= -google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= +golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da h1:noIWHXmPHxILtqtCOPIhSt0ABwskkZKjD3bXGnZGpNY= +golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90= +gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= +gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= +google.golang.org/api v0.268.0 h1:hgA3aS4lt9rpF5RCCkX0Q2l7DvHgvlb53y4T4u6iKkA= +google.golang.org/api v0.268.0/go.mod h1:HXMyMH496wz+dAJwD/GkAPLd3ZL33Kh0zEG32eNvy9w= +google.golang.org/genproto v0.0.0-20260128011058-8636f8732409 h1:VQZ/yAbAtjkHgH80teYd2em3xtIkkHd7ZhqfH2N9CsM= +google.golang.org/genproto v0.0.0-20260128011058-8636f8732409/go.mod h1:rxKD3IEILWEu3P44seeNOAwZN4SaoKaQ/2eTg4mM6EM= +google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409 h1:merA0rdPeUV3YIIfHHcH4qBkiQAc1nfCKSI7lB4cV2M= +google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409/go.mod h1:fl8J1IvUjCilwZzQowmw2b7HQB2eAuYBabMXzWurF+I= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260203192932-546029d2fa20 h1:Jr5R2J6F6qWyzINc+4AM8t5pfUz6beZpHp678GNrMbE= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260203192932-546029d2fa20/go.mod h1:j9x/tPzZkyxcgEFkiKEEGxfvyumM01BEtsW8xzOahRQ= +google.golang.org/grpc v1.79.1 h1:zGhSi45ODB9/p3VAawt9a+O/MULLl9dpizzNNpq7flY= +google.golang.org/grpc v1.79.1/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ= +google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= +google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= @@ -191,3 +296,31 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q= gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA= +modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis= +modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0= +modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc= +modernc.org/ccgo/v4 v4.30.1/go.mod h1:bIOeI1JL54Utlxn+LwrFyjCx2n2RDiYEaJVSrgdrRfM= +modernc.org/fileutil v1.3.40 h1:ZGMswMNc9JOCrcrakF1HrvmergNLAmxOPjizirpfqBA= +modernc.org/fileutil v1.3.40/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc= +modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI= +modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito= +modernc.org/gc/v3 v3.1.1 h1:k8T3gkXWY9sEiytKhcgyiZ2L0DTyCQ/nvX+LoCljoRE= +modernc.org/gc/v3 v3.1.1/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY= +modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks= +modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI= +modernc.org/libc v1.67.6 h1:eVOQvpModVLKOdT+LvBPjdQqfrZq+pC39BygcT+E7OI= +modernc.org/libc v1.67.6/go.mod h1:JAhxUVlolfYDErnwiqaLvUqc8nfb2r6S6slAgZOnaiE= +modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= +modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= +modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= +modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= +modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8= +modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= +modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= +modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= +modernc.org/sqlite v1.46.1 h1:eFJ2ShBLIEnUWlLy12raN0Z1plqmFX9Qe3rjQTKt6sU= +modernc.org/sqlite v1.46.1/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA= +modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= +modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= +modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= +modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= diff --git a/json.go b/json.go index 5867bf8..a1351f3 100644 --- a/json.go +++ b/json.go @@ -131,14 +131,10 @@ func (con *converter) buildJSONPathForArray(expr *exprpb.Expr) error { if operandSelect := operand.GetSelectExpr(); operandSelect != nil { // This is nested access - recursively build the path for the operand if con.hasJSONFieldInChain(operand) { - if err := con.buildJSONPathForArray(operand); err != nil { - return err - } - // Add intermediate JSON path operator (always -> for arrays) - con.str.WriteString("->'") - con.str.WriteString(escapeJSONFieldName(field)) - con.str.WriteString("'") - return nil + // Add intermediate JSON path operator (always non-final for arrays) + return con.dialect.WriteJSONFieldAccess(&con.str, func() error { + return con.buildJSONPathForArray(operand) + }, field, false) } } @@ -153,13 +149,9 @@ func (con *converter) buildJSONPathForArray(expr *exprpb.Expr) error { } // For other cases, visit the operand and add JSON operator - if err := con.visit(operand); err != nil { - return err - } - con.str.WriteString("->'") - con.str.WriteString(escapeJSONFieldName(field)) - con.str.WriteString("'") - return nil + return con.dialect.WriteJSONFieldAccess(&con.str, func() error { + return con.visit(operand) + }, field, false) } // isJSONObjectFieldAccess determines if this is a JSON object field access in comprehensions @@ -303,61 +295,34 @@ func (con *converter) buildJSONPathInternal(expr *exprpb.Expr, isFinalField bool // If so, we should NOT apply JSON operators to this level if tableName, columnName, ok := con.getTableAndFieldFromSelectChain(operand); ok { // This is table.column where column is JSON/JSONB - // Render as table.column without JSON operators - con.str.WriteString(tableName) - con.str.WriteString(".") - con.str.WriteString(columnName) - // Now add JSON operator for the current field - if isFinalField { - con.str.WriteString("->>'") // Final field: extract as text - } else { - con.str.WriteString("->'") // Intermediate field: keep as JSON - } - con.str.WriteString(escapeJSONFieldName(field)) - con.str.WriteString("'") - return nil + // Render as table.column without JSON operators, then add JSON operator for the current field + return con.dialect.WriteJSONFieldAccess(&con.str, func() error { + con.str.WriteString(tableName) + con.str.WriteString(".") + con.str.WriteString(columnName) + return nil + }, field, isFinalField) } // This is deeper nesting like table.jsonfield.subfield.finalfield // We need to determine if the operand is JSON-related if con.shouldUseJSONPath(operandSelect.GetOperand(), operandSelect.GetField()) { - // Recursively build the path for the operand (not final since we have more fields) - if err := con.buildJSONPathInternal(operand, false); err != nil { - return err - } // Add appropriate JSON path operator based on whether this is the final field - if isFinalField { - con.str.WriteString("->>'") // Final field: extract as text - } else { - con.str.WriteString("->'") // Intermediate field: keep as JSON - } - con.str.WriteString(escapeJSONFieldName(field)) - con.str.WriteString("'") - return nil + return con.dialect.WriteJSONFieldAccess(&con.str, func() error { + // Recursively build the path for the operand (not final since we have more fields) + return con.buildJSONPathInternal(operand, false) + }, field, isFinalField) } } - // Visit the base operand (like table.jsonfield) - if err := con.visit(operand); err != nil { - return err - } - - // Add the appropriate JSON path operator based on whether this is the final field - operator := "->>" - if !isFinalField { - operator = "->" - } - con.logger.LogAttrs(context.Background(), slog.LevelDebug, "JSON path operator selection", slog.String("field", field), - slog.String("operator", operator), slog.Bool("is_final", isFinalField), ) - con.str.WriteString(operator) - con.str.WriteString("'") - con.str.WriteString(escapeJSONFieldName(field)) - con.str.WriteString("'") - return nil + // Visit the base operand (like table.jsonfield) + return con.dialect.WriteJSONFieldAccess(&con.str, func() error { + return con.visit(operand) + }, field, isFinalField) } diff --git a/mysql/provider.go b/mysql/provider.go new file mode 100644 index 0000000..781f40b --- /dev/null +++ b/mysql/provider.go @@ -0,0 +1,226 @@ +// Package mysql provides MySQL type provider for CEL type system integration. +package mysql + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" + + "github.com/google/cel-go/checker/decls" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + + "github.com/spandigital/cel2sql/v3/schema" + "github.com/spandigital/cel2sql/v3/sqltypes" +) + +// Sentinel errors for the mysql package. +var ( + // ErrInvalidSchema indicates a problem with the provided schema or database introspection. + ErrInvalidSchema = errors.New("invalid schema") +) + +// FieldSchema is an alias for schema.FieldSchema. +type FieldSchema = schema.FieldSchema + +// Schema is an alias for schema.Schema. +type Schema = schema.Schema + +// NewSchema creates a new Schema. This is an alias for schema.NewSchema. +var NewSchema = schema.NewSchema + +// TypeProvider interface for MySQL type providers. +type TypeProvider interface { + types.Provider + LoadTableSchema(ctx context.Context, tableName string) error + GetSchemas() map[string]Schema + Close() +} + +type typeProvider struct { + schemas map[string]Schema + db *sql.DB +} + +// NewTypeProvider creates a new MySQL type provider with pre-defined schemas. +func NewTypeProvider(schemas map[string]Schema) TypeProvider { + return &typeProvider{schemas: schemas} +} + +// NewTypeProviderWithConnection creates a new MySQL type provider that can introspect database schemas. +// The caller owns the *sql.DB and is responsible for closing it. +func NewTypeProviderWithConnection(_ context.Context, db *sql.DB) (TypeProvider, error) { + if db == nil { + return nil, fmt.Errorf("%w: db connection must not be nil", ErrInvalidSchema) + } + + return &typeProvider{ + schemas: make(map[string]Schema), + db: db, + }, nil +} + +// LoadTableSchema loads schema information for a table from the database. +func (tp *typeProvider) LoadTableSchema(ctx context.Context, tableName string) error { + if tp.db == nil { + return fmt.Errorf("%w: no database connection available", ErrInvalidSchema) + } + + query := ` + SELECT column_name, data_type, column_type, is_nullable + FROM information_schema.columns + WHERE table_schema = DATABASE() AND table_name = ? + ORDER BY ordinal_position + ` + + rows, err := tp.db.QueryContext(ctx, query, tableName) + if err != nil { + return fmt.Errorf("%w: failed to query table schema", ErrInvalidSchema) + } + defer func() { _ = rows.Close() }() + + var fields []FieldSchema + for rows.Next() { + var columnName, dataType, columnType, isNullable string + + if err := rows.Scan(&columnName, &dataType, &columnType, &isNullable); err != nil { + return fmt.Errorf("%w: failed to scan row", ErrInvalidSchema) + } + + field := mysqlColumnToFieldSchema(columnName, dataType, columnType) + fields = append(fields, field) + } + + if err := rows.Err(); err != nil { + return fmt.Errorf("%w: error iterating rows", ErrInvalidSchema) + } + + if len(fields) == 0 { + return fmt.Errorf("%w: table %q has no columns or does not exist", ErrInvalidSchema, tableName) + } + + tp.schemas[tableName] = NewSchema(fields) + return nil +} + +// mysqlColumnToFieldSchema converts MySQL column metadata to a FieldSchema. +func mysqlColumnToFieldSchema(columnName, dataType, _ string) FieldSchema { + // Normalize data type to lowercase + dataType = strings.ToLower(dataType) + + isJSON := dataType == "json" + + return FieldSchema{ + Name: columnName, + Type: dataType, + IsJSON: isJSON, + } +} + +// Close is a no-op since we don't own the *sql.DB. +func (tp *typeProvider) Close() { + // No-op: caller owns the *sql.DB connection +} + +// GetSchemas returns the schemas known to this type provider. +func (tp *typeProvider) GetSchemas() map[string]Schema { + return tp.schemas +} + +// EnumValue implements types.Provider. +func (tp *typeProvider) EnumValue(_ string) ref.Val { + return types.NewErr("unknown enum value") +} + +// FindIdent implements types.Provider. +func (tp *typeProvider) FindIdent(_ string) (ref.Val, bool) { + return nil, false +} + +// FindStructType implements types.Provider. +func (tp *typeProvider) FindStructType(structType string) (*types.Type, bool) { + if _, ok := tp.schemas[structType]; ok { + return types.NewObjectType(structType), true + } + return nil, false +} + +// FindStructFieldNames implements types.Provider. +func (tp *typeProvider) FindStructFieldNames(structType string) ([]string, bool) { + s, ok := tp.schemas[structType] + if !ok { + return nil, false + } + fields := s.Fields() + names := make([]string, len(fields)) + for i, f := range fields { + names[i] = f.Name + } + return names, true +} + +// FindStructFieldType implements types.Provider. +func (tp *typeProvider) FindStructFieldType(structType, fieldName string) (*types.FieldType, bool) { + s, ok := tp.schemas[structType] + if !ok { + return nil, false + } + field, found := s.FindField(fieldName) + if !found { + return nil, false + } + + exprType := mysqlTypeToCELExprType(field) + + celType, err := types.ExprTypeToType(exprType) + if err != nil { + return nil, false + } + + return &types.FieldType{ + Type: celType, + }, true +} + +// NewValue implements types.Provider. +func (tp *typeProvider) NewValue(_ string, _ map[string]ref.Val) ref.Val { + return types.NewErr("unknown type in schema") +} + +// mysqlTypeToCELExprType converts a MySQL field schema to a CEL expression type. +func mysqlTypeToCELExprType(field *schema.FieldSchema) *exprpb.Type { + baseType := mysqlBaseTypeToCEL(field.Type) + if field.Repeated { + return decls.NewListType(baseType) + } + return baseType +} + +// mysqlBaseTypeToCEL converts a MySQL type name to a CEL expression type. +func mysqlBaseTypeToCEL(typeName string) *exprpb.Type { + switch typeName { + case "varchar", "char", "text", "tinytext", "mediumtext", "longtext", "enum", "set": + return decls.String + case "int", "integer", "tinyint", "smallint", "mediumint", "bigint": + return decls.Int + case "float", "double", "decimal", "numeric", "real": + return decls.Double + case "boolean", "bool": + return decls.Bool + case "blob", "binary", "varbinary", "tinyblob", "mediumblob", "longblob": + return decls.Bytes + case "json": + return decls.Dyn + case "datetime", "timestamp": + return decls.Timestamp + case "date": + return sqltypes.Date + case "time": + return sqltypes.Time + default: + return decls.Dyn + } +} diff --git a/mysql/provider_test.go b/mysql/provider_test.go new file mode 100644 index 0000000..5a73099 --- /dev/null +++ b/mysql/provider_test.go @@ -0,0 +1,280 @@ +package mysql_test + +import ( + "context" + "database/sql" + "fmt" + "testing" + + _ "github.com/go-sql-driver/mysql" + "github.com/google/cel-go/common/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/testcontainers/testcontainers-go" + tcmysql "github.com/testcontainers/testcontainers-go/modules/mysql" + + "github.com/spandigital/cel2sql/v3/mysql" +) + +func TestNewTypeProvider(t *testing.T) { + schemas := map[string]mysql.Schema{ + "users": mysql.NewSchema([]mysql.FieldSchema{ + {Name: "id", Type: "int"}, + {Name: "name", Type: "varchar"}, + {Name: "email", Type: "text"}, + }), + } + + provider := mysql.NewTypeProvider(schemas) + require.NotNil(t, provider) + + got := provider.GetSchemas() + assert.Len(t, got, 1) + assert.Contains(t, got, "users") +} + +func TestNewTypeProviderWithConnection_NilDB(t *testing.T) { + _, err := mysql.NewTypeProviderWithConnection(context.Background(), nil) + require.Error(t, err) + assert.ErrorIs(t, err, mysql.ErrInvalidSchema) +} + +func TestLoadTableSchema_NoDB(t *testing.T) { + provider := mysql.NewTypeProvider(nil) + err := provider.LoadTableSchema(context.Background(), "test") + require.Error(t, err) + assert.ErrorIs(t, err, mysql.ErrInvalidSchema) +} + +func TestTypeProvider_FindStructType(t *testing.T) { + schemas := map[string]mysql.Schema{ + "users": mysql.NewSchema([]mysql.FieldSchema{ + {Name: "id", Type: "int"}, + {Name: "name", Type: "varchar"}, + }), + } + provider := mysql.NewTypeProvider(schemas) + + typ, found := provider.FindStructType("users") + assert.True(t, found) + assert.NotNil(t, typ) + + _, found = provider.FindStructType("nonexistent") + assert.False(t, found) +} + +func TestTypeProvider_FindStructFieldNames(t *testing.T) { + schemas := map[string]mysql.Schema{ + "users": mysql.NewSchema([]mysql.FieldSchema{ + {Name: "id", Type: "int"}, + {Name: "name", Type: "varchar"}, + {Name: "email", Type: "text"}, + }), + } + provider := mysql.NewTypeProvider(schemas) + + names, found := provider.FindStructFieldNames("users") + assert.True(t, found) + assert.ElementsMatch(t, []string{"id", "name", "email"}, names) + + _, found = provider.FindStructFieldNames("nonexistent") + assert.False(t, found) +} + +func TestTypeProvider_FindStructFieldType(t *testing.T) { + schemas := map[string]mysql.Schema{ + "test_table": mysql.NewSchema([]mysql.FieldSchema{ + {Name: "str_field", Type: "varchar"}, + {Name: "int_field", Type: "int"}, + {Name: "bigint_field", Type: "bigint"}, + {Name: "float_field", Type: "float"}, + {Name: "double_field", Type: "double"}, + {Name: "decimal_field", Type: "decimal"}, + {Name: "bool_field", Type: "boolean"}, + {Name: "blob_field", Type: "blob"}, + {Name: "json_field", Type: "json"}, + {Name: "datetime_field", Type: "datetime"}, + {Name: "timestamp_field", Type: "timestamp"}, + {Name: "text_field", Type: "text"}, + {Name: "enum_field", Type: "enum"}, + }), + } + provider := mysql.NewTypeProvider(schemas) + + tests := []struct { + fieldName string + wantType *types.Type + wantFound bool + }{ + {"str_field", types.StringType, true}, + {"int_field", types.IntType, true}, + {"bigint_field", types.IntType, true}, + {"float_field", types.DoubleType, true}, + {"double_field", types.DoubleType, true}, + {"decimal_field", types.DoubleType, true}, + {"bool_field", types.BoolType, true}, + {"blob_field", types.BytesType, true}, + {"json_field", types.DynType, true}, + {"datetime_field", types.TimestampType, true}, + {"timestamp_field", types.TimestampType, true}, + {"text_field", types.StringType, true}, + {"enum_field", types.StringType, true}, + {"nonexistent", nil, false}, + } + + for _, tt := range tests { + t.Run(tt.fieldName, func(t *testing.T) { + got, found := provider.FindStructFieldType("test_table", tt.fieldName) + assert.Equal(t, tt.wantFound, found) + if found { + assert.Equal(t, tt.wantType, got.Type) + } + }) + } +} + +func TestTypeProvider_Close(_ *testing.T) { + provider := mysql.NewTypeProvider(nil) + // Close should not panic + provider.Close() +} + +// setupMySQLContainer starts a MySQL 8 container and returns a database connection. +func setupMySQLContainer(ctx context.Context, t *testing.T) (*tcmysql.MySQLContainer, *sql.DB) { + t.Helper() + + container, err := tcmysql.Run(ctx, "mysql:8.0", + tcmysql.WithDatabase("testdb"), + tcmysql.WithUsername("testuser"), + tcmysql.WithPassword("testpass"), + ) + require.NoError(t, err, "Failed to start MySQL container") + + host, err := container.Host(ctx) + require.NoError(t, err) + port, err := container.MappedPort(ctx, "3306") + require.NoError(t, err) + + connStr := fmt.Sprintf("testuser:testpass@tcp(%s:%s)/testdb?parseTime=true", + host, port.Port()) + db, err := sql.Open("mysql", connStr) + require.NoError(t, err, "Failed to connect to MySQL database") + + err = db.Ping() + require.NoError(t, err, "Failed to ping MySQL database") + + return container, db +} + +func TestLoadTableSchema_Integration(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + ctx := context.Background() + container, db := setupMySQLContainer(ctx, t) + defer func() { + if closeErr := db.Close(); closeErr != nil { + t.Logf("failed to close db: %v", closeErr) + } + if err := testcontainers.TerminateContainer(container); err != nil { + t.Logf("failed to terminate container: %v", err) + } + }() + + // Create test table with various MySQL types + _, err := db.ExecContext(ctx, ` + CREATE TABLE schema_test ( + id INT PRIMARY KEY AUTO_INCREMENT, + name VARCHAR(255) NOT NULL, + description TEXT, + age INT, + score DOUBLE, + price DECIMAL(10,2), + is_active BOOLEAN, + avatar BLOB, + metadata JSON, + created_at DATETIME, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + `) + require.NoError(t, err) + + provider, err := mysql.NewTypeProviderWithConnection(ctx, db) + require.NoError(t, err) + + err = provider.LoadTableSchema(ctx, "schema_test") + require.NoError(t, err) + + // Verify schema was loaded + schemas := provider.GetSchemas() + assert.Contains(t, schemas, "schema_test") + + // Verify FindStructType + typ, found := provider.FindStructType("schema_test") + assert.True(t, found) + assert.NotNil(t, typ) + + // Verify FindStructFieldNames + names, found := provider.FindStructFieldNames("schema_test") + assert.True(t, found) + assert.Contains(t, names, "id") + assert.Contains(t, names, "name") + assert.Contains(t, names, "metadata") + + // Verify type mappings from loaded schema + tests := []struct { + fieldName string + wantType *types.Type + }{ + {"id", types.IntType}, + {"name", types.StringType}, + {"description", types.StringType}, + {"age", types.IntType}, + {"score", types.DoubleType}, + {"price", types.DoubleType}, + {"is_active", types.IntType}, // MySQL BOOLEAN is TINYINT(1), data_type shows "tinyint" + {"metadata", types.DynType}, + {"created_at", types.TimestampType}, + } + + for _, tt := range tests { + t.Run("type_"+tt.fieldName, func(t *testing.T) { + got, found := provider.FindStructFieldType("schema_test", tt.fieldName) + assert.True(t, found, "field %q should be found", tt.fieldName) + if found { + assert.Equal(t, tt.wantType, got.Type, "field %q type mismatch", tt.fieldName) + } + }) + } + + // Verify JSON detection + schemaObj := schemas["schema_test"] + metadataField, found := schemaObj.FindField("metadata") + assert.True(t, found) + assert.True(t, metadataField.IsJSON, "metadata should be detected as JSON") +} + +func TestLoadTableSchema_NonexistentTable(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + ctx := context.Background() + container, db := setupMySQLContainer(ctx, t) + defer func() { + if closeErr := db.Close(); closeErr != nil { + t.Logf("failed to close db: %v", closeErr) + } + if err := testcontainers.TerminateContainer(container); err != nil { + t.Logf("failed to terminate container: %v", err) + } + }() + + provider, err := mysql.NewTypeProviderWithConnection(ctx, db) + require.NoError(t, err) + + err = provider.LoadTableSchema(ctx, "nonexistent_table") + require.Error(t, err) + assert.ErrorIs(t, err, mysql.ErrInvalidSchema) +} diff --git a/mysql_integration_test.go b/mysql_integration_test.go new file mode 100644 index 0000000..768dbb9 --- /dev/null +++ b/mysql_integration_test.go @@ -0,0 +1,447 @@ +package cel2sql_test + +import ( + "context" + "database/sql" + "fmt" + "testing" + + _ "github.com/go-sql-driver/mysql" + "github.com/google/cel-go/cel" + "github.com/stretchr/testify/require" + "github.com/testcontainers/testcontainers-go" + tcmysql "github.com/testcontainers/testcontainers-go/modules/mysql" + + "github.com/spandigital/cel2sql/v3" + mysqlDialect "github.com/spandigital/cel2sql/v3/dialect/mysql" + "github.com/spandigital/cel2sql/v3/pg" +) + +// setupMySQLContainer starts a MySQL 8 container and returns a database connection. +func setupMySQLContainer(ctx context.Context, t *testing.T) (testcontainers.Container, *sql.DB) { + t.Helper() + + container, err := tcmysql.Run(ctx, "mysql:8.0", + tcmysql.WithDatabase("testdb"), + tcmysql.WithUsername("testuser"), + tcmysql.WithPassword("testpass"), + ) + require.NoError(t, err, "Failed to start MySQL container") + + // Get connection string + host, err := container.Host(ctx) + require.NoError(t, err) + port, err := container.MappedPort(ctx, "3306") + require.NoError(t, err) + + connStr := fmt.Sprintf("testuser:testpass@tcp(%s:%s)/testdb?parseTime=true", + host, port.Port()) + db, err := sql.Open("mysql", connStr) + require.NoError(t, err, "Failed to connect to MySQL database") + + err = db.Ping() + require.NoError(t, err, "Failed to ping MySQL database") + + return container, db +} + +// TestMySQLOperatorsIntegration validates operator conversions against a real MySQL database. +func TestMySQLOperatorsIntegration(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + ctx := context.Background() + container, db := setupMySQLContainer(ctx, t) + defer func() { + if closeErr := db.Close(); closeErr != nil { + t.Logf("failed to close db: %v", closeErr) + } + if err := container.Terminate(ctx); err != nil { + t.Logf("failed to terminate container: %v", err) + } + }() + + // Create test table + _, err := db.Exec(` + CREATE TABLE test_data ( + id INTEGER PRIMARY KEY, + text_val TEXT, + int_val INTEGER, + float_val DOUBLE, + bool_val BOOLEAN, + nullable_text TEXT, + nullable_int INTEGER + ) + `) + require.NoError(t, err) + + // Insert test data + _, err = db.Exec(` + INSERT INTO test_data VALUES + (1, 'hello', 10, 10.5, true, 'present', 100), + (2, 'world', 20, 20.5, false, NULL, NULL), + (3, 'test', 30, 30.5, true, 'here', 200), + (4, 'hello world', 5, 5.5, false, 'value', 50), + (5, 'testing', 15, 15.5, true, 'test', 150) + `) + require.NoError(t, err) + + // Set up CEL environment with simple variables + env, err := cel.NewEnv( + cel.Variable("id", cel.IntType), + cel.Variable("text_val", cel.StringType), + cel.Variable("int_val", cel.IntType), + cel.Variable("float_val", cel.DoubleType), + cel.Variable("bool_val", cel.BoolType), + cel.Variable("nullable_text", cel.StringType), + cel.Variable("nullable_int", cel.IntType), + ) + require.NoError(t, err) + + dialectOpt := cel2sql.WithDialect(mysqlDialect.New()) + + tests := []struct { + name string + celExpr string + expectedRows int + description string + }{ + // Comparison operators + { + name: "Equality string", + celExpr: `text_val == "hello"`, + expectedRows: 1, + description: "String equality comparison", + }, + { + name: "Equality integer", + celExpr: `int_val == 20`, + expectedRows: 1, + description: "Integer equality comparison", + }, + { + name: "Equality float", + celExpr: `float_val == 10.5`, + expectedRows: 1, + description: "Float equality comparison", + }, + { + name: "Equality boolean", + celExpr: `bool_val == true`, + expectedRows: 3, + description: "Boolean equality comparison", + }, + { + name: "Not equal", + celExpr: `text_val != "hello"`, + expectedRows: 4, + description: "Not equal comparison", + }, + { + name: "Less than", + celExpr: `int_val < 15`, + expectedRows: 2, // 10, 5 + description: "Less than comparison", + }, + { + name: "Less than or equal", + celExpr: `int_val <= 15`, + expectedRows: 3, // 10, 5, 15 + description: "Less than or equal comparison", + }, + { + name: "Greater than", + celExpr: `int_val > 15`, + expectedRows: 2, // 20, 30 + description: "Greater than comparison", + }, + { + name: "Greater than or equal", + celExpr: `int_val >= 15`, + expectedRows: 3, // 20, 30, 15 + description: "Greater than or equal comparison", + }, + + // Logical operators + { + name: "Logical AND", + celExpr: `int_val > 10 && bool_val == true`, + expectedRows: 2, // rows 3 (30,true) and 5 (15,true) + description: "Logical AND operator", + }, + { + name: "Logical OR", + celExpr: `int_val < 10 || bool_val == false`, + expectedRows: 2, // rows 2 (20,false) and 4 (5,false) + description: "Logical OR operator", + }, + { + name: "Logical NOT", + celExpr: `!bool_val`, + expectedRows: 2, // rows 2 and 4 + description: "Logical NOT operator", + }, + { + name: "Complex logical expression", + celExpr: `(int_val > 10 && bool_val) || int_val < 10`, + expectedRows: 3, // rows 3, 5, 4 + description: "Complex nested logical operators", + }, + + // Arithmetic operators + { + name: "Addition", + celExpr: `int_val + 10 == 20`, + expectedRows: 1, // 10 + 10 = 20 + description: "Addition operator", + }, + { + name: "Subtraction", + celExpr: `int_val - 5 == 15`, + expectedRows: 1, // 20 - 5 = 15 + description: "Subtraction operator", + }, + { + name: "Multiplication", + celExpr: `int_val * 2 == 20`, + expectedRows: 1, // 10 * 2 = 20 + description: "Multiplication operator", + }, + { + name: "Division", + celExpr: `int_val / 2 == 10`, + expectedRows: 1, // 20 / 2 = 10 + description: "Division operator", + }, + { + name: "Modulo", + celExpr: `int_val % 10 == 0`, + expectedRows: 3, // 10, 20, 30 + description: "Modulo operator", + }, + { + name: "Complex arithmetic", + celExpr: `(int_val * 2) + 5 > 30`, + expectedRows: 3, // (20*2)+5=45, (30*2)+5=65, (15*2)+5=35 + description: "Complex arithmetic expression", + }, + + // String operators + { + name: "String concatenation", + celExpr: `text_val + "!" == "hello!"`, + expectedRows: 1, + description: "String concatenation (CONCAT)", + }, + { + name: "String contains", + celExpr: `text_val.contains("world")`, + expectedRows: 2, // "world", "hello world" + description: "String contains function (LOCATE)", + }, + { + name: "String startsWith", + celExpr: `text_val.startsWith("hello")`, + expectedRows: 2, // "hello", "hello world" + description: "String startsWith function (LIKE)", + }, + { + name: "String endsWith", + celExpr: `text_val.endsWith("world")`, + expectedRows: 2, // "world", "hello world" + description: "String endsWith function (LIKE)", + }, + + // Regex (MySQL 8.0+ supports REGEXP) + { + name: "Regex match", + celExpr: `text_val.matches(r"^hello")`, + expectedRows: 2, // "hello", "hello world" + description: "Regex match (REGEXP)", + }, + { + name: "Regex word boundary", + celExpr: `text_val.matches(r"test")`, + expectedRows: 2, // "test", "testing" + description: "Regex simple pattern", + }, + + // Complex combined operators + { + name: "Complex multi-operator expression", + celExpr: `int_val > 10 && bool_val && text_val.contains("test")`, + expectedRows: 2, // rows 3 and 5 + description: "Complex expression with multiple operator types", + }, + { + name: "Nested parenthesized operators", + celExpr: `((int_val + 5) * 2 > 30) && (text_val.contains("test") || bool_val)`, + expectedRows: 2, // rows 3 and 5 + description: "Deeply nested operators with parentheses", + }, + { + name: "Triple negation", + celExpr: `!!!bool_val`, + expectedRows: 2, + description: "Multiple NOT operators", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Compile CEL expression + ast, issues := env.Compile(tt.celExpr) + if issues != nil && issues.Err() != nil { + t.Fatalf("CEL compilation failed: %v", issues.Err()) + } + + // Convert to SQL with MySQL dialect + sqlCondition, err := cel2sql.Convert(ast, dialectOpt) + require.NoError(t, err, "Conversion should succeed for: %s", tt.description) + + t.Logf("CEL Expression: %s", tt.celExpr) + t.Logf("Generated SQL WHERE clause: %s", sqlCondition) + + // Execute query and count results + // #nosec G202 - This is a test validating SQL generation, not a security risk + query := "SELECT COUNT(*) FROM test_data WHERE " + sqlCondition + t.Logf("Full SQL Query: %s", query) + + var actualRows int + err = db.QueryRow(query).Scan(&actualRows) + require.NoError(t, err, "Generated SQL should execute successfully. %s\nSQL: %s", + tt.description, sqlCondition) + + require.Equal(t, tt.expectedRows, actualRows, + "Query should return expected number of rows. %s\nCEL: %s\nSQL: %s", + tt.description, tt.celExpr, sqlCondition) + + t.Logf("OK: %s (expected %d rows, got %d rows)", + tt.description, tt.expectedRows, actualRows) + }) + } +} + +// TestMySQLJSONIntegration validates JSON operations against a real MySQL database. +func TestMySQLJSONIntegration(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + ctx := context.Background() + container, db := setupMySQLContainer(ctx, t) + defer func() { + if closeErr := db.Close(); closeErr != nil { + t.Logf("failed to close db: %v", closeErr) + } + if err := container.Terminate(ctx); err != nil { + t.Logf("failed to terminate container: %v", err) + } + }() + + // Create test table with JSON column + _, err := db.Exec(` + CREATE TABLE products ( + id INTEGER PRIMARY KEY, + name TEXT, + price DOUBLE, + metadata JSON + ) + `) + require.NoError(t, err) + + // Insert test data with JSON metadata + _, err = db.Exec(` + INSERT INTO products VALUES + (1, 'Widget', 19.99, '{"brand": "Acme", "color": "red", "specs": {"weight": 100}}'), + (2, 'Gadget', 29.99, '{"brand": "Beta", "color": "blue", "specs": {"weight": 200}}'), + (3, 'Doohickey', 39.99, '{"brand": "Acme", "color": "green", "specs": {"weight": 150}}') + `) + require.NoError(t, err) + + // Set up CEL environment with schema for JSON detection + productSchema := pg.NewSchema([]pg.FieldSchema{ + {Name: "id", Type: "integer"}, + {Name: "name", Type: "text"}, + {Name: "price", Type: "double precision"}, + {Name: "metadata", Type: "json", IsJSON: true}, + }) + + schemas := map[string]pg.Schema{ + "product": productSchema, + } + + provider := pg.NewTypeProvider(schemas) + + env, err := cel.NewEnv( + cel.CustomTypeProvider(provider), + cel.Variable("product", cel.ObjectType("product")), + ) + require.NoError(t, err) + + dialectOpt := cel2sql.WithDialect(mysqlDialect.New()) + schemaOpt := cel2sql.WithSchemas(schemas) + + tests := []struct { + name string + celExpr string + expectedRows int + description string + }{ + { + name: "JSON field access", + celExpr: `product.metadata.brand == "Acme"`, + expectedRows: 2, + description: "JSON field access with ->>", + }, + { + name: "JSON field access different value", + celExpr: `product.metadata.color == "blue"`, + expectedRows: 1, + description: "JSON field access with different value", + }, + { + name: "JSON with regular field", + celExpr: `product.metadata.brand == "Acme" && product.price > 30.0`, + expectedRows: 1, // Doohickey (Acme, 39.99) + description: "JSON field combined with regular field comparison", + }, + { + name: "JSON field existence", + celExpr: `has(product.metadata.brand)`, + expectedRows: 3, // All rows have 'brand' + description: "JSON field existence check", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ast, issues := env.Compile(tt.celExpr) + if issues != nil && issues.Err() != nil { + t.Fatalf("CEL compilation failed: %v", issues.Err()) + } + + sqlCondition, err := cel2sql.Convert(ast, dialectOpt, schemaOpt) + require.NoError(t, err, "Conversion should succeed for: %s", tt.description) + + t.Logf("CEL Expression: %s", tt.celExpr) + t.Logf("Generated SQL WHERE clause: %s", sqlCondition) + + // #nosec G202 - This is a test validating SQL generation, not a security risk + query := "SELECT COUNT(*) FROM products product WHERE " + sqlCondition + t.Logf("Full SQL Query: %s", query) + + var actualRows int + err = db.QueryRow(query).Scan(&actualRows) + require.NoError(t, err, "Generated SQL should execute successfully. %s\nSQL: %s", + tt.description, sqlCondition) + + require.Equal(t, tt.expectedRows, actualRows, + "Query should return expected number of rows. %s\nCEL: %s\nSQL: %s", + tt.description, tt.celExpr, sqlCondition) + + t.Logf("OK: %s", tt.description) + }) + } +} diff --git a/pg/provider.go b/pg/provider.go index f48565b..0565251 100644 --- a/pg/provider.go +++ b/pg/provider.go @@ -13,6 +13,7 @@ import ( "github.com/jackc/pgx/v5/pgxpool" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + "github.com/spandigital/cel2sql/v3/schema" "github.com/spandigital/cel2sql/v3/sqltypes" ) @@ -36,70 +37,17 @@ const ( errMsgUnknownType = "unknown type in schema" ) -// FieldSchema represents a PostgreSQL field type with name, type, and optional nested schema. -type FieldSchema struct { - Name string - Type string // PostgreSQL type name (text, integer, boolean, etc.) - Repeated bool // true for arrays - Dimensions int // number of array dimensions (1 for integer[], 2 for integer[][], etc.) - Schema []FieldSchema // for composite types - IsJSON bool // true for json/jsonb types - IsJSONB bool // true for jsonb (vs json) - ElementType string // for arrays: element type name -} - -// Schema represents a PostgreSQL table schema with O(1) field lookup. -// It contains a slice of fields for ordered iteration and a map index for fast lookups. -type Schema struct { - fields []FieldSchema - fieldIndex map[string]*FieldSchema -} - -// NewSchema creates a new Schema with field indexing for O(1) lookups. -// This improves performance for tables with many columns. -func NewSchema(fields []FieldSchema) Schema { - index := make(map[string]*FieldSchema, len(fields)) - for i := range fields { - index[fields[i].Name] = &fields[i] - - // Build indices for nested schemas recursively - if len(fields[i].Schema) > 0 { - fields[i].Schema = rebuildSchemaIndex(fields[i].Schema) - } - } - - return Schema{ - fields: fields, - fieldIndex: index, - } -} - -// rebuildSchemaIndex recursively rebuilds indices for nested schemas. -// This is used internally when converting old-style []FieldSchema to new Schema struct. -func rebuildSchemaIndex(oldSchema []FieldSchema) []FieldSchema { - // For nested schemas, we need to ensure they're properly indexed too - // But since nested schemas are stored as []FieldSchema in FieldSchema.Schema, - // we keep them as slices but process them when needed - return oldSchema -} +// FieldSchema is an alias for schema.FieldSchema for backward compatibility. +// New code should prefer schema.FieldSchema directly. +type FieldSchema = schema.FieldSchema -// Fields returns the ordered slice of field schemas. -// Use this when you need to iterate over fields in their defined order. -func (s Schema) Fields() []FieldSchema { - return s.fields -} - -// FindField performs an O(1) lookup for a field by name. -// Returns the field schema and true if found, nil and false otherwise. -func (s Schema) FindField(name string) (*FieldSchema, bool) { - field, found := s.fieldIndex[name] - return field, found -} +// Schema is an alias for schema.Schema for backward compatibility. +// New code should prefer schema.Schema directly. +type Schema = schema.Schema -// Len returns the number of fields in the schema. -func (s Schema) Len() int { - return len(s.fields) -} +// NewSchema creates a new Schema. This is an alias for schema.NewSchema. +// New code should prefer schema.NewSchema directly. +var NewSchema = schema.NewSchema // TypeProvider interface for PostgreSQL type providers type TypeProvider interface { @@ -294,7 +242,7 @@ func (p *typeProvider) findSchema(typeName string) (Schema, bool) { } // For nested types, traverse the schema hierarchy using O(1) lookups - currentFields := schema.fields + currentFields := schema.Fields() for _, tn := range typeNames[1:] { // Use O(1) indexed lookup instead of linear search var nestedField *FieldSchema diff --git a/schema/schema.go b/schema/schema.go new file mode 100644 index 0000000..5bb3547 --- /dev/null +++ b/schema/schema.go @@ -0,0 +1,70 @@ +// Package schema provides dialect-agnostic database schema types for CEL to SQL conversion. +// These types describe column names, types, array dimensions, and JSON flags without +// coupling to any specific SQL dialect. +package schema + +// FieldSchema represents a database field type with name, type, and optional nested schema. +// This type is dialect-agnostic and used by all SQL dialect providers. +type FieldSchema struct { + Name string + Type string // SQL type name (text, integer, boolean, etc.) + Repeated bool // true for arrays + Dimensions int // number of array dimensions (1 for integer[], 2 for integer[][], etc.) + Schema []FieldSchema // for composite types + IsJSON bool // true for json/jsonb types + IsJSONB bool // true for jsonb (vs json) + ElementType string // for arrays: element type name +} + +// Schema represents a table schema with O(1) field lookup. +// It contains a slice of fields for ordered iteration and a map index for fast lookups. +type Schema struct { + fields []FieldSchema + fieldIndex map[string]*FieldSchema +} + +// NewSchema creates a new Schema with field indexing for O(1) lookups. +// This improves performance for tables with many columns. +func NewSchema(fields []FieldSchema) Schema { + index := make(map[string]*FieldSchema, len(fields)) + for i := range fields { + index[fields[i].Name] = &fields[i] + + // Build indices for nested schemas recursively + if len(fields[i].Schema) > 0 { + fields[i].Schema = rebuildSchemaIndex(fields[i].Schema) + } + } + + return Schema{ + fields: fields, + fieldIndex: index, + } +} + +// rebuildSchemaIndex recursively rebuilds indices for nested schemas. +// This is used internally when converting old-style []FieldSchema to new Schema struct. +func rebuildSchemaIndex(oldSchema []FieldSchema) []FieldSchema { + // For nested schemas, we need to ensure they're properly indexed too + // But since nested schemas are stored as []FieldSchema in FieldSchema.Schema, + // we keep them as slices but process them when needed + return oldSchema +} + +// Fields returns the ordered slice of field schemas. +// Use this when you need to iterate over fields in their defined order. +func (s Schema) Fields() []FieldSchema { + return s.fields +} + +// FindField performs an O(1) lookup for a field by name. +// Returns the field schema and true if found, nil and false otherwise. +func (s Schema) FindField(name string) (*FieldSchema, bool) { + field, found := s.fieldIndex[name] + return field, found +} + +// Len returns the number of fields in the schema. +func (s Schema) Len() int { + return len(s.fields) +} diff --git a/sqlite/provider.go b/sqlite/provider.go new file mode 100644 index 0000000..0877074 --- /dev/null +++ b/sqlite/provider.go @@ -0,0 +1,280 @@ +// Package sqlite provides SQLite type provider for CEL type system integration. +package sqlite + +import ( + "context" + "database/sql" + "errors" + "fmt" + "regexp" + "strings" + + "github.com/google/cel-go/checker/decls" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + + "github.com/spandigital/cel2sql/v3/schema" +) + +// Sentinel errors for the sqlite package. +var ( + // ErrInvalidSchema indicates a problem with the provided schema or database introspection. + ErrInvalidSchema = errors.New("invalid schema") +) + +// validTableName matches safe SQLite table names (letters, digits, underscores). +var validTableName = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) + +// FieldSchema is an alias for schema.FieldSchema. +type FieldSchema = schema.FieldSchema + +// Schema is an alias for schema.Schema. +type Schema = schema.Schema + +// NewSchema creates a new Schema. This is an alias for schema.NewSchema. +var NewSchema = schema.NewSchema + +// TypeProvider interface for SQLite type providers. +type TypeProvider interface { + types.Provider + LoadTableSchema(ctx context.Context, tableName string) error + GetSchemas() map[string]Schema + Close() +} + +type typeProvider struct { + schemas map[string]Schema + db *sql.DB +} + +// NewTypeProvider creates a new SQLite type provider with pre-defined schemas. +func NewTypeProvider(schemas map[string]Schema) TypeProvider { + return &typeProvider{schemas: schemas} +} + +// NewTypeProviderWithConnection creates a new SQLite type provider that can introspect database schemas. +// The caller owns the *sql.DB and is responsible for closing it. +func NewTypeProviderWithConnection(_ context.Context, db *sql.DB) (TypeProvider, error) { + if db == nil { + return nil, fmt.Errorf("%w: db connection must not be nil", ErrInvalidSchema) + } + + return &typeProvider{ + schemas: make(map[string]Schema), + db: db, + }, nil +} + +// LoadTableSchema loads schema information for a table from the database using PRAGMA table_info. +func (tp *typeProvider) LoadTableSchema(ctx context.Context, tableName string) error { + if tp.db == nil { + return fmt.Errorf("%w: no database connection available", ErrInvalidSchema) + } + + // Validate table name to prevent SQL injection (PRAGMA doesn't support parameterized queries) + if !validTableName.MatchString(tableName) { + return fmt.Errorf("%w: invalid table name %q", ErrInvalidSchema, tableName) + } + + // PRAGMA doesn't support parameterized queries, but we've validated the table name above + // #nosec G202 - table name is validated against strict regex pattern + query := fmt.Sprintf("PRAGMA table_info(%s)", tableName) + + rows, err := tp.db.QueryContext(ctx, query) + if err != nil { + return fmt.Errorf("%w: failed to query table schema", ErrInvalidSchema) + } + defer func() { _ = rows.Close() }() + + var fields []FieldSchema + for rows.Next() { + var cid int + var name, colType string + var notnull int + var dfltValue *string + var pk int + + if err := rows.Scan(&cid, &name, &colType, ¬null, &dfltValue, &pk); err != nil { + return fmt.Errorf("%w: failed to scan row", ErrInvalidSchema) + } + + field := sqliteColumnToFieldSchema(name, colType) + fields = append(fields, field) + } + + if err := rows.Err(); err != nil { + return fmt.Errorf("%w: error iterating rows", ErrInvalidSchema) + } + + if len(fields) == 0 { + return fmt.Errorf("%w: table %q has no columns or does not exist", ErrInvalidSchema, tableName) + } + + tp.schemas[tableName] = NewSchema(fields) + return nil +} + +// sqliteColumnToFieldSchema converts SQLite column metadata to a FieldSchema. +func sqliteColumnToFieldSchema(name, colType string) FieldSchema { + normalizedType := normalizeSQLiteType(colType) + isJSON := strings.EqualFold(colType, "json") || strings.EqualFold(colType, "jsonb") + + return FieldSchema{ + Name: name, + Type: normalizedType, + IsJSON: isJSON, + } +} + +// Normalized type constants used by normalizeSQLiteType. +const ( + sqliteTypeText = "text" + sqliteTypeInteger = "integer" + sqliteTypeReal = "real" + sqliteTypeBlob = "blob" + sqliteTypeJSON = "json" + sqliteTypeBool = "boolean" + sqliteTypeDatetime = "datetime" +) + +// normalizeSQLiteType converts a SQLite column type declaration to a normalized type name. +// SQLite uses type affinity, so we map common type names to our internal types. +func normalizeSQLiteType(colType string) string { + upper := strings.ToUpper(strings.TrimSpace(colType)) + + // Check for exact matches first + switch upper { + case "TEXT", "VARCHAR", "CHAR", "CLOB": + return sqliteTypeText + case "INTEGER", "INT", "TINYINT", "SMALLINT", "MEDIUMINT", "BIGINT": + return sqliteTypeInteger + case "REAL", "DOUBLE", "FLOAT", "NUMERIC", "DECIMAL": + return sqliteTypeReal + case "BOOLEAN", "BOOL": + return sqliteTypeBool + case "BLOB": + return sqliteTypeBlob + case "JSON", "JSONB": + return sqliteTypeJSON + case "DATETIME", "TIMESTAMP": + return sqliteTypeDatetime + } + + // Check for type names that contain known keywords (e.g., "VARCHAR(255)") + if strings.Contains(upper, "INT") { + return sqliteTypeInteger + } + if strings.Contains(upper, "CHAR") || strings.Contains(upper, "CLOB") || strings.Contains(upper, "TEXT") { + return sqliteTypeText + } + if strings.Contains(upper, "BLOB") { + return sqliteTypeBlob + } + if strings.Contains(upper, "REAL") || strings.Contains(upper, "FLOA") || strings.Contains(upper, "DOUBLE") { + return sqliteTypeReal + } + + // Default to text for unknown types (SQLite's flexible typing) + return sqliteTypeText +} + +// Close is a no-op since we don't own the *sql.DB. +func (tp *typeProvider) Close() { + // No-op: caller owns the *sql.DB connection +} + +// GetSchemas returns the schemas known to this type provider. +func (tp *typeProvider) GetSchemas() map[string]Schema { + return tp.schemas +} + +// EnumValue implements types.Provider. +func (tp *typeProvider) EnumValue(_ string) ref.Val { + return types.NewErr("unknown enum value") +} + +// FindIdent implements types.Provider. +func (tp *typeProvider) FindIdent(_ string) (ref.Val, bool) { + return nil, false +} + +// FindStructType implements types.Provider. +func (tp *typeProvider) FindStructType(structType string) (*types.Type, bool) { + if _, ok := tp.schemas[structType]; ok { + return types.NewObjectType(structType), true + } + return nil, false +} + +// FindStructFieldNames implements types.Provider. +func (tp *typeProvider) FindStructFieldNames(structType string) ([]string, bool) { + s, ok := tp.schemas[structType] + if !ok { + return nil, false + } + fields := s.Fields() + names := make([]string, len(fields)) + for i, f := range fields { + names[i] = f.Name + } + return names, true +} + +// FindStructFieldType implements types.Provider. +func (tp *typeProvider) FindStructFieldType(structType, fieldName string) (*types.FieldType, bool) { + s, ok := tp.schemas[structType] + if !ok { + return nil, false + } + field, found := s.FindField(fieldName) + if !found { + return nil, false + } + + exprType := sqliteTypeToCELExprType(field) + celType, err := types.ExprTypeToType(exprType) + if err != nil { + return nil, false + } + + return &types.FieldType{ + Type: celType, + }, true +} + +// NewValue implements types.Provider. +func (tp *typeProvider) NewValue(_ string, _ map[string]ref.Val) ref.Val { + return types.NewErr("unknown type in schema") +} + +// sqliteTypeToCELExprType converts a SQLite field schema to a CEL expression type. +func sqliteTypeToCELExprType(field *schema.FieldSchema) *exprpb.Type { + baseType := sqliteBaseTypeToCEL(field.Type) + if field.Repeated { + return decls.NewListType(baseType) + } + return baseType +} + +// sqliteBaseTypeToCEL converts a SQLite type name to a CEL expression type. +func sqliteBaseTypeToCEL(typeName string) *exprpb.Type { + switch typeName { + case "text", "varchar", "char", "clob": + return decls.String + case "integer", "int", "tinyint", "smallint", "mediumint", "bigint": + return decls.Int + case "real", "double", "float", "numeric", "decimal": + return decls.Double + case "boolean", "bool": + return decls.Bool + case "blob": + return decls.Bytes + case "json": + return decls.Dyn + case "datetime", "timestamp": + return decls.Timestamp + default: + return decls.Dyn + } +} diff --git a/sqlite/provider_test.go b/sqlite/provider_test.go new file mode 100644 index 0000000..e4e47f8 --- /dev/null +++ b/sqlite/provider_test.go @@ -0,0 +1,348 @@ +package sqlite_test + +import ( + "context" + "database/sql" + "testing" + + "github.com/google/cel-go/common/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + _ "modernc.org/sqlite" + + "github.com/spandigital/cel2sql/v3/sqlite" +) + +func TestNewTypeProvider(t *testing.T) { + schemas := map[string]sqlite.Schema{ + "users": sqlite.NewSchema([]sqlite.FieldSchema{ + {Name: "id", Type: "integer"}, + {Name: "name", Type: "text"}, + }), + } + + provider := sqlite.NewTypeProvider(schemas) + require.NotNil(t, provider) + + got := provider.GetSchemas() + assert.Len(t, got, 1) + assert.Contains(t, got, "users") +} + +func TestNewTypeProviderWithConnection_NilDB(t *testing.T) { + _, err := sqlite.NewTypeProviderWithConnection(context.Background(), nil) + require.Error(t, err) + assert.ErrorIs(t, err, sqlite.ErrInvalidSchema) +} + +func TestLoadTableSchema_NoDB(t *testing.T) { + provider := sqlite.NewTypeProvider(nil) + err := provider.LoadTableSchema(context.Background(), "test") + require.Error(t, err) + assert.ErrorIs(t, err, sqlite.ErrInvalidSchema) +} + +func TestLoadTableSchema_InvalidTableName(t *testing.T) { + db, err := sql.Open("sqlite", ":memory:") + require.NoError(t, err) + defer func() { _ = db.Close() }() + + ctx := context.Background() + provider, err := sqlite.NewTypeProviderWithConnection(ctx, db) + require.NoError(t, err) + + // SQL injection attempts should be rejected + invalidNames := []string{ + "table; DROP TABLE users", + "table'name", + "table-name", + "table.name", + "123table", + "", + "table name", + } + + for _, name := range invalidNames { + t.Run(name, func(t *testing.T) { + err := provider.LoadTableSchema(ctx, name) + require.Error(t, err) + assert.ErrorIs(t, err, sqlite.ErrInvalidSchema) + }) + } +} + +func TestTypeProvider_FindStructType(t *testing.T) { + schemas := map[string]sqlite.Schema{ + "users": sqlite.NewSchema([]sqlite.FieldSchema{ + {Name: "id", Type: "integer"}, + {Name: "name", Type: "text"}, + }), + } + provider := sqlite.NewTypeProvider(schemas) + + typ, found := provider.FindStructType("users") + assert.True(t, found) + assert.NotNil(t, typ) + + _, found = provider.FindStructType("nonexistent") + assert.False(t, found) +} + +func TestTypeProvider_FindStructFieldNames(t *testing.T) { + schemas := map[string]sqlite.Schema{ + "users": sqlite.NewSchema([]sqlite.FieldSchema{ + {Name: "id", Type: "integer"}, + {Name: "name", Type: "text"}, + {Name: "email", Type: "text"}, + }), + } + provider := sqlite.NewTypeProvider(schemas) + + names, found := provider.FindStructFieldNames("users") + assert.True(t, found) + assert.ElementsMatch(t, []string{"id", "name", "email"}, names) + + _, found = provider.FindStructFieldNames("nonexistent") + assert.False(t, found) +} + +func TestTypeProvider_FindStructFieldType(t *testing.T) { + schemas := map[string]sqlite.Schema{ + "test_table": sqlite.NewSchema([]sqlite.FieldSchema{ + {Name: "text_field", Type: "text"}, + {Name: "int_field", Type: "integer"}, + {Name: "real_field", Type: "real"}, + {Name: "bool_field", Type: "boolean"}, + {Name: "blob_field", Type: "blob"}, + {Name: "json_field", Type: "json"}, + {Name: "datetime_field", Type: "datetime"}, + }), + } + provider := sqlite.NewTypeProvider(schemas) + + tests := []struct { + fieldName string + wantType *types.Type + wantFound bool + }{ + {"text_field", types.StringType, true}, + {"int_field", types.IntType, true}, + {"real_field", types.DoubleType, true}, + {"bool_field", types.BoolType, true}, + {"blob_field", types.BytesType, true}, + {"json_field", types.DynType, true}, + {"datetime_field", types.TimestampType, true}, + {"nonexistent", nil, false}, + } + + for _, tt := range tests { + t.Run(tt.fieldName, func(t *testing.T) { + got, found := provider.FindStructFieldType("test_table", tt.fieldName) + assert.Equal(t, tt.wantFound, found) + if found { + assert.Equal(t, tt.wantType, got.Type) + } + }) + } +} + +func TestTypeProvider_Close(_ *testing.T) { + provider := sqlite.NewTypeProvider(nil) + // Close should not panic + provider.Close() +} + +func TestLoadTableSchema_Integration(t *testing.T) { + db, err := sql.Open("sqlite", ":memory:") + require.NoError(t, err) + defer func() { _ = db.Close() }() + + ctx := context.Background() + + // Create test table with various SQLite types + _, err = db.ExecContext(ctx, ` + CREATE TABLE schema_test ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + description VARCHAR(255), + age INTEGER, + score REAL, + is_active BOOLEAN, + avatar BLOB, + metadata JSON, + created_at DATETIME + ) + `) + require.NoError(t, err) + + provider, err := sqlite.NewTypeProviderWithConnection(ctx, db) + require.NoError(t, err) + + err = provider.LoadTableSchema(ctx, "schema_test") + require.NoError(t, err) + + // Verify schema was loaded + schemas := provider.GetSchemas() + assert.Contains(t, schemas, "schema_test") + + // Verify FindStructType + typ, found := provider.FindStructType("schema_test") + assert.True(t, found) + assert.NotNil(t, typ) + + // Verify FindStructFieldNames + names, found := provider.FindStructFieldNames("schema_test") + assert.True(t, found) + assert.Contains(t, names, "id") + assert.Contains(t, names, "name") + assert.Contains(t, names, "metadata") + + // Verify type mappings from loaded schema + tests := []struct { + fieldName string + wantType *types.Type + }{ + {"id", types.IntType}, + {"name", types.StringType}, + {"description", types.StringType}, + {"age", types.IntType}, + {"score", types.DoubleType}, + {"is_active", types.BoolType}, + {"avatar", types.BytesType}, + {"metadata", types.DynType}, + {"created_at", types.TimestampType}, + } + + for _, tt := range tests { + t.Run("type_"+tt.fieldName, func(t *testing.T) { + got, found := provider.FindStructFieldType("schema_test", tt.fieldName) + assert.True(t, found, "field %q should be found", tt.fieldName) + if found { + assert.Equal(t, tt.wantType, got.Type, "field %q type mismatch", tt.fieldName) + } + }) + } + + // Verify JSON detection + schemaObj := schemas["schema_test"] + metadataField, found := schemaObj.FindField("metadata") + assert.True(t, found) + assert.True(t, metadataField.IsJSON, "metadata should be detected as JSON") +} + +func TestLoadTableSchema_NonexistentTable(t *testing.T) { + db, err := sql.Open("sqlite", ":memory:") + require.NoError(t, err) + defer func() { _ = db.Close() }() + + ctx := context.Background() + provider, err := sqlite.NewTypeProviderWithConnection(ctx, db) + require.NoError(t, err) + + err = provider.LoadTableSchema(ctx, "nonexistent_table") + require.Error(t, err) + assert.ErrorIs(t, err, sqlite.ErrInvalidSchema) +} + +func TestLoadTableSchema_MultipleTablesIntegration(t *testing.T) { + db, err := sql.Open("sqlite", ":memory:") + require.NoError(t, err) + defer func() { _ = db.Close() }() + + ctx := context.Background() + + // Create two tables + _, err = db.ExecContext(ctx, `CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, email TEXT)`) + require.NoError(t, err) + _, err = db.ExecContext(ctx, `CREATE TABLE products (id INTEGER PRIMARY KEY, title TEXT, price REAL)`) + require.NoError(t, err) + + provider, err := sqlite.NewTypeProviderWithConnection(ctx, db) + require.NoError(t, err) + + err = provider.LoadTableSchema(ctx, "users") + require.NoError(t, err) + err = provider.LoadTableSchema(ctx, "products") + require.NoError(t, err) + + schemas := provider.GetSchemas() + assert.Len(t, schemas, 2) + assert.Contains(t, schemas, "users") + assert.Contains(t, schemas, "products") + + // Verify both schemas are independent + userNames, found := provider.FindStructFieldNames("users") + assert.True(t, found) + assert.ElementsMatch(t, []string{"id", "name", "email"}, userNames) + + productNames, found := provider.FindStructFieldNames("products") + assert.True(t, found) + assert.ElementsMatch(t, []string{"id", "title", "price"}, productNames) +} + +func TestNormalizeSQLiteType(t *testing.T) { + // Test via LoadTableSchema that various SQLite type declarations are normalized correctly + db, err := sql.Open("sqlite", ":memory:") + require.NoError(t, err) + defer func() { _ = db.Close() }() + + ctx := context.Background() + + _, err = db.ExecContext(ctx, ` + CREATE TABLE type_test ( + col_int INTEGER, + col_varchar VARCHAR(255), + col_char CHAR(10), + col_text TEXT, + col_real REAL, + col_float FLOAT, + col_double DOUBLE, + col_numeric NUMERIC, + col_blob BLOB, + col_bool BOOLEAN, + col_datetime DATETIME, + col_timestamp TIMESTAMP, + col_bigint BIGINT, + col_smallint SMALLINT, + col_tinyint TINYINT, + col_json JSON + ) + `) + require.NoError(t, err) + + provider, err := sqlite.NewTypeProviderWithConnection(ctx, db) + require.NoError(t, err) + + err = provider.LoadTableSchema(ctx, "type_test") + require.NoError(t, err) + + tests := []struct { + fieldName string + wantType *types.Type + }{ + {"col_int", types.IntType}, + {"col_varchar", types.StringType}, + {"col_char", types.StringType}, + {"col_text", types.StringType}, + {"col_real", types.DoubleType}, + {"col_float", types.DoubleType}, + {"col_double", types.DoubleType}, + {"col_numeric", types.DoubleType}, + {"col_blob", types.BytesType}, + {"col_bool", types.BoolType}, + {"col_datetime", types.TimestampType}, + {"col_timestamp", types.TimestampType}, + {"col_bigint", types.IntType}, + {"col_smallint", types.IntType}, + {"col_tinyint", types.IntType}, + {"col_json", types.DynType}, + } + + for _, tt := range tests { + t.Run(tt.fieldName, func(t *testing.T) { + got, found := provider.FindStructFieldType("type_test", tt.fieldName) + require.True(t, found, "field %q should be found", tt.fieldName) + assert.Equal(t, tt.wantType, got.Type, "field %q type mismatch", tt.fieldName) + }) + } +} diff --git a/sqlite_integration_test.go b/sqlite_integration_test.go new file mode 100644 index 0000000..f70b5d7 --- /dev/null +++ b/sqlite_integration_test.go @@ -0,0 +1,518 @@ +package cel2sql_test + +import ( + "database/sql" + "testing" + + "github.com/google/cel-go/cel" + "github.com/stretchr/testify/require" + _ "modernc.org/sqlite" + + "github.com/spandigital/cel2sql/v3" + sqliteDialect "github.com/spandigital/cel2sql/v3/dialect/sqlite" + "github.com/spandigital/cel2sql/v3/pg" +) + +// TestSQLiteOperatorsIntegration validates operator conversions against a real SQLite database. +// This uses an in-memory SQLite database (no Docker required). +func TestSQLiteOperatorsIntegration(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + // Open in-memory SQLite database + db, err := sql.Open("sqlite", ":memory:") + require.NoError(t, err) + defer func() { + if closeErr := db.Close(); closeErr != nil { + t.Logf("failed to close db: %v", closeErr) + } + }() + + // Create test table (SQLite uses INTEGER for booleans, REAL for floats) + _, err = db.Exec(` + CREATE TABLE test_data ( + id INTEGER PRIMARY KEY, + text_val TEXT, + int_val INTEGER, + float_val REAL, + bool_val INTEGER, + nullable_text TEXT, + nullable_int INTEGER + ) + `) + require.NoError(t, err) + + // Insert test data (using 1/0 for boolean values) + _, err = db.Exec(` + INSERT INTO test_data VALUES + (1, 'hello', 10, 10.5, 1, 'present', 100), + (2, 'world', 20, 20.5, 0, NULL, NULL), + (3, 'test', 30, 30.5, 1, 'here', 200), + (4, 'hello world', 5, 5.5, 0, 'value', 50), + (5, 'testing', 15, 15.5, 1, 'test', 150) + `) + require.NoError(t, err) + + // Set up CEL environment with simple variables + env, err := cel.NewEnv( + cel.Variable("id", cel.IntType), + cel.Variable("text_val", cel.StringType), + cel.Variable("int_val", cel.IntType), + cel.Variable("float_val", cel.DoubleType), + cel.Variable("bool_val", cel.BoolType), + cel.Variable("nullable_text", cel.StringType), + cel.Variable("nullable_int", cel.IntType), + ) + require.NoError(t, err) + + dialectOpt := cel2sql.WithDialect(sqliteDialect.New()) + + tests := []struct { + name string + celExpr string + expectedRows int + description string + }{ + // Comparison operators + { + name: "Equality string", + celExpr: `text_val == "hello"`, + expectedRows: 1, + description: "String equality comparison", + }, + { + name: "Equality integer", + celExpr: `int_val == 20`, + expectedRows: 1, + description: "Integer equality comparison", + }, + { + name: "Equality float", + celExpr: `float_val == 10.5`, + expectedRows: 1, + description: "Float equality comparison", + }, + { + name: "Equality boolean", + celExpr: `bool_val == true`, + expectedRows: 3, + description: "Boolean equality comparison", + }, + { + name: "Not equal", + celExpr: `text_val != "hello"`, + expectedRows: 4, + description: "Not equal comparison", + }, + { + name: "Less than", + celExpr: `int_val < 15`, + expectedRows: 2, // 10, 5 + description: "Less than comparison", + }, + { + name: "Less than or equal", + celExpr: `int_val <= 15`, + expectedRows: 3, // 10, 5, 15 + description: "Less than or equal comparison", + }, + { + name: "Greater than", + celExpr: `int_val > 15`, + expectedRows: 2, // 20, 30 + description: "Greater than comparison", + }, + { + name: "Greater than or equal", + celExpr: `int_val >= 15`, + expectedRows: 3, // 20, 30, 15 + description: "Greater than or equal comparison", + }, + + // Logical operators + { + name: "Logical AND", + celExpr: `int_val > 10 && bool_val == true`, + expectedRows: 2, // rows 3 (30,true) and 5 (15,true) + description: "Logical AND operator", + }, + { + name: "Logical OR", + celExpr: `int_val < 10 || bool_val == false`, + expectedRows: 2, // rows 2 (20,false) and 4 (5,false) + description: "Logical OR operator", + }, + { + name: "Logical NOT", + celExpr: `!bool_val`, + expectedRows: 2, // rows 2 and 4 (bool_val == false) + description: "Logical NOT operator", + }, + { + name: "Complex logical expression", + celExpr: `(int_val > 10 && bool_val) || int_val < 10`, + expectedRows: 3, // rows 3, 5 (>10 && true), row 4 (<10) + description: "Complex nested logical operators", + }, + + // Arithmetic operators + { + name: "Addition", + celExpr: `int_val + 10 == 20`, + expectedRows: 1, // 10 + 10 = 20 + description: "Addition operator", + }, + { + name: "Subtraction", + celExpr: `int_val - 5 == 15`, + expectedRows: 1, // 20 - 5 = 15 + description: "Subtraction operator", + }, + { + name: "Multiplication", + celExpr: `int_val * 2 == 20`, + expectedRows: 1, // 10 * 2 = 20 + description: "Multiplication operator", + }, + { + name: "Division", + celExpr: `int_val / 2 == 10`, + expectedRows: 1, // 20 / 2 = 10 + description: "Division operator", + }, + { + name: "Modulo", + celExpr: `int_val % 10 == 0`, + expectedRows: 3, // 10, 20, 30 + description: "Modulo operator", + }, + { + name: "Complex arithmetic", + celExpr: `(int_val * 2) + 5 > 30`, + expectedRows: 3, // (20*2)+5=45, (30*2)+5=65, (15*2)+5=35 + description: "Complex arithmetic expression", + }, + + // String operators + { + name: "String concatenation", + celExpr: `text_val + "!" == "hello!"`, + expectedRows: 1, + description: "String concatenation operator (||)", + }, + { + name: "String contains", + celExpr: `text_val.contains("world")`, + expectedRows: 2, // "world", "hello world" + description: "String contains function (INSTR)", + }, + { + name: "String startsWith", + celExpr: `text_val.startsWith("hello")`, + expectedRows: 2, // "hello", "hello world" + description: "String startsWith function (LIKE)", + }, + { + name: "String endsWith", + celExpr: `text_val.endsWith("world")`, + expectedRows: 2, // "world", "hello world" + description: "String endsWith function (LIKE)", + }, + + // Complex combined operators + { + name: "Complex multi-operator expression", + celExpr: `int_val > 10 && bool_val && text_val.contains("test")`, + expectedRows: 2, // rows 3 (30, true, "test") and 5 (15, true, "testing") + description: "Complex expression with multiple operator types", + }, + { + name: "Nested parenthesized operators", + celExpr: `((int_val + 5) * 2 > 30) && (text_val.contains("test") || bool_val)`, + expectedRows: 2, // rows 3 and 5 + description: "Deeply nested operators with parentheses", + }, + { + name: "Triple negation", + celExpr: `!!!bool_val`, + expectedRows: 2, // rows 2 and 4 (bool_val == false) + description: "Multiple NOT operators", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Compile CEL expression + ast, issues := env.Compile(tt.celExpr) + if issues != nil && issues.Err() != nil { + t.Fatalf("CEL compilation failed: %v", issues.Err()) + } + + // Convert to SQL with SQLite dialect + sqlCondition, err := cel2sql.Convert(ast, dialectOpt) + require.NoError(t, err, "Conversion should succeed for: %s", tt.description) + + t.Logf("CEL Expression: %s", tt.celExpr) + t.Logf("Generated SQL WHERE clause: %s", sqlCondition) + + // Execute query and count results + // #nosec G202 - This is a test validating SQL generation, not a security risk + query := "SELECT COUNT(*) FROM test_data WHERE " + sqlCondition + t.Logf("Full SQL Query: %s", query) + + var actualRows int + err = db.QueryRow(query).Scan(&actualRows) + require.NoError(t, err, "Generated SQL should execute successfully. %s\nSQL: %s", + tt.description, sqlCondition) + + require.Equal(t, tt.expectedRows, actualRows, + "Query should return expected number of rows. %s\nCEL: %s\nSQL: %s", + tt.description, tt.celExpr, sqlCondition) + + t.Logf("OK: %s (expected %d rows, got %d rows)", + tt.description, tt.expectedRows, actualRows) + }) + } +} + +// TestSQLiteEdgeCasesIntegration validates edge cases against a real SQLite database. +func TestSQLiteEdgeCasesIntegration(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + db, err := sql.Open("sqlite", ":memory:") + require.NoError(t, err) + defer func() { + if closeErr := db.Close(); closeErr != nil { + t.Logf("failed to close db: %v", closeErr) + } + }() + + // Create test table with edge case values + _, err = db.Exec(` + CREATE TABLE edge_cases ( + id INTEGER PRIMARY KEY, + empty_string TEXT, + zero_int INTEGER, + zero_float REAL, + negative_int INTEGER, + negative_float REAL, + large_int INTEGER + ) + `) + require.NoError(t, err) + + _, err = db.Exec(` + INSERT INTO edge_cases VALUES + (1, '', 0, 0.0, -10, -5.5, 9223372036854775807), + (2, 'value', 1, 1.0, -1, -0.1, 123456789), + (3, 'another', 0, 0.0, 0, 0.0, 0) + `) + require.NoError(t, err) + + env, err := cel.NewEnv( + cel.Variable("id", cel.IntType), + cel.Variable("empty_string", cel.StringType), + cel.Variable("zero_int", cel.IntType), + cel.Variable("zero_float", cel.DoubleType), + cel.Variable("negative_int", cel.IntType), + cel.Variable("negative_float", cel.DoubleType), + cel.Variable("large_int", cel.IntType), + ) + require.NoError(t, err) + + dialectOpt := cel2sql.WithDialect(sqliteDialect.New()) + + tests := []struct { + name string + celExpr string + expectedRows int + description string + }{ + { + name: "Empty string equality", + celExpr: `empty_string == ""`, + expectedRows: 1, + description: "Empty string should be handled correctly", + }, + { + name: "Zero integer equality", + celExpr: `zero_int == 0`, + expectedRows: 2, + description: "Zero should be handled correctly", + }, + { + name: "Negative integer comparison", + celExpr: `negative_int < 0`, + expectedRows: 2, // -10 and -1 + description: "Negative numbers should work correctly", + }, + { + name: "Large integer comparison", + celExpr: `large_int > 1000000`, + expectedRows: 2, // 9223372036854775807 and 123456789 + description: "Large integers should be handled correctly", + }, + { + name: "Zero float equality", + celExpr: `zero_float == 0.0`, + expectedRows: 2, + description: "Zero float should be handled correctly", + }, + { + name: "Negative float comparison", + celExpr: `negative_float < 0.0`, + expectedRows: 2, // -5.5 and -0.1 + description: "Negative floats should work correctly", + }, + { + name: "Arithmetic with zero", + celExpr: `zero_int + 10 == 10`, + expectedRows: 2, // 0 + 10 = 10 + description: "Arithmetic with zero should work", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ast, issues := env.Compile(tt.celExpr) + if issues != nil && issues.Err() != nil { + t.Fatalf("CEL compilation failed: %v", issues.Err()) + } + + sqlCondition, err := cel2sql.Convert(ast, dialectOpt) + require.NoError(t, err, "Conversion should succeed for: %s", tt.description) + + t.Logf("CEL Expression: %s", tt.celExpr) + t.Logf("Generated SQL WHERE clause: %s", sqlCondition) + + // #nosec G202 - This is a test validating SQL generation, not a security risk + query := "SELECT COUNT(*) FROM edge_cases WHERE " + sqlCondition + + var actualRows int + err = db.QueryRow(query).Scan(&actualRows) + require.NoError(t, err, "Generated SQL should execute successfully. %s", + tt.description) + + require.Equal(t, tt.expectedRows, actualRows, + "Query should return expected number of rows. %s", tt.description) + + t.Logf("OK: %s", tt.description) + }) + } +} + +// TestSQLiteJSONIntegration validates JSON operations against a real SQLite database. +func TestSQLiteJSONIntegration(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + db, err := sql.Open("sqlite", ":memory:") + require.NoError(t, err) + defer func() { + if closeErr := db.Close(); closeErr != nil { + t.Logf("failed to close db: %v", closeErr) + } + }() + + // Create test table with JSON column (stored as TEXT in SQLite) + _, err = db.Exec(` + CREATE TABLE products ( + id INTEGER PRIMARY KEY, + name TEXT, + price REAL, + metadata TEXT + ) + `) + require.NoError(t, err) + + _, err = db.Exec(` + INSERT INTO products VALUES + (1, 'Widget', 19.99, '{"brand": "Acme", "color": "red", "specs": {"weight": 100}}'), + (2, 'Gadget', 29.99, '{"brand": "Beta", "color": "blue", "specs": {"weight": 200}}'), + (3, 'Doohickey', 39.99, '{"brand": "Acme", "color": "green", "specs": {"weight": 150}}') + `) + require.NoError(t, err) + + // Set up CEL environment with schema for JSON detection + productSchema := pg.NewSchema([]pg.FieldSchema{ + {Name: "id", Type: "integer"}, + {Name: "name", Type: "text"}, + {Name: "price", Type: "double precision"}, + {Name: "metadata", Type: "json", IsJSON: true}, + }) + + schemas := map[string]pg.Schema{ + "product": productSchema, + } + + provider := pg.NewTypeProvider(schemas) + + env, err := cel.NewEnv( + cel.CustomTypeProvider(provider), + cel.Variable("product", cel.ObjectType("product")), + ) + require.NoError(t, err) + + dialectOpt := cel2sql.WithDialect(sqliteDialect.New()) + schemaOpt := cel2sql.WithSchemas(schemas) + + tests := []struct { + name string + celExpr string + expectedRows int + description string + }{ + { + name: "JSON field access", + celExpr: `product.metadata.brand == "Acme"`, + expectedRows: 2, + description: "JSON field access with json_extract", + }, + { + name: "JSON field access different value", + celExpr: `product.metadata.color == "blue"`, + expectedRows: 1, + description: "JSON field access with different value", + }, + { + name: "JSON nested field access", + celExpr: `product.metadata.brand == "Acme" && product.price > 30.0`, + expectedRows: 1, // Doohickey (Acme, 39.99) + description: "JSON field combined with regular field comparison", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ast, issues := env.Compile(tt.celExpr) + if issues != nil && issues.Err() != nil { + t.Fatalf("CEL compilation failed: %v", issues.Err()) + } + + sqlCondition, err := cel2sql.Convert(ast, dialectOpt, schemaOpt) + require.NoError(t, err, "Conversion should succeed for: %s", tt.description) + + t.Logf("CEL Expression: %s", tt.celExpr) + t.Logf("Generated SQL WHERE clause: %s", sqlCondition) + + // #nosec G202 - This is a test validating SQL generation, not a security risk + query := "SELECT COUNT(*) FROM products product WHERE " + sqlCondition + t.Logf("Full SQL Query: %s", query) + + var actualRows int + err = db.QueryRow(query).Scan(&actualRows) + require.NoError(t, err, "Generated SQL should execute successfully. %s\nSQL: %s", + tt.description, sqlCondition) + + require.Equal(t, tt.expectedRows, actualRows, + "Query should return expected number of rows. %s\nCEL: %s\nSQL: %s", + tt.description, tt.celExpr, sqlCondition) + + t.Logf("OK: %s", tt.description) + }) + } +} diff --git a/testcases/array_tests.go b/testcases/array_tests.go new file mode 100644 index 0000000..3a92d35 --- /dev/null +++ b/testcases/array_tests.go @@ -0,0 +1,69 @@ +package testcases + +import "github.com/spandigital/cel2sql/v3/dialect" + +// ArrayTests returns test cases for array operations. +func ArrayTests() []ConvertTestCase { + return []ConvertTestCase{ + { + Name: "list_index_literal", + CELExpr: `[1, 2, 3][0] == 1`, + Category: CategoryArray, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "ARRAY[1, 2, 3][1] = 1", + dialect.DuckDB: "[1, 2, 3][1] = 1", + dialect.BigQuery: "[1, 2, 3][OFFSET(0)] = 1", + }, + }, + { + Name: "list_var_index", + CELExpr: `string_list[0] == "a"`, + Category: CategoryArray, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "string_list[1] = 'a'", + dialect.DuckDB: "string_list[1] = 'a'", + dialect.BigQuery: "string_list[OFFSET(0)] = 'a'", + }, + }, + { + Name: "size_list", + CELExpr: `size(string_list)`, + Category: CategoryArray, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "COALESCE(ARRAY_LENGTH(string_list, 1), 0)", + dialect.DuckDB: "COALESCE(array_length(string_list), 0)", + dialect.BigQuery: "ARRAY_LENGTH(string_list)", + }, + }, + { + Name: "size_list_comparison", + CELExpr: `size(string_list) > 0`, + Category: CategoryArray, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "COALESCE(ARRAY_LENGTH(string_list, 1), 0) > 0", + dialect.DuckDB: "COALESCE(array_length(string_list), 0) > 0", + dialect.BigQuery: "ARRAY_LENGTH(string_list) > 0", + }, + }, + { + Name: "array_index_overflow", + CELExpr: `string_list[9223372036854775807]`, + Category: CategoryArray, + WantErr: map[dialect.Name]bool{ + dialect.PostgreSQL: true, + dialect.DuckDB: true, + dialect.BigQuery: true, + }, + }, + { + Name: "array_index_negative", + CELExpr: `string_list[-1]`, + Category: CategoryArray, + WantErr: map[dialect.Name]bool{ + dialect.PostgreSQL: true, + dialect.DuckDB: true, + dialect.BigQuery: true, + }, + }, + } +} diff --git a/testcases/basic_tests.go b/testcases/basic_tests.go new file mode 100644 index 0000000..a04ed70 --- /dev/null +++ b/testcases/basic_tests.go @@ -0,0 +1,141 @@ +package testcases + +import "github.com/spandigital/cel2sql/v3/dialect" + +// BasicTests returns test cases for basic comparisons and expressions. +func BasicTests() []ConvertTestCase { + return []ConvertTestCase{ + { + Name: "equality_string", + CELExpr: `name == "a"`, + Category: CategoryBasic, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "name = 'a'", + dialect.MySQL: "name = 'a'", + dialect.SQLite: "name = 'a'", + dialect.DuckDB: "name = 'a'", + dialect.BigQuery: "name = 'a'", + }, + }, + { + Name: "inequality_int", + CELExpr: `age != 20`, + Category: CategoryBasic, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "age != 20", + dialect.MySQL: "age != 20", + dialect.SQLite: "age != 20", + dialect.DuckDB: "age != 20", + dialect.BigQuery: "age != 20", + }, + }, + { + Name: "less_than", + CELExpr: `age < 20`, + Category: CategoryBasic, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "age < 20", + dialect.MySQL: "age < 20", + dialect.SQLite: "age < 20", + dialect.DuckDB: "age < 20", + dialect.BigQuery: "age < 20", + }, + }, + { + Name: "greater_equal_float", + CELExpr: `height >= 1.6180339887`, + Category: CategoryBasic, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "height >= 1.6180339887", + dialect.MySQL: "height >= 1.6180339887", + dialect.SQLite: "height >= 1.6180339887", + dialect.DuckDB: "height >= 1.6180339887", + dialect.BigQuery: "height >= 1.6180339887", + }, + }, + { + Name: "is_null", + CELExpr: `null_var == null`, + Category: CategoryBasic, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "null_var IS NULL", + dialect.MySQL: "null_var IS NULL", + dialect.SQLite: "null_var IS NULL", + dialect.DuckDB: "null_var IS NULL", + dialect.BigQuery: "null_var IS NULL", + }, + }, + { + Name: "is_not_true", + CELExpr: `adult != true`, + Category: CategoryBasic, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "adult IS NOT TRUE", + dialect.MySQL: "adult IS NOT TRUE", + dialect.SQLite: "adult IS NOT TRUE", + dialect.DuckDB: "adult IS NOT TRUE", + dialect.BigQuery: "adult IS NOT TRUE", + }, + }, + { + Name: "not", + CELExpr: `!adult`, + Category: CategoryBasic, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "NOT adult", + dialect.MySQL: "NOT adult", + dialect.SQLite: "NOT adult", + dialect.DuckDB: "NOT adult", + dialect.BigQuery: "NOT adult", + }, + }, + { + Name: "negative_int", + CELExpr: `-1`, + Category: CategoryBasic, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "-1", + dialect.MySQL: "-1", + dialect.SQLite: "-1", + dialect.DuckDB: "-1", + dialect.BigQuery: "-1", + }, + }, + { + Name: "ternary", + CELExpr: `name == "a" ? "a" : "b"`, + Category: CategoryBasic, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "CASE WHEN name = 'a' THEN 'a' ELSE 'b' END", + dialect.MySQL: "CASE WHEN name = 'a' THEN 'a' ELSE 'b' END", + dialect.SQLite: "CASE WHEN name = 'a' THEN 'a' ELSE 'b' END", + dialect.DuckDB: "CASE WHEN name = 'a' THEN 'a' ELSE 'b' END", + dialect.BigQuery: "CASE WHEN name = 'a' THEN 'a' ELSE 'b' END", + }, + }, + { + Name: "field_select", + CELExpr: `page.title == "test"`, + Category: CategoryFieldAccess, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "page.title = 'test'", + dialect.MySQL: "page.title = 'test'", + dialect.SQLite: "page.title = 'test'", + dialect.DuckDB: "page.title = 'test'", + dialect.BigQuery: "page.title = 'test'", + }, + }, + { + Name: "in_list", + CELExpr: `name in ["a", "b", "c"]`, + Category: CategoryBasic, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "name = ANY(ARRAY['a', 'b', 'c'])", + dialect.MySQL: "JSON_CONTAINS(JSON_ARRAY('a', 'b', 'c'), CAST(name AS JSON))", + dialect.SQLite: "name IN (SELECT value FROM json_each(json_array('a', 'b', 'c')))", + dialect.DuckDB: "name = ANY(['a', 'b', 'c'])", + dialect.BigQuery: "name IN UNNEST(['a', 'b', 'c'])", + }, + }, + } +} diff --git a/testcases/cast_tests.go b/testcases/cast_tests.go new file mode 100644 index 0000000..bcf85a5 --- /dev/null +++ b/testcases/cast_tests.go @@ -0,0 +1,81 @@ +package testcases + +import "github.com/spandigital/cel2sql/v3/dialect" + +// CastTests returns test cases for type casting operations. +func CastTests() []ConvertTestCase { + return []ConvertTestCase{ + { + Name: "cast_bool", + CELExpr: `bool(0) == false`, + Category: CategoryCast, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "CAST(0 AS BOOLEAN) IS FALSE", + dialect.MySQL: "CAST(0 AS UNSIGNED) IS FALSE", + dialect.SQLite: "CAST(0 AS INTEGER) IS FALSE", + dialect.DuckDB: "CAST(0 AS BOOLEAN) IS FALSE", + dialect.BigQuery: "CAST(0 AS BOOL) IS FALSE", + }, + }, + { + Name: "cast_bytes", + CELExpr: `bytes("test")`, + Category: CategoryCast, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "CAST('test' AS BYTEA)", + dialect.MySQL: "CAST('test' AS BINARY)", + dialect.SQLite: "CAST('test' AS BLOB)", + dialect.DuckDB: "CAST('test' AS BLOB)", + dialect.BigQuery: "CAST('test' AS BYTES)", + }, + }, + { + Name: "cast_int", + CELExpr: `int(true) == 1`, + Category: CategoryCast, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "CAST(TRUE AS BIGINT) = 1", + dialect.MySQL: "CAST(TRUE AS SIGNED) = 1", + dialect.SQLite: "CAST(TRUE AS INTEGER) = 1", + dialect.DuckDB: "CAST(TRUE AS BIGINT) = 1", + dialect.BigQuery: "CAST(TRUE AS INT64) = 1", + }, + }, + { + Name: "cast_string", + CELExpr: `string(true) == "true"`, + Category: CategoryCast, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "CAST(TRUE AS TEXT) = 'true'", + dialect.MySQL: "CAST(TRUE AS CHAR) = 'true'", + dialect.SQLite: "CAST(TRUE AS TEXT) = 'true'", + dialect.DuckDB: "CAST(TRUE AS VARCHAR) = 'true'", + dialect.BigQuery: "CAST(TRUE AS STRING) = 'true'", + }, + }, + { + Name: "cast_string_from_timestamp", + CELExpr: `string(created_at)`, + Category: CategoryCast, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "CAST(created_at AS TEXT)", + dialect.MySQL: "CAST(created_at AS CHAR)", + dialect.SQLite: "CAST(created_at AS TEXT)", + dialect.DuckDB: "CAST(created_at AS VARCHAR)", + dialect.BigQuery: "CAST(created_at AS STRING)", + }, + }, + { + Name: "cast_int_epoch", + CELExpr: `int(created_at)`, + Category: CategoryCast, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "EXTRACT(EPOCH FROM created_at)::bigint", + dialect.MySQL: "UNIX_TIMESTAMP(created_at)", + dialect.SQLite: "CAST(strftime('%s', created_at) AS INTEGER)", + dialect.DuckDB: "EXTRACT(EPOCH FROM created_at)::BIGINT", + dialect.BigQuery: "UNIX_SECONDS(created_at)", + }, + }, + } +} diff --git a/testcases/comprehension_tests.go b/testcases/comprehension_tests.go new file mode 100644 index 0000000..9eebd0d --- /dev/null +++ b/testcases/comprehension_tests.go @@ -0,0 +1,64 @@ +package testcases + +import "github.com/spandigital/cel2sql/v3/dialect" + +// ComprehensionTests returns test cases for CEL comprehension operations. +func ComprehensionTests() []ConvertTestCase { + return []ConvertTestCase{ + { + Name: "all", + CELExpr: `string_list.all(x, x != "bad")`, + Category: CategoryComprehension, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "NOT EXISTS (SELECT 1 FROM UNNEST(string_list) AS x WHERE NOT (x != 'bad'))", + dialect.SQLite: "NOT EXISTS (SELECT 1 FROM json_each(string_list) AS x WHERE NOT (x != 'bad'))", + dialect.DuckDB: "NOT EXISTS (SELECT 1 FROM UNNEST(string_list) AS x WHERE NOT (x != 'bad'))", + dialect.BigQuery: "NOT EXISTS (SELECT 1 FROM UNNEST(string_list) AS x WHERE NOT (x != 'bad'))", + }, + }, + { + Name: "exists", + CELExpr: `string_list.exists(x, x == "good")`, + Category: CategoryComprehension, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "EXISTS (SELECT 1 FROM UNNEST(string_list) AS x WHERE x = 'good')", + dialect.SQLite: "EXISTS (SELECT 1 FROM json_each(string_list) AS x WHERE x = 'good')", + dialect.DuckDB: "EXISTS (SELECT 1 FROM UNNEST(string_list) AS x WHERE x = 'good')", + dialect.BigQuery: "EXISTS (SELECT 1 FROM UNNEST(string_list) AS x WHERE x = 'good')", + }, + }, + { + Name: "exists_one", + CELExpr: `string_list.exists_one(x, x == "unique")`, + Category: CategoryComprehension, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "(SELECT COUNT(*) FROM UNNEST(string_list) AS x WHERE x = 'unique') = 1", + dialect.SQLite: "(SELECT COUNT(*) FROM json_each(string_list) AS x WHERE x = 'unique') = 1", + dialect.DuckDB: "(SELECT COUNT(*) FROM UNNEST(string_list) AS x WHERE x = 'unique') = 1", + dialect.BigQuery: "(SELECT COUNT(*) FROM UNNEST(string_list) AS x WHERE x = 'unique') = 1", + }, + }, + { + Name: "filter", + CELExpr: `string_list.filter(x, x != "bad")`, + Category: CategoryComprehension, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "ARRAY(SELECT x FROM UNNEST(string_list) AS x WHERE x != 'bad')", + dialect.SQLite: "(SELECT json_group_array(x) FROM json_each(string_list) AS x WHERE x != 'bad')", + dialect.DuckDB: "ARRAY(SELECT x FROM UNNEST(string_list) AS x WHERE x != 'bad')", + dialect.BigQuery: "ARRAY(SELECT x FROM UNNEST(string_list) AS x WHERE x != 'bad')", + }, + }, + { + Name: "map_transform", + CELExpr: `string_list.map(x, x + "_suffix")`, + Category: CategoryComprehension, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "ARRAY(SELECT x || '_suffix' FROM UNNEST(string_list) AS x)", + dialect.SQLite: "(SELECT json_group_array(x || '_suffix') FROM json_each(string_list) AS x)", + dialect.DuckDB: "ARRAY(SELECT x || '_suffix' FROM UNNEST(string_list) AS x)", + dialect.BigQuery: "ARRAY(SELECT x || '_suffix' FROM UNNEST(string_list) AS x)", + }, + }, + } +} diff --git a/testcases/fixtures.go b/testcases/fixtures.go new file mode 100644 index 0000000..e33e904 --- /dev/null +++ b/testcases/fixtures.go @@ -0,0 +1,93 @@ +package testcases + +import ( + "github.com/spandigital/cel2sql/v3/pg" + "github.com/spandigital/cel2sql/v3/schema" +) + +// EnvDefault is the default environment setup name (basic types, no schema). +const EnvDefault = "" + +// EnvWithSchema is an environment with a schema-based type provider. +const EnvWithSchema = "schema" + +// EnvWithJSON is an environment with JSON/JSONB schema fields. +const EnvWithJSON = "json_schema" + +// EnvWithTimestamp is an environment for timestamp operations. +const EnvWithTimestamp = "timestamp" + +// NewPersonSchema returns a dialect-agnostic schema for the "person" table, +// suitable for basic, operator, and string tests. +func NewPersonSchema() schema.Schema { + return schema.NewSchema([]schema.FieldSchema{ + {Name: "name", Type: "text"}, + {Name: "age", Type: "integer"}, + {Name: "adult", Type: "boolean"}, + {Name: "height", Type: "double precision"}, + {Name: "email", Type: "text"}, + {Name: "tags", Type: "text", Repeated: true}, + {Name: "scores", Type: "integer", Repeated: true}, + }) +} + +// NewPersonPGSchema returns a PostgreSQL-specific schema for the "person" table. +func NewPersonPGSchema() pg.Schema { + return pg.NewSchema([]pg.FieldSchema{ + {Name: "name", Type: "text"}, + {Name: "age", Type: "integer"}, + {Name: "adult", Type: "boolean"}, + {Name: "height", Type: "double precision"}, + {Name: "email", Type: "text"}, + {Name: "tags", Type: "text", Repeated: true}, + {Name: "scores", Type: "integer", Repeated: true}, + }) +} + +// NewProductSchema returns a dialect-agnostic schema for the "product" table, +// with JSON/JSONB fields for JSON-related tests. +func NewProductSchema() schema.Schema { + return schema.NewSchema([]schema.FieldSchema{ + {Name: "name", Type: "text"}, + {Name: "price", Type: "double precision"}, + {Name: "metadata", Type: "jsonb", IsJSON: true, IsJSONB: true}, + {Name: "attributes", Type: "json", IsJSON: true}, + {Name: "tags", Type: "jsonb", IsJSON: true, IsJSONB: true, Repeated: true, ElementType: "text"}, + }) +} + +// NewProductPGSchema returns a PostgreSQL-specific schema for the "product" table. +func NewProductPGSchema() pg.Schema { + return pg.NewSchema([]pg.FieldSchema{ + {Name: "name", Type: "text"}, + {Name: "price", Type: "double precision"}, + {Name: "metadata", Type: "jsonb", IsJSON: true, IsJSONB: true}, + {Name: "attributes", Type: "json", IsJSON: true}, + {Name: "tags", Type: "jsonb", IsJSON: true, IsJSONB: true, Repeated: true, ElementType: "text"}, + }) +} + +// NewOrderSchema returns a dialect-agnostic schema for the "orders" table, +// with array and timestamp fields. +func NewOrderSchema() schema.Schema { + return schema.NewSchema([]schema.FieldSchema{ + {Name: "order_id", Type: "bigint"}, + {Name: "customer_name", Type: "text"}, + {Name: "total", Type: "double precision"}, + {Name: "items", Type: "text", Repeated: true}, + {Name: "created_at", Type: "timestamp with time zone"}, + {Name: "status", Type: "text"}, + }) +} + +// NewOrderPGSchema returns a PostgreSQL-specific schema for the "orders" table. +func NewOrderPGSchema() pg.Schema { + return pg.NewSchema([]pg.FieldSchema{ + {Name: "order_id", Type: "bigint"}, + {Name: "customer_name", Type: "text"}, + {Name: "total", Type: "double precision"}, + {Name: "items", Type: "text", Repeated: true}, + {Name: "created_at", Type: "timestamp with time zone"}, + {Name: "status", Type: "text"}, + }) +} diff --git a/testcases/json_tests.go b/testcases/json_tests.go new file mode 100644 index 0000000..7495a02 --- /dev/null +++ b/testcases/json_tests.go @@ -0,0 +1,49 @@ +package testcases + +import "github.com/spandigital/cel2sql/v3/dialect" + +// JSONTests returns test cases for JSON/JSONB field access and operations. +// These tests require the "json_schema" environment setup. +func JSONTests() []ConvertTestCase { + return []ConvertTestCase{ + { + Name: "json_field_access", + CELExpr: `product.metadata.brand == "Acme"`, + Category: CategoryJSON, + EnvSetup: EnvWithJSON, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "product.metadata->>'brand' = 'Acme'", + dialect.MySQL: "product.metadata->>'$.brand' = 'Acme'", + dialect.SQLite: "json_extract(product.metadata, '$.brand') = 'Acme'", + dialect.DuckDB: "product.metadata->>'brand' = 'Acme'", + dialect.BigQuery: "JSON_VALUE(product.metadata, '$.brand') = 'Acme'", + }, + }, + { + Name: "json_nested_access", + CELExpr: `product.metadata.specs.color == "red"`, + Category: CategoryJSON, + EnvSetup: EnvWithJSON, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "product.metadata->'specs'->>'color' = 'red'", + dialect.MySQL: "product.metadata->'$.specs'->>'$.color' = 'red'", + dialect.SQLite: "json_extract(json_extract(product.metadata, '$.specs'), '$.color') = 'red'", + dialect.DuckDB: "product.metadata->'specs'->>'color' = 'red'", + dialect.BigQuery: "JSON_VALUE(JSON_QUERY(product.metadata, '$.specs'), '$.color') = 'red'", + }, + }, + { + Name: "json_has_field", + CELExpr: `has(product.metadata.brand)`, + Category: CategoryJSON, + EnvSetup: EnvWithJSON, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "product.metadata ? 'brand'", + dialect.MySQL: "JSON_CONTAINS_PATH(product.metadata, 'one', '$.brand')", + dialect.SQLite: "json_type(product.metadata, '$.brand') IS NOT NULL", + dialect.DuckDB: "json_exists(product.metadata, '$.brand')", + dialect.BigQuery: "JSON_VALUE(product.metadata, '$.brand') IS NOT NULL", + }, + }, + } +} diff --git a/testcases/operator_tests.go b/testcases/operator_tests.go new file mode 100644 index 0000000..c80b779 --- /dev/null +++ b/testcases/operator_tests.go @@ -0,0 +1,91 @@ +package testcases + +import "github.com/spandigital/cel2sql/v3/dialect" + +// OperatorTests returns test cases for logical and arithmetic operators. +func OperatorTests() []ConvertTestCase { + return []ConvertTestCase{ + { + Name: "logical_and", + CELExpr: `name == "a" && age > 20`, + Category: CategoryOperator, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "name = 'a' AND age > 20", + dialect.MySQL: "name = 'a' AND age > 20", + dialect.SQLite: "name = 'a' AND age > 20", + dialect.DuckDB: "name = 'a' AND age > 20", + dialect.BigQuery: "name = 'a' AND age > 20", + }, + }, + { + Name: "logical_or", + CELExpr: `name == "a" || age > 20`, + Category: CategoryOperator, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "name = 'a' OR age > 20", + dialect.MySQL: "name = 'a' OR age > 20", + dialect.SQLite: "name = 'a' OR age > 20", + dialect.DuckDB: "name = 'a' OR age > 20", + dialect.BigQuery: "name = 'a' OR age > 20", + }, + }, + { + Name: "parenthesized", + CELExpr: `age >= 10 && (name == "a" || name == "b")`, + Category: CategoryOperator, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "age >= 10 AND (name = 'a' OR name = 'b')", + dialect.MySQL: "age >= 10 AND (name = 'a' OR name = 'b')", + dialect.SQLite: "age >= 10 AND (name = 'a' OR name = 'b')", + dialect.DuckDB: "age >= 10 AND (name = 'a' OR name = 'b')", + dialect.BigQuery: "age >= 10 AND (name = 'a' OR name = 'b')", + }, + }, + { + Name: "addition", + CELExpr: `1 + 2 == 3`, + Category: CategoryOperator, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "1 + 2 = 3", + dialect.MySQL: "1 + 2 = 3", + dialect.SQLite: "1 + 2 = 3", + dialect.DuckDB: "1 + 2 = 3", + dialect.BigQuery: "1 + 2 = 3", + }, + }, + { + Name: "modulo", + CELExpr: `5 % 3 == 2`, + Category: CategoryOperator, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "MOD(5, 3) = 2", + dialect.MySQL: "MOD(5, 3) = 2", + dialect.SQLite: "MOD(5, 3) = 2", + dialect.DuckDB: "MOD(5, 3) = 2", + dialect.BigQuery: "MOD(5, 3) = 2", + }, + }, + { + Name: "string_concat", + CELExpr: `"a" + "b" == "ab"`, + Category: CategoryOperator, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "'a' || 'b' = 'ab'", + dialect.MySQL: "CONCAT('a', 'b') = 'ab'", + dialect.SQLite: "'a' || 'b' = 'ab'", + dialect.DuckDB: "'a' || 'b' = 'ab'", + dialect.BigQuery: "'a' || 'b' = 'ab'", + }, + }, + { + Name: "list_concat_in", + CELExpr: `1 in [1] + [2, 3]`, + Category: CategoryOperator, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "1 = ANY(ARRAY[1] || ARRAY[2, 3])", + dialect.DuckDB: "1 = ANY([1] || [2, 3])", + dialect.BigQuery: "1 IN UNNEST([1] || [2, 3])", + }, + }, + } +} diff --git a/testcases/parameterized_tests.go b/testcases/parameterized_tests.go new file mode 100644 index 0000000..88441ba --- /dev/null +++ b/testcases/parameterized_tests.go @@ -0,0 +1,111 @@ +package testcases + +import "github.com/spandigital/cel2sql/v3/dialect" + +// ParameterizedTests returns test cases for parameterized SQL conversion. +func ParameterizedTests() []ParameterizedTestCase { + return []ParameterizedTestCase{ + { + Name: "simple_string_equality", + CELExpr: `name == "John"`, + Category: CategoryParameterized, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "name = $1", + dialect.SQLite: "name = ?", + dialect.DuckDB: "name = $1", + dialect.BigQuery: "name = @p1", + }, + WantParams: map[dialect.Name][]any{ + dialect.PostgreSQL: {"John"}, + dialect.SQLite: {"John"}, + dialect.DuckDB: {"John"}, + dialect.BigQuery: {"John"}, + }, + }, + { + Name: "multiple_string_params", + CELExpr: `name == "John" && name != "Jane"`, + Category: CategoryParameterized, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "name = $1 AND name != $2", + dialect.SQLite: "name = ? AND name != ?", + dialect.DuckDB: "name = $1 AND name != $2", + dialect.BigQuery: "name = @p1 AND name != @p2", + }, + WantParams: map[dialect.Name][]any{ + dialect.PostgreSQL: {"John", "Jane"}, + dialect.SQLite: {"John", "Jane"}, + dialect.DuckDB: {"John", "Jane"}, + dialect.BigQuery: {"John", "Jane"}, + }, + }, + { + Name: "integer_equality", + CELExpr: `age == 18`, + Category: CategoryParameterized, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "age = $1", + dialect.SQLite: "age = ?", + dialect.DuckDB: "age = $1", + dialect.BigQuery: "age = @p1", + }, + WantParams: map[dialect.Name][]any{ + dialect.PostgreSQL: {int64(18)}, + dialect.SQLite: {int64(18)}, + dialect.DuckDB: {int64(18)}, + dialect.BigQuery: {int64(18)}, + }, + }, + { + Name: "integer_range", + CELExpr: `age > 21 && age < 65`, + Category: CategoryParameterized, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "age > $1 AND age < $2", + dialect.SQLite: "age > ? AND age < ?", + dialect.DuckDB: "age > $1 AND age < $2", + dialect.BigQuery: "age > @p1 AND age < @p2", + }, + WantParams: map[dialect.Name][]any{ + dialect.PostgreSQL: {int64(21), int64(65)}, + dialect.SQLite: {int64(21), int64(65)}, + dialect.DuckDB: {int64(21), int64(65)}, + dialect.BigQuery: {int64(21), int64(65)}, + }, + }, + { + Name: "double_equality", + CELExpr: `salary == 50000.50`, + Category: CategoryParameterized, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "salary = $1", + dialect.SQLite: "salary = ?", + dialect.DuckDB: "salary = $1", + dialect.BigQuery: "salary = @p1", + }, + WantParams: map[dialect.Name][]any{ + dialect.PostgreSQL: {50000.50}, + dialect.SQLite: {50000.50}, + dialect.DuckDB: {50000.50}, + dialect.BigQuery: {50000.50}, + }, + }, + { + Name: "boolean_true_inline", + CELExpr: `active == true`, + Category: CategoryParameterized, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "active IS TRUE", + dialect.SQLite: "active IS TRUE", + dialect.DuckDB: "active IS TRUE", + dialect.BigQuery: "active IS TRUE", + }, + WantParams: map[dialect.Name][]any{ + dialect.PostgreSQL: {}, + dialect.SQLite: {}, + dialect.DuckDB: {}, + dialect.BigQuery: {}, + }, + }, + } +} diff --git a/testcases/regex_tests.go b/testcases/regex_tests.go new file mode 100644 index 0000000..aeded0e --- /dev/null +++ b/testcases/regex_tests.go @@ -0,0 +1,93 @@ +package testcases + +import "github.com/spandigital/cel2sql/v3/dialect" + +// RegexTests returns test cases for regex pattern matching. +func RegexTests() []ConvertTestCase { + return []ConvertTestCase{ + { + Name: "simple_match", + CELExpr: `name.matches("a+")`, + Category: CategoryRegex, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "name ~ 'a+'", + dialect.MySQL: "name REGEXP 'a+'", + dialect.DuckDB: "name ~ 'a+'", + dialect.BigQuery: "REGEXP_CONTAINS(name, 'a+')", + }, + SkipDialect: map[dialect.Name]string{ + dialect.SQLite: "SQLite does not support regex", + }, + }, + { + Name: "function_style", + CELExpr: `matches(name, "^[0-9]+$")`, + Category: CategoryRegex, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "name ~ '^[0-9]+$'", + dialect.MySQL: "name REGEXP '^[0-9]+$'", + dialect.DuckDB: "name ~ '^[0-9]+$'", + dialect.BigQuery: "REGEXP_CONTAINS(name, '^[0-9]+$')", + }, + SkipDialect: map[dialect.Name]string{ + dialect.SQLite: "SQLite does not support regex", + }, + }, + { + Name: "word_boundary", + CELExpr: `name.matches("\\btest\\b")`, + Category: CategoryRegex, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "name ~ '\\ytest\\y'", + dialect.MySQL: "name REGEXP '\\btest\\b'", + dialect.DuckDB: "name ~ '\\btest\\b'", + dialect.BigQuery: "REGEXP_CONTAINS(name, '\\btest\\b')", + }, + SkipDialect: map[dialect.Name]string{ + dialect.SQLite: "SQLite does not support regex", + }, + }, + { + Name: "digit_class", + CELExpr: `name.matches("\\d{3}-\\d{4}")`, + Category: CategoryRegex, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "name ~ '[[:digit:]]{3}-[[:digit:]]{4}'", + dialect.MySQL: "name REGEXP '\\d{3}-\\d{4}'", + dialect.DuckDB: "name ~ '\\d{3}-\\d{4}'", + dialect.BigQuery: "REGEXP_CONTAINS(name, '\\d{3}-\\d{4}')", + }, + SkipDialect: map[dialect.Name]string{ + dialect.SQLite: "SQLite does not support regex", + }, + }, + { + Name: "word_class", + CELExpr: `name.matches("\\w+@\\w+\\.\\w+")`, + Category: CategoryRegex, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "name ~ '[[:alnum:]_]+@[[:alnum:]_]+\\.[[:alnum:]_]+'", + dialect.MySQL: "name REGEXP '\\w+@\\w+\\.\\w+'", + dialect.DuckDB: "name ~ '\\w+@\\w+\\.\\w+'", + dialect.BigQuery: "REGEXP_CONTAINS(name, '\\w+@\\w+\\.\\w+')", + }, + SkipDialect: map[dialect.Name]string{ + dialect.SQLite: "SQLite does not support regex", + }, + }, + { + Name: "complex_pattern", + CELExpr: `name.matches(".*pattern.*")`, + Category: CategoryRegex, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "name ~ '.*pattern.*'", + dialect.MySQL: "name REGEXP '.*pattern.*'", + dialect.DuckDB: "name ~ '.*pattern.*'", + dialect.BigQuery: "REGEXP_CONTAINS(name, '.*pattern.*')", + }, + SkipDialect: map[dialect.Name]string{ + dialect.SQLite: "SQLite does not support regex", + }, + }, + } +} diff --git a/testcases/string_tests.go b/testcases/string_tests.go new file mode 100644 index 0000000..b7ec52d --- /dev/null +++ b/testcases/string_tests.go @@ -0,0 +1,69 @@ +package testcases + +import "github.com/spandigital/cel2sql/v3/dialect" + +// StringTests returns test cases for string functions. +func StringTests() []ConvertTestCase { + return []ConvertTestCase{ + { + Name: "starts_with", + CELExpr: `name.startsWith("a")`, + Category: CategoryString, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "name LIKE 'a%' ESCAPE E'\\\\'", + dialect.MySQL: "name LIKE 'a%' ESCAPE '\\\\'", + dialect.SQLite: "name LIKE 'a%' ESCAPE '\\'", + dialect.DuckDB: "name LIKE 'a%' ESCAPE '\\\\'", + dialect.BigQuery: "name LIKE 'a%'", + }, + }, + { + Name: "ends_with", + CELExpr: `name.endsWith("z")`, + Category: CategoryString, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "name LIKE '%z' ESCAPE E'\\\\'", + dialect.MySQL: "name LIKE '%z' ESCAPE '\\\\'", + dialect.SQLite: "name LIKE '%z' ESCAPE '\\'", + dialect.DuckDB: "name LIKE '%z' ESCAPE '\\\\'", + dialect.BigQuery: "name LIKE '%z'", + }, + }, + { + Name: "contains", + CELExpr: `name.contains("abc")`, + Category: CategoryString, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "POSITION('abc' IN name) > 0", + dialect.MySQL: "LOCATE('abc', name) > 0", + dialect.SQLite: "INSTR(name, 'abc') > 0", + dialect.DuckDB: "CONTAINS(name, 'abc')", + dialect.BigQuery: "STRPOS(name, 'abc') > 0", + }, + }, + { + Name: "size_string", + CELExpr: `size("test")`, + Category: CategoryString, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "LENGTH('test')", + dialect.MySQL: "LENGTH('test')", + dialect.SQLite: "LENGTH('test')", + dialect.DuckDB: "LENGTH('test')", + dialect.BigQuery: "LENGTH('test')", + }, + }, + { + Name: "starts_with_and_ends_with", + CELExpr: `name.startsWith("a") && name.endsWith("z")`, + Category: CategoryString, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "name LIKE 'a%' ESCAPE E'\\\\' AND name LIKE '%z' ESCAPE E'\\\\'", + dialect.MySQL: "name LIKE 'a%' ESCAPE '\\\\' AND name LIKE '%z' ESCAPE '\\\\'", + dialect.SQLite: "name LIKE 'a%' ESCAPE '\\' AND name LIKE '%z' ESCAPE '\\'", + dialect.DuckDB: "name LIKE 'a%' ESCAPE '\\\\' AND name LIKE '%z' ESCAPE '\\\\'", + dialect.BigQuery: "name LIKE 'a%' AND name LIKE '%z'", + }, + }, + } +} diff --git a/testcases/testcases.go b/testcases/testcases.go new file mode 100644 index 0000000..1de6a41 --- /dev/null +++ b/testcases/testcases.go @@ -0,0 +1,98 @@ +// Package testcases defines shared test case types and helpers for multi-dialect testing. +package testcases + +import "github.com/spandigital/cel2sql/v3/dialect" + +// Category classifies a test case for organization and selective running. +type Category string + +// Test case categories. +const ( + CategoryBasic Category = "basic" + CategoryOperator Category = "operator" + CategoryString Category = "string" + CategoryRegex Category = "regex" + CategoryJSON Category = "json" + CategoryArray Category = "array" + CategoryComprehension Category = "comprehension" + CategoryTimestamp Category = "timestamp" + CategoryParameterized Category = "parameterized" + CategoryCast Category = "cast" + CategoryFieldAccess Category = "field_access" +) + +// ConvertTestCase defines a single CEL-to-SQL conversion test case +// with expected output per dialect. +type ConvertTestCase struct { + // Name is the test case name (used for t.Run). + Name string + + // CELExpr is the CEL expression source to compile and convert. + CELExpr string + + // Category classifies the test case. + Category Category + + // EnvSetup identifies which CEL environment setup to use. + // Empty string means "default" (basic types, no schema). + EnvSetup string + + // WantSQL maps dialect name to expected SQL output. + // If a dialect is absent, the test is skipped for that dialect. + WantSQL map[dialect.Name]string + + // WantErr maps dialect name to whether an error is expected. + // If a dialect is absent, no error is expected. + WantErr map[dialect.Name]bool + + // SkipDialect maps dialect name to a skip reason. + // If a dialect is present, the test is skipped with the given message. + SkipDialect map[dialect.Name]string +} + +// ForDialect returns the expected SQL for a given dialect, and whether the +// test case has an expectation for that dialect. +func (tc *ConvertTestCase) ForDialect(d dialect.Name) (sql string, hasExpected bool) { + sql, hasExpected = tc.WantSQL[d] + return +} + +// ShouldError returns whether an error is expected for the given dialect. +func (tc *ConvertTestCase) ShouldError(d dialect.Name) bool { + return tc.WantErr[d] +} + +// ShouldSkip returns the skip reason for a dialect, or empty string if not skipped. +func (tc *ConvertTestCase) ShouldSkip(d dialect.Name) string { + if tc.SkipDialect == nil { + return "" + } + return tc.SkipDialect[d] +} + +// ParameterizedTestCase defines a test case for parameterized SQL conversion. +type ParameterizedTestCase struct { + // Name is the test case name. + Name string + + // CELExpr is the CEL expression source. + CELExpr string + + // Category classifies the test case. + Category Category + + // EnvSetup identifies which CEL environment setup to use. + EnvSetup string + + // WantSQL maps dialect name to expected parameterized SQL output. + WantSQL map[dialect.Name]string + + // WantParams maps dialect name to expected parameter values. + WantParams map[dialect.Name][]any + + // WantErr maps dialect name to whether an error is expected. + WantErr map[dialect.Name]bool + + // SkipDialect maps dialect name to a skip reason. + SkipDialect map[dialect.Name]string +} diff --git a/testcases/timestamp_tests.go b/testcases/timestamp_tests.go new file mode 100644 index 0000000..81fe7d3 --- /dev/null +++ b/testcases/timestamp_tests.go @@ -0,0 +1,120 @@ +package testcases + +import "github.com/spandigital/cel2sql/v3/dialect" + +// TimestampTests returns test cases for timestamp and duration operations. +func TimestampTests() []ConvertTestCase { + return []ConvertTestCase{ + { + Name: "duration_second", + CELExpr: `duration("10s")`, + Category: CategoryTimestamp, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "INTERVAL 10 SECOND", + dialect.MySQL: "INTERVAL 10 SECOND", + dialect.SQLite: "'+10 seconds'", + dialect.DuckDB: "INTERVAL 10 SECOND", + dialect.BigQuery: "INTERVAL 10 SECOND", + }, + }, + { + Name: "duration_minute", + CELExpr: `duration("1h1m")`, + Category: CategoryTimestamp, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "INTERVAL 61 MINUTE", + dialect.MySQL: "INTERVAL 61 MINUTE", + dialect.SQLite: "'+61 minutes'", + dialect.DuckDB: "INTERVAL 61 MINUTE", + dialect.BigQuery: "INTERVAL 61 MINUTE", + }, + }, + { + Name: "duration_hour", + CELExpr: `duration("60m")`, + Category: CategoryTimestamp, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "INTERVAL 1 HOUR", + dialect.MySQL: "INTERVAL 1 HOUR", + dialect.SQLite: "'+1 hours'", + dialect.DuckDB: "INTERVAL 1 HOUR", + dialect.BigQuery: "INTERVAL 1 HOUR", + }, + }, + { + Name: "timestamp_getSeconds", + CELExpr: `created_at.getSeconds()`, + Category: CategoryTimestamp, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "EXTRACT(SECOND FROM created_at)", + dialect.MySQL: "EXTRACT(SECOND FROM created_at)", + dialect.SQLite: "CAST(strftime('%S', created_at) AS INTEGER)", + dialect.DuckDB: "EXTRACT(SECOND FROM created_at)", + dialect.BigQuery: "EXTRACT(SECOND FROM created_at)", + }, + }, + { + Name: "timestamp_getHours_withTimezone", + CELExpr: `created_at.getHours("Asia/Tokyo")`, + Category: CategoryTimestamp, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "EXTRACT(HOUR FROM created_at AT TIME ZONE 'Asia/Tokyo')", + dialect.MySQL: "EXTRACT(HOUR FROM created_at AT TIME ZONE 'Asia/Tokyo')", + dialect.SQLite: "CAST(strftime('%H', created_at) AS INTEGER)", + dialect.DuckDB: "EXTRACT(HOUR FROM created_at AT TIME ZONE 'Asia/Tokyo')", + dialect.BigQuery: "EXTRACT(HOUR FROM created_at AT TIME ZONE 'Asia/Tokyo')", + }, + }, + { + Name: "timestamp_sub_duration", + CELExpr: `created_at - duration("60m") <= timestamp("2021-09-01T18:00:00Z")`, + Category: CategoryTimestamp, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "created_at - INTERVAL 1 HOUR <= CAST('2021-09-01T18:00:00Z' AS TIMESTAMP WITH TIME ZONE)", + dialect.MySQL: "created_at - INTERVAL 1 HOUR <= CAST('2021-09-01T18:00:00Z' AS DATETIME)", + dialect.SQLite: "datetime(created_at, '-'||'+1 hours') <= datetime('2021-09-01T18:00:00Z')", + dialect.DuckDB: "created_at - INTERVAL 1 HOUR <= CAST('2021-09-01T18:00:00Z' AS TIMESTAMPTZ)", + dialect.BigQuery: "TIMESTAMP_SUB(created_at, INTERVAL 1 HOUR) <= CAST('2021-09-01T18:00:00Z' AS TIMESTAMP)", + }, + }, + { + Name: "interval_month", + CELExpr: `interval(1, MONTH)`, + Category: CategoryTimestamp, + EnvSetup: EnvWithTimestamp, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "INTERVAL 1 MONTH", + dialect.MySQL: "INTERVAL 1 MONTH", + dialect.SQLite: "'+'||1||' months'", + dialect.DuckDB: "INTERVAL 1 MONTH", + dialect.BigQuery: "INTERVAL 1 MONTH", + }, + }, + { + Name: "date_getFullYear", + CELExpr: `birthday.getFullYear()`, + Category: CategoryTimestamp, + EnvSetup: EnvWithTimestamp, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "EXTRACT(YEAR FROM birthday)", + dialect.MySQL: "EXTRACT(YEAR FROM birthday)", + dialect.SQLite: "CAST(strftime('%Y', birthday) AS INTEGER)", + dialect.DuckDB: "EXTRACT(YEAR FROM birthday)", + dialect.BigQuery: "EXTRACT(YEAR FROM birthday)", + }, + }, + { + Name: "datetime_getMonth", + CELExpr: `scheduled_at.getMonth()`, + Category: CategoryTimestamp, + EnvSetup: EnvWithTimestamp, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "EXTRACT(MONTH FROM scheduled_at) - 1", + dialect.MySQL: "EXTRACT(MONTH FROM scheduled_at) - 1", + dialect.SQLite: "CAST(strftime('%m', scheduled_at) AS INTEGER) - 1", + dialect.DuckDB: "EXTRACT(MONTH FROM scheduled_at) - 1", + dialect.BigQuery: "EXTRACT(MONTH FROM scheduled_at) - 1", + }, + }, + } +} diff --git a/testdata/bigquery_seed.yaml b/testdata/bigquery_seed.yaml new file mode 100644 index 0000000..4375cbf --- /dev/null +++ b/testdata/bigquery_seed.yaml @@ -0,0 +1,78 @@ +projects: + - id: test-project + datasets: + - id: testdataset + tables: + - id: test_data + columns: + - name: id + type: INT64 + - name: text_val + type: STRING + - name: int_val + type: INT64 + - name: float_val + type: FLOAT64 + - name: bool_val + type: BOOL + - name: nullable_text + type: STRING + - name: nullable_int + type: INT64 + data: + - id: 1 + text_val: "hello" + int_val: 10 + float_val: 10.5 + bool_val: true + nullable_text: "present" + nullable_int: 100 + - id: 2 + text_val: "world" + int_val: 20 + float_val: 20.5 + bool_val: false + - id: 3 + text_val: "test" + int_val: 30 + float_val: 30.5 + bool_val: true + nullable_text: "here" + nullable_int: 200 + - id: 4 + text_val: "hello world" + int_val: 5 + float_val: 5.5 + bool_val: false + nullable_text: "value" + nullable_int: 50 + - id: 5 + text_val: "testing" + int_val: 15 + float_val: 15.5 + bool_val: true + nullable_text: "test" + nullable_int: 150 + - id: products + columns: + - name: id + type: INT64 + - name: name + type: STRING + - name: price + type: FLOAT64 + - name: metadata + type: STRING + data: + - id: 1 + name: "Widget" + price: 19.99 + metadata: '{"brand": "Acme", "color": "red", "specs": {"weight": 100}}' + - id: 2 + name: "Gadget" + price: 29.99 + metadata: '{"brand": "Beta", "color": "blue", "specs": {"weight": 200}}' + - id: 3 + name: "Doohickey" + price: 39.99 + metadata: '{"brand": "Acme", "color": "green", "specs": {"weight": 150}}' diff --git a/testutil/env.go b/testutil/env.go new file mode 100644 index 0000000..1239a86 --- /dev/null +++ b/testutil/env.go @@ -0,0 +1,310 @@ +package testutil + +import ( + "fmt" + + "github.com/google/cel-go/cel" + + "github.com/spandigital/cel2sql/v3" + dialectpkg "github.com/spandigital/cel2sql/v3/dialect" + bigqueryDialect "github.com/spandigital/cel2sql/v3/dialect/bigquery" + duckdbDialect "github.com/spandigital/cel2sql/v3/dialect/duckdb" + mysqlDialect "github.com/spandigital/cel2sql/v3/dialect/mysql" + sqliteDialect "github.com/spandigital/cel2sql/v3/dialect/sqlite" + "github.com/spandigital/cel2sql/v3/pg" + "github.com/spandigital/cel2sql/v3/sqltypes" + "github.com/spandigital/cel2sql/v3/testcases" +) + +// EnvResult holds both the CEL environment and convert options needed for testing. +type EnvResult struct { + Env *cel.Env + Opts []cel2sql.ConvertOption +} + +// NewDefaultEnv creates a basic CEL environment with standard variable types. +func NewDefaultEnv() (*EnvResult, error) { + env, err := cel.NewEnv( + cel.Types( + sqltypes.Date, sqltypes.Time, sqltypes.DateTime, sqltypes.Interval, sqltypes.DatePart, + ), + cel.Variable("name", cel.StringType), + cel.Variable("age", cel.IntType), + cel.Variable("adult", cel.BoolType), + cel.Variable("height", cel.DoubleType), + cel.Variable("string_list", cel.ListType(cel.StringType)), + cel.Variable("string_int_map", cel.MapType(cel.StringType, cel.IntType)), + cel.Variable("null_var", cel.NullType), + cel.Variable("created_at", cel.TimestampType), + cel.Variable("page", cel.MapType(cel.StringType, cel.StringType)), + cel.Variable("salary", cel.DoubleType), + cel.Variable("active", cel.BoolType), + cel.Variable("data", cel.BytesType), + cel.Variable("tags", cel.ListType(cel.StringType)), + cel.Variable("scores", cel.ListType(cel.IntType)), + // Cast functions + cel.Function("bool", cel.Overload("bool_from_int", []*cel.Type{cel.IntType}, cel.BoolType)), + cel.Function("int", cel.Overload("int_from_bool", []*cel.Type{cel.BoolType}, cel.IntType)), + ) + if err != nil { + return nil, err + } + return &EnvResult{Env: env}, nil +} + +// NewTimestampEnv creates a CEL environment with timestamp-related types and functions. +func NewTimestampEnv() (*EnvResult, error) { + env, err := cel.NewEnv( + cel.Types( + sqltypes.Date, sqltypes.Time, sqltypes.DateTime, sqltypes.Interval, sqltypes.DatePart, + ), + cel.Variable("name", cel.StringType), + cel.Variable("age", cel.IntType), + cel.Variable("adult", cel.BoolType), + cel.Variable("height", cel.DoubleType), + cel.Variable("string_list", cel.ListType(cel.StringType)), + cel.Variable("string_int_map", cel.MapType(cel.StringType, cel.IntType)), + cel.Variable("null_var", cel.NullType), + cel.Variable("birthday", cel.ObjectType("DATE")), + cel.Variable("fixed_time", cel.ObjectType("TIME")), + cel.Variable("scheduled_at", cel.ObjectType("DATETIME")), + cel.Variable("created_at", cel.TimestampType), + cel.Variable("page", cel.MapType(cel.StringType, cel.StringType)), + // Date part constants + cel.Variable("YEAR", cel.ObjectType("date_part")), + cel.Variable("MONTH", cel.ObjectType("date_part")), + cel.Variable("DAY", cel.ObjectType("date_part")), + cel.Variable("HOUR", cel.ObjectType("date_part")), + cel.Variable("MINUTE", cel.ObjectType("date_part")), + cel.Variable("SECOND", cel.ObjectType("date_part")), + // SQL functions + cel.Function("date", + cel.Overload("date_string", []*cel.Type{cel.StringType}, cel.ObjectType("DATE")), + cel.Overload("date_int_int_int", []*cel.Type{cel.IntType, cel.IntType, cel.IntType}, cel.ObjectType("DATE"))), + cel.Function("time", cel.Overload("time_string", []*cel.Type{cel.StringType}, cel.ObjectType("TIME"))), + cel.Function("datetime", + cel.Overload("datetime_string", []*cel.Type{cel.StringType}, cel.ObjectType("DATETIME")), + cel.Overload("datetime_date_time", []*cel.Type{cel.ObjectType("DATE"), cel.ObjectType("TIME")}, cel.ObjectType("DATETIME"))), + cel.Function("timestamp", + cel.Overload("timestamp_datetime_string", []*cel.Type{cel.ObjectType("DATETIME"), cel.StringType}, cel.TimestampType)), + cel.Function("interval", cel.Overload("interval_int_datepart", []*cel.Type{cel.IntType, cel.ObjectType("date_part")}, cel.ObjectType("INTERVAL"))), + cel.Function("current_date", cel.Overload("current_date", []*cel.Type{}, cel.ObjectType("DATE"))), + cel.Function("current_datetime", cel.Overload("current_datetime_string", []*cel.Type{cel.StringType}, cel.ObjectType("DATETIME"))), + // Date/Time arithmetic operators + cel.Function("_+_", + cel.Overload("date_add_interval", []*cel.Type{cel.ObjectType("DATE"), cel.ObjectType("INTERVAL")}, cel.ObjectType("DATE")), + cel.Overload("date_add_int", []*cel.Type{cel.ObjectType("DATE"), cel.IntType}, cel.ObjectType("DATE")), + cel.Overload("time_add_interval", []*cel.Type{cel.ObjectType("TIME"), cel.ObjectType("INTERVAL")}, cel.ObjectType("TIME")), + cel.Overload("datetime_add_interval", []*cel.Type{cel.ObjectType("DATETIME"), cel.ObjectType("INTERVAL")}, cel.ObjectType("DATETIME")), + cel.Overload("timestamp_add_interval", []*cel.Type{cel.TimestampType, cel.ObjectType("INTERVAL")}, cel.TimestampType)), + cel.Function("_-_", + cel.Overload("date_sub_interval", []*cel.Type{cel.ObjectType("DATE"), cel.ObjectType("INTERVAL")}, cel.ObjectType("DATE")), + cel.Overload("time_sub_interval", []*cel.Type{cel.ObjectType("TIME"), cel.ObjectType("INTERVAL")}, cel.ObjectType("TIME")), + cel.Overload("datetime_sub_interval", []*cel.Type{cel.ObjectType("DATETIME"), cel.ObjectType("INTERVAL")}, cel.ObjectType("DATETIME")), + cel.Overload("timestamp_sub_interval", []*cel.Type{cel.TimestampType, cel.ObjectType("INTERVAL")}, cel.TimestampType)), + // Date/Time comparison operators + cel.Function("_>_", + cel.Overload("date_gt_date", []*cel.Type{cel.ObjectType("DATE"), cel.ObjectType("DATE")}, cel.BoolType)), + // Date/Time methods + cel.Function("getFullYear", cel.MemberOverload("date_getFullYear", []*cel.Type{cel.ObjectType("DATE")}, cel.IntType)), + cel.Function("getMonth", cel.MemberOverload("datetime_getMonth", []*cel.Type{cel.ObjectType("DATETIME")}, cel.IntType)), + cel.Function("getDayOfMonth", cel.MemberOverload("datetime_getDayOfMonth", []*cel.Type{cel.ObjectType("DATETIME")}, cel.IntType)), + cel.Function("getMinutes", cel.MemberOverload("time_getMinutes", []*cel.Type{cel.ObjectType("TIME")}, cel.IntType)), + // Cast functions + cel.Function("bool", cel.Overload("bool_from_int", []*cel.Type{cel.IntType}, cel.BoolType)), + cel.Function("int", cel.Overload("int_from_bool", []*cel.Type{cel.BoolType}, cel.IntType)), + ) + if err != nil { + return nil, err + } + return &EnvResult{Env: env}, nil +} + +// NewJSONSchemaEnv creates a CEL environment with a JSON-enabled schema type provider. +func NewJSONSchemaEnv() (*EnvResult, error) { + productSchema := testcases.NewProductPGSchema() + schemas := map[string]pg.Schema{ + "product": productSchema, + } + provider := pg.NewTypeProvider(schemas) + env, err := cel.NewEnv( + cel.CustomTypeProvider(provider), + cel.Variable("product", cel.ObjectType("product")), + ) + if err != nil { + return nil, err + } + return &EnvResult{ + Env: env, + Opts: []cel2sql.ConvertOption{cel2sql.WithSchemas(schemas)}, + }, nil +} + +// PostgreSQLEnvFactory returns an environment factory for PostgreSQL tests. +func PostgreSQLEnvFactory() func(envSetup string) (*EnvResult, error) { + return func(envSetup string) (*EnvResult, error) { + switch envSetup { + case testcases.EnvDefault: + return NewDefaultEnv() + case testcases.EnvWithTimestamp: + return NewTimestampEnv() + case testcases.EnvWithJSON: + return NewJSONSchemaEnv() + default: + return nil, fmt.Errorf("unknown environment setup: %s", envSetup) + } + } +} + +// MySQLEnvFactory returns an environment factory for MySQL tests. +// It uses the same CEL environments as PostgreSQL (CEL compilation is dialect-independent) +// but sets the MySQL dialect for SQL generation. +func MySQLEnvFactory() func(envSetup string) (*EnvResult, error) { + return func(envSetup string) (*EnvResult, error) { + switch envSetup { + case testcases.EnvDefault: + result, err := NewDefaultEnv() + if err != nil { + return nil, err + } + result.Opts = append(result.Opts, cel2sql.WithDialect(mysqlDialect.New())) + return result, nil + case testcases.EnvWithTimestamp: + result, err := NewTimestampEnv() + if err != nil { + return nil, err + } + result.Opts = append(result.Opts, cel2sql.WithDialect(mysqlDialect.New())) + return result, nil + case testcases.EnvWithJSON: + result, err := NewJSONSchemaEnv() + if err != nil { + return nil, err + } + result.Opts = append(result.Opts, cel2sql.WithDialect(mysqlDialect.New())) + return result, nil + default: + return nil, fmt.Errorf("unknown environment setup: %s", envSetup) + } + } +} + +// SQLiteEnvFactory returns an environment factory for SQLite tests. +// It uses the same CEL environments as PostgreSQL (CEL compilation is dialect-independent) +// but sets the SQLite dialect for SQL generation. +func SQLiteEnvFactory() func(envSetup string) (*EnvResult, error) { + return func(envSetup string) (*EnvResult, error) { + switch envSetup { + case testcases.EnvDefault: + result, err := NewDefaultEnv() + if err != nil { + return nil, err + } + result.Opts = append(result.Opts, cel2sql.WithDialect(sqliteDialect.New())) + return result, nil + case testcases.EnvWithTimestamp: + result, err := NewTimestampEnv() + if err != nil { + return nil, err + } + result.Opts = append(result.Opts, cel2sql.WithDialect(sqliteDialect.New())) + return result, nil + case testcases.EnvWithJSON: + result, err := NewJSONSchemaEnv() + if err != nil { + return nil, err + } + result.Opts = append(result.Opts, cel2sql.WithDialect(sqliteDialect.New())) + return result, nil + default: + return nil, fmt.Errorf("unknown environment setup: %s", envSetup) + } + } +} + +// DuckDBEnvFactory returns an environment factory for DuckDB tests. +// It uses the same CEL environments as PostgreSQL (CEL compilation is dialect-independent) +// but sets the DuckDB dialect for SQL generation. +func DuckDBEnvFactory() func(envSetup string) (*EnvResult, error) { + return func(envSetup string) (*EnvResult, error) { + switch envSetup { + case testcases.EnvDefault: + result, err := NewDefaultEnv() + if err != nil { + return nil, err + } + result.Opts = append(result.Opts, cel2sql.WithDialect(duckdbDialect.New())) + return result, nil + case testcases.EnvWithTimestamp: + result, err := NewTimestampEnv() + if err != nil { + return nil, err + } + result.Opts = append(result.Opts, cel2sql.WithDialect(duckdbDialect.New())) + return result, nil + case testcases.EnvWithJSON: + result, err := NewJSONSchemaEnv() + if err != nil { + return nil, err + } + result.Opts = append(result.Opts, cel2sql.WithDialect(duckdbDialect.New())) + return result, nil + default: + return nil, fmt.Errorf("unknown environment setup: %s", envSetup) + } + } +} + +// BigQueryEnvFactory returns an environment factory for BigQuery tests. +// It uses the same CEL environments as PostgreSQL (CEL compilation is dialect-independent) +// but sets the BigQuery dialect for SQL generation. +func BigQueryEnvFactory() func(envSetup string) (*EnvResult, error) { + return func(envSetup string) (*EnvResult, error) { + switch envSetup { + case testcases.EnvDefault: + result, err := NewDefaultEnv() + if err != nil { + return nil, err + } + result.Opts = append(result.Opts, cel2sql.WithDialect(bigqueryDialect.New())) + return result, nil + case testcases.EnvWithTimestamp: + result, err := NewTimestampEnv() + if err != nil { + return nil, err + } + result.Opts = append(result.Opts, cel2sql.WithDialect(bigqueryDialect.New())) + return result, nil + case testcases.EnvWithJSON: + result, err := NewJSONSchemaEnv() + if err != nil { + return nil, err + } + result.Opts = append(result.Opts, cel2sql.WithDialect(bigqueryDialect.New())) + return result, nil + default: + return nil, fmt.Errorf("unknown environment setup: %s", envSetup) + } + } +} + +// DialectEnvFactory returns an environment factory for the given dialect. +// This is a convenience function that maps dialect names to their env factories. +func DialectEnvFactory(d dialectpkg.Name) func(envSetup string) (*EnvResult, error) { + switch d { + case dialectpkg.PostgreSQL: + return PostgreSQLEnvFactory() + case dialectpkg.MySQL: + return MySQLEnvFactory() + case dialectpkg.SQLite: + return SQLiteEnvFactory() + case dialectpkg.DuckDB: + return DuckDBEnvFactory() + case dialectpkg.BigQuery: + return BigQueryEnvFactory() + default: + return func(_ string) (*EnvResult, error) { + return nil, fmt.Errorf("no environment factory for dialect %s", d) + } + } +} diff --git a/testutil/runner.go b/testutil/runner.go new file mode 100644 index 0000000..74fb8a2 --- /dev/null +++ b/testutil/runner.go @@ -0,0 +1,180 @@ +// Package testutil provides multi-dialect test runners and helpers. +package testutil + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/spandigital/cel2sql/v3" + "github.com/spandigital/cel2sql/v3/dialect" + "github.com/spandigital/cel2sql/v3/testcases" +) + +// RunConvertTests runs a set of ConvertTestCase entries for a given dialect. +// envFactory returns an EnvResult (CEL env + convert options) for the given EnvSetup key. +// Additional opts are appended after any env-specific options. +func RunConvertTests( + t *testing.T, + dialectName dialect.Name, + cases []testcases.ConvertTestCase, + envFactory func(envSetup string) (*EnvResult, error), + opts ...cel2sql.ConvertOption, +) { + t.Helper() + for _, tc := range cases { + t.Run(tc.Name, func(t *testing.T) { + // Check skip + if reason := tc.ShouldSkip(dialectName); reason != "" { + t.Skip(reason) + } + + // Check if we have an expectation for this dialect + wantSQL, hasExpected := tc.ForDialect(dialectName) + wantErr := tc.ShouldError(dialectName) + + if !hasExpected && !wantErr { + t.Skipf("no expected SQL for dialect %s", dialectName) + } + + // Build CEL environment + envResult, err := envFactory(tc.EnvSetup) + require.NoError(t, err, "failed to create CEL environment") + + // Compile CEL expression + ast, issues := envResult.Env.Compile(tc.CELExpr) + if issues != nil && issues.Err() != nil { + if wantErr { + return // expected compile error + } + t.Fatalf("CEL compile failed: %v", issues.Err()) + } + + // Merge options: env-specific first, then caller-provided + allOpts := make([]cel2sql.ConvertOption, 0, len(envResult.Opts)+len(opts)) + allOpts = append(allOpts, envResult.Opts...) + allOpts = append(allOpts, opts...) + + // Convert + got, err := cel2sql.Convert(ast, allOpts...) + if wantErr { + assert.Error(t, err, "expected error for dialect %s", dialectName) + return + } + + if assert.NoError(t, err) { + assert.Equal(t, wantSQL, got, "SQL mismatch for dialect %s", dialectName) + } + }) + } +} + +// RunParameterizedTests runs a set of ParameterizedTestCase entries for a given dialect. +func RunParameterizedTests( + t *testing.T, + dialectName dialect.Name, + cases []testcases.ParameterizedTestCase, + envFactory func(envSetup string) (*EnvResult, error), + opts ...cel2sql.ConvertOption, +) { + t.Helper() + for _, tc := range cases { + t.Run(tc.Name, func(t *testing.T) { + // Check skip + if tc.SkipDialect != nil { + if reason, ok := tc.SkipDialect[dialectName]; ok && reason != "" { + t.Skip(reason) + } + } + + wantSQL, hasExpected := tc.WantSQL[dialectName] + wantErr := tc.WantErr[dialectName] + + if !hasExpected && !wantErr { + t.Skipf("no expected SQL for dialect %s", dialectName) + } + + // Build CEL environment + envResult, err := envFactory(tc.EnvSetup) + require.NoError(t, err, "failed to create CEL environment") + + // Compile CEL expression + ast, issues := envResult.Env.Compile(tc.CELExpr) + if issues != nil && issues.Err() != nil { + if wantErr { + return + } + t.Fatalf("CEL compile failed: %v", issues.Err()) + } + + // Merge options + allOpts := make([]cel2sql.ConvertOption, 0, len(envResult.Opts)+len(opts)) + allOpts = append(allOpts, envResult.Opts...) + allOpts = append(allOpts, opts...) + + // Convert + result, err := cel2sql.ConvertParameterized(ast, allOpts...) + if wantErr { + assert.Error(t, err) + return + } + + if assert.NoError(t, err) { + assert.Equal(t, wantSQL, result.SQL, "SQL mismatch for dialect %s", dialectName) + + if wantParams, ok := tc.WantParams[dialectName]; ok && len(wantParams) > 0 { + assert.Equal(t, wantParams, result.Parameters, "params mismatch for dialect %s", dialectName) + } + } + }) + } +} + +// RunAllConvertTests runs all standard test suites for a given dialect. +func RunAllConvertTests( + t *testing.T, + dialectName dialect.Name, + envFactory func(envSetup string) (*EnvResult, error), + opts ...cel2sql.ConvertOption, +) { + t.Run("basic", func(t *testing.T) { + RunConvertTests(t, dialectName, testcases.BasicTests(), envFactory, opts...) + }) + t.Run("operators", func(t *testing.T) { + RunConvertTests(t, dialectName, testcases.OperatorTests(), envFactory, opts...) + }) + t.Run("strings", func(t *testing.T) { + RunConvertTests(t, dialectName, testcases.StringTests(), envFactory, opts...) + }) + t.Run("regex", func(t *testing.T) { + RunConvertTests(t, dialectName, testcases.RegexTests(), envFactory, opts...) + }) + t.Run("casts", func(t *testing.T) { + RunConvertTests(t, dialectName, testcases.CastTests(), envFactory, opts...) + }) + t.Run("arrays", func(t *testing.T) { + RunConvertTests(t, dialectName, testcases.ArrayTests(), envFactory, opts...) + }) + t.Run("timestamps", func(t *testing.T) { + RunConvertTests(t, dialectName, testcases.TimestampTests(), envFactory, opts...) + }) + t.Run("comprehensions", func(t *testing.T) { + RunConvertTests(t, dialectName, testcases.ComprehensionTests(), envFactory, opts...) + }) + t.Run("json", func(t *testing.T) { + RunConvertTests(t, dialectName, testcases.JSONTests(), envFactory, opts...) + }) +} + +// RunAllParameterizedTests runs all parameterized test suites for a given dialect. +func RunAllParameterizedTests( + t *testing.T, + dialectName dialect.Name, + envFactory func(envSetup string) (*EnvResult, error), + opts ...cel2sql.ConvertOption, +) { + t.Run("parameterized", func(t *testing.T) { + RunParameterizedTests(t, dialectName, testcases.ParameterizedTests(), envFactory, opts...) + }) +} diff --git a/testutil/runner_bigquery_test.go b/testutil/runner_bigquery_test.go new file mode 100644 index 0000000..4381120 --- /dev/null +++ b/testutil/runner_bigquery_test.go @@ -0,0 +1,16 @@ +package testutil_test + +import ( + "testing" + + "github.com/spandigital/cel2sql/v3/dialect" + "github.com/spandigital/cel2sql/v3/testutil" +) + +func TestBigQuerySharedCases(t *testing.T) { + testutil.RunAllConvertTests(t, dialect.BigQuery, testutil.BigQueryEnvFactory()) +} + +func TestBigQueryParameterizedSharedCases(t *testing.T) { + testutil.RunAllParameterizedTests(t, dialect.BigQuery, testutil.BigQueryEnvFactory()) +} diff --git a/testutil/runner_duckdb_test.go b/testutil/runner_duckdb_test.go new file mode 100644 index 0000000..c509510 --- /dev/null +++ b/testutil/runner_duckdb_test.go @@ -0,0 +1,16 @@ +package testutil_test + +import ( + "testing" + + "github.com/spandigital/cel2sql/v3/dialect" + "github.com/spandigital/cel2sql/v3/testutil" +) + +func TestDuckDBSharedCases(t *testing.T) { + testutil.RunAllConvertTests(t, dialect.DuckDB, testutil.DuckDBEnvFactory()) +} + +func TestDuckDBParameterizedSharedCases(t *testing.T) { + testutil.RunAllParameterizedTests(t, dialect.DuckDB, testutil.DuckDBEnvFactory()) +} diff --git a/testutil/runner_mysql_test.go b/testutil/runner_mysql_test.go new file mode 100644 index 0000000..5087059 --- /dev/null +++ b/testutil/runner_mysql_test.go @@ -0,0 +1,12 @@ +package testutil_test + +import ( + "testing" + + "github.com/spandigital/cel2sql/v3/dialect" + "github.com/spandigital/cel2sql/v3/testutil" +) + +func TestMySQLSharedCases(t *testing.T) { + testutil.RunAllConvertTests(t, dialect.MySQL, testutil.MySQLEnvFactory()) +} diff --git a/testutil/runner_pg_test.go b/testutil/runner_pg_test.go new file mode 100644 index 0000000..bfc3e15 --- /dev/null +++ b/testutil/runner_pg_test.go @@ -0,0 +1,16 @@ +package testutil_test + +import ( + "testing" + + "github.com/spandigital/cel2sql/v3/dialect" + "github.com/spandigital/cel2sql/v3/testutil" +) + +func TestPostgreSQLSharedCases(t *testing.T) { + testutil.RunAllConvertTests(t, dialect.PostgreSQL, testutil.PostgreSQLEnvFactory()) +} + +func TestPostgreSQLParameterizedSharedCases(t *testing.T) { + testutil.RunAllParameterizedTests(t, dialect.PostgreSQL, testutil.PostgreSQLEnvFactory()) +} diff --git a/testutil/runner_sqlite_test.go b/testutil/runner_sqlite_test.go new file mode 100644 index 0000000..0e6cfb7 --- /dev/null +++ b/testutil/runner_sqlite_test.go @@ -0,0 +1,16 @@ +package testutil_test + +import ( + "testing" + + "github.com/spandigital/cel2sql/v3/dialect" + "github.com/spandigital/cel2sql/v3/testutil" +) + +func TestSQLiteSharedCases(t *testing.T) { + testutil.RunAllConvertTests(t, dialect.SQLite, testutil.SQLiteEnvFactory()) +} + +func TestSQLiteParameterizedSharedCases(t *testing.T) { + testutil.RunAllParameterizedTests(t, dialect.SQLite, testutil.SQLiteEnvFactory()) +} diff --git a/timestamps.go b/timestamps.go index d0defc6..13d62e9 100644 --- a/timestamps.go +++ b/timestamps.go @@ -2,7 +2,6 @@ package cel2sql import ( "fmt" - "strconv" "time" "github.com/google/cel-go/common/operators" @@ -55,7 +54,6 @@ func (con *converter) callTimestampOperation(fun string, lhs *exprpb.Expr, rhs * return newConversionError(errMsgInvalidTimestampOp, "timestamp operation requires at least one timestamp operand") } - // PostgreSQL uses simple + and - operators for date arithmetic var sqlOp string switch fun { case operators.Add: @@ -66,16 +64,10 @@ func (con *converter) callTimestampOperation(fun string, lhs *exprpb.Expr, rhs * return newConversionError(errMsgInvalidTimestampOp, "unsupported timestamp operation") } - if err := con.visitMaybeNested(timestamp, timestampParen); err != nil { - return err - } - con.str.WriteString(" ") - con.str.WriteString(sqlOp) - con.str.WriteString(" ") - if err := con.visitMaybeNested(duration, durationParen); err != nil { - return err - } - return nil + return con.dialect.WriteTimestampArithmetic(&con.str, sqlOp, + func() error { return con.visitMaybeNested(timestamp, timestampParen) }, + func() error { return con.visitMaybeNested(duration, durationParen) }, + ) } // callDuration converts CEL duration expressions to PostgreSQL INTERVAL @@ -100,105 +92,96 @@ func (con *converter) callDuration(_ *exprpb.Expr, args []*exprpb.Expr) error { if err != nil { return err } - con.str.WriteString("INTERVAL ") + var value int64 + var unit string switch d { case d.Round(time.Hour): - con.str.WriteString(strconv.FormatFloat(d.Hours(), 'f', 0, 64)) - con.str.WriteString(" HOUR") + value = int64(d.Hours()) + unit = "HOUR" case d.Round(time.Minute): - con.str.WriteString(strconv.FormatFloat(d.Minutes(), 'f', 0, 64)) - con.str.WriteString(" MINUTE") + value = int64(d.Minutes()) + unit = "MINUTE" case d.Round(time.Second): - con.str.WriteString(strconv.FormatFloat(d.Seconds(), 'f', 0, 64)) - con.str.WriteString(" SECOND") + value = int64(d.Seconds()) + unit = "SECOND" case d.Round(time.Millisecond): - con.str.WriteString(strconv.FormatInt(d.Milliseconds(), 10)) - con.str.WriteString(" MILLISECOND") + value = d.Milliseconds() + unit = "MILLISECOND" default: - con.str.WriteString(strconv.FormatInt(d.Truncate(time.Microsecond).Microseconds(), 10)) - con.str.WriteString(" MICROSECOND") + value = d.Truncate(time.Microsecond).Microseconds() + unit = "MICROSECOND" } + con.dialect.WriteDuration(&con.str, value, unit) return nil } -// callInterval creates PostgreSQL INTERVAL expressions +// callInterval creates INTERVAL expressions using the dialect func (con *converter) callInterval(_ *exprpb.Expr, args []*exprpb.Expr) error { - con.str.WriteString("INTERVAL ") - if err := con.visit(args[0]); err != nil { - return err - } - con.str.WriteString(" ") datePart := args[1] - con.str.WriteString(datePart.GetIdentExpr().GetName()) - return nil + unit := datePart.GetIdentExpr().GetName() + return con.dialect.WriteInterval(&con.str, func() error { + return con.visit(args[0]) + }, unit) } // callExtractFromTimestamp handles timestamp field extraction (YEAR, MONTH, DAY, etc.) func (con *converter) callExtractFromTimestamp(function string, target *exprpb.Expr, args []*exprpb.Expr) error { - // For getDayOfWeek, we need to wrap the entire EXTRACT in parentheses for modulo operation - if function == overloads.TimeGetDayOfWeek { - con.str.WriteString("(") - } - con.str.WriteString("EXTRACT(") + var part string switch function { case overloads.TimeGetFullYear: - con.str.WriteString("YEAR") + part = "YEAR" case overloads.TimeGetMonth: - con.str.WriteString("MONTH") + part = "MONTH" case overloads.TimeGetDate: - con.str.WriteString("DAY") + part = "DAY" case overloads.TimeGetHours: - con.str.WriteString("HOUR") + part = "HOUR" case overloads.TimeGetMinutes: - con.str.WriteString("MINUTE") + part = "MINUTE" case overloads.TimeGetSeconds: - con.str.WriteString("SECOND") + part = "SECOND" case overloads.TimeGetMilliseconds: - con.str.WriteString("MILLISECONDS") + part = "MILLISECONDS" case overloads.TimeGetDayOfYear: - con.str.WriteString("DOY") + part = "DOY" case overloads.TimeGetDayOfMonth: - con.str.WriteString("DAY") + part = "DAY" case overloads.TimeGetDayOfWeek: - con.str.WriteString("DOW") + part = "DOW" } - con.str.WriteString(" FROM ") - if err := con.visit(target); err != nil { - return err + + writeExpr := func() error { + return con.visit(target) } + + var writeTZ func() error if isTimestampType(con.getType(target)) && len(args) == 1 { - con.str.WriteString(" AT TIME ZONE ") - if err := con.visit(args[0]); err != nil { - return err + writeTZ = func() error { + return con.visit(args[0]) } } - con.str.WriteString(")") + + if err := con.dialect.WriteExtract(&con.str, part, writeExpr, writeTZ); err != nil { + return err + } + + // Apply CEL-specific adjustments (these are universal, not dialect-specific) switch function { case overloads.TimeGetMonth, overloads.TimeGetDayOfYear, overloads.TimeGetDayOfMonth: con.str.WriteString(" - 1") - case overloads.TimeGetDayOfWeek: - // PostgreSQL DOW: 0=Sunday, 1=Monday, ..., 6=Saturday - // CEL getDayOfWeek: 0=Monday, 1=Tuesday, ..., 6=Sunday (ISO 8601) - // Convert: (DOW + 6) % 7 - con.str.WriteString(" + 6) % 7") } return nil } -// callTimestampFromString converts string literals to PostgreSQL timestamps +// callTimestampFromString converts string literals to timestamps using the dialect func (con *converter) callTimestampFromString(_ *exprpb.Expr, args []*exprpb.Expr) error { if len(args) == 1 { - // For PostgreSQL, we need to cast the string to a timestamp - con.str.WriteString("CAST(") - err := con.visit(args[0]) - if err != nil { - return err - } - con.str.WriteString(" AS TIMESTAMP WITH TIME ZONE)") - return nil + return con.dialect.WriteTimestampCast(&con.str, func() error { + return con.visit(args[0]) + }) } else if len(args) == 2 { // Handle timestamp(datetime, timezone) format - // In PostgreSQL, use: datetime AT TIME ZONE timezone + // For most dialects: datetime AT TIME ZONE timezone err := con.visit(args[0]) if err != nil { return err