diff --git a/db_changes/db/dialect_postgres.go b/db_changes/db/dialect_postgres.go index 2143745..84867e9 100644 --- a/db_changes/db/dialect_postgres.go +++ b/db_changes/db/dialect_postgres.go @@ -307,11 +307,11 @@ func (d PostgresDialect) historyTable(schema string) string { return fmt.Sprintf("%s.%s", EscapeIdentifier(schema), EscapeIdentifier(d.historyTableName)) } -func (d PostgresDialect) saveInsert(schema string, table string, primaryKey map[string]string, blockNum uint64) string { +func (d PostgresDialect) saveInsert(schema string, table *TableInfo, primaryKey map[string]string, blockNum uint64) string { return fmt.Sprintf(`INSERT INTO %s (op,table_name,pk,block_num) values (%s,%s,%s,%d);`, d.historyTable(schema), escapeStringValue("I"), - escapeStringValue(table), + escapeStringValue(table.identifier), escapeStringValue(primaryKeyToJSON(primaryKey)), blockNum, ) @@ -321,8 +321,12 @@ func (d PostgresDialect) saveInsert(schema string, table string, primaryKey map[ with t as (select 'default' id) select CASE WHEN block_meta.id is null THEN 'I' ELSE 'U' END AS op, '"public"."block_meta"', 'allo', row_to_json(block_meta),10 from t left join block_meta on block_meta.id='default'; */ -func (d PostgresDialect) saveUpsert(schema string, escapedTableName string, primaryKey map[string]string, blockNum uint64) string { - schemaAndTable := fmt.Sprintf("%s.%s", EscapeIdentifier(schema), escapedTableName) +func (d PostgresDialect) saveUpsert(schema string, table *TableInfo, primaryKey map[string]string, blockNum uint64) string { + schemaAndTable := table.schemaEscaped + "." + table.nameEscaped + onClause, err := d.getPrimaryKeyWhereClauseTyped(table, primaryKey, table.nameEscaped) + if err != nil { + onClause = getPrimaryKeyWhereClause(primaryKey, table.nameEscaped) + } return fmt.Sprintf(` WITH t as (select %s) @@ -332,30 +336,34 @@ func (d PostgresDialect) saveUpsert(schema string, escapedTableName string, prim getPrimaryKeyFakeEmptyValues(primaryKey), d.historyTable(schema), - getPrimaryKeyFakeEmptyValuesAssertion(primaryKey, escapedTableName), + getPrimaryKeyFakeEmptyValuesAssertion(primaryKey, table.nameEscaped), - escapeStringValue(schemaAndTable), escapeStringValue(primaryKeyToJSON(primaryKey)), escapedTableName, blockNum, - EscapeIdentifier(schema), escapedTableName, - getPrimaryKeyWhereClause(primaryKey, escapedTableName), + escapeStringValue(schemaAndTable), escapeStringValue(primaryKeyToJSON(primaryKey)), table.nameEscaped, blockNum, + table.schemaEscaped, table.nameEscaped, + onClause, ) } -func (d PostgresDialect) saveUpdate(schema string, escapedTableName string, primaryKey map[string]string, blockNum uint64) string { - return d.saveRow("U", schema, escapedTableName, primaryKey, blockNum) +func (d PostgresDialect) saveUpdate(schema string, table *TableInfo, primaryKey map[string]string, blockNum uint64) string { + return d.saveRow("U", schema, table, primaryKey, blockNum) } -func (d PostgresDialect) saveDelete(schema string, escapedTableName string, primaryKey map[string]string, blockNum uint64) string { - return d.saveRow("D", schema, escapedTableName, primaryKey, blockNum) +func (d PostgresDialect) saveDelete(schema string, table *TableInfo, primaryKey map[string]string, blockNum uint64) string { + return d.saveRow("D", schema, table, primaryKey, blockNum) } -func (d PostgresDialect) saveRow(op, schema, escapedTableName string, primaryKey map[string]string, blockNum uint64) string { - schemaAndTable := fmt.Sprintf("%s.%s", EscapeIdentifier(schema), escapedTableName) +func (d PostgresDialect) saveRow(op, schema string, table *TableInfo, primaryKey map[string]string, blockNum uint64) string { + schemaAndTable := table.schemaEscaped + "." + table.nameEscaped + whereClause, err := d.getPrimaryKeyWhereClauseTyped(table, primaryKey, "") + if err != nil { + whereClause = getPrimaryKeyWhereClause(primaryKey, "") + } return fmt.Sprintf(`INSERT INTO %s (op,table_name,pk,prev_value,block_num) SELECT %s,%s,%s,row_to_json(%s),%d FROM %s.%s WHERE %s;`, d.historyTable(schema), - escapeStringValue(op), escapeStringValue(schemaAndTable), escapeStringValue(primaryKeyToJSON(primaryKey)), escapedTableName, blockNum, - EscapeIdentifier(schema), escapedTableName, - getPrimaryKeyWhereClause(primaryKey, ""), + escapeStringValue(op), escapeStringValue(schemaAndTable), escapeStringValue(primaryKeyToJSON(primaryKey)), table.nameEscaped, blockNum, + table.schemaEscaped, table.nameEscaped, + whereClause, ) } @@ -386,7 +394,7 @@ func (d *PostgresDialect) prepareStatement(schema string, o *Operation) (string, ) if o.reversibleBlockNum != nil { - return d.saveInsert(schema, o.table.identifier, o.primaryKey, *o.reversibleBlockNum) + insertQuery, nil + return d.saveInsert(schema, o.table, o.primaryKey, *o.reversibleBlockNum) + insertQuery, nil } return insertQuery, nil @@ -405,7 +413,7 @@ func (d *PostgresDialect) prepareStatement(schema string, o *Operation) (string, ) if o.reversibleBlockNum != nil { - return d.saveUpsert(schema, o.table.nameEscaped, o.primaryKey, *o.reversibleBlockNum) + insertQuery, nil + return d.saveUpsert(schema, o.table, o.primaryKey, *o.reversibleBlockNum) + insertQuery, nil } return insertQuery, nil @@ -415,7 +423,10 @@ func (d *PostgresDialect) prepareStatement(schema string, o *Operation) (string, updates[i] = fmt.Sprintf("%s=%s", columns[i], values[i]) } - primaryKeySelector := getPrimaryKeyWhereClause(o.primaryKey, "") + primaryKeySelector, err := d.getPrimaryKeyWhereClauseTyped(o.table, o.primaryKey, "") + if err != nil { + return "", err + } updateQuery := fmt.Sprintf("UPDATE %s SET %s WHERE %s", o.table.identifier, @@ -424,18 +435,21 @@ func (d *PostgresDialect) prepareStatement(schema string, o *Operation) (string, ) if o.reversibleBlockNum != nil { - return d.saveUpdate(schema, o.table.nameEscaped, o.primaryKey, *o.reversibleBlockNum) + updateQuery, nil + return d.saveUpdate(schema, o.table, o.primaryKey, *o.reversibleBlockNum) + updateQuery, nil } return updateQuery, nil case OperationTypeDelete: - primaryKeyWhereClause := getPrimaryKeyWhereClause(o.primaryKey, "") + primaryKeyWhereClause, err := d.getPrimaryKeyWhereClauseTyped(o.table, o.primaryKey, "") + if err != nil { + return "", err + } deleteQuery := fmt.Sprintf("DELETE FROM %s WHERE %s", o.table.identifier, primaryKeyWhereClause, ) if o.reversibleBlockNum != nil { - return d.saveDelete(schema, o.table.nameEscaped, o.primaryKey, *o.reversibleBlockNum) + deleteQuery, nil + return d.saveDelete(schema, o.table, o.primaryKey, *o.reversibleBlockNum) + deleteQuery, nil } return deleteQuery, nil @@ -466,7 +480,7 @@ func (d *PostgresDialect) prepareColValues(table *TableInfo, colValues map[strin return nil, nil, fmt.Errorf("cannot find column %q for table %q (valid columns are %q)", columnName, table.identifier, strings.Join(maps.Keys(table.columnsByName), ", ")) } - normalizedValue, err := d.normalizeValueType(value, columnInfo.scanType) + normalizedValue, err := d.normalizeLiteral(table, columnName, value, "insert") if err != nil { return nil, nil, fmt.Errorf("getting sql value from table %s for column %q raw value %q: %w", table.identifier, columnName, value, err) } @@ -535,6 +549,107 @@ func getPrimaryKeyWhereClause(primaryKey map[string]string, escapedTableName str return strings.Join(reg[:], " AND ") } +// Build a typed WHERE clause using column types for proper literal normalization +func (d *PostgresDialect) getPrimaryKeyWhereClauseTyped(table *TableInfo, primaryKey map[string]string, escapedTableName string) (string, error) { + // Avoid any allocation if there is a single primary key + if len(primaryKey) == 1 { + for key, value := range primaryKey { + rhs, err := d.normalizeLiteral(table, key, value, "where") + if err != nil { + return "", err + } + if escapedTableName == "" { + return EscapeIdentifier(key) + " = " + rhs, nil + } + return escapedTableName + "." + EscapeIdentifier(key) + " = " + rhs, nil + } + } + + reg := make([]string, 0, len(primaryKey)) + for key, value := range primaryKey { + rhs, err := d.normalizeLiteral(table, key, value, "where") + if err != nil { + return "", err + } + if escapedTableName == "" { + reg = append(reg, EscapeIdentifier(key)+" = "+rhs) + } else { + reg = append(reg, escapedTableName+"."+EscapeIdentifier(key)+" = "+rhs) + } + } + sort.Strings(reg) + return strings.Join(reg, " AND "), nil +} + +// normalizeLiteral centralizes literal formatting based on column database type name and context (insert/where) +// Returns a SQL RHS expression ready to be embedded in a statement +func (d *PostgresDialect) normalizeLiteral(table *TableInfo, columnName string, rawValue string, context string) (string, error) { + colInfo, found := table.columnsByName[columnName] + if !found { + return "", fmt.Errorf("cannot find column %q for table %q (valid columns are %q)", columnName, table.identifier, strings.Join(maps.Keys(table.columnsByName), ", ")) + } + + dt := strings.ToLower(colInfo.databaseTypeName) + trimmed := strings.TrimSpace(rawValue) + + // Array handling using databaseTypeName + if strings.HasSuffix(dt, "[]") { + baseType := strings.TrimSuffix(dt, "[]") + return canonicalizeArrayLiteral(baseType, trimmed), nil + } + + // Scalar handling as today using scanType + return d.normalizeValueType(trimmed, colInfo.scanType) +} + +// canonicalizeArrayLiteral emits a curly literal with explicit cast to the column's base type array +// Examples: +// - empty -> '{ }'::base[] +// - 'ARRAY[1,2]' -> '{1,2}'::base[] +// - '{1,2}' or "'{1,2}'::text[]" -> '{1,2}'::base[] +func canonicalizeArrayLiteral(baseType string, raw string) string { + // Empty array detection (accept {}, { }, [], ARRAY[]) + switch raw { + case "{}", "{ }", "[]", "ARRAY[]", "array[]", "ARRAY []", "array []", "": + return "'{ }'::" + baseType + "[]" + } + + upper := strings.ToUpper(raw) + if strings.HasPrefix(upper, "ARRAY[") { + // Extract elements inside ARRAY[...] + end := strings.LastIndex(raw, "]") + inner := "" + if end > 6 { // len("ARRAY[") == 6 + inner = strings.TrimSpace(raw[6:end]) + } + return "'{" + inner + "}'::" + baseType + "[]" + } + + // Strip surrounding single quotes if present + if strings.HasPrefix(raw, "'") && strings.HasSuffix(raw, "'") && len(raw) >= 2 { + raw = strings.TrimSuffix(strings.TrimPrefix(raw, "'"), "'") + } + // Remove any existing cast suffix like }::type[] + if idx := strings.LastIndex(raw, "}::"); idx != -1 { + raw = raw[:idx+1] + } + + // If contains curly braces, extract content and rebuild + if l := strings.Index(raw, "{"); l != -1 { + if r := strings.LastIndex(raw, "}"); r != -1 && r > l { + inner := strings.TrimSpace(raw[l+1 : r]) + return "'{" + inner + "}'::" + baseType + "[]" + } + } + + // Fallback: treat raw as a comma-separated list of elements + inner := strings.TrimSpace(raw) + if strings.HasPrefix(inner, "{") && strings.HasSuffix(inner, "}") && len(inner) >= 2 { + inner = strings.TrimSuffix(strings.TrimPrefix(inner, "{"), "}") + } + return "'{" + inner + "}'::" + baseType + "[]" +} + // Format based on type, value returned unescaped func (d *PostgresDialect) normalizeValueType(value string, valueType reflect.Type) (string, error) { switch valueType.Kind() {