From 3db704fa5babd4b141cd85759574d046b58beb73 Mon Sep 17 00:00:00 2001 From: hamdan Date: Tue, 26 May 2026 10:36:15 +0800 Subject: [PATCH 1/6] feat: add PostgreSQL dialect support - add SqlDialect abstraction and database-specific implementations for MySQL, MariaDB, SQLite, and PostgreSQL - add PostgreSQL connection/configuration support and JDBC dependency wiring - route SQL rendering through dialects for identifiers, DDL, joins, where clauses, upserts, and migration metadata checks - fix PostgreSQL-specific regressions (where-clause quoting, batch request quoting, case-preserving migration column detection) - add foreign key/order-by value objects and dialect-aware schema rendering - add PostgreSQL-focused regression and behavior tests across dialect, connection, conditions, create, upsert, migration, and request SQL generation - document PostgreSQL usage in README and bump library version to 1.24 --- README.md | 11 + build.gradle.kts | 6 +- .../sarah/DatabaseConfiguration.java | 12 + .../maxlego08/sarah/DatabaseConnection.java | 7 +- .../sarah/HikariDatabaseConnection.java | 14 +- .../fr/maxlego08/sarah/MariaDbConnection.java | 4 +- .../fr/maxlego08/sarah/MigrationManager.java | 41 +-- .../fr/maxlego08/sarah/MySqlConnection.java | 4 +- .../maxlego08/sarah/PostgreSqlConnection.java | 24 ++ .../fr/maxlego08/sarah/SchemaBuilder.java | 86 ++++-- .../sarah/conditions/ColumnDefinition.java | 64 ++--- .../conditions/ForeignKeyDefinition.java | 40 +++ .../sarah/conditions/JoinCondition.java | 19 ++ .../sarah/conditions/OrderByCondition.java | 43 +++ .../sarah/conditions/SelectCondition.java | 16 +- .../sarah/conditions/WhereCondition.java | 38 ++- .../sarah/database/DatabaseType.java | 1 + .../fr/maxlego08/sarah/database/Schema.java | 22 +- .../sarah/dialect/AbstractSqlDialect.java | 272 ++++++++++++++++++ .../sarah/dialect/MariaDbDialect.java | 16 ++ .../maxlego08/sarah/dialect/MySqlDialect.java | 70 +++++ .../sarah/dialect/PostgreSqlDialect.java | 72 +++++ .../maxlego08/sarah/dialect/SqlDialect.java | 38 +++ .../maxlego08/sarah/dialect/SqlDialects.java | 33 +++ .../sarah/dialect/SqliteDialect.java | 68 +++++ .../sarah/requests/AlterRequest.java | 33 ++- .../sarah/requests/CreateIndexRequest.java | 9 +- .../sarah/requests/CreateRequest.java | 34 ++- .../sarah/requests/DeleteRequest.java | 7 +- .../sarah/requests/DropTableRequest.java | 5 +- .../sarah/requests/InsertAllRequest.java | 29 +- .../sarah/requests/InsertBatchRequest.java | 7 +- .../sarah/requests/InsertRequest.java | 7 +- .../sarah/requests/RenameExecutor.java | 7 +- .../sarah/requests/UpdateBatchRequest.java | 11 +- .../sarah/requests/UpdateRequest.java | 11 +- .../sarah/requests/UpsertBatchRequest.java | 41 +-- .../sarah/requests/UpsertRequest.java | 73 +---- .../java/fr/maxlego08/sarah/DialectTest.java | 92 ++++++ .../PostgreSqlConditionRenderingTest.java | 63 ++++ .../sarah/PostgreSqlConnectionTest.java | 39 +++ .../sarah/PostgreSqlCreateRequestTest.java | 137 +++++++++ ...stgreSqlMigrationCasePreservationTest.java | 54 ++++ .../sarah/PostgreSqlMigrationDialectTest.java | 20 ++ ...ostgreSqlRequestDialectRegressionTest.java | 157 ++++++++++ .../sarah/PostgreSqlUpsertRequestTest.java | 48 ++++ 46 files changed, 1635 insertions(+), 270 deletions(-) create mode 100644 src/main/java/fr/maxlego08/sarah/PostgreSqlConnection.java create mode 100644 src/main/java/fr/maxlego08/sarah/conditions/ForeignKeyDefinition.java create mode 100644 src/main/java/fr/maxlego08/sarah/conditions/OrderByCondition.java create mode 100644 src/main/java/fr/maxlego08/sarah/dialect/AbstractSqlDialect.java create mode 100644 src/main/java/fr/maxlego08/sarah/dialect/MariaDbDialect.java create mode 100644 src/main/java/fr/maxlego08/sarah/dialect/MySqlDialect.java create mode 100644 src/main/java/fr/maxlego08/sarah/dialect/PostgreSqlDialect.java create mode 100644 src/main/java/fr/maxlego08/sarah/dialect/SqlDialect.java create mode 100644 src/main/java/fr/maxlego08/sarah/dialect/SqlDialects.java create mode 100644 src/main/java/fr/maxlego08/sarah/dialect/SqliteDialect.java create mode 100644 src/test/java/fr/maxlego08/sarah/DialectTest.java create mode 100644 src/test/java/fr/maxlego08/sarah/PostgreSqlConditionRenderingTest.java create mode 100644 src/test/java/fr/maxlego08/sarah/PostgreSqlConnectionTest.java create mode 100644 src/test/java/fr/maxlego08/sarah/PostgreSqlCreateRequestTest.java create mode 100644 src/test/java/fr/maxlego08/sarah/PostgreSqlMigrationCasePreservationTest.java create mode 100644 src/test/java/fr/maxlego08/sarah/PostgreSqlMigrationDialectTest.java create mode 100644 src/test/java/fr/maxlego08/sarah/PostgreSqlRequestDialectRegressionTest.java create mode 100644 src/test/java/fr/maxlego08/sarah/PostgreSqlUpsertRequestTest.java diff --git a/README.md b/README.md index 318c94c..8657ce4 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,17 @@ public void connect() { } ```` +### With PostgreSQL + +````java +public void connect() { + DatabaseConfiguration configuration=DatabaseConfiguration.createPostgreSql(,,,,); + DatabaseConnection connection=new PostgreSqlConnection(configuration,); +} +```` + +PostgreSQL uses the `org.postgresql:postgresql` JDBC driver. Consumers must provide this driver at runtime. + ### With SQLITE ````java diff --git a/build.gradle.kts b/build.gradle.kts index 1237c54..6680d1b 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -12,7 +12,7 @@ extra.set("classifier", System.getProperty("archive.classifier")) extra.set("sha", System.getProperty("github.sha")) group = "fr.maxlego08.sarah" -version = "1.23" +version = "1.24" rootProject.extra.properties["sha"]?.let { sha -> version = sha @@ -37,6 +37,7 @@ dependencies { compileOnly("org.xerial:sqlite-jdbc:3.42.0.0") compileOnly("org.mariadb.jdbc:mariadb-java-client:3.1.4") compileOnly("com.mysql:mysql-connector-j:8.2.0") + compileOnly("org.postgresql:postgresql:42.7.3") // Test dependencies testImplementation("org.junit.jupiter:junit-jupiter-api:5.9.3") @@ -47,6 +48,7 @@ dependencies { testImplementation("org.xerial:sqlite-jdbc:3.42.0.0") testImplementation("org.mariadb.jdbc:mariadb-java-client:3.1.4") testImplementation("com.mysql:mysql-connector-j:8.2.0") + testImplementation("org.postgresql:postgresql:42.7.3") } tasks.withType { @@ -78,4 +80,4 @@ tasks.test { publishConfig { githubOwner = "GroupeZ-dev" useRootProjectName = true -} \ No newline at end of file +} diff --git a/src/main/java/fr/maxlego08/sarah/DatabaseConfiguration.java b/src/main/java/fr/maxlego08/sarah/DatabaseConfiguration.java index c71b5e8..19eb876 100644 --- a/src/main/java/fr/maxlego08/sarah/DatabaseConfiguration.java +++ b/src/main/java/fr/maxlego08/sarah/DatabaseConfiguration.java @@ -61,6 +61,14 @@ public static DatabaseConfiguration createMariaDb(String user, String password, return new DatabaseConfiguration("", user, password, port, host, database, debug, DatabaseType.MARIADB); } + public static DatabaseConfiguration createPostgreSql(String user, String password, int port, String host, String database) { + return new DatabaseConfiguration("", user, password, port, host, database, false, DatabaseType.POSTGRESQL); + } + + public static DatabaseConfiguration createPostgreSql(String user, String password, int port, String host, String database, boolean debug) { + return new DatabaseConfiguration("", user, password, port, host, database, debug, DatabaseType.POSTGRESQL); + } + public static DatabaseConfiguration create(String user, String password, String host, String database, DatabaseType databaseType) { return new DatabaseConfiguration("", user, password, 3306, host, database, false, databaseType); } @@ -73,6 +81,10 @@ public static DatabaseConfiguration createMariaDb(String user, String password, return new DatabaseConfiguration("", user, password, 3306, host, database, false, DatabaseType.MARIADB); } + public static DatabaseConfiguration createPostgreSql(String user, String password, String host, String database) { + return new DatabaseConfiguration("", user, password, 5432, host, database, false, DatabaseType.POSTGRESQL); + } + public static DatabaseConfiguration sqlite(boolean debug) { return new DatabaseConfiguration("", null, null, 0, null, null, debug, DatabaseType.SQLITE); } diff --git a/src/main/java/fr/maxlego08/sarah/DatabaseConnection.java b/src/main/java/fr/maxlego08/sarah/DatabaseConnection.java index af94a7e..b51f9e1 100644 --- a/src/main/java/fr/maxlego08/sarah/DatabaseConnection.java +++ b/src/main/java/fr/maxlego08/sarah/DatabaseConnection.java @@ -1,6 +1,7 @@ package fr.maxlego08.sarah; import fr.maxlego08.sarah.database.DatabaseType; +import fr.maxlego08.sarah.dialect.SqlDialects; import fr.maxlego08.sarah.exceptions.DatabaseException; import fr.maxlego08.sarah.logger.Logger; import fr.maxlego08.sarah.transaction.Transaction; @@ -42,11 +43,7 @@ public boolean isValid() { DatabaseType databaseType = this.databaseConfiguration.getDatabaseType(); try { - if (databaseType == DatabaseType.MARIADB) { - Class.forName("org.mariadb.jdbc.Driver"); - } else { - Class.forName("com.mysql.cj.jdbc.Driver"); - } + Class.forName(SqlDialects.from(databaseType).driverClassName()); } catch (Exception ignored) { } diff --git a/src/main/java/fr/maxlego08/sarah/HikariDatabaseConnection.java b/src/main/java/fr/maxlego08/sarah/HikariDatabaseConnection.java index 34d6ed4..12c5b3d 100644 --- a/src/main/java/fr/maxlego08/sarah/HikariDatabaseConnection.java +++ b/src/main/java/fr/maxlego08/sarah/HikariDatabaseConnection.java @@ -3,6 +3,8 @@ import com.zaxxer.hikari.HikariConfig; import com.zaxxer.hikari.HikariDataSource; import fr.maxlego08.sarah.database.DatabaseType; +import fr.maxlego08.sarah.dialect.SqlDialect; +import fr.maxlego08.sarah.dialect.SqlDialects; import fr.maxlego08.sarah.exceptions.DatabaseException; import fr.maxlego08.sarah.logger.Logger; @@ -38,17 +40,11 @@ private void initializeDataSource() { config.setPoolName("sarah-" + POOL_COUNTER.getAndIncrement()); DatabaseType databaseType = databaseConfiguration.getDatabaseType(); + SqlDialect dialect = SqlDialects.from(databaseType); // URL + Driver - final String jdbcUrl; - if (databaseType == DatabaseType.MARIADB) { - jdbcUrl = "jdbc:mariadb://" + databaseConfiguration.getHost() + ":" + databaseConfiguration.getPort() + "/" + databaseConfiguration.getDatabase() + "?allowMultiQueries=true"; - config.setDriverClassName("org.mariadb.jdbc.Driver"); - } else { - jdbcUrl = "jdbc:mysql://" + databaseConfiguration.getHost() + ":" + databaseConfiguration.getPort() + "/" + databaseConfiguration.getDatabase() + "?allowMultiQueries=true"; - config.setDriverClassName("com.mysql.cj.jdbc.Driver"); - } - config.setJdbcUrl(jdbcUrl); + config.setJdbcUrl(dialect.jdbcUrl(databaseConfiguration)); + config.setDriverClassName(dialect.driverClassName()); // Auth config.setUsername(databaseConfiguration.getUser()); diff --git a/src/main/java/fr/maxlego08/sarah/MariaDbConnection.java b/src/main/java/fr/maxlego08/sarah/MariaDbConnection.java index 0205b30..1afcffe 100644 --- a/src/main/java/fr/maxlego08/sarah/MariaDbConnection.java +++ b/src/main/java/fr/maxlego08/sarah/MariaDbConnection.java @@ -1,5 +1,6 @@ package fr.maxlego08.sarah; +import fr.maxlego08.sarah.dialect.SqlDialects; import fr.maxlego08.sarah.logger.Logger; import java.sql.Connection; @@ -21,6 +22,7 @@ public Connection connectToDatabase() throws Exception { properties.setProperty("useSSL", "false"); properties.setProperty("user", databaseConfiguration.getUser()); properties.setProperty("password", databaseConfiguration.getPassword()); - return DriverManager.getConnection("jdbc:mariadb://" + databaseConfiguration.getHost() + ":" + databaseConfiguration.getPort() + "/" + databaseConfiguration.getDatabase() + "?allowMultiQueries=true", properties); + String url = SqlDialects.from(databaseConfiguration.getDatabaseType()).jdbcUrl(databaseConfiguration); + return DriverManager.getConnection(url, properties); } } diff --git a/src/main/java/fr/maxlego08/sarah/MigrationManager.java b/src/main/java/fr/maxlego08/sarah/MigrationManager.java index 16f4746..78fc57f 100644 --- a/src/main/java/fr/maxlego08/sarah/MigrationManager.java +++ b/src/main/java/fr/maxlego08/sarah/MigrationManager.java @@ -4,12 +4,11 @@ import fr.maxlego08.sarah.database.DatabaseType; import fr.maxlego08.sarah.database.Migration; import fr.maxlego08.sarah.database.Schema; +import fr.maxlego08.sarah.dialect.SqlDialect; +import fr.maxlego08.sarah.dialect.SqlDialects; import fr.maxlego08.sarah.exceptions.DatabaseException; import fr.maxlego08.sarah.logger.Logger; -import java.sql.Connection; -import java.sql.PreparedStatement; -import java.sql.ResultSet; import java.sql.SQLException; import java.util.ArrayList; import java.util.List; @@ -108,40 +107,8 @@ public static void execute(DatabaseConnection databaseConnection, Logger logger) String tableName = schema.getTableName(); tableName = tableName.replace("%prefix%", databaseConnection.getDatabaseConfiguration().getTablePrefix()); - if (databaseConnection.getDatabaseConfiguration().getDatabaseType() == DatabaseType.SQLITE) { - try (Connection connection = databaseConnection.getConnection(); - PreparedStatement preparedStatement = connection.prepareStatement(String.format("PRAGMA table_info(%s)", tableName))) { - List columnDefinitions = schema.getColumns(); - logger.info("Executing SQL: " + String.format("PRAGMA table_info(%s)", tableName)); - try (ResultSet resultSet = preparedStatement.executeQuery()) { - while (resultSet.next()) { - String columnName = resultSet.getString("name"); - columnDefinitions.removeIf(column -> column.getName().equals(columnName)); - } - } - mustBeAdd.addAll(columnDefinitions); - } catch (SQLException exception) { - logger.info("Failed to get table info for migration: " + exception.getMessage()); - throw new DatabaseException("migration-table-info", tableName, exception); - } - } else { - for (ColumnDefinition column : schema.getColumns()) { - Schema columnExistQuery; - long result; - columnExistQuery = SchemaBuilder.selectCount("information_schema.COLUMNS") - .where("TABLE_NAME", tableName) - .where("TABLE_SCHEMA", databaseConnection.getDatabaseConfiguration().getDatabase()) - .where("COLUMN_NAME", column.getName()); - try { - result = columnExistQuery.executeSelectCount(databaseConnection, logger); - } catch (SQLException e) { - throw new RuntimeException(e); - } - if (result == 0) { - mustBeAdd.add(column); - } - } - } + SqlDialect dialect = SqlDialects.from(databaseConnection.getDatabaseConfiguration().getDatabaseType()); + mustBeAdd.addAll(dialect.missingColumns(databaseConnection, logger, tableName, schema.getColumns())); if (mustBeAdd.isEmpty()) { return; diff --git a/src/main/java/fr/maxlego08/sarah/MySqlConnection.java b/src/main/java/fr/maxlego08/sarah/MySqlConnection.java index 049b2b7..2f6f63f 100644 --- a/src/main/java/fr/maxlego08/sarah/MySqlConnection.java +++ b/src/main/java/fr/maxlego08/sarah/MySqlConnection.java @@ -1,5 +1,6 @@ package fr.maxlego08.sarah; +import fr.maxlego08.sarah.dialect.SqlDialects; import fr.maxlego08.sarah.logger.Logger; import java.sql.Connection; @@ -18,6 +19,7 @@ public Connection connectToDatabase() throws Exception { properties.setProperty("useSSL", "false"); properties.setProperty("user", databaseConfiguration.getUser()); properties.setProperty("password", databaseConfiguration.getPassword()); - return DriverManager.getConnection("jdbc:mysql://" + databaseConfiguration.getHost() + ":" + databaseConfiguration.getPort() + "/" + databaseConfiguration.getDatabase() + "?allowMultiQueries=true", properties); + String url = SqlDialects.from(databaseConfiguration.getDatabaseType()).jdbcUrl(databaseConfiguration); + return DriverManager.getConnection(url, properties); } } diff --git a/src/main/java/fr/maxlego08/sarah/PostgreSqlConnection.java b/src/main/java/fr/maxlego08/sarah/PostgreSqlConnection.java new file mode 100644 index 0000000..ae01342 --- /dev/null +++ b/src/main/java/fr/maxlego08/sarah/PostgreSqlConnection.java @@ -0,0 +1,24 @@ +package fr.maxlego08.sarah; + +import fr.maxlego08.sarah.dialect.SqlDialects; +import fr.maxlego08.sarah.logger.Logger; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.util.Properties; + +public class PostgreSqlConnection extends DatabaseConnection { + + public PostgreSqlConnection(DatabaseConfiguration databaseConfiguration, Logger logger) { + super(databaseConfiguration, logger); + } + + @Override + public Connection connectToDatabase() throws Exception { + Properties properties = new Properties(); + properties.setProperty("user", databaseConfiguration.getUser()); + properties.setProperty("password", databaseConfiguration.getPassword()); + String url = SqlDialects.from(databaseConfiguration.getDatabaseType()).jdbcUrl(databaseConfiguration); + return DriverManager.getConnection(url, properties); + } +} diff --git a/src/main/java/fr/maxlego08/sarah/SchemaBuilder.java b/src/main/java/fr/maxlego08/sarah/SchemaBuilder.java index 2932a44..37d654d 100644 --- a/src/main/java/fr/maxlego08/sarah/SchemaBuilder.java +++ b/src/main/java/fr/maxlego08/sarah/SchemaBuilder.java @@ -1,7 +1,9 @@ package fr.maxlego08.sarah; import fr.maxlego08.sarah.conditions.ColumnDefinition; +import fr.maxlego08.sarah.conditions.ForeignKeyDefinition; import fr.maxlego08.sarah.conditions.JoinCondition; +import fr.maxlego08.sarah.conditions.OrderByCondition; import fr.maxlego08.sarah.conditions.SelectCondition; import fr.maxlego08.sarah.conditions.WhereCondition; import fr.maxlego08.sarah.database.DatabaseType; @@ -9,6 +11,8 @@ import fr.maxlego08.sarah.database.Migration; import fr.maxlego08.sarah.database.Schema; import fr.maxlego08.sarah.database.SchemaType; +import fr.maxlego08.sarah.dialect.SqlDialect; +import fr.maxlego08.sarah.dialect.SqlDialects; import fr.maxlego08.sarah.exceptions.SarahException; import fr.maxlego08.sarah.logger.Logger; import fr.maxlego08.sarah.requests.AlterRequest; @@ -55,12 +59,12 @@ public class SchemaBuilder implements Schema { private final SchemaType schemaType; private final List columns = new ArrayList<>(); private final List primaryKeys = new ArrayList<>(); - private final List foreignKeys = new ArrayList<>(); + private final List foreignKeys = new ArrayList<>(); private final List whereConditions = new ArrayList<>(); private final List joinConditions = new ArrayList<>(); private final List selectColumns = new ArrayList<>(); private String newTableName; - private String orderBy; + private OrderByCondition orderByCondition; private Migration migration; private boolean isDistinct; @@ -78,7 +82,7 @@ public static Schema copy(String tableName, SchemaType newSchemaType, Schema old schema.whereConditions.addAll(oldSchema.getWhereConditions()); schema.joinConditions.addAll(oldSchema.getJoinConditions()); schema.selectColumns.addAll(oldSchema.getSelectColumns()); - schema.orderBy = oldSchema.getOrderBy(); + schema.orderByCondition = oldSchema.getOrderByCondition(); schema.migration = oldSchema.getMigration(); schema.isDistinct = oldSchema.isDistinct(); schema.newTableName = oldSchema.getNewTableName(); @@ -372,8 +376,7 @@ public Schema foreignKey(String referenceTable) { if (this.columns.isEmpty()) throw new IllegalStateException("No column defined to apply foreign key."); ColumnDefinition lastColumn = this.columns.get(this.columns.size() - 1); - String fkDefinition = String.format("FOREIGN KEY (%s) REFERENCES %s(%s) ON DELETE CASCADE", lastColumn.getSafeName(), safeTable(referenceTable), lastColumn.getSafeName()); - this.foreignKeys.add(fkDefinition); + this.foreignKeys.add(new ForeignKeyDefinition(lastColumn.getName(), referenceTable, lastColumn.getName(), true)); return this; } @@ -382,19 +385,10 @@ public Schema foreignKey(String referenceTable, String columnName, boolean onCas if (this.columns.isEmpty()) throw new IllegalStateException("No column defined to apply foreign key."); ColumnDefinition lastColumn = this.columns.get(this.columns.size() - 1); - String fkDefinition = String.format("FOREIGN KEY (%s) REFERENCES %s(`%s`)%s", lastColumn.getSafeName(), safeTable(referenceTable), columnName, onCascade ? " ON DELETE CASCADE" : ""); - this.foreignKeys.add(fkDefinition); + this.foreignKeys.add(new ForeignKeyDefinition(lastColumn.getName(), referenceTable, columnName, onCascade)); return this; } - /** - * Wraps a table name in backticks for safe SQL identifier quoting. - * Works with %prefix% placeholders since they are replaced after SQL generation. - */ - private String safeTable(String tableName) { - return "`" + tableName + "`"; - } - @Override public Schema createdAt() { ColumnDefinition column = new ColumnDefinition("created_at", "TIMESTAMP"); @@ -423,10 +417,11 @@ public Schema updatedAt() { ColumnDefinition column = new ColumnDefinition("updated_at", "TIMESTAMP"); DatabaseConfiguration configuration = MigrationManager.getDatabaseConfiguration(); - if (configuration.getDatabaseType() == DatabaseType.SQLITE) { + if (configuration == null) { column.setDefaultValue("CURRENT_TIMESTAMP"); } else { - column.setDefaultValue("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"); + SqlDialect dialect = SqlDialects.from(configuration.getDatabaseType()); + column.setDefaultValue(dialect.updatedAtDefaultValue()); } this.columns.add(column); return this; @@ -464,7 +459,7 @@ public Schema defaultCurrentTimestamp() { public Schema primary() { ColumnDefinition lastColumn = getLastColumn(); lastColumn.setPrimaryKey(true); - primaryKeys.add(lastColumn.getSafeName()); + primaryKeys.add(lastColumn.getName()); return this; } @@ -492,10 +487,15 @@ public String getTableName() { @Override public void whereConditions(StringBuilder sql) { + whereConditions(sql, SqlDialects.from(DatabaseType.MYSQL)); + } + + @Override + public void whereConditions(StringBuilder sql, SqlDialect dialect) { if (!this.whereConditions.isEmpty()) { List conditions = new ArrayList<>(); for (WhereCondition condition : this.whereConditions) { - conditions.add(condition.getCondition()); + conditions.add(condition.getCondition(dialect)); } sql.append(" WHERE ").append(String.join(" AND ", conditions)); } @@ -503,8 +503,9 @@ public void whereConditions(StringBuilder sql) { @Override public long executeSelectCount(DatabaseConnection databaseConnection, Logger logger) throws SQLException { + SqlDialect dialect = SqlDialects.from(databaseConnection.getDatabaseConfiguration().getDatabaseType()); StringBuilder selectQuery = new StringBuilder("SELECT COUNT(*) FROM " + tableName); - this.whereConditions(selectQuery); + this.whereConditions(selectQuery, dialect); String finalQuery = databaseConnection.getDatabaseConfiguration().replacePrefix(selectQuery.toString()); if (databaseConnection.getDatabaseConfiguration().isDebug()) { @@ -533,9 +534,10 @@ public List> executeSelect(DatabaseConnection databaseConnec List> results = new ArrayList<>(); String selectedValues = "*"; + SqlDialect dialect = SqlDialects.from(databaseConnection.getDatabaseConfiguration().getDatabaseType()); if (!this.selectColumns.isEmpty()) { selectedValues = this.selectColumns.stream() - .map(SelectCondition::getSelectColumn) + .map(select -> select.getSelectColumn(dialect)) .collect(Collectors.joining(",")); } @@ -548,14 +550,14 @@ public List> executeSelect(DatabaseConnection databaseConnec if (!this.joinConditions.isEmpty()) { for (JoinCondition join : this.joinConditions) { - selectQuery.append(" ").append(join.getJoinClause()); + selectQuery.append(" ").append(join.getJoinClause(dialect)); } } - this.whereConditions(selectQuery); + this.whereConditions(selectQuery, dialect); - if (this.orderBy != null) { - selectQuery.append(" ").append(this.orderBy); + if (this.orderByCondition != null) { + selectQuery.append(" ").append(this.orderByCondition.getOrderByClause(dialect)); } DatabaseConfiguration databaseConfiguration = databaseConnection.getDatabaseConfiguration(); @@ -764,7 +766,7 @@ public List getPrimaryKeys() { } @Override - public List getForeignKeys() { + public List getForeignKeys() { return foreignKeys; } @@ -775,17 +777,24 @@ public List getJoinConditions() { @Override public void orderBy(String columnName) { - this.orderBy = String.format("ORDER BY %s", columnName); + String[] parts = splitQualifiedName(columnName); + this.orderByCondition = new OrderByCondition(parts[0], parts[1], false); } @Override public void orderByDesc(String columnName) { - this.orderBy = String.format("ORDER BY %s DESC", columnName); + String[] parts = splitQualifiedName(columnName); + this.orderByCondition = new OrderByCondition(parts[0], parts[1], true); } @Override public String getOrderBy() { - return this.orderBy; + return this.orderByCondition == null ? null : this.orderByCondition.getOrderByClause(); + } + + @Override + public OrderByCondition getOrderByCondition() { + return this.orderByCondition; } @Override @@ -854,12 +863,12 @@ public void addSelect(String prefix, String selectedColumn) { @Override public void addSelect(String prefix, String selectedColumn, String aliases) { - this.selectColumns.add(new SelectCondition(null, selectedColumn, aliases, false, null)); + this.selectColumns.add(new SelectCondition(prefix, selectedColumn, aliases, false, null)); } @Override public void addSelect(String prefix, String selectedColumn, String aliases, Object defaultValue) { - this.selectColumns.add(new SelectCondition(null, selectedColumn, aliases, true, defaultValue)); + this.selectColumns.add(new SelectCondition(prefix, selectedColumn, aliases, true, defaultValue)); } @Override @@ -881,4 +890,19 @@ public List getSelectColumns() { public String getNewTableName() { return newTableName; } + + private String[] splitQualifiedName(String columnName) { + if (columnName == null) { + return new String[]{null, null}; + } + + int separatorIndex = columnName.indexOf('.'); + if (separatorIndex < 0) { + return new String[]{null, columnName}; + } + + String prefix = columnName.substring(0, separatorIndex); + String column = columnName.substring(separatorIndex + 1); + return new String[]{prefix, column}; + } } diff --git a/src/main/java/fr/maxlego08/sarah/conditions/ColumnDefinition.java b/src/main/java/fr/maxlego08/sarah/conditions/ColumnDefinition.java index 49e956c..fe99e65 100644 --- a/src/main/java/fr/maxlego08/sarah/conditions/ColumnDefinition.java +++ b/src/main/java/fr/maxlego08/sarah/conditions/ColumnDefinition.java @@ -2,6 +2,8 @@ import fr.maxlego08.sarah.DatabaseConfiguration; import fr.maxlego08.sarah.database.DatabaseType; +import fr.maxlego08.sarah.dialect.SqlDialect; +import fr.maxlego08.sarah.dialect.SqlDialects; import java.util.Arrays; import java.util.List; @@ -38,47 +40,30 @@ public ColumnDefinition(String name) { * @return The SQL string representation of the column */ public String build(DatabaseConfiguration databaseConfiguration) { - // For SQLite autoincrement, use INTEGER instead of BIGINT/INT - String columnType = type; - if (isAutoIncrement && databaseConfiguration.getDatabaseType() == DatabaseType.SQLITE) { - if (type.equalsIgnoreCase("BIGINT") || type.equalsIgnoreCase("INT") || type.equalsIgnoreCase("INTEGER")) { - columnType = "INTEGER"; + return build(databaseConfiguration, SqlDialects.from(databaseConfiguration.getDatabaseType())); + } + + public String build(DatabaseConfiguration databaseConfiguration, SqlDialect dialect) { + if (isAutoIncrement && isPrimaryKey && databaseConfiguration.getDatabaseType() == DatabaseType.SQLITE) { + StringBuilder sqliteAutoIncrement = new StringBuilder(dialect.quoteIdentifier(name)).append(" INTEGER PRIMARY KEY AUTOINCREMENT"); + if (unique) { + sqliteAutoIncrement.append(" UNIQUE"); } + return sqliteAutoIncrement.toString(); } - StringBuilder columnSQL = new StringBuilder("`" + name + "` " + columnType); - - // Handle ENUM type with values - if (enumValues != null && !enumValues.isEmpty()) { - if (databaseConfiguration.getDatabaseType() == DatabaseType.SQLITE) { - // SQLite doesn't support ENUM, use TEXT instead - columnSQL = new StringBuilder("`" + name + "` TEXT"); - } else { - // MySQL/MariaDB ENUM syntax: ENUM('value1', 'value2', ...) - String values = enumValues.stream() - .map(v -> "'" + v.replace("'", "''") + "'") - .collect(Collectors.joining(", ")); - columnSQL = new StringBuilder("`" + name + "` ENUM(" + values + ")"); - } - } else if (length != 0 && decimal != 0) { - columnSQL.append("(").append(length).append(",").append(decimal).append(")"); - } else if (length != 0) { - columnSQL.append("(").append(length).append(")"); + StringBuilder columnSQL = new StringBuilder(dialect.quoteIdentifier(name)).append(" "); + + boolean isEnumColumn = enumValues != null && !enumValues.isEmpty(); + if (isEnumColumn) { + columnSQL.append(dialect.enumColumnType(this)); + } else { + columnSQL.append(dialect.columnType(this)); } - // For autoincrement columns with primary key - if (isAutoIncrement && isPrimaryKey) { - if (databaseConfiguration.getDatabaseType() == DatabaseType.SQLITE) { - // SQLite: INTEGER PRIMARY KEY AUTOINCREMENT (inline, no NOT NULL needed) - columnSQL.append(" PRIMARY KEY AUTOINCREMENT"); - if (unique) { - columnSQL.append(" UNIQUE"); - } - return columnSQL.toString(); - } else { - // MySQL/MariaDB: column will have AUTO_INCREMENT - columnSQL.append(" AUTO_INCREMENT"); - } + String autoIncrementKeyword = dialect.autoIncrementKeyword(this); + if (!autoIncrementKeyword.isEmpty()) { + columnSQL.append(" ").append(autoIncrementKeyword); } if (nullable) { @@ -137,6 +122,10 @@ public ColumnDefinition setDecimal(Integer decimal) { return this; } + public int getDecimal() { + return decimal; + } + public Boolean getNullable() { return nullable; } @@ -153,8 +142,9 @@ public boolean isPrimaryKey() { return isPrimaryKey; } - public void setPrimaryKey(boolean primaryKey) { + public ColumnDefinition setPrimaryKey(boolean primaryKey) { isPrimaryKey = primaryKey; + return this; } public String getReferenceTable() { diff --git a/src/main/java/fr/maxlego08/sarah/conditions/ForeignKeyDefinition.java b/src/main/java/fr/maxlego08/sarah/conditions/ForeignKeyDefinition.java new file mode 100644 index 0000000..5d98cf9 --- /dev/null +++ b/src/main/java/fr/maxlego08/sarah/conditions/ForeignKeyDefinition.java @@ -0,0 +1,40 @@ +package fr.maxlego08.sarah.conditions; + +import fr.maxlego08.sarah.dialect.SqlDialect; + +public class ForeignKeyDefinition { + + private final String sourceColumn; + private final String referenceTable; + private final String referenceColumn; + private final boolean cascade; + + public ForeignKeyDefinition(String sourceColumn, String referenceTable, String referenceColumn, boolean cascade) { + this.sourceColumn = sourceColumn; + this.referenceTable = referenceTable; + this.referenceColumn = referenceColumn; + this.cascade = cascade; + } + + public String getSourceColumn() { + return sourceColumn; + } + + public String getReferenceTable() { + return referenceTable; + } + + public String getReferenceColumn() { + return referenceColumn; + } + + public boolean isCascade() { + return cascade; + } + + public String render(SqlDialect dialect) { + return "FOREIGN KEY (" + dialect.quoteIdentifier(sourceColumn) + ") REFERENCES " + + dialect.quoteIdentifier(referenceTable) + "(" + dialect.quoteIdentifier(referenceColumn) + ")" + + (cascade ? " ON DELETE CASCADE" : ""); + } +} diff --git a/src/main/java/fr/maxlego08/sarah/conditions/JoinCondition.java b/src/main/java/fr/maxlego08/sarah/conditions/JoinCondition.java index 1a7ab06..388ff92 100644 --- a/src/main/java/fr/maxlego08/sarah/conditions/JoinCondition.java +++ b/src/main/java/fr/maxlego08/sarah/conditions/JoinCondition.java @@ -1,5 +1,7 @@ package fr.maxlego08.sarah.conditions; +import fr.maxlego08.sarah.dialect.SqlDialect; + public class JoinCondition { private final String primaryTable; private final String primaryTableAlias; @@ -43,10 +45,27 @@ public String getJoinClause() { return joinClause.toString(); } + public String getJoinClause(SqlDialect dialect) { + StringBuilder joinClause = new StringBuilder(); + joinClause.append(this.joinType.getSql()).append(" ") + .append(this.primaryTable).append(" AS ").append(this.primaryTableAlias) + .append(" ON ").append(dialect.qualifyIdentifier(this.primaryTableAlias, this.primaryColumn)) + .append(" = ").append(dialect.qualifyIdentifier(this.foreignTable, this.foreignColumn)); + + if (this.additionalCondition != null) { + joinClause.append(" AND ").append(this.additionalCondition.getCondition(dialect)); + } + return joinClause.toString(); + } + private String getCondition() { return this.primaryTableAlias + "." + this.primaryColumn + " = '" + this.foreignColumn + "'"; } + private String getCondition(SqlDialect dialect) { + return dialect.qualifyIdentifier(this.primaryTableAlias, this.primaryColumn) + " = '" + this.foreignColumn + "'"; + } + public enum JoinType { INNER("INNER JOIN"), LEFT("LEFT JOIN"), diff --git a/src/main/java/fr/maxlego08/sarah/conditions/OrderByCondition.java b/src/main/java/fr/maxlego08/sarah/conditions/OrderByCondition.java new file mode 100644 index 0000000..437d378 --- /dev/null +++ b/src/main/java/fr/maxlego08/sarah/conditions/OrderByCondition.java @@ -0,0 +1,43 @@ +package fr.maxlego08.sarah.conditions; + +import fr.maxlego08.sarah.database.DatabaseType; +import fr.maxlego08.sarah.dialect.SqlDialect; +import fr.maxlego08.sarah.dialect.SqlDialects; + +public class OrderByCondition { + + private final String tablePrefix; + private final String column; + private final boolean descending; + + public OrderByCondition(String tablePrefix, String column, boolean descending) { + this.tablePrefix = tablePrefix; + this.column = column; + this.descending = descending; + } + + public String getOrderByClause() { + return this.buildClause(SqlDialects.from(DatabaseType.MYSQL)); + } + + public String getOrderByClause(SqlDialect dialect) { + return this.buildClause(dialect); + } + + public String getTablePrefix() { + return tablePrefix; + } + + public String getColumn() { + return column; + } + + public boolean isDescending() { + return descending; + } + + private String buildClause(SqlDialect dialect) { + String qualified = dialect.qualifyIdentifier(this.tablePrefix, this.column); + return this.descending ? "ORDER BY " + qualified + " DESC" : "ORDER BY " + qualified; + } +} diff --git a/src/main/java/fr/maxlego08/sarah/conditions/SelectCondition.java b/src/main/java/fr/maxlego08/sarah/conditions/SelectCondition.java index e5a5ba6..aeb5b7b 100644 --- a/src/main/java/fr/maxlego08/sarah/conditions/SelectCondition.java +++ b/src/main/java/fr/maxlego08/sarah/conditions/SelectCondition.java @@ -1,5 +1,9 @@ package fr.maxlego08.sarah.conditions; +import fr.maxlego08.sarah.database.DatabaseType; +import fr.maxlego08.sarah.dialect.SqlDialect; +import fr.maxlego08.sarah.dialect.SqlDialects; + import java.util.Objects; public class SelectCondition { @@ -51,8 +55,16 @@ public String getSelectColumn() { return result; } + public String getSelectColumn(SqlDialect dialect) { + String quotedColumn = dialect.qualifyIdentifier(this.tablePrefix, this.column); + if (isCoalesce) { + return "COALESCE(" + quotedColumn + ", " + defaultValue + ")" + getAliases(); + } + return quotedColumn + getAliases(); + } + private String getColumnAndAliases() { - return "`" + this.column + "`" + getAliases(); + return SqlDialects.from(DatabaseType.MYSQL).quoteIdentifier(this.column) + getAliases(); } private String getAliases() { @@ -86,4 +98,4 @@ public String toString() { ", defaultValue=" + defaultValue + '}'; } -} \ No newline at end of file +} diff --git a/src/main/java/fr/maxlego08/sarah/conditions/WhereCondition.java b/src/main/java/fr/maxlego08/sarah/conditions/WhereCondition.java index 516b1c7..1d76458 100644 --- a/src/main/java/fr/maxlego08/sarah/conditions/WhereCondition.java +++ b/src/main/java/fr/maxlego08/sarah/conditions/WhereCondition.java @@ -1,11 +1,16 @@ package fr.maxlego08.sarah.conditions; +import fr.maxlego08.sarah.database.DatabaseType; +import fr.maxlego08.sarah.dialect.SqlDialect; +import fr.maxlego08.sarah.dialect.SqlDialects; + import java.util.ArrayList; import java.util.List; import java.util.stream.Collectors; public class WhereCondition { + private final String tablePrefix; private final String column; private final Object value; private final String operator; @@ -14,14 +19,16 @@ public class WhereCondition { private final List values = new ArrayList<>(); public WhereCondition(String prefix, String column, String operator, Object value) { - this.column = (prefix == null ? "" : prefix + ".") + "`" + column + "`"; + this.tablePrefix = prefix; + this.column = column; this.operator = operator; this.value = value; this.whereAction = WhereAction.NORMAL; } public WhereCondition(String prefix, String column, List values) { - this.column = (prefix == null ? "" : prefix + ".") + "`" + column + "`"; + this.tablePrefix = prefix; + this.column = column; this.value = null; this.operator = null; this.values.addAll(values); @@ -29,6 +36,7 @@ public WhereCondition(String prefix, String column, List values) { } public WhereCondition(String column, WhereAction whereAction) { + this.tablePrefix = null; this.column = column; this.value = null; this.operator = null; @@ -39,9 +47,18 @@ public String getCondition() { if (this.whereAction == WhereAction.IS_NOT_NULL) return this.column + " IS NOT NULL"; if (this.whereAction == WhereAction.IS_NULL) return this.column + " IS NULL"; if (this.whereAction == WhereAction.IN) { - return this.column + " IN (" + values.stream().map(id -> "?").collect(Collectors.joining(",")) + ")"; + return this.legacyQualifiedColumn() + " IN (" + values.stream().map(id -> "?").collect(Collectors.joining(",")) + ")"; + } + return this.legacyQualifiedColumn() + " " + this.operator + " ?"; + } + + public String getCondition(SqlDialect dialect) { + if (this.whereAction == WhereAction.IS_NOT_NULL) return this.qualifiedColumn(dialect) + " IS NOT NULL"; + if (this.whereAction == WhereAction.IS_NULL) return this.qualifiedColumn(dialect) + " IS NULL"; + if (this.whereAction == WhereAction.IN) { + return this.qualifiedColumn(dialect) + " IN (" + values.stream().map(id -> "?").collect(Collectors.joining(",")) + ")"; } - return this.column + " " + this.operator + " ?"; + return this.qualifiedColumn(dialect) + " " + this.operator + " ?"; } public String getOperator() { @@ -56,6 +73,10 @@ public String getColumn() { return this.column; } + public String getTablePrefix() { + return tablePrefix; + } + public WhereAction getWhereAction() { return whereAction; } @@ -67,5 +88,14 @@ public List getValues() { public enum WhereAction { IS_NOT_NULL, IS_NULL, NORMAL, IN, } + + private String legacyQualifiedColumn() { + String quote = SqlDialects.from(DatabaseType.MYSQL).quoteIdentifier(this.column); + return this.tablePrefix == null ? quote : this.tablePrefix + "." + quote; + } + + private String qualifiedColumn(SqlDialect dialect) { + return dialect.qualifyIdentifier(this.tablePrefix, this.column); + } } diff --git a/src/main/java/fr/maxlego08/sarah/database/DatabaseType.java b/src/main/java/fr/maxlego08/sarah/database/DatabaseType.java index 8532656..99a1091 100644 --- a/src/main/java/fr/maxlego08/sarah/database/DatabaseType.java +++ b/src/main/java/fr/maxlego08/sarah/database/DatabaseType.java @@ -5,5 +5,6 @@ public enum DatabaseType { MYSQL, MARIADB, SQLITE, + POSTGRESQL, } diff --git a/src/main/java/fr/maxlego08/sarah/database/Schema.java b/src/main/java/fr/maxlego08/sarah/database/Schema.java index 8b03fa5..2de724b 100644 --- a/src/main/java/fr/maxlego08/sarah/database/Schema.java +++ b/src/main/java/fr/maxlego08/sarah/database/Schema.java @@ -2,9 +2,12 @@ import fr.maxlego08.sarah.DatabaseConnection; import fr.maxlego08.sarah.conditions.ColumnDefinition; +import fr.maxlego08.sarah.conditions.ForeignKeyDefinition; import fr.maxlego08.sarah.conditions.JoinCondition; +import fr.maxlego08.sarah.conditions.OrderByCondition; import fr.maxlego08.sarah.conditions.SelectCondition; import fr.maxlego08.sarah.conditions.WhereCondition; +import fr.maxlego08.sarah.dialect.SqlDialect; import fr.maxlego08.sarah.logger.Logger; import java.sql.PreparedStatement; @@ -590,6 +593,14 @@ public interface Schema { */ void whereConditions(StringBuilder stringBuilder); + /** + * Appends WHERE conditions to the provided SQL query using the provided SQL dialect. + * + * @param stringBuilder the StringBuilder to append the WHERE conditions to + * @param dialect the SQL dialect used to render identifiers + */ + void whereConditions(StringBuilder stringBuilder, SqlDialect dialect); + /** * Applies the stored WHERE conditions to the provided PreparedStatement. * This method iterates over all the WHERE conditions configured in the schema, @@ -621,7 +632,7 @@ public interface Schema { * * @return the list of foreign keys defined in this schema */ - List getForeignKeys(); + List getForeignKeys(); /** * Retrieves the list of join conditions configured in this schema. @@ -654,6 +665,13 @@ public interface Schema { */ String getOrderBy(); + /** + * Gets the ORDER BY condition configured for this schema, if any. + * + * @return the order-by condition or null when no order-by is configured + */ + OrderByCondition getOrderByCondition(); + /** * Makes the query results of this schema distinct. * This has the same effect as adding the DISTINCT keyword to the SQL query. @@ -741,4 +759,4 @@ public interface Schema { * @return the new table name, or null if not set */ String getNewTableName(); -} \ No newline at end of file +} diff --git a/src/main/java/fr/maxlego08/sarah/dialect/AbstractSqlDialect.java b/src/main/java/fr/maxlego08/sarah/dialect/AbstractSqlDialect.java new file mode 100644 index 0000000..320dfa1 --- /dev/null +++ b/src/main/java/fr/maxlego08/sarah/dialect/AbstractSqlDialect.java @@ -0,0 +1,272 @@ +package fr.maxlego08.sarah.dialect; + +import fr.maxlego08.sarah.DatabaseConfiguration; +import fr.maxlego08.sarah.DatabaseConnection; +import fr.maxlego08.sarah.conditions.ColumnDefinition; +import fr.maxlego08.sarah.database.Schema; +import fr.maxlego08.sarah.exceptions.DatabaseException; +import fr.maxlego08.sarah.logger.Logger; + +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Locale; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; + +public abstract class AbstractSqlDialect implements SqlDialect { + + private final String quote; + + protected AbstractSqlDialect(String quote) { + this.quote = quote; + } + + @Override + public String quoteIdentifier(String name) { + if (name == null || name.isEmpty()) { + throw new IllegalArgumentException("Identifier cannot be null or empty"); + } + + if (name.startsWith(quote) && name.endsWith(quote)) { + return name; + } + + return quote + name + quote; + } + + @Override + public String qualifyIdentifier(String prefix, String column) { + if (prefix == null || prefix.trim().isEmpty()) { + return quoteIdentifier(column); + } + return prefix + "." + quoteIdentifier(column); + } + + @Override + public String driverClassName() { + throw new UnsupportedOperationException("Driver class is not configured"); + } + + @Override + public String jdbcUrl(DatabaseConfiguration configuration) { + throw new UnsupportedOperationException("JDBC URL is not configured"); + } + + @Override + public String createTableSuffix() { + return ""; + } + + @Override + public String columnType(ColumnDefinition column) { + Objects.requireNonNull(column, "column"); + + if (column.getEnumValues() != null && !column.getEnumValues().isEmpty()) { + return enumColumnType(column); + } + + String baseType = column.getType(); + if (baseType == null || baseType.isEmpty()) { + baseType = "TEXT"; + } + + if (useIntegerTypeForAutoIncrementPrimaryKey(column) && isIntegerType(baseType)) { + baseType = "INTEGER"; + } + + Integer length = column.getLength(); + int decimal = getDecimal(column); + if (length != null && length > 0 && decimal > 0) { + return baseType + "(" + length + "," + decimal + ")"; + } + if (length != null && length > 0) { + return baseType + "(" + length + ")"; + } + + return baseType; + } + + @Override + public String autoIncrementKeyword(ColumnDefinition column) { + return ""; + } + + @Override + public String enumColumnType(ColumnDefinition column) { + return "TEXT"; + } + + @Override + public String updatedAtDefaultValue() { + return "CURRENT_TIMESTAMP"; + } + + @Override + public String upsertConflictClause(Schema schema) { + return " ON CONFLICT (" + String.join(", ", conflictColumns(schema)) + ") DO UPDATE SET "; + } + + @Override + public String upsertUpdateExpression(String quotedColumn, boolean batch) { + return quotedColumn + " = excluded." + quotedColumn; + } + + @Override + public boolean usesUpsertUpdateParameters(boolean batch) { + return false; + } + + @Override + public List missingColumns(DatabaseConnection connection, Logger logger, String tableName, List expectedColumns) { + Objects.requireNonNull(connection, "connection"); + Objects.requireNonNull(tableName, "tableName"); + Objects.requireNonNull(expectedColumns, "expectedColumns"); + + return queryMissingColumns( + connection, + logger, + "SELECT COUNT(*) FROM information_schema.COLUMNS WHERE TABLE_NAME = ? AND TABLE_SCHEMA = ? AND COLUMN_NAME = ?", + tableName, + connection.getDatabaseConfiguration().getDatabase(), + expectedColumns + ); + } + + protected boolean useIntegerTypeForAutoIncrementPrimaryKey(ColumnDefinition column) { + return false; + } + + protected List conflictColumns(Schema schema) { + Objects.requireNonNull(schema, "schema"); + + List columns = new ArrayList(); + for (String primaryKey : schema.getPrimaryKeys()) { + ColumnDefinition column = findColumn(schema, primaryKey); + if (column != null && !column.isAutoIncrement()) { + columns.add(quoteIdentifier(column.getName())); + } + } + + if (columns.isEmpty()) { + for (ColumnDefinition column : schema.getColumns()) { + if (column.isUnique() && !column.isAutoIncrement()) { + columns.add(quoteIdentifier(column.getName())); + } + } + } + + if (columns.isEmpty()) { + throw new IllegalStateException("UPSERT requires at least one non-auto-increment primary key or unique constraint"); + } + + return columns; + } + + protected ColumnDefinition findColumn(Schema schema, String identifier) { + if (identifier == null) { + return null; + } + + String normalizedIdentifier = normalizeIdentifier(identifier); + for (ColumnDefinition column : schema.getColumns()) { + if (column.getName() != null && column.getName().equals(identifier)) { + return column; + } + if (column.getSafeName().equals(identifier)) { + return column; + } + if (column.getName() != null && normalizeIdentifier(column.getName()).equals(normalizedIdentifier)) { + return column; + } + } + return null; + } + + protected String normalizeIdentifier(String identifier) { + return identifier.replace("`", "").replace("\"", ""); + } + + protected boolean isIntegerType(String type) { + String normalized = type.toUpperCase(Locale.ROOT); + return "INT".equals(normalized) || "INTEGER".equals(normalized) || "BIGINT".equals(normalized); + } + + protected int getDecimal(ColumnDefinition column) { + return column.getDecimal(); + } + + protected String escapeSingleQuotes(String value) { + return value == null ? null : value.replace("'", "''"); + } + + protected List queryMissingColumns( + DatabaseConnection connection, + Logger logger, + String query, + String tableName, + String schemaName, + List expectedColumns + ) { + List missing = new ArrayList(); + + try (Connection sqlConnection = connection.getConnection()) { + for (ColumnDefinition column : expectedColumns) { + long count = countExistingColumn(sqlConnection, query, tableName, schemaName, column.getName()); + if (count == 0) { + missing.add(column); + } + } + } catch (SQLException exception) { + if (logger != null) { + logger.info("Failed to check column metadata for table '" + tableName + "': " + exception.getMessage()); + } + throw new DatabaseException("missing-columns", tableName, exception); + } + + return missing; + } + + protected long countExistingColumn(Connection connection, String query, String tableName, String schemaName, String columnName) throws SQLException { + try (PreparedStatement statement = connection.prepareStatement(query)) { + statement.setString(1, tableName); + statement.setString(2, schemaName); + statement.setString(3, columnName); + try (ResultSet resultSet = statement.executeQuery()) { + if (resultSet.next()) { + return resultSet.getLong(1); + } + } + } + return 0L; + } + + protected List missingColumnsFromExistingNames(List expectedColumns, Set existingColumns, boolean normalizeToLowerCase) { + Set normalizedExistingColumns = existingColumns; + if (normalizeToLowerCase) { + normalizedExistingColumns = existingColumns.stream() + .filter(Objects::nonNull) + .map(value -> value.toLowerCase(Locale.ROOT)) + .collect(Collectors.toCollection(HashSet::new)); + } + + List missing = new ArrayList(); + for (ColumnDefinition column : expectedColumns) { + String name = column.getName(); + if (name == null) { + missing.add(column); + continue; + } + String lookup = normalizeToLowerCase ? name.toLowerCase(Locale.ROOT) : name; + if (!normalizedExistingColumns.contains(lookup)) { + missing.add(column); + } + } + return missing; + } +} diff --git a/src/main/java/fr/maxlego08/sarah/dialect/MariaDbDialect.java b/src/main/java/fr/maxlego08/sarah/dialect/MariaDbDialect.java new file mode 100644 index 0000000..fc5c2db --- /dev/null +++ b/src/main/java/fr/maxlego08/sarah/dialect/MariaDbDialect.java @@ -0,0 +1,16 @@ +package fr.maxlego08.sarah.dialect; + +import fr.maxlego08.sarah.DatabaseConfiguration; + +public class MariaDbDialect extends MySqlDialect { + + @Override + public String driverClassName() { + return "org.mariadb.jdbc.Driver"; + } + + @Override + public String jdbcUrl(DatabaseConfiguration configuration) { + return "jdbc:mariadb://" + configuration.getHost() + ":" + configuration.getPort() + "/" + configuration.getDatabase() + "?allowMultiQueries=true"; + } +} diff --git a/src/main/java/fr/maxlego08/sarah/dialect/MySqlDialect.java b/src/main/java/fr/maxlego08/sarah/dialect/MySqlDialect.java new file mode 100644 index 0000000..7bc01fd --- /dev/null +++ b/src/main/java/fr/maxlego08/sarah/dialect/MySqlDialect.java @@ -0,0 +1,70 @@ +package fr.maxlego08.sarah.dialect; + +import fr.maxlego08.sarah.DatabaseConfiguration; +import fr.maxlego08.sarah.conditions.ColumnDefinition; +import fr.maxlego08.sarah.database.Schema; + +import java.util.List; +import java.util.stream.Collectors; + +public class MySqlDialect extends AbstractSqlDialect { + + public MySqlDialect() { + super("`"); + } + + @Override + public String driverClassName() { + return "com.mysql.cj.jdbc.Driver"; + } + + @Override + public String jdbcUrl(DatabaseConfiguration configuration) { + return "jdbc:mysql://" + configuration.getHost() + ":" + configuration.getPort() + "/" + configuration.getDatabase() + "?allowMultiQueries=true"; + } + + @Override + public String createTableSuffix() { + return " ENGINE=InnoDB DEFAULT CHARSET=utf8mb4"; + } + + @Override + public String autoIncrementKeyword(ColumnDefinition column) { + return column != null && column.isAutoIncrement() ? "AUTO_INCREMENT" : ""; + } + + @Override + public String enumColumnType(ColumnDefinition column) { + List values = column.getEnumValues(); + if (values == null || values.isEmpty()) { + return "ENUM('')"; + } + + return "ENUM(" + values.stream() + .map(value -> "'" + escapeSingleQuotes(value) + "'") + .collect(Collectors.joining(", ")) + ")"; + } + + @Override + public String updatedAtDefaultValue() { + return "CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"; + } + + @Override + public String upsertConflictClause(Schema schema) { + return " ON DUPLICATE KEY UPDATE "; + } + + @Override + public String upsertUpdateExpression(String quotedColumn, boolean batch) { + if (batch) { + return quotedColumn + " = VALUES(" + quotedColumn + ")"; + } + return quotedColumn + " = ?"; + } + + @Override + public boolean usesUpsertUpdateParameters(boolean batch) { + return !batch; + } +} diff --git a/src/main/java/fr/maxlego08/sarah/dialect/PostgreSqlDialect.java b/src/main/java/fr/maxlego08/sarah/dialect/PostgreSqlDialect.java new file mode 100644 index 0000000..ab13ff6 --- /dev/null +++ b/src/main/java/fr/maxlego08/sarah/dialect/PostgreSqlDialect.java @@ -0,0 +1,72 @@ +package fr.maxlego08.sarah.dialect; + +import fr.maxlego08.sarah.DatabaseConfiguration; +import fr.maxlego08.sarah.DatabaseConnection; +import fr.maxlego08.sarah.conditions.ColumnDefinition; +import fr.maxlego08.sarah.exceptions.DatabaseException; +import fr.maxlego08.sarah.logger.Logger; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +public class PostgreSqlDialect extends AbstractSqlDialect { + + public PostgreSqlDialect() { + super("\""); + } + + @Override + public String driverClassName() { + return "org.postgresql.Driver"; + } + + @Override + public String jdbcUrl(DatabaseConfiguration configuration) { + return "jdbc:postgresql://" + configuration.getHost() + ":" + configuration.getPort() + "/" + configuration.getDatabase(); + } + + @Override + public String autoIncrementKeyword(ColumnDefinition column) { + return column != null && column.isAutoIncrement() ? "GENERATED BY DEFAULT AS IDENTITY" : ""; + } + + @Override + public String enumColumnType(ColumnDefinition column) { + return "TEXT"; + } + + @Override + public List missingColumns(DatabaseConnection connection, Logger logger, String tableName, List expectedColumns) { + Objects.requireNonNull(connection, "connection"); + Objects.requireNonNull(tableName, "tableName"); + Objects.requireNonNull(expectedColumns, "expectedColumns"); + + List missing = new ArrayList(); + + try (Connection sqlConnection = connection.getConnection()) { + for (ColumnDefinition column : expectedColumns) { + String columnName = column.getName(); + long count = countExistingColumn( + sqlConnection, + "SELECT COUNT(*) FROM information_schema.columns WHERE table_name = ? AND table_schema = ? AND column_name = ?", + tableName, + "public", + columnName + ); + if (count == 0) { + missing.add(column); + } + } + } catch (SQLException exception) { + if (logger != null) { + logger.info("Failed to check PostgreSQL table info for '" + tableName + "': " + exception.getMessage()); + } + throw new DatabaseException("missing-columns", tableName, exception); + } + + return missing; + } +} diff --git a/src/main/java/fr/maxlego08/sarah/dialect/SqlDialect.java b/src/main/java/fr/maxlego08/sarah/dialect/SqlDialect.java new file mode 100644 index 0000000..6ede61e --- /dev/null +++ b/src/main/java/fr/maxlego08/sarah/dialect/SqlDialect.java @@ -0,0 +1,38 @@ +package fr.maxlego08.sarah.dialect; + +import fr.maxlego08.sarah.DatabaseConfiguration; +import fr.maxlego08.sarah.DatabaseConnection; +import fr.maxlego08.sarah.conditions.ColumnDefinition; +import fr.maxlego08.sarah.database.Schema; +import fr.maxlego08.sarah.logger.Logger; + +import java.util.List; + +public interface SqlDialect { + + String quoteIdentifier(String name); + + String qualifyIdentifier(String prefix, String column); + + String driverClassName(); + + String jdbcUrl(DatabaseConfiguration configuration); + + String createTableSuffix(); + + String columnType(ColumnDefinition column); + + String autoIncrementKeyword(ColumnDefinition column); + + String enumColumnType(ColumnDefinition column); + + String updatedAtDefaultValue(); + + String upsertConflictClause(Schema schema); + + String upsertUpdateExpression(String quotedColumn, boolean batch); + + boolean usesUpsertUpdateParameters(boolean batch); + + List missingColumns(DatabaseConnection connection, Logger logger, String tableName, List expectedColumns); +} diff --git a/src/main/java/fr/maxlego08/sarah/dialect/SqlDialects.java b/src/main/java/fr/maxlego08/sarah/dialect/SqlDialects.java new file mode 100644 index 0000000..7bceb7e --- /dev/null +++ b/src/main/java/fr/maxlego08/sarah/dialect/SqlDialects.java @@ -0,0 +1,33 @@ +package fr.maxlego08.sarah.dialect; + +import fr.maxlego08.sarah.database.DatabaseType; + +public final class SqlDialects { + + private static final SqlDialect MYSQL = new MySqlDialect(); + private static final SqlDialect MARIADB = new MariaDbDialect(); + private static final SqlDialect SQLITE = new SqliteDialect(); + private static final SqlDialect POSTGRESQL = new PostgreSqlDialect(); + + private SqlDialects() { + } + + public static SqlDialect from(DatabaseType databaseType) { + if (databaseType == null) { + throw new IllegalArgumentException("Database type cannot be null"); + } + + switch (databaseType) { + case MYSQL: + return MYSQL; + case MARIADB: + return MARIADB; + case SQLITE: + return SQLITE; + case POSTGRESQL: + return POSTGRESQL; + default: + throw new IllegalArgumentException("Unsupported database type: " + databaseType); + } + } +} diff --git a/src/main/java/fr/maxlego08/sarah/dialect/SqliteDialect.java b/src/main/java/fr/maxlego08/sarah/dialect/SqliteDialect.java new file mode 100644 index 0000000..1afa31d --- /dev/null +++ b/src/main/java/fr/maxlego08/sarah/dialect/SqliteDialect.java @@ -0,0 +1,68 @@ +package fr.maxlego08.sarah.dialect; + +import fr.maxlego08.sarah.DatabaseConfiguration; +import fr.maxlego08.sarah.DatabaseConnection; +import fr.maxlego08.sarah.conditions.ColumnDefinition; +import fr.maxlego08.sarah.exceptions.DatabaseException; +import fr.maxlego08.sarah.logger.Logger; + +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.HashSet; +import java.util.List; +import java.util.Objects; +import java.util.Set; + +public class SqliteDialect extends AbstractSqlDialect { + + public SqliteDialect() { + super("`"); + } + + @Override + public String driverClassName() { + return "org.sqlite.JDBC"; + } + + @Override + public String jdbcUrl(DatabaseConfiguration configuration) { + throw new UnsupportedOperationException("SQLite JDBC URL is file-based and resolved by SqliteConnection"); + } + + @Override + public String autoIncrementKeyword(ColumnDefinition column) { + return ""; + } + + @Override + protected boolean useIntegerTypeForAutoIncrementPrimaryKey(ColumnDefinition column) { + return column != null && column.isAutoIncrement(); + } + + @Override + public List missingColumns(DatabaseConnection connection, Logger logger, String tableName, List expectedColumns) { + Objects.requireNonNull(connection, "connection"); + Objects.requireNonNull(tableName, "tableName"); + Objects.requireNonNull(expectedColumns, "expectedColumns"); + + Set existingColumns = new HashSet(); + + String query = String.format("PRAGMA table_info(%s)", tableName); + try (Connection sqlConnection = connection.getConnection(); + PreparedStatement statement = sqlConnection.prepareStatement(query); + ResultSet resultSet = statement.executeQuery()) { + while (resultSet.next()) { + existingColumns.add(resultSet.getString("name")); + } + } catch (SQLException exception) { + if (logger != null) { + logger.info("Failed to check SQLite table info for '" + tableName + "': " + exception.getMessage()); + } + throw new DatabaseException("missing-columns", tableName, exception); + } + + return missingColumnsFromExistingNames(expectedColumns, existingColumns, false); + } +} diff --git a/src/main/java/fr/maxlego08/sarah/requests/AlterRequest.java b/src/main/java/fr/maxlego08/sarah/requests/AlterRequest.java index a25525d..da4506c 100644 --- a/src/main/java/fr/maxlego08/sarah/requests/AlterRequest.java +++ b/src/main/java/fr/maxlego08/sarah/requests/AlterRequest.java @@ -3,8 +3,11 @@ import fr.maxlego08.sarah.DatabaseConfiguration; import fr.maxlego08.sarah.DatabaseConnection; import fr.maxlego08.sarah.conditions.ColumnDefinition; +import fr.maxlego08.sarah.conditions.ForeignKeyDefinition; import fr.maxlego08.sarah.database.Executor; import fr.maxlego08.sarah.database.Schema; +import fr.maxlego08.sarah.dialect.SqlDialect; +import fr.maxlego08.sarah.dialect.SqlDialects; import fr.maxlego08.sarah.exceptions.DatabaseException; import fr.maxlego08.sarah.logger.Logger; @@ -24,22 +27,27 @@ public AlterRequest(Schema schema) { @Override public int execute(DatabaseConnection databaseConnection, DatabaseConfiguration databaseConfiguration, Logger logger) { + SqlDialect dialect = SqlDialects.from(databaseConfiguration.getDatabaseType()); StringBuilder alterTableSQL = new StringBuilder("ALTER TABLE "); - alterTableSQL.append(this.schema.getTableName()).append(" "); + alterTableSQL.append(dialect.quoteIdentifier(this.schema.getTableName())).append(" "); List columnSQLs = new ArrayList<>(); for (ColumnDefinition column : this.schema.getColumns()) { - columnSQLs.add("ADD COLUMN " + column.build(databaseConfiguration)); + columnSQLs.add("ADD COLUMN " + column.build(databaseConfiguration, dialect)); } alterTableSQL.append(String.join(", ", columnSQLs)); if (!this.schema.getPrimaryKeys().isEmpty()) { - alterTableSQL.append(", PRIMARY KEY (").append(String.join(", ", this.schema.getPrimaryKeys())).append(")"); + List primaryKeys = new ArrayList(); + for (String primaryKey : this.schema.getPrimaryKeys()) { + primaryKeys.add(dialect.quoteIdentifier(stripWrappingQuotes(primaryKey))); + } + alterTableSQL.append(", PRIMARY KEY (").append(String.join(", ", primaryKeys)).append(")"); } - for (String fk : this.schema.getForeignKeys()) { - alterTableSQL.append(", ADD ").append(fk); + for (ForeignKeyDefinition foreignKey : this.schema.getForeignKeys()) { + alterTableSQL.append(", ADD ").append(foreignKey.render(dialect)); } String finalQuery = databaseConfiguration.replacePrefix(alterTableSQL.toString()); @@ -56,4 +64,19 @@ public int execute(DatabaseConnection databaseConnection, DatabaseConfiguration throw new DatabaseException("alter", this.schema.getTableName(), exception); } } + + private boolean isQuoted(String identifier) { + return (identifier.startsWith("`") && identifier.endsWith("`")) || + (identifier.startsWith("\"") && identifier.endsWith("\"")); + } + + private String stripWrappingQuotes(String identifier) { + if (identifier == null || identifier.length() < 2) { + return identifier; + } + if (isQuoted(identifier)) { + return identifier.substring(1, identifier.length() - 1); + } + return identifier; + } } diff --git a/src/main/java/fr/maxlego08/sarah/requests/CreateIndexRequest.java b/src/main/java/fr/maxlego08/sarah/requests/CreateIndexRequest.java index 68557bf..4cfc75f 100644 --- a/src/main/java/fr/maxlego08/sarah/requests/CreateIndexRequest.java +++ b/src/main/java/fr/maxlego08/sarah/requests/CreateIndexRequest.java @@ -5,6 +5,8 @@ import fr.maxlego08.sarah.conditions.ColumnDefinition; import fr.maxlego08.sarah.database.Executor; import fr.maxlego08.sarah.database.Schema; +import fr.maxlego08.sarah.dialect.SqlDialect; +import fr.maxlego08.sarah.dialect.SqlDialects; import fr.maxlego08.sarah.exceptions.DatabaseException; import fr.maxlego08.sarah.logger.Logger; @@ -22,17 +24,18 @@ public CreateIndexRequest(Schema schema) { @Override public int execute(DatabaseConnection databaseConnection, DatabaseConfiguration databaseConfiguration, Logger logger) { + SqlDialect dialect = SqlDialects.from(databaseConfiguration.getDatabaseType()); StringBuilder indexTableSQL = new StringBuilder("CREATE INDEX "); String tableName = schema.getTableName(); ColumnDefinition column = schema.getColumns().get(0); String indexName = "idx_" + tableName + "_" + column.getName(); - indexTableSQL.append(indexName); + indexTableSQL.append(dialect.quoteIdentifier(indexName)); indexTableSQL.append(" ON "); - indexTableSQL.append(String.format("`%s`", tableName)); + indexTableSQL.append(dialect.quoteIdentifier(tableName)); indexTableSQL.append(" ("); - indexTableSQL.append(column.getSafeName()); + indexTableSQL.append(dialect.quoteIdentifier(column.getName())); indexTableSQL.append(" )"); String finalQuery = databaseConfiguration.replacePrefix(indexTableSQL.toString()); diff --git a/src/main/java/fr/maxlego08/sarah/requests/CreateRequest.java b/src/main/java/fr/maxlego08/sarah/requests/CreateRequest.java index 7c59854..23c82e9 100644 --- a/src/main/java/fr/maxlego08/sarah/requests/CreateRequest.java +++ b/src/main/java/fr/maxlego08/sarah/requests/CreateRequest.java @@ -3,9 +3,11 @@ import fr.maxlego08.sarah.DatabaseConfiguration; import fr.maxlego08.sarah.DatabaseConnection; import fr.maxlego08.sarah.conditions.ColumnDefinition; -import fr.maxlego08.sarah.database.DatabaseType; +import fr.maxlego08.sarah.conditions.ForeignKeyDefinition; import fr.maxlego08.sarah.database.Executor; import fr.maxlego08.sarah.database.Schema; +import fr.maxlego08.sarah.dialect.SqlDialect; +import fr.maxlego08.sarah.dialect.SqlDialects; import fr.maxlego08.sarah.exceptions.DatabaseException; import fr.maxlego08.sarah.logger.Logger; @@ -26,14 +28,15 @@ public CreateRequest(Schema schema) { @Override public int execute(DatabaseConnection databaseConnection, DatabaseConfiguration databaseConfiguration, Logger logger) { + SqlDialect dialect = SqlDialects.from(databaseConfiguration.getDatabaseType()); StringBuilder createTableSQL = new StringBuilder("CREATE TABLE IF NOT EXISTS "); - createTableSQL.append(this.schema.getTableName()).append(" ("); + createTableSQL.append(dialect.quoteIdentifier(this.schema.getTableName())).append(" ("); List columnSQLs = new ArrayList<>(); boolean hasInlinePrimaryKey = false; for (ColumnDefinition column : this.schema.getColumns()) { - columnSQLs.add(column.build(databaseConfiguration)); + columnSQLs.add(column.build(databaseConfiguration, dialect)); // Check if this column has inline PRIMARY KEY (SQLite autoincrement) if (column.isAutoIncrement() && column.isPrimaryKey() && databaseConfiguration.getDatabaseType() == fr.maxlego08.sarah.database.DatabaseType.SQLITE) { @@ -44,19 +47,20 @@ public int execute(DatabaseConnection databaseConnection, DatabaseConfiguration // Only add separate PRIMARY KEY clause if there's no inline PRIMARY KEY if (!this.schema.getPrimaryKeys().isEmpty() && !hasInlinePrimaryKey) { - createTableSQL.append(", PRIMARY KEY (").append(String.join(", ", this.schema.getPrimaryKeys())).append(")"); + List primaryKeys = new ArrayList<>(); + for (String primaryKey : this.schema.getPrimaryKeys()) { + primaryKeys.add(dialect.quoteIdentifier(normalizeIdentifier(primaryKey))); + } + createTableSQL.append(", PRIMARY KEY (").append(String.join(", ", primaryKeys)).append(")"); } - for (String fk : this.schema.getForeignKeys()) { - createTableSQL.append(", ").append(fk); + for (ForeignKeyDefinition fk : this.schema.getForeignKeys()) { + createTableSQL.append(", ").append(fk.render(dialect)); } createTableSQL.append(")"); - // Force InnoDB engine for MySQL/MariaDB to ensure foreign key support - if (databaseConfiguration.getDatabaseType() != DatabaseType.SQLITE) { - createTableSQL.append(" ENGINE=InnoDB DEFAULT CHARSET=utf8mb4"); - } + createTableSQL.append(dialect.createTableSuffix()); String finalQuery = databaseConfiguration.replacePrefix(createTableSQL.toString()); if (databaseConfiguration.isDebug()) { @@ -71,4 +75,14 @@ public int execute(DatabaseConnection databaseConnection, DatabaseConfiguration throw new DatabaseException("create", this.schema.getTableName(), exception); } } + + private String normalizeIdentifier(String identifier) { + if (identifier == null || identifier.length() < 2) { + return identifier; + } + if ((identifier.startsWith("`") && identifier.endsWith("`")) || (identifier.startsWith("\"") && identifier.endsWith("\""))) { + return identifier.substring(1, identifier.length() - 1); + } + return identifier; + } } diff --git a/src/main/java/fr/maxlego08/sarah/requests/DeleteRequest.java b/src/main/java/fr/maxlego08/sarah/requests/DeleteRequest.java index 1b9639c..95f9bd1 100644 --- a/src/main/java/fr/maxlego08/sarah/requests/DeleteRequest.java +++ b/src/main/java/fr/maxlego08/sarah/requests/DeleteRequest.java @@ -4,6 +4,8 @@ import fr.maxlego08.sarah.DatabaseConnection; import fr.maxlego08.sarah.database.Executor; import fr.maxlego08.sarah.database.Schema; +import fr.maxlego08.sarah.dialect.SqlDialect; +import fr.maxlego08.sarah.dialect.SqlDialects; import fr.maxlego08.sarah.exceptions.DatabaseException; import fr.maxlego08.sarah.logger.Logger; @@ -21,8 +23,9 @@ public DeleteRequest(Schema schemaBuilder) { @Override public int execute(DatabaseConnection databaseConnection, DatabaseConfiguration databaseConfiguration, Logger logger) { - StringBuilder sql = new StringBuilder("DELETE FROM ").append(schemaBuilder.getTableName()); - schemaBuilder.whereConditions(sql); + SqlDialect dialect = SqlDialects.from(databaseConfiguration.getDatabaseType()); + StringBuilder sql = new StringBuilder("DELETE FROM ").append(dialect.quoteIdentifier(schemaBuilder.getTableName())); + schemaBuilder.whereConditions(sql, dialect); String finalQuery = databaseConfiguration.replacePrefix(sql.toString()); if (databaseConfiguration.isDebug()) { diff --git a/src/main/java/fr/maxlego08/sarah/requests/DropTableRequest.java b/src/main/java/fr/maxlego08/sarah/requests/DropTableRequest.java index 1068d19..2a2050b 100644 --- a/src/main/java/fr/maxlego08/sarah/requests/DropTableRequest.java +++ b/src/main/java/fr/maxlego08/sarah/requests/DropTableRequest.java @@ -4,6 +4,8 @@ import fr.maxlego08.sarah.DatabaseConnection; import fr.maxlego08.sarah.database.Executor; import fr.maxlego08.sarah.database.Schema; +import fr.maxlego08.sarah.dialect.SqlDialect; +import fr.maxlego08.sarah.dialect.SqlDialects; import fr.maxlego08.sarah.logger.Logger; import java.sql.Connection; @@ -26,7 +28,8 @@ public int execute(DatabaseConnection databaseConnection, DatabaseConfiguration return -1; } - String finalQuery = databaseConfiguration.replacePrefix("DROP TABLE IF EXISTS " + tableName); + SqlDialect dialect = SqlDialects.from(databaseConfiguration.getDatabaseType()); + String finalQuery = databaseConfiguration.replacePrefix("DROP TABLE IF EXISTS " + dialect.quoteIdentifier(tableName)); if (databaseConfiguration.isDebug()) { logger.info("Executing SQL: " + finalQuery); } diff --git a/src/main/java/fr/maxlego08/sarah/requests/InsertAllRequest.java b/src/main/java/fr/maxlego08/sarah/requests/InsertAllRequest.java index 0f82ea4..31097f4 100644 --- a/src/main/java/fr/maxlego08/sarah/requests/InsertAllRequest.java +++ b/src/main/java/fr/maxlego08/sarah/requests/InsertAllRequest.java @@ -5,12 +5,16 @@ import fr.maxlego08.sarah.conditions.ColumnDefinition; import fr.maxlego08.sarah.database.Executor; import fr.maxlego08.sarah.database.Schema; +import fr.maxlego08.sarah.dialect.SqlDialect; +import fr.maxlego08.sarah.dialect.SqlDialects; import fr.maxlego08.sarah.exceptions.DatabaseException; import fr.maxlego08.sarah.logger.Logger; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; public class InsertAllRequest implements Executor { @@ -24,29 +28,24 @@ public InsertAllRequest(Schema schema, String toTableName) { @Override public int execute(DatabaseConnection databaseConnection, DatabaseConfiguration databaseConfiguration, Logger logger) { + SqlDialect dialect = SqlDialects.from(databaseConfiguration.getDatabaseType()); - StringBuilder insertBuilder = new StringBuilder("INSERT INTO " + this.toTableName + " ("); - StringBuilder columns = new StringBuilder(); + StringBuilder insertBuilder = new StringBuilder("INSERT INTO ") + .append(dialect.quoteIdentifier(this.toTableName)) + .append(" ("); + List quotedColumns = new ArrayList(); - int columnIndex = 0; for (ColumnDefinition columnDefinition : this.schema.getColumns()) { - // Skip auto-increment columns if (columnDefinition.isAutoIncrement()) { continue; } - - if (columnIndex > 0) { - columns.append(","); - } - columns.append(columnDefinition.getSafeName()); - columnIndex++; + quotedColumns.add(dialect.quoteIdentifier(columnDefinition.getName())); } - insertBuilder.append(columns).append(") "); - - insertBuilder.append("SELECT ").append(columns); - insertBuilder.append(" FROM "); - insertBuilder.append(this.schema.getTableName()); + String columnsSql = String.join(", ", quotedColumns); + insertBuilder.append(columnsSql).append(") "); + insertBuilder.append("SELECT ").append(columnsSql); + insertBuilder.append(" FROM ").append(dialect.quoteIdentifier(this.schema.getTableName())); String insertQuery = databaseConfiguration.replacePrefix(insertBuilder.toString()); diff --git a/src/main/java/fr/maxlego08/sarah/requests/InsertBatchRequest.java b/src/main/java/fr/maxlego08/sarah/requests/InsertBatchRequest.java index 87f31d9..20c7f31 100644 --- a/src/main/java/fr/maxlego08/sarah/requests/InsertBatchRequest.java +++ b/src/main/java/fr/maxlego08/sarah/requests/InsertBatchRequest.java @@ -5,6 +5,8 @@ import fr.maxlego08.sarah.conditions.ColumnDefinition; import fr.maxlego08.sarah.database.Executor; import fr.maxlego08.sarah.database.Schema; +import fr.maxlego08.sarah.dialect.SqlDialect; +import fr.maxlego08.sarah.dialect.SqlDialects; import fr.maxlego08.sarah.exceptions.DatabaseException; import fr.maxlego08.sarah.logger.Logger; @@ -30,8 +32,9 @@ public int execute(DatabaseConnection databaseConnection, DatabaseConfiguration return 0; } + SqlDialect dialect = SqlDialects.from(databaseConfiguration.getDatabaseType()); Schema firstSchema = schemas.get(0); - StringBuilder insertQuery = new StringBuilder("INSERT INTO " + firstSchema.getTableName() + " ("); + StringBuilder insertQuery = new StringBuilder("INSERT INTO " + dialect.quoteIdentifier(firstSchema.getTableName()) + " ("); StringBuilder valuesQuery = new StringBuilder("VALUES "); List values = new ArrayList<>(); @@ -41,7 +44,7 @@ public int execute(DatabaseConnection databaseConnection, DatabaseConfiguration // Skip auto-increment columns for (ColumnDefinition column : firstSchema.getColumns()) { if (!column.isAutoIncrement()) { - columnNames.add(column.getSafeName()); + columnNames.add(dialect.quoteIdentifier(column.getName())); } } diff --git a/src/main/java/fr/maxlego08/sarah/requests/InsertRequest.java b/src/main/java/fr/maxlego08/sarah/requests/InsertRequest.java index b22b015..5a3f0f4 100644 --- a/src/main/java/fr/maxlego08/sarah/requests/InsertRequest.java +++ b/src/main/java/fr/maxlego08/sarah/requests/InsertRequest.java @@ -5,6 +5,8 @@ import fr.maxlego08.sarah.conditions.ColumnDefinition; import fr.maxlego08.sarah.database.Executor; import fr.maxlego08.sarah.database.Schema; +import fr.maxlego08.sarah.dialect.SqlDialect; +import fr.maxlego08.sarah.dialect.SqlDialects; import fr.maxlego08.sarah.exceptions.DatabaseException; import fr.maxlego08.sarah.logger.Logger; @@ -26,8 +28,9 @@ public InsertRequest(Schema schema) { @Override public int execute(DatabaseConnection databaseConnection, DatabaseConfiguration databaseConfiguration, Logger logger) { + SqlDialect dialect = SqlDialects.from(databaseConfiguration.getDatabaseType()); - StringBuilder insertQuery = new StringBuilder("INSERT INTO " + this.schema.getTableName() + " ("); + StringBuilder insertQuery = new StringBuilder("INSERT INTO " + dialect.quoteIdentifier(this.schema.getTableName()) + " ("); StringBuilder valuesQuery = new StringBuilder("VALUES ("); List values = new ArrayList<>(); @@ -38,7 +41,7 @@ public int execute(DatabaseConnection databaseConnection, DatabaseConfiguration if (columnDefinition.isAutoIncrement()) { continue; } - insertQuery.append(paramIndex > 0 ? ", " : "").append(columnDefinition.getSafeName()); + insertQuery.append(paramIndex > 0 ? ", " : "").append(dialect.quoteIdentifier(columnDefinition.getName())); valuesQuery.append(paramIndex > 0 ? ", " : "").append("?"); values.add(columnDefinition.getObject()); paramIndex++; diff --git a/src/main/java/fr/maxlego08/sarah/requests/RenameExecutor.java b/src/main/java/fr/maxlego08/sarah/requests/RenameExecutor.java index 48585bc..a8d8e61 100644 --- a/src/main/java/fr/maxlego08/sarah/requests/RenameExecutor.java +++ b/src/main/java/fr/maxlego08/sarah/requests/RenameExecutor.java @@ -4,6 +4,8 @@ import fr.maxlego08.sarah.DatabaseConnection; import fr.maxlego08.sarah.database.Executor; import fr.maxlego08.sarah.database.Schema; +import fr.maxlego08.sarah.dialect.SqlDialect; +import fr.maxlego08.sarah.dialect.SqlDialects; import fr.maxlego08.sarah.exceptions.DatabaseException; import fr.maxlego08.sarah.logger.Logger; @@ -21,11 +23,12 @@ public RenameExecutor(Schema schema) { @Override public int execute(DatabaseConnection databaseConnection, DatabaseConfiguration databaseConfiguration, Logger logger) { + SqlDialect dialect = SqlDialects.from(databaseConfiguration.getDatabaseType()); StringBuilder alterTableSQL = new StringBuilder("ALTER TABLE "); - alterTableSQL.append(this.schema.getTableName()); + alterTableSQL.append(dialect.quoteIdentifier(this.schema.getTableName())); alterTableSQL.append(" RENAME TO "); - alterTableSQL.append(this.schema.getNewTableName()); + alterTableSQL.append(dialect.quoteIdentifier(this.schema.getNewTableName())); String finalQuery = databaseConfiguration.replacePrefix(alterTableSQL.toString()); if (databaseConfiguration.isDebug()) { diff --git a/src/main/java/fr/maxlego08/sarah/requests/UpdateBatchRequest.java b/src/main/java/fr/maxlego08/sarah/requests/UpdateBatchRequest.java index fbc5cff..ed039f7 100644 --- a/src/main/java/fr/maxlego08/sarah/requests/UpdateBatchRequest.java +++ b/src/main/java/fr/maxlego08/sarah/requests/UpdateBatchRequest.java @@ -6,6 +6,8 @@ import fr.maxlego08.sarah.conditions.JoinCondition; import fr.maxlego08.sarah.database.Executor; import fr.maxlego08.sarah.database.Schema; +import fr.maxlego08.sarah.dialect.SqlDialect; +import fr.maxlego08.sarah.dialect.SqlDialects; import fr.maxlego08.sarah.exceptions.DatabaseException; import fr.maxlego08.sarah.logger.Logger; @@ -26,12 +28,13 @@ public UpdateBatchRequest(List schemas) { public int execute(DatabaseConnection databaseConnection, DatabaseConfiguration databaseConfiguration, Logger logger) { if (schemas.isEmpty()) return 0; + SqlDialect dialect = SqlDialects.from(databaseConfiguration.getDatabaseType()); Schema firstSchema = schemas.get(0); - StringBuilder updateQuery = new StringBuilder("UPDATE " + firstSchema.getTableName()); + StringBuilder updateQuery = new StringBuilder("UPDATE " + dialect.quoteIdentifier(firstSchema.getTableName())); if (!firstSchema.getJoinConditions().isEmpty()) { for (JoinCondition join : firstSchema.getJoinConditions()) { - updateQuery.append(" ").append(join.getJoinClause()); + updateQuery.append(" ").append(join.getJoinClause(dialect)); } } @@ -40,10 +43,10 @@ public int execute(DatabaseConnection databaseConnection, DatabaseConfiguration List columns = firstSchema.getColumns(); for (int i = 0; i < columns.size(); i++) { ColumnDefinition columnDefinition = columns.get(i); - updateQuery.append(i > 0 ? ", " : "").append(columnDefinition.getSafeName()).append(" = ?"); + updateQuery.append(i > 0 ? ", " : "").append(dialect.quoteIdentifier(columnDefinition.getName())).append(" = ?"); } - firstSchema.whereConditions(updateQuery); + firstSchema.whereConditions(updateQuery, dialect); String updateSql = databaseConfiguration.replacePrefix(updateQuery.toString()); if (databaseConfiguration.isDebug()) { diff --git a/src/main/java/fr/maxlego08/sarah/requests/UpdateRequest.java b/src/main/java/fr/maxlego08/sarah/requests/UpdateRequest.java index e74b244..d2fd0f8 100644 --- a/src/main/java/fr/maxlego08/sarah/requests/UpdateRequest.java +++ b/src/main/java/fr/maxlego08/sarah/requests/UpdateRequest.java @@ -6,6 +6,8 @@ import fr.maxlego08.sarah.conditions.JoinCondition; import fr.maxlego08.sarah.database.Executor; import fr.maxlego08.sarah.database.Schema; +import fr.maxlego08.sarah.dialect.SqlDialect; +import fr.maxlego08.sarah.dialect.SqlDialects; import fr.maxlego08.sarah.exceptions.DatabaseException; import fr.maxlego08.sarah.logger.Logger; @@ -25,12 +27,13 @@ public UpdateRequest(Schema schema) { @Override public int execute(DatabaseConnection databaseConnection, DatabaseConfiguration databaseConfiguration, Logger logger) { + SqlDialect dialect = SqlDialects.from(databaseConfiguration.getDatabaseType()); - StringBuilder updateQuery = new StringBuilder("UPDATE " + this.schema.getTableName()); + StringBuilder updateQuery = new StringBuilder("UPDATE " + dialect.quoteIdentifier(this.schema.getTableName())); if (!this.schema.getJoinConditions().isEmpty()) { for (JoinCondition join : this.schema.getJoinConditions()) { - updateQuery.append(" ").append(join.getJoinClause()); + updateQuery.append(" ").append(join.getJoinClause(dialect)); } } @@ -40,11 +43,11 @@ public int execute(DatabaseConnection databaseConnection, DatabaseConfiguration for (int i = 0; i < this.schema.getColumns().size(); i++) { ColumnDefinition columnDefinition = this.schema.getColumns().get(i); - updateQuery.append(i > 0 ? ", " : "").append(columnDefinition.getSafeName()).append(" = ?"); + updateQuery.append(i > 0 ? ", " : "").append(dialect.quoteIdentifier(columnDefinition.getName())).append(" = ?"); values.add(columnDefinition.getObject()); } - this.schema.whereConditions(updateQuery); + this.schema.whereConditions(updateQuery, dialect); String updateSql = databaseConfiguration.replacePrefix(updateQuery.toString()); if (databaseConfiguration.isDebug()) { diff --git a/src/main/java/fr/maxlego08/sarah/requests/UpsertBatchRequest.java b/src/main/java/fr/maxlego08/sarah/requests/UpsertBatchRequest.java index de65f72..c42f264 100644 --- a/src/main/java/fr/maxlego08/sarah/requests/UpsertBatchRequest.java +++ b/src/main/java/fr/maxlego08/sarah/requests/UpsertBatchRequest.java @@ -3,9 +3,10 @@ import fr.maxlego08.sarah.DatabaseConfiguration; import fr.maxlego08.sarah.DatabaseConnection; import fr.maxlego08.sarah.conditions.ColumnDefinition; -import fr.maxlego08.sarah.database.DatabaseType; import fr.maxlego08.sarah.database.Executor; import fr.maxlego08.sarah.database.Schema; +import fr.maxlego08.sarah.dialect.SqlDialect; +import fr.maxlego08.sarah.dialect.SqlDialects; import fr.maxlego08.sarah.exceptions.DatabaseException; import fr.maxlego08.sarah.logger.Logger; @@ -29,20 +30,19 @@ public int execute(DatabaseConnection databaseConnection, DatabaseConfiguration return 0; } - DatabaseType databaseType = databaseConfiguration.getDatabaseType(); + SqlDialect dialect = SqlDialects.from(databaseConfiguration.getDatabaseType()); Schema firstSchema = schemas.get(0); - StringBuilder insertQuery = new StringBuilder("INSERT INTO " + firstSchema.getTableName() + " ("); + StringBuilder insertQuery = new StringBuilder("INSERT INTO " + dialect.quoteIdentifier(firstSchema.getTableName()) + " ("); StringBuilder valuesQuery = new StringBuilder("VALUES "); - StringBuilder onUpdateQuery = new StringBuilder(); List values = new ArrayList<>(); List placeholders = new ArrayList<>(); - List insertColumnNames = new ArrayList<>(); + List insertColumnNames = new ArrayList(); // Build column list - skip auto-increment columns for (ColumnDefinition column : firstSchema.getColumns()) { if (!column.isAutoIncrement()) { - insertColumnNames.add(column.getSafeName()); + insertColumnNames.add(dialect.quoteIdentifier(column.getName())); } } @@ -62,29 +62,18 @@ public int execute(DatabaseConnection databaseConnection, DatabaseConfiguration valuesQuery.append(String.join(", ", placeholders)); - if (databaseType == DatabaseType.SQLITE) { - StringBuilder onConflictQuery = new StringBuilder(" ON CONFLICT ("); - List primaryKeys = firstSchema.getPrimaryKeys(); - onConflictQuery.append(String.join(", ", primaryKeys)).append(") DO UPDATE SET "); - - // Skip auto-increment columns in UPDATE as well - for (int i = 0; i < insertColumnNames.size(); i++) { - if (i > 0) onUpdateQuery.append(", "); - onUpdateQuery.append(insertColumnNames.get(i)).append(" = excluded.").append(insertColumnNames.get(i)); - } - - insertQuery.append(valuesQuery).append(onConflictQuery).append(onUpdateQuery); - } else { - onUpdateQuery.append(" ON DUPLICATE KEY UPDATE "); - // Skip auto-increment columns in UPDATE as well - for (int i = 0; i < insertColumnNames.size(); i++) { - if (i > 0) onUpdateQuery.append(", "); - onUpdateQuery.append(insertColumnNames.get(i)).append(" = VALUES(").append(insertColumnNames.get(i)).append(")"); + StringBuilder onUpdateQuery = new StringBuilder(); + for (int i = 0; i < insertColumnNames.size(); i++) { + if (i > 0) { + onUpdateQuery.append(", "); } - - insertQuery.append(valuesQuery).append(onUpdateQuery); + onUpdateQuery.append(dialect.upsertUpdateExpression(insertColumnNames.get(i), true)); } + insertQuery.append(valuesQuery) + .append(dialect.upsertConflictClause(firstSchema)) + .append(onUpdateQuery); + String finalQuery = databaseConfiguration.replacePrefix(insertQuery.toString()); if (databaseConfiguration.isDebug()) { logger.info("Executing SQL: " + finalQuery); diff --git a/src/main/java/fr/maxlego08/sarah/requests/UpsertRequest.java b/src/main/java/fr/maxlego08/sarah/requests/UpsertRequest.java index a571f8d..c57677d 100644 --- a/src/main/java/fr/maxlego08/sarah/requests/UpsertRequest.java +++ b/src/main/java/fr/maxlego08/sarah/requests/UpsertRequest.java @@ -6,6 +6,8 @@ import fr.maxlego08.sarah.database.DatabaseType; import fr.maxlego08.sarah.database.Executor; import fr.maxlego08.sarah.database.Schema; +import fr.maxlego08.sarah.dialect.SqlDialect; +import fr.maxlego08.sarah.dialect.SqlDialects; import fr.maxlego08.sarah.exceptions.DatabaseException; import fr.maxlego08.sarah.logger.Logger; @@ -25,8 +27,8 @@ public UpsertRequest(Schema schema) { @Override public int execute(DatabaseConnection databaseConnection, DatabaseConfiguration databaseConfiguration, Logger logger) { - DatabaseType databaseType = databaseConfiguration.getDatabaseType(); - StringBuilder insertQuery = new StringBuilder("INSERT INTO " + this.schema.getTableName() + " ("); + SqlDialect dialect = SqlDialects.from(databaseConfiguration.getDatabaseType()); + StringBuilder insertQuery = new StringBuilder("INSERT INTO " + dialect.quoteIdentifier(this.schema.getTableName()) + " ("); StringBuilder valuesQuery = new StringBuilder("VALUES ("); StringBuilder onUpdateQuery = new StringBuilder(); @@ -38,21 +40,16 @@ public int execute(DatabaseConnection databaseConnection, DatabaseConfiguration for (ColumnDefinition columnDefinition : this.schema.getColumns()) { // Skip auto-increment columns in INSERT part if (!columnDefinition.isAutoIncrement()) { - insertQuery.append(insertIndex > 0 ? ", " : "").append(columnDefinition.getSafeName()); + String quotedColumn = dialect.quoteIdentifier(columnDefinition.getName()); + insertQuery.append(insertIndex > 0 ? ", " : "").append(quotedColumn); valuesQuery.append(insertIndex > 0 ? ", " : "").append("?"); insertValues.add(columnDefinition.getObject()); insertIndex++; - } - - // Skip auto-increment columns in UPDATE part as well - if (!columnDefinition.isAutoIncrement()) { if (updateIndex > 0) { onUpdateQuery.append(", "); } - if (databaseType == DatabaseType.SQLITE) { - onUpdateQuery.append(columnDefinition.getSafeName()).append(" = excluded.").append(columnDefinition.getSafeName()); - } else { - onUpdateQuery.append(columnDefinition.getSafeName()).append(" = ?"); + onUpdateQuery.append(dialect.upsertUpdateExpression(quotedColumn, false)); + if (dialect.usesUpsertUpdateParameters(false)) { updateValues.add(columnDefinition.getObject()); } updateIndex++; @@ -62,21 +59,7 @@ public int execute(DatabaseConnection databaseConnection, DatabaseConfiguration insertQuery.append(") "); valuesQuery.append(")"); - String upsertQuery; - - if (databaseType == DatabaseType.SQLITE) { - StringBuilder onConflictQuery = new StringBuilder(" ON CONFLICT ("); - List nonAutoIncrementPrimaryKeys = getNonAutoIncrementPrimaryKeys(); - - for (int i = 0; i < nonAutoIncrementPrimaryKeys.size(); i++) { - onConflictQuery.append(i > 0 ? ", " : "").append(nonAutoIncrementPrimaryKeys.get(i)); - } - onConflictQuery.append(") DO UPDATE SET "); - upsertQuery = insertQuery + valuesQuery.toString() + onConflictQuery + onUpdateQuery; - } else { - onUpdateQuery.insert(0, " ON DUPLICATE KEY UPDATE "); - upsertQuery = insertQuery + valuesQuery.toString() + onUpdateQuery; - } + String upsertQuery = insertQuery + valuesQuery.toString() + dialect.upsertConflictClause(schema) + onUpdateQuery; String finalQuery = databaseConfiguration.replacePrefix(upsertQuery); if (databaseConfiguration.isDebug()) { @@ -93,8 +76,7 @@ public int execute(DatabaseConnection databaseConnection, DatabaseConfiguration preparedStatement.setObject(index++, value); } - // Setting values for UPDATE part (only if not SQLite, since SQLite uses "excluded" keyword) - if (databaseType != DatabaseType.SQLITE) { + if (dialect.usesUpsertUpdateParameters(false)) { for (Object value : updateValues) { preparedStatement.setObject(index++, value); } @@ -107,39 +89,4 @@ public int execute(DatabaseConnection databaseConnection, DatabaseConfiguration } } - - private List getNonAutoIncrementPrimaryKeys() { - List conflictColumns = new ArrayList<>(); - - // First, try to find primary keys that are not auto-increment - List primaryKeys = schema.getPrimaryKeys(); - for (String primaryKey : primaryKeys) { - boolean isAutoIncrement = false; - for (ColumnDefinition col : schema.getColumns()) { - if (col.getSafeName().equals(primaryKey) && col.isAutoIncrement()) { - isAutoIncrement = true; - break; - } - } - if (!isAutoIncrement) { - conflictColumns.add(primaryKey); - } - } - - // If no non-auto-increment primary keys exist, look for UNIQUE columns - if (conflictColumns.isEmpty()) { - for (ColumnDefinition col : schema.getColumns()) { - // Check if column is unique and not auto-increment - if (col.isUnique() && !col.isAutoIncrement()) { - conflictColumns.add(col.getSafeName()); - } - } - } - - // If still no conflict columns found, throw error - if (conflictColumns.isEmpty()) { - throw new IllegalStateException("UPSERT requires at least one non-auto-increment primary key or unique constraint for SQLite"); - } - return conflictColumns; - } } diff --git a/src/test/java/fr/maxlego08/sarah/DialectTest.java b/src/test/java/fr/maxlego08/sarah/DialectTest.java new file mode 100644 index 0000000..1d0b133 --- /dev/null +++ b/src/test/java/fr/maxlego08/sarah/DialectTest.java @@ -0,0 +1,92 @@ +package fr.maxlego08.sarah; + +import fr.maxlego08.sarah.database.DatabaseType; +import fr.maxlego08.sarah.database.Schema; +import fr.maxlego08.sarah.dialect.SqlDialects; +import fr.maxlego08.sarah.conditions.ColumnDefinition; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class DialectTest { + + @Test + public void testDialectIdentifierQuoting() { + assertEquals("`user_name`", SqlDialects.from(DatabaseType.MYSQL).quoteIdentifier("user_name")); + assertEquals("`user_name`", SqlDialects.from(DatabaseType.MARIADB).quoteIdentifier("user_name")); + assertEquals("`user_name`", SqlDialects.from(DatabaseType.SQLITE).quoteIdentifier("user_name")); + assertEquals("\"user_name\"", SqlDialects.from(DatabaseType.POSTGRESQL).quoteIdentifier("user_name")); + } + + @Test + public void testQualifiedIdentifierQuoting() { + assertEquals("u.`name`", SqlDialects.from(DatabaseType.MYSQL).qualifyIdentifier("u", "name")); + assertEquals("u.\"name\"", SqlDialects.from(DatabaseType.POSTGRESQL).qualifyIdentifier("u", "name")); + assertEquals("\"name\"", SqlDialects.from(DatabaseType.POSTGRESQL).qualifyIdentifier(null, "name")); + } + + @Test + public void testNullDatabaseTypeFails() { + assertThrows(IllegalArgumentException.class, () -> SqlDialects.from(null)); + } + + @Test + public void testCreateTableSuffixes() { + assertEquals(" ENGINE=InnoDB DEFAULT CHARSET=utf8mb4", SqlDialects.from(DatabaseType.MYSQL).createTableSuffix()); + assertEquals(" ENGINE=InnoDB DEFAULT CHARSET=utf8mb4", SqlDialects.from(DatabaseType.MARIADB).createTableSuffix()); + assertEquals("", SqlDialects.from(DatabaseType.SQLITE).createTableSuffix()); + assertEquals("", SqlDialects.from(DatabaseType.POSTGRESQL).createTableSuffix()); + } + + @Test + public void testColumnTypesAndAutoIncrementKeywords() { + ColumnDefinition integerAutoIncrement = new ColumnDefinition("id", "INT").setAutoIncrement(true).setPrimaryKey(true); + ColumnDefinition decimal = new ColumnDefinition("amount", "DECIMAL").setLength(10).setDecimal(2); + ColumnDefinition enumColumn = new ColumnDefinition("state", "ENUM").setEnumValues("OPEN", "CLOSED"); + + assertEquals("AUTO_INCREMENT", SqlDialects.from(DatabaseType.MYSQL).autoIncrementKeyword(integerAutoIncrement)); + assertEquals("AUTO_INCREMENT", SqlDialects.from(DatabaseType.MARIADB).autoIncrementKeyword(integerAutoIncrement)); + assertEquals("", SqlDialects.from(DatabaseType.SQLITE).autoIncrementKeyword(integerAutoIncrement)); + assertEquals("GENERATED BY DEFAULT AS IDENTITY", SqlDialects.from(DatabaseType.POSTGRESQL).autoIncrementKeyword(integerAutoIncrement)); + + assertEquals("DECIMAL(10,2)", SqlDialects.from(DatabaseType.MYSQL).columnType(decimal)); + assertEquals("ENUM('OPEN', 'CLOSED')", SqlDialects.from(DatabaseType.MYSQL).enumColumnType(enumColumn)); + assertEquals("TEXT", SqlDialects.from(DatabaseType.SQLITE).enumColumnType(enumColumn)); + assertEquals("TEXT", SqlDialects.from(DatabaseType.POSTGRESQL).enumColumnType(enumColumn)); + assertEquals("INTEGER", SqlDialects.from(DatabaseType.SQLITE).columnType(integerAutoIncrement)); + } + + @Test + public void testUpdatedAtDefaults() { + assertEquals("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP", SqlDialects.from(DatabaseType.MYSQL).updatedAtDefaultValue()); + assertEquals("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP", SqlDialects.from(DatabaseType.MARIADB).updatedAtDefaultValue()); + assertEquals("CURRENT_TIMESTAMP", SqlDialects.from(DatabaseType.SQLITE).updatedAtDefaultValue()); + assertEquals("CURRENT_TIMESTAMP", SqlDialects.from(DatabaseType.POSTGRESQL).updatedAtDefaultValue()); + } + + @Test + public void testUpsertStrategies() { + Schema schema = SchemaBuilder.upsert("users", table -> { + table.autoIncrementBigInt("id"); + table.string("username", "sarah").unique(); + table.string("email", "sarah@example.com"); + }); + + assertEquals(" ON DUPLICATE KEY UPDATE ", SqlDialects.from(DatabaseType.MYSQL).upsertConflictClause(schema)); + assertEquals(" ON DUPLICATE KEY UPDATE ", SqlDialects.from(DatabaseType.MARIADB).upsertConflictClause(schema)); + assertEquals(" ON CONFLICT (`username`) DO UPDATE SET ", SqlDialects.from(DatabaseType.SQLITE).upsertConflictClause(schema)); + assertEquals(" ON CONFLICT (\"username\") DO UPDATE SET ", SqlDialects.from(DatabaseType.POSTGRESQL).upsertConflictClause(schema)); + + assertEquals("`email` = ?", SqlDialects.from(DatabaseType.MYSQL).upsertUpdateExpression("`email`", false)); + assertEquals("`email` = VALUES(`email`)", SqlDialects.from(DatabaseType.MYSQL).upsertUpdateExpression("`email`", true)); + assertEquals("\"email\" = excluded.\"email\"", SqlDialects.from(DatabaseType.POSTGRESQL).upsertUpdateExpression("\"email\"", false)); + + assertTrue(SqlDialects.from(DatabaseType.MYSQL).usesUpsertUpdateParameters(false)); + assertFalse(SqlDialects.from(DatabaseType.MYSQL).usesUpsertUpdateParameters(true)); + assertFalse(SqlDialects.from(DatabaseType.POSTGRESQL).usesUpsertUpdateParameters(false)); + assertFalse(SqlDialects.from(DatabaseType.SQLITE).usesUpsertUpdateParameters(true)); + } +} diff --git a/src/test/java/fr/maxlego08/sarah/PostgreSqlConditionRenderingTest.java b/src/test/java/fr/maxlego08/sarah/PostgreSqlConditionRenderingTest.java new file mode 100644 index 0000000..a0f41c9 --- /dev/null +++ b/src/test/java/fr/maxlego08/sarah/PostgreSqlConditionRenderingTest.java @@ -0,0 +1,63 @@ +package fr.maxlego08.sarah; + +import fr.maxlego08.sarah.conditions.JoinCondition; +import fr.maxlego08.sarah.conditions.OrderByCondition; +import fr.maxlego08.sarah.conditions.SelectCondition; +import fr.maxlego08.sarah.conditions.WhereCondition; +import fr.maxlego08.sarah.database.DatabaseType; +import fr.maxlego08.sarah.dialect.SqlDialect; +import fr.maxlego08.sarah.dialect.SqlDialects; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class PostgreSqlConditionRenderingTest { + + private final SqlDialect postgres = SqlDialects.from(DatabaseType.POSTGRESQL); + + @Test + public void testWhereConditionUsesDialectQuoting() { + WhereCondition condition = new WhereCondition("u", "name", "=", "Sarah"); + assertEquals("u.\"name\" = ?", condition.getCondition(postgres)); + } + + @Test + public void testWhereInUsesDialectQuoting() { + WhereCondition condition = new WhereCondition("u", "id", Arrays.asList("1", "2", "3")); + assertEquals("u.\"id\" IN (?,?,?)", condition.getCondition(postgres)); + } + + @Test + public void testSelectConditionUsesDialectQuoting() { + SelectCondition select = new SelectCondition("u", "name", "username", false, null); + assertEquals("u.\"name\" as username", select.getSelectColumn(postgres)); + } + + @Test + public void testSelectCoalesceUsesDialectQuoting() { + SelectCondition select = new SelectCondition("u", "name", "username", true, "'N/A'"); + assertEquals("COALESCE(u.\"name\", 'N/A') as username", select.getSelectColumn(postgres)); + } + + @Test + public void testJoinConditionUsesDialectQuoting() { + JoinCondition join = new JoinCondition( + JoinCondition.JoinType.LEFT, + "orders", + "o", + "user_id", + "users", + "id", + null + ); + assertEquals("LEFT JOIN orders AS o ON o.\"user_id\" = users.\"id\"", join.getJoinClause(postgres)); + } + + @Test + public void testOrderByConditionUsesDialectQuoting() { + OrderByCondition orderByCondition = new OrderByCondition(null, "created_at", true); + assertEquals("ORDER BY \"created_at\" DESC", orderByCondition.getOrderByClause(postgres)); + } +} diff --git a/src/test/java/fr/maxlego08/sarah/PostgreSqlConnectionTest.java b/src/test/java/fr/maxlego08/sarah/PostgreSqlConnectionTest.java new file mode 100644 index 0000000..e906d9b --- /dev/null +++ b/src/test/java/fr/maxlego08/sarah/PostgreSqlConnectionTest.java @@ -0,0 +1,39 @@ +package fr.maxlego08.sarah; + +import fr.maxlego08.sarah.database.DatabaseType; +import fr.maxlego08.sarah.dialect.SqlDialects; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class PostgreSqlConnectionTest { + + @Test + public void testPostgreSqlConfigurationFactory() { + DatabaseConfiguration configuration = DatabaseConfiguration.createPostgreSql("sarah", "secret", "localhost", "sarah_db"); + + assertEquals(DatabaseType.POSTGRESQL, configuration.getDatabaseType()); + assertEquals(5432, configuration.getPort()); + assertEquals("sarah", configuration.getUser()); + assertEquals("secret", configuration.getPassword()); + assertEquals("localhost", configuration.getHost()); + assertEquals("sarah_db", configuration.getDatabase()); + } + + @Test + public void testPostgreSqlJdbcUrlAndDriver() { + DatabaseConfiguration configuration = DatabaseConfiguration.createPostgreSql("sarah", "secret", 15432, "db.local", "sarah_db"); + + assertEquals("org.postgresql.Driver", SqlDialects.from(DatabaseType.POSTGRESQL).driverClassName()); + assertEquals("jdbc:postgresql://db.local:15432/sarah_db", SqlDialects.from(DatabaseType.POSTGRESQL).jdbcUrl(configuration)); + } + + @Test + public void testExistingJdbcUrlsStayStable() { + DatabaseConfiguration mysql = DatabaseConfiguration.create("u", "p", 3307, "host", "db"); + DatabaseConfiguration maria = DatabaseConfiguration.createMariaDb("u", "p", 3308, "host", "db"); + + assertEquals("jdbc:mysql://host:3307/db?allowMultiQueries=true", SqlDialects.from(DatabaseType.MYSQL).jdbcUrl(mysql)); + assertEquals("jdbc:mariadb://host:3308/db?allowMultiQueries=true", SqlDialects.from(DatabaseType.MARIADB).jdbcUrl(maria)); + } +} diff --git a/src/test/java/fr/maxlego08/sarah/PostgreSqlCreateRequestTest.java b/src/test/java/fr/maxlego08/sarah/PostgreSqlCreateRequestTest.java new file mode 100644 index 0000000..4ccdbdf --- /dev/null +++ b/src/test/java/fr/maxlego08/sarah/PostgreSqlCreateRequestTest.java @@ -0,0 +1,137 @@ +package fr.maxlego08.sarah; + +import fr.maxlego08.sarah.conditions.ColumnDefinition; +import fr.maxlego08.sarah.database.DatabaseType; +import fr.maxlego08.sarah.database.Schema; +import fr.maxlego08.sarah.dialect.SqlDialect; +import fr.maxlego08.sarah.dialect.SqlDialects; +import fr.maxlego08.sarah.logger.Logger; +import fr.maxlego08.sarah.requests.CreateRequest; +import org.junit.jupiter.api.Test; + +import java.lang.reflect.InvocationHandler; +import java.lang.reflect.Method; +import java.lang.reflect.Proxy; +import java.sql.Connection; +import java.sql.PreparedStatement; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class PostgreSqlCreateRequestTest { + + private final SqlDialect postgres = SqlDialects.from(DatabaseType.POSTGRESQL); + private final DatabaseConfiguration config = new DatabaseConfiguration("", "u", "p", 5432, "localhost", "db", false, DatabaseType.POSTGRESQL); + + @Test + public void testPostgreSqlAutoIncrementBigIntColumn() { + ColumnDefinition column = new ColumnDefinition("id", "BIGINT").setAutoIncrement(true); + column.setPrimaryKey(true); + + assertEquals("\"id\" BIGINT GENERATED BY DEFAULT AS IDENTITY NOT NULL", column.build(config, postgres)); + } + + @Test + public void testPostgreSqlEnumFallsBackToText() { + ColumnDefinition column = new ColumnDefinition("status", "ENUM").setEnumValues("OPEN", "CLOSED"); + assertEquals("\"status\" TEXT NOT NULL", column.build(config, postgres)); + } + + @Test + public void testPostgreSqlCreateRequestRendersWithDialect() throws Exception { + final Recording recording = new Recording(); + + DatabaseConnection connection = new DatabaseConnection(config, new NoopLogger()) { + @Override + public Connection connectToDatabase() { + return createConnectionProxy(recording); + } + }; + + Schema schema = SchemaBuilder.create(null, "users", table -> { + table.autoIncrementBigInt("id"); + table.string("username", 50).unique(); + table.bigInt("group_id"); + table.foreignKey("groups", "id", true); + }); + + new CreateRequest(schema).execute(connection, config, new NoopLogger()); + + assertTrue(recording.query.startsWith("CREATE TABLE IF NOT EXISTS \"users\"")); + assertTrue(recording.query.contains("\"id\" BIGINT GENERATED BY DEFAULT AS IDENTITY NOT NULL")); + assertTrue(recording.query.contains("PRIMARY KEY (\"id\")")); + assertTrue(recording.query.contains("FOREIGN KEY (\"group_id\") REFERENCES \"groups\"(\"id\") ON DELETE CASCADE")); + assertTrue(recording.query.endsWith(")")); + } + + @Test + public void testPostgreSqlCreateTableHasNoMysqlSuffix() { + assertEquals("", postgres.createTableSuffix()); + } + + @Test + public void testMysqlCreateTableSuffixStaysStable() { + assertEquals(" ENGINE=InnoDB DEFAULT CHARSET=utf8mb4", SqlDialects.from(DatabaseType.MYSQL).createTableSuffix()); + } + + private static Connection createConnectionProxy(final Recording recording) { + InvocationHandler handler = new InvocationHandler() { + @Override + public Object invoke(Object proxy, Method method, Object[] args) { + String methodName = method.getName(); + if ("prepareStatement".equals(methodName)) { + recording.query = (String) args[0]; + return createPreparedStatementProxy(); + } + if ("isClosed".equals(methodName)) { + return false; + } + if ("isValid".equals(methodName)) { + return true; + } + return defaultValue(method.getReturnType()); + } + }; + return (Connection) Proxy.newProxyInstance(Connection.class.getClassLoader(), new Class[]{Connection.class}, handler); + } + + private static PreparedStatement createPreparedStatementProxy() { + InvocationHandler handler = new InvocationHandler() { + @Override + public Object invoke(Object proxy, Method method, Object[] args) { + String methodName = method.getName(); + if ("execute".equals(methodName)) { + return true; + } + if ("getUpdateCount".equals(methodName)) { + return 0; + } + return defaultValue(method.getReturnType()); + } + }; + return (PreparedStatement) Proxy.newProxyInstance(PreparedStatement.class.getClassLoader(), new Class[]{PreparedStatement.class}, handler); + } + + private static Object defaultValue(Class returnType) { + if (returnType == Boolean.TYPE) { + return false; + } + if (returnType == Integer.TYPE) { + return 0; + } + if (returnType == Long.TYPE) { + return 0L; + } + return null; + } + + private static class Recording { + private String query; + } + + private static class NoopLogger implements Logger { + @Override + public void info(String string) { + } + } +} diff --git a/src/test/java/fr/maxlego08/sarah/PostgreSqlMigrationCasePreservationTest.java b/src/test/java/fr/maxlego08/sarah/PostgreSqlMigrationCasePreservationTest.java new file mode 100644 index 0000000..4391d3a --- /dev/null +++ b/src/test/java/fr/maxlego08/sarah/PostgreSqlMigrationCasePreservationTest.java @@ -0,0 +1,54 @@ +package fr.maxlego08.sarah; + +import fr.maxlego08.sarah.conditions.ColumnDefinition; +import fr.maxlego08.sarah.database.DatabaseType; +import fr.maxlego08.sarah.dialect.SqlDialect; +import fr.maxlego08.sarah.dialect.SqlDialects; +import fr.maxlego08.sarah.logger.Logger; +import org.junit.jupiter.api.Test; + +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.util.Arrays; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class PostgreSqlMigrationCasePreservationTest { + + @Test + public void testMissingColumnsUsesExactCaseForQuotedIdentifiers() throws Exception { + DatabaseConfiguration configuration = DatabaseConfiguration.createPostgreSql("u", "p", 5432, "localhost", "db"); + DatabaseConnection databaseConnection = mock(DatabaseConnection.class); + Connection sqlConnection = mock(Connection.class); + PreparedStatement preparedStatement = mock(PreparedStatement.class); + ResultSet resultSet = mock(ResultSet.class); + Logger logger = message -> { + }; + + when(databaseConnection.getConnection()).thenReturn(sqlConnection); + when(databaseConnection.getDatabaseConfiguration()).thenReturn(configuration); + when(sqlConnection.prepareStatement(anyString())).thenReturn(preparedStatement); + when(preparedStatement.executeQuery()).thenReturn(resultSet); + when(resultSet.next()).thenReturn(true); + when(resultSet.getLong(1)).thenReturn(1L); + + SqlDialect dialect = SqlDialects.from(DatabaseType.POSTGRESQL); + List missingColumns = dialect.missingColumns( + databaseConnection, + logger, + "UsersTable", + Arrays.asList(new ColumnDefinition("groupId", "VARCHAR")) + ); + + verify(preparedStatement).setString(1, "UsersTable"); + verify(preparedStatement).setString(2, "public"); + verify(preparedStatement).setString(3, "groupId"); + assertTrue(missingColumns.isEmpty()); + } +} diff --git a/src/test/java/fr/maxlego08/sarah/PostgreSqlMigrationDialectTest.java b/src/test/java/fr/maxlego08/sarah/PostgreSqlMigrationDialectTest.java new file mode 100644 index 0000000..e962d76 --- /dev/null +++ b/src/test/java/fr/maxlego08/sarah/PostgreSqlMigrationDialectTest.java @@ -0,0 +1,20 @@ +package fr.maxlego08.sarah; + +import fr.maxlego08.sarah.conditions.ColumnDefinition; +import fr.maxlego08.sarah.database.DatabaseType; +import fr.maxlego08.sarah.dialect.SqlDialect; +import fr.maxlego08.sarah.dialect.SqlDialects; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class PostgreSqlMigrationDialectTest { + + @Test + public void testPostgreSqlMissingColumnsRequiresConnection() { + SqlDialect dialect = SqlDialects.from(DatabaseType.POSTGRESQL); + assertThrows(NullPointerException.class, () -> dialect.missingColumns(null, null, "users", Arrays.asList(new ColumnDefinition("name", "VARCHAR")))); + } +} diff --git a/src/test/java/fr/maxlego08/sarah/PostgreSqlRequestDialectRegressionTest.java b/src/test/java/fr/maxlego08/sarah/PostgreSqlRequestDialectRegressionTest.java new file mode 100644 index 0000000..099dade --- /dev/null +++ b/src/test/java/fr/maxlego08/sarah/PostgreSqlRequestDialectRegressionTest.java @@ -0,0 +1,157 @@ +package fr.maxlego08.sarah; + +import fr.maxlego08.sarah.database.Schema; +import fr.maxlego08.sarah.logger.Logger; +import fr.maxlego08.sarah.requests.DeleteRequest; +import fr.maxlego08.sarah.requests.InsertAllRequest; +import fr.maxlego08.sarah.requests.InsertBatchRequest; +import fr.maxlego08.sarah.requests.UpdateBatchRequest; +import fr.maxlego08.sarah.requests.UpdateRequest; +import org.junit.jupiter.api.Test; + +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.util.Arrays; +import java.util.concurrent.atomic.AtomicReference; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class PostgreSqlRequestDialectRegressionTest { + + private final Logger logger = message -> { + }; + + @Test + public void testDeleteRequestUsesPostgreSqlWhereQuoting() throws Exception { + DatabaseConfiguration configuration = DatabaseConfiguration.createPostgreSql("u", "p", 5432, "localhost", "db"); + DatabaseConnection databaseConnection = mock(DatabaseConnection.class); + Connection sqlConnection = mock(Connection.class); + PreparedStatement preparedStatement = mock(PreparedStatement.class); + AtomicReference capturedSql = new AtomicReference(); + + when(databaseConnection.getConnection()).thenReturn(sqlConnection); + when(sqlConnection.prepareStatement(anyString())).thenAnswer(invocation -> { + capturedSql.set(invocation.getArgument(0)); + return preparedStatement; + }); + when(preparedStatement.executeUpdate()).thenReturn(1); + + Schema schema = SchemaBuilder.delete("users"); + schema.where("email", "sarah@example.com"); + + new DeleteRequest(schema).execute(databaseConnection, configuration, logger); + + assertEquals("DELETE FROM \"users\" WHERE \"email\" = ?", capturedSql.get()); + verify(preparedStatement).setObject(1, "sarah@example.com"); + } + + @Test + public void testUpdateRequestUsesPostgreSqlWhereQuoting() throws Exception { + DatabaseConfiguration configuration = DatabaseConfiguration.createPostgreSql("u", "p", 5432, "localhost", "db"); + DatabaseConnection databaseConnection = mock(DatabaseConnection.class); + Connection sqlConnection = mock(Connection.class); + PreparedStatement preparedStatement = mock(PreparedStatement.class); + AtomicReference capturedSql = new AtomicReference(); + + when(databaseConnection.getConnection()).thenReturn(sqlConnection); + when(sqlConnection.prepareStatement(anyString())).thenAnswer(invocation -> { + capturedSql.set(invocation.getArgument(0)); + return preparedStatement; + }); + when(preparedStatement.executeUpdate()).thenReturn(1); + + Schema schema = SchemaBuilder.update("users", builder -> { + builder.string("name", "Sarah"); + builder.where("email", "sarah@example.com"); + }); + + new UpdateRequest(schema).execute(databaseConnection, configuration, logger); + + assertEquals("UPDATE \"users\" SET \"name\" = ? WHERE \"email\" = ?", capturedSql.get()); + verify(preparedStatement).setObject(1, "Sarah"); + verify(preparedStatement).setObject(2, "sarah@example.com"); + } + + @Test + public void testInsertBatchRequestUsesPostgreSqlIdentifierQuoting() throws Exception { + DatabaseConfiguration configuration = DatabaseConfiguration.createPostgreSql("u", "p", 5432, "localhost", "db"); + DatabaseConnection databaseConnection = mock(DatabaseConnection.class); + Connection sqlConnection = mock(Connection.class); + PreparedStatement preparedStatement = mock(PreparedStatement.class); + ResultSet generatedKeys = mock(ResultSet.class); + AtomicReference capturedSql = new AtomicReference(); + + when(databaseConnection.getConnection()).thenReturn(sqlConnection); + when(sqlConnection.prepareStatement(anyString(), anyInt())).thenAnswer(invocation -> { + capturedSql.set(invocation.getArgument(0)); + return preparedStatement; + }); + when(preparedStatement.executeUpdate()).thenReturn(2); + when(preparedStatement.getGeneratedKeys()).thenReturn(generatedKeys); + when(generatedKeys.next()).thenReturn(false); + + Schema first = SchemaBuilder.insert("users", builder -> builder.string("name", "Sarah")); + Schema second = SchemaBuilder.insert("users", builder -> builder.string("name", "Max")); + + new InsertBatchRequest(Arrays.asList(first, second)).execute(databaseConnection, configuration, logger); + + assertEquals("INSERT INTO \"users\" (\"name\") VALUES (?), (?)", capturedSql.get()); + } + + @Test + public void testInsertAllRequestUsesPostgreSqlIdentifierQuoting() throws Exception { + DatabaseConfiguration configuration = DatabaseConfiguration.createPostgreSql("u", "p", 5432, "localhost", "db"); + DatabaseConnection databaseConnection = mock(DatabaseConnection.class); + Connection sqlConnection = mock(Connection.class); + PreparedStatement preparedStatement = mock(PreparedStatement.class); + AtomicReference capturedSql = new AtomicReference(); + + when(databaseConnection.getConnection()).thenReturn(sqlConnection); + when(sqlConnection.prepareStatement(anyString())).thenAnswer(invocation -> { + capturedSql.set(invocation.getArgument(0)); + return preparedStatement; + }); + when(preparedStatement.executeUpdate()).thenReturn(1); + + Schema source = SchemaBuilder.create(null, "users", builder -> builder.string("name", 32)); + new InsertAllRequest(source, "users_tmp").execute(databaseConnection, configuration, logger); + + assertEquals("INSERT INTO \"users_tmp\" (\"name\") SELECT \"name\" FROM \"users\"", capturedSql.get()); + } + + @Test + public void testUpdateBatchRequestUsesPostgreSqlWhereAndColumnQuoting() throws Exception { + DatabaseConfiguration configuration = DatabaseConfiguration.createPostgreSql("u", "p", 5432, "localhost", "db"); + DatabaseConnection databaseConnection = mock(DatabaseConnection.class); + Connection sqlConnection = mock(Connection.class); + PreparedStatement preparedStatement = mock(PreparedStatement.class); + AtomicReference capturedSql = new AtomicReference(); + + when(databaseConnection.getConnection()).thenReturn(sqlConnection); + when(sqlConnection.prepareStatement(anyString())).thenAnswer(invocation -> { + capturedSql.set(invocation.getArgument(0)); + return preparedStatement; + }); + when(sqlConnection.getAutoCommit()).thenReturn(true); + when(preparedStatement.executeBatch()).thenReturn(new int[]{1, 1}); + + Schema first = SchemaBuilder.update("users", builder -> { + builder.string("name", "Sarah"); + builder.where("email", "sarah@example.com"); + }); + Schema second = SchemaBuilder.update("users", builder -> { + builder.string("name", "Max"); + builder.where("email", "max@example.com"); + }); + + new UpdateBatchRequest(Arrays.asList(first, second)).execute(databaseConnection, configuration, logger); + + assertEquals("UPDATE \"users\" SET \"name\" = ? WHERE \"email\" = ?", capturedSql.get()); + } +} diff --git a/src/test/java/fr/maxlego08/sarah/PostgreSqlUpsertRequestTest.java b/src/test/java/fr/maxlego08/sarah/PostgreSqlUpsertRequestTest.java new file mode 100644 index 0000000..fb714bc --- /dev/null +++ b/src/test/java/fr/maxlego08/sarah/PostgreSqlUpsertRequestTest.java @@ -0,0 +1,48 @@ +package fr.maxlego08.sarah; + +import fr.maxlego08.sarah.database.DatabaseType; +import fr.maxlego08.sarah.database.Schema; +import fr.maxlego08.sarah.dialect.SqlDialect; +import fr.maxlego08.sarah.dialect.SqlDialects; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; + +public class PostgreSqlUpsertRequestTest { + + @Test + public void testPostgreSqlConflictTargetSkipsAutoIncrementPrimaryKey() { + SqlDialect dialect = SqlDialects.from(DatabaseType.POSTGRESQL); + Schema schema = SchemaBuilder.upsert("users", table -> { + table.autoIncrementBigInt("id"); + table.string("username", "sarah").unique(); + table.string("email", "sarah@example.com"); + }); + + String conflictClause = dialect.upsertConflictClause(schema); + assertEquals(" ON CONFLICT (\"username\") DO UPDATE SET ", conflictClause); + } + + @Test + public void testPostgreSqlUpsertExpressionUsesExcluded() { + SqlDialect dialect = SqlDialects.from(DatabaseType.POSTGRESQL); + + assertEquals("\"email\" = excluded.\"email\"", + dialect.upsertUpdateExpression("\"email\"", false)); + assertEquals("\"email\" = excluded.\"email\"", + dialect.upsertUpdateExpression("\"email\"", true)); + assertFalse(dialect.usesUpsertUpdateParameters(false)); + assertFalse(dialect.usesUpsertUpdateParameters(true)); + } + + @Test + public void testMysqlSingleAndBatchUpsertDiffer() { + SqlDialect dialect = SqlDialects.from(DatabaseType.MYSQL); + + assertEquals("`email` = ?", + dialect.upsertUpdateExpression("`email`", false)); + assertEquals("`email` = VALUES(`email`)", + dialect.upsertUpdateExpression("`email`", true)); + } +} From 2381c5551ad701bfd9955a5ea1c47da2f1b81d36 Mon Sep 17 00:00:00 2001 From: hamdan <94024788+hmdnnrmn@users.noreply.github.com> Date: Tue, 26 May 2026 10:26:08 +0000 Subject: [PATCH 2/6] fix: map MySQL-specific types for PostgreSQL dialect - Override columnType() in PostgreSqlDialect to remap LONGTEXT, MEDIUMTEXT, TINYTEXT -> TEXT and BLOB, LONGBLOB, MEDIUMBLOB, TINYBLOB -> BYTEA for PostgreSQL compatibility - Gate MySQL/MariaDB-specific HikariCP data source properties (useSSL, useUnicode, characterEncoding) to MySQL/MariaDB only - Add PostgreSQL-specific socketTimeout (in seconds, not ms) - Add PostgreSqlTypeMappingTest with 9 test cases --- .../sarah/HikariDatabaseConnection.java | 15 +++- .../sarah/dialect/PostgreSqlDialect.java | 28 +++++++ .../sarah/PostgreSqlTypeMappingTest.java | 80 +++++++++++++++++++ 3 files changed, 119 insertions(+), 4 deletions(-) create mode 100644 src/test/java/fr/maxlego08/sarah/PostgreSqlTypeMappingTest.java diff --git a/src/main/java/fr/maxlego08/sarah/HikariDatabaseConnection.java b/src/main/java/fr/maxlego08/sarah/HikariDatabaseConnection.java index 12c5b3d..4e535ea 100644 --- a/src/main/java/fr/maxlego08/sarah/HikariDatabaseConnection.java +++ b/src/main/java/fr/maxlego08/sarah/HikariDatabaseConnection.java @@ -70,10 +70,13 @@ private void initializeDataSource() { config.setLeakDetectionThreshold(LEAK_DETECTION_THRESHOLD); Map commonProps = new HashMap<>(); - commonProps.put("useSSL", "false"); - commonProps.put("useUnicode", "true"); - commonProps.put("characterEncoding", "utf8"); - commonProps.put("socketTimeout", String.valueOf(TimeUnit.SECONDS.toMillis(30))); + + if (databaseType == DatabaseType.MYSQL || databaseType == DatabaseType.MARIADB) { + commonProps.put("useSSL", "false"); + commonProps.put("useUnicode", "true"); + commonProps.put("characterEncoding", "utf8"); + commonProps.put("socketTimeout", String.valueOf(TimeUnit.SECONDS.toMillis(30))); + } if (databaseType == DatabaseType.MYSQL) { commonProps.put("cachePrepStmts", "true"); @@ -90,6 +93,10 @@ private void initializeDataSource() { commonProps.put("cacheCallableStmts", "true"); } + if (databaseType == DatabaseType.POSTGRESQL) { + commonProps.put("socketTimeout", "30"); + } + for (Map.Entry e : commonProps.entrySet()) { config.addDataSourceProperty(e.getKey(), e.getValue()); } diff --git a/src/main/java/fr/maxlego08/sarah/dialect/PostgreSqlDialect.java b/src/main/java/fr/maxlego08/sarah/dialect/PostgreSqlDialect.java index ab13ff6..b718baf 100644 --- a/src/main/java/fr/maxlego08/sarah/dialect/PostgreSqlDialect.java +++ b/src/main/java/fr/maxlego08/sarah/dialect/PostgreSqlDialect.java @@ -38,6 +38,34 @@ public String enumColumnType(ColumnDefinition column) { return "TEXT"; } + @Override + public String columnType(ColumnDefinition column) { + String baseType = column.getType(); + if (baseType != null) { + String mapped = mapType(baseType); + if (!mapped.equals(baseType)) { + column.setType(mapped); + } + } + return super.columnType(column); + } + + private String mapType(String type) { + switch (type.toUpperCase(java.util.Locale.ROOT)) { + case "LONGTEXT": + case "MEDIUMTEXT": + case "TINYTEXT": + return "TEXT"; + case "BLOB": + case "LONGBLOB": + case "MEDIUMBLOB": + case "TINYBLOB": + return "BYTEA"; + default: + return type; + } + } + @Override public List missingColumns(DatabaseConnection connection, Logger logger, String tableName, List expectedColumns) { Objects.requireNonNull(connection, "connection"); diff --git a/src/test/java/fr/maxlego08/sarah/PostgreSqlTypeMappingTest.java b/src/test/java/fr/maxlego08/sarah/PostgreSqlTypeMappingTest.java new file mode 100644 index 0000000..1b9e400 --- /dev/null +++ b/src/test/java/fr/maxlego08/sarah/PostgreSqlTypeMappingTest.java @@ -0,0 +1,80 @@ +package fr.maxlego08.sarah; + +import fr.maxlego08.sarah.conditions.ColumnDefinition; +import fr.maxlego08.sarah.database.DatabaseType; +import fr.maxlego08.sarah.dialect.SqlDialect; +import fr.maxlego08.sarah.dialect.SqlDialects; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class PostgreSqlTypeMappingTest { + + private final SqlDialect postgres = SqlDialects.from(DatabaseType.POSTGRESQL); + private final DatabaseConfiguration config = new DatabaseConfiguration("", "u", "p", 5432, "localhost", "db", false, DatabaseType.POSTGRESQL); + + @Test + public void testLongTextMappedToText() { + ColumnDefinition column = new ColumnDefinition("data", "LONGTEXT"); + String result = column.build(config, postgres); + assertEquals("\"data\" TEXT NOT NULL", result); + } + + @Test + public void testMediumTextMappedToText() { + ColumnDefinition column = new ColumnDefinition("content", "MEDIUMTEXT"); + String result = column.build(config, postgres); + assertEquals("\"content\" TEXT NOT NULL", result); + } + + @Test + public void testTinyTextMappedToText() { + ColumnDefinition column = new ColumnDefinition("note", "TINYTEXT"); + String result = column.build(config, postgres); + assertEquals("\"note\" TEXT NOT NULL", result); + } + + @Test + public void testBlobMappedToBytea() { + ColumnDefinition column = new ColumnDefinition("payload", "BLOB"); + String result = column.build(config, postgres); + assertEquals("\"payload\" BYTEA NOT NULL", result); + } + + @Test + public void testLongBlobMappedToBytea() { + ColumnDefinition column = new ColumnDefinition("payload", "LONGBLOB"); + String result = column.build(config, postgres); + assertEquals("\"payload\" BYTEA NOT NULL", result); + } + + @Test + public void testTextStaysText() { + ColumnDefinition column = new ColumnDefinition("body", "TEXT"); + String result = column.build(config, postgres); + assertEquals("\"body\" TEXT NOT NULL", result); + } + + @Test + public void testVarcharUnchanged() { + ColumnDefinition column = new ColumnDefinition("name", "VARCHAR").setLength(255); + String result = column.build(config, postgres); + assertEquals("\"name\" VARCHAR(255) NOT NULL", result); + } + + @Test + public void testIntegerUnchanged() { + ColumnDefinition column = new ColumnDefinition("count", "INTEGER"); + String result = column.build(config, postgres); + assertEquals("\"count\" INTEGER NOT NULL", result); + } + + @Test + public void testMysqlLongTextIsUnchangedOnMysql() { + SqlDialect mysql = SqlDialects.from(DatabaseType.MYSQL); + DatabaseConfiguration mysqlConfig = new DatabaseConfiguration("", "u", "p", 3306, "localhost", "db", false, DatabaseType.MYSQL); + ColumnDefinition column = new ColumnDefinition("data", "LONGTEXT"); + String result = column.build(mysqlConfig, mysql); + assertEquals("`data` LONGTEXT NOT NULL", result); + } +} From 48de7fff0acdac110d17b23487b161b744b975f9 Mon Sep 17 00:00:00 2001 From: hamdan Date: Sat, 30 May 2026 11:42:45 +0800 Subject: [PATCH 3/6] Fix SELECT quoting and SQLite/Hikari guard --- .../sarah/HikariDatabaseConnection.java | 4 ++ .../fr/maxlego08/sarah/SchemaBuilder.java | 41 +++++++++++++++++-- .../sarah/conditions/WhereCondition.java | 4 +- .../fr/maxlego08/sarah/SelectRequestTest.java | 23 ++++++++++- 4 files changed, 66 insertions(+), 6 deletions(-) diff --git a/src/main/java/fr/maxlego08/sarah/HikariDatabaseConnection.java b/src/main/java/fr/maxlego08/sarah/HikariDatabaseConnection.java index 4e535ea..93236e5 100644 --- a/src/main/java/fr/maxlego08/sarah/HikariDatabaseConnection.java +++ b/src/main/java/fr/maxlego08/sarah/HikariDatabaseConnection.java @@ -42,6 +42,10 @@ private void initializeDataSource() { DatabaseType databaseType = databaseConfiguration.getDatabaseType(); SqlDialect dialect = SqlDialects.from(databaseType); + if (databaseType == DatabaseType.SQLITE) { + throw new UnsupportedOperationException("HikariDatabaseConnection does not support SQLITE. Use SqliteConnection for file-based SQLite databases."); + } + // URL + Driver config.setJdbcUrl(dialect.jdbcUrl(databaseConfiguration)); config.setDriverClassName(dialect.driverClassName()); diff --git a/src/main/java/fr/maxlego08/sarah/SchemaBuilder.java b/src/main/java/fr/maxlego08/sarah/SchemaBuilder.java index 37d654d..6108142 100644 --- a/src/main/java/fr/maxlego08/sarah/SchemaBuilder.java +++ b/src/main/java/fr/maxlego08/sarah/SchemaBuilder.java @@ -504,7 +504,7 @@ public void whereConditions(StringBuilder sql, SqlDialect dialect) { @Override public long executeSelectCount(DatabaseConnection databaseConnection, Logger logger) throws SQLException { SqlDialect dialect = SqlDialects.from(databaseConnection.getDatabaseConfiguration().getDatabaseType()); - StringBuilder selectQuery = new StringBuilder("SELECT COUNT(*) FROM " + tableName); + StringBuilder selectQuery = new StringBuilder("SELECT COUNT(*) FROM " + quoteTableReference(dialect, tableName)); this.whereConditions(selectQuery, dialect); String finalQuery = databaseConnection.getDatabaseConfiguration().replacePrefix(selectQuery.toString()); @@ -543,9 +543,9 @@ public List> executeSelect(DatabaseConnection databaseConnec StringBuilder selectQuery; if (this.isDistinct) { - selectQuery = new StringBuilder("SELECT DISTINCT " + selectedValues + " FROM " + this.tableName); + selectQuery = new StringBuilder("SELECT DISTINCT " + selectedValues + " FROM " + quoteTableReference(dialect, this.tableName)); } else { - selectQuery = new StringBuilder("SELECT " + selectedValues + " FROM " + this.tableName); + selectQuery = new StringBuilder("SELECT " + selectedValues + " FROM " + quoteTableReference(dialect, this.tableName)); } if (!this.joinConditions.isEmpty()) { @@ -589,6 +589,41 @@ public List> executeSelect(DatabaseConnection databaseConnec return results; } + private String quoteTableReference(SqlDialect dialect, String tableReference) { + if (tableReference == null) { + throw new IllegalArgumentException("Table reference cannot be null"); + } + + String trimmed = tableReference.trim(); + if (trimmed.isEmpty()) { + throw new IllegalArgumentException("Table reference cannot be empty"); + } + + int firstWhitespace = -1; + for (int i = 0; i < trimmed.length(); i++) { + if (Character.isWhitespace(trimmed.charAt(i))) { + firstWhitespace = i; + break; + } + } + + String base = firstWhitespace == -1 ? trimmed : trimmed.substring(0, firstWhitespace); + String remainder = firstWhitespace == -1 ? "" : trimmed.substring(firstWhitespace).trim(); + + String quotedBase; + if (base.indexOf('.') != -1) { + String[] parts = base.split("\\."); + quotedBase = Arrays.stream(parts) + .filter(part -> part != null && !part.isEmpty()) + .map(part -> dialect.quoteIdentifier(part)) + .collect(Collectors.joining(".")); + } else { + quotedBase = dialect.quoteIdentifier(base); + } + + return remainder.isEmpty() ? quotedBase : quotedBase + " " + remainder; + } + @Override public void applyWhereConditions(PreparedStatement preparedStatement, int index) throws SQLException { for (WhereCondition condition : this.whereConditions) { diff --git a/src/main/java/fr/maxlego08/sarah/conditions/WhereCondition.java b/src/main/java/fr/maxlego08/sarah/conditions/WhereCondition.java index 1d76458..40ec4d1 100644 --- a/src/main/java/fr/maxlego08/sarah/conditions/WhereCondition.java +++ b/src/main/java/fr/maxlego08/sarah/conditions/WhereCondition.java @@ -44,8 +44,8 @@ public WhereCondition(String column, WhereAction whereAction) { } public String getCondition() { - if (this.whereAction == WhereAction.IS_NOT_NULL) return this.column + " IS NOT NULL"; - if (this.whereAction == WhereAction.IS_NULL) return this.column + " IS NULL"; + if (this.whereAction == WhereAction.IS_NOT_NULL) return this.legacyQualifiedColumn() + " IS NOT NULL"; + if (this.whereAction == WhereAction.IS_NULL) return this.legacyQualifiedColumn() + " IS NULL"; if (this.whereAction == WhereAction.IN) { return this.legacyQualifiedColumn() + " IN (" + values.stream().map(id -> "?").collect(Collectors.joining(",")) + ")"; } diff --git a/src/test/java/fr/maxlego08/sarah/SelectRequestTest.java b/src/test/java/fr/maxlego08/sarah/SelectRequestTest.java index e7470a6..bb0f9b6 100644 --- a/src/test/java/fr/maxlego08/sarah/SelectRequestTest.java +++ b/src/test/java/fr/maxlego08/sarah/SelectRequestTest.java @@ -72,6 +72,20 @@ public void testSelectAll() { assertEquals(3, results.size()); } + @Test + public void testSelectAllWithAliasInTableName() { + List> results = requestHelper.select("test_users u", schema -> {}); + + assertEquals(3, results.size()); + } + + @Test + public void testSelectAllWithSchemaQualifiedTableName() { + List> results = requestHelper.select("main.test_users", schema -> {}); + + assertEquals(3, results.size()); + } + @Test public void testSelectWithWhere() { List> results = requestHelper.select("test_users", schema -> { @@ -144,6 +158,13 @@ public void testSelectCount() { assertEquals(3, count); } + @Test + public void testSelectCountWithAliasInTableName() { + long count = requestHelper.count("test_users u", schema -> {}); + + assertEquals(3, count); + } + @Test public void testSelectCountWithWhere() { long count = requestHelper.count("test_users", schema -> { @@ -203,4 +224,4 @@ public void testSelectNoMatches() { assertEquals(0, results.size()); } -} \ No newline at end of file +} From 938095581cc1946b2af555c4964406a42dac2bbc Mon Sep 17 00:00:00 2001 From: hamdan Date: Sat, 30 May 2026 12:07:19 +0800 Subject: [PATCH 4/6] Fix PostgreSQL schema-aware migrations and quoting --- .../fr/maxlego08/sarah/SchemaBuilder.java | 41 +------- .../conditions/ForeignKeyDefinition.java | 2 +- .../sarah/dialect/AbstractSqlDialect.java | 46 +++++++++ .../sarah/dialect/PostgreSqlDialect.java | 93 +++++++++++++++---- .../maxlego08/sarah/dialect/SqlDialect.java | 15 +++ .../sarah/requests/AlterRequest.java | 2 +- .../sarah/requests/CreateIndexRequest.java | 2 +- .../sarah/requests/CreateRequest.java | 2 +- .../sarah/requests/DeleteRequest.java | 2 +- .../sarah/requests/DropTableRequest.java | 2 +- .../sarah/requests/InsertAllRequest.java | 4 +- .../sarah/requests/InsertBatchRequest.java | 2 +- .../sarah/requests/InsertRequest.java | 2 +- .../sarah/requests/RenameExecutor.java | 4 +- .../sarah/requests/UpdateBatchRequest.java | 2 +- .../sarah/requests/UpdateRequest.java | 2 +- .../sarah/requests/UpsertBatchRequest.java | 2 +- .../sarah/requests/UpsertRequest.java | 2 +- .../sarah/PostgreSqlMigrationDialectTest.java | 55 +++++++++++ ...ostgreSqlRequestDialectRegressionTest.java | 48 ++++++++++ .../sarah/PostgreSqlTypeMappingTest.java | 7 ++ 21 files changed, 267 insertions(+), 70 deletions(-) diff --git a/src/main/java/fr/maxlego08/sarah/SchemaBuilder.java b/src/main/java/fr/maxlego08/sarah/SchemaBuilder.java index 6108142..3beece0 100644 --- a/src/main/java/fr/maxlego08/sarah/SchemaBuilder.java +++ b/src/main/java/fr/maxlego08/sarah/SchemaBuilder.java @@ -504,7 +504,7 @@ public void whereConditions(StringBuilder sql, SqlDialect dialect) { @Override public long executeSelectCount(DatabaseConnection databaseConnection, Logger logger) throws SQLException { SqlDialect dialect = SqlDialects.from(databaseConnection.getDatabaseConfiguration().getDatabaseType()); - StringBuilder selectQuery = new StringBuilder("SELECT COUNT(*) FROM " + quoteTableReference(dialect, tableName)); + StringBuilder selectQuery = new StringBuilder("SELECT COUNT(*) FROM " + dialect.quoteTableReference(tableName)); this.whereConditions(selectQuery, dialect); String finalQuery = databaseConnection.getDatabaseConfiguration().replacePrefix(selectQuery.toString()); @@ -543,9 +543,9 @@ public List> executeSelect(DatabaseConnection databaseConnec StringBuilder selectQuery; if (this.isDistinct) { - selectQuery = new StringBuilder("SELECT DISTINCT " + selectedValues + " FROM " + quoteTableReference(dialect, this.tableName)); + selectQuery = new StringBuilder("SELECT DISTINCT " + selectedValues + " FROM " + dialect.quoteTableReference(this.tableName)); } else { - selectQuery = new StringBuilder("SELECT " + selectedValues + " FROM " + quoteTableReference(dialect, this.tableName)); + selectQuery = new StringBuilder("SELECT " + selectedValues + " FROM " + dialect.quoteTableReference(this.tableName)); } if (!this.joinConditions.isEmpty()) { @@ -589,41 +589,6 @@ public List> executeSelect(DatabaseConnection databaseConnec return results; } - private String quoteTableReference(SqlDialect dialect, String tableReference) { - if (tableReference == null) { - throw new IllegalArgumentException("Table reference cannot be null"); - } - - String trimmed = tableReference.trim(); - if (trimmed.isEmpty()) { - throw new IllegalArgumentException("Table reference cannot be empty"); - } - - int firstWhitespace = -1; - for (int i = 0; i < trimmed.length(); i++) { - if (Character.isWhitespace(trimmed.charAt(i))) { - firstWhitespace = i; - break; - } - } - - String base = firstWhitespace == -1 ? trimmed : trimmed.substring(0, firstWhitespace); - String remainder = firstWhitespace == -1 ? "" : trimmed.substring(firstWhitespace).trim(); - - String quotedBase; - if (base.indexOf('.') != -1) { - String[] parts = base.split("\\."); - quotedBase = Arrays.stream(parts) - .filter(part -> part != null && !part.isEmpty()) - .map(part -> dialect.quoteIdentifier(part)) - .collect(Collectors.joining(".")); - } else { - quotedBase = dialect.quoteIdentifier(base); - } - - return remainder.isEmpty() ? quotedBase : quotedBase + " " + remainder; - } - @Override public void applyWhereConditions(PreparedStatement preparedStatement, int index) throws SQLException { for (WhereCondition condition : this.whereConditions) { diff --git a/src/main/java/fr/maxlego08/sarah/conditions/ForeignKeyDefinition.java b/src/main/java/fr/maxlego08/sarah/conditions/ForeignKeyDefinition.java index 5d98cf9..da3b7cf 100644 --- a/src/main/java/fr/maxlego08/sarah/conditions/ForeignKeyDefinition.java +++ b/src/main/java/fr/maxlego08/sarah/conditions/ForeignKeyDefinition.java @@ -34,7 +34,7 @@ public boolean isCascade() { public String render(SqlDialect dialect) { return "FOREIGN KEY (" + dialect.quoteIdentifier(sourceColumn) + ") REFERENCES " + - dialect.quoteIdentifier(referenceTable) + "(" + dialect.quoteIdentifier(referenceColumn) + ")" + + dialect.quoteTableReference(referenceTable) + "(" + dialect.quoteIdentifier(referenceColumn) + ")" + (cascade ? " ON DELETE CASCADE" : ""); } } diff --git a/src/main/java/fr/maxlego08/sarah/dialect/AbstractSqlDialect.java b/src/main/java/fr/maxlego08/sarah/dialect/AbstractSqlDialect.java index 320dfa1..ed3fad3 100644 --- a/src/main/java/fr/maxlego08/sarah/dialect/AbstractSqlDialect.java +++ b/src/main/java/fr/maxlego08/sarah/dialect/AbstractSqlDialect.java @@ -40,6 +40,46 @@ public String quoteIdentifier(String name) { return quote + name + quote; } + @Override + public String quoteTableReference(String tableReference) { + if (tableReference == null) { + throw new IllegalArgumentException("Table reference cannot be null"); + } + + String trimmed = tableReference.trim(); + if (trimmed.isEmpty()) { + throw new IllegalArgumentException("Table reference cannot be empty"); + } + + int firstWhitespace = -1; + for (int i = 0; i < trimmed.length(); i++) { + if (Character.isWhitespace(trimmed.charAt(i))) { + firstWhitespace = i; + break; + } + } + + String base = firstWhitespace == -1 ? trimmed : trimmed.substring(0, firstWhitespace); + String remainder = firstWhitespace == -1 ? "" : trimmed.substring(firstWhitespace).trim(); + + String quotedBase; + int dotIndex = base.indexOf('.'); + if (dotIndex != -1) { + String[] parts = base.split("\\."); + List quotedParts = new ArrayList(parts.length); + for (String part : parts) { + if (part != null && !part.isEmpty()) { + quotedParts.add(quoteIdentifier(part)); + } + } + quotedBase = String.join(".", quotedParts); + } else { + quotedBase = quoteIdentifier(base); + } + + return remainder.isEmpty() ? quotedBase : quotedBase + " " + remainder; + } + @Override public String qualifyIdentifier(String prefix, String column) { if (prefix == null || prefix.trim().isEmpty()) { @@ -76,6 +116,8 @@ public String columnType(ColumnDefinition column) { baseType = "TEXT"; } + baseType = mapColumnType(column, baseType); + if (useIntegerTypeForAutoIncrementPrimaryKey(column) && isIntegerType(baseType)) { baseType = "INTEGER"; } @@ -92,6 +134,10 @@ public String columnType(ColumnDefinition column) { return baseType; } + protected String mapColumnType(ColumnDefinition column, String baseType) { + return baseType; + } + @Override public String autoIncrementKeyword(ColumnDefinition column) { return ""; diff --git a/src/main/java/fr/maxlego08/sarah/dialect/PostgreSqlDialect.java b/src/main/java/fr/maxlego08/sarah/dialect/PostgreSqlDialect.java index b718baf..a8272a5 100644 --- a/src/main/java/fr/maxlego08/sarah/dialect/PostgreSqlDialect.java +++ b/src/main/java/fr/maxlego08/sarah/dialect/PostgreSqlDialect.java @@ -7,9 +7,12 @@ import fr.maxlego08.sarah.logger.Logger; import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; import java.sql.SQLException; import java.util.ArrayList; import java.util.List; +import java.util.Locale; import java.util.Objects; public class PostgreSqlDialect extends AbstractSqlDialect { @@ -39,19 +42,8 @@ public String enumColumnType(ColumnDefinition column) { } @Override - public String columnType(ColumnDefinition column) { - String baseType = column.getType(); - if (baseType != null) { - String mapped = mapType(baseType); - if (!mapped.equals(baseType)) { - column.setType(mapped); - } - } - return super.columnType(column); - } - - private String mapType(String type) { - switch (type.toUpperCase(java.util.Locale.ROOT)) { + protected String mapColumnType(ColumnDefinition column, String baseType) { + switch (baseType.toUpperCase(Locale.ROOT)) { case "LONGTEXT": case "MEDIUMTEXT": case "TINYTEXT": @@ -62,7 +54,7 @@ private String mapType(String type) { case "TINYBLOB": return "BYTEA"; default: - return type; + return baseType; } } @@ -75,13 +67,14 @@ public List missingColumns(DatabaseConnection connection, Logg List missing = new ArrayList(); try (Connection sqlConnection = connection.getConnection()) { + TableRef ref = resolveTableRef(sqlConnection, tableName); for (ColumnDefinition column : expectedColumns) { String columnName = column.getName(); long count = countExistingColumn( sqlConnection, "SELECT COUNT(*) FROM information_schema.columns WHERE table_name = ? AND table_schema = ? AND column_name = ?", - tableName, - "public", + ref.table, + ref.schema, columnName ); if (count == 0) { @@ -97,4 +90,72 @@ public List missingColumns(DatabaseConnection connection, Logg return missing; } + + private TableRef resolveTableRef(Connection sqlConnection, String tableReference) throws SQLException { + String trimmed = tableReference == null ? "" : tableReference.trim(); + if (trimmed.isEmpty()) { + throw new IllegalArgumentException("tableName cannot be null or empty"); + } + + String base = trimmed; + int firstWhitespace = -1; + for (int i = 0; i < trimmed.length(); i++) { + if (Character.isWhitespace(trimmed.charAt(i))) { + firstWhitespace = i; + break; + } + } + if (firstWhitespace != -1) { + base = trimmed.substring(0, firstWhitespace); + } + + String schema = null; + String table = base; + int dotIndex = base.lastIndexOf('.'); + if (dotIndex != -1) { + schema = base.substring(0, dotIndex); + table = base.substring(dotIndex + 1); + } + + if (schema == null || schema.isEmpty()) { + schema = safeCurrentSchema(sqlConnection); + } + if (schema == null || schema.isEmpty()) { + schema = "public"; + } + + return new TableRef(schema, table); + } + + private String safeCurrentSchema(Connection sqlConnection) throws SQLException { + try { + String schema = sqlConnection.getSchema(); + if (schema != null && !schema.trim().isEmpty()) { + return schema.trim(); + } + } catch (Throwable ignored) { + // Some drivers may not support getSchema consistently. + } + + try (PreparedStatement statement = sqlConnection.prepareStatement("SELECT current_schema()")) { + try (ResultSet resultSet = statement.executeQuery()) { + if (resultSet.next()) { + String schema = resultSet.getString(1); + return schema == null ? null : schema.trim(); + } + } + } + + return null; + } + + private static final class TableRef { + private final String schema; + private final String table; + + private TableRef(String schema, String table) { + this.schema = schema; + this.table = table; + } + } } diff --git a/src/main/java/fr/maxlego08/sarah/dialect/SqlDialect.java b/src/main/java/fr/maxlego08/sarah/dialect/SqlDialect.java index 6ede61e..b3f5782 100644 --- a/src/main/java/fr/maxlego08/sarah/dialect/SqlDialect.java +++ b/src/main/java/fr/maxlego08/sarah/dialect/SqlDialect.java @@ -12,6 +12,21 @@ public interface SqlDialect { String quoteIdentifier(String name); + /** + * Quotes a table reference, supporting schema-qualified names and optional aliases. + *

+ * Examples: + *

    + *
  • {@code users} -> {@code "users"}
  • + *
  • {@code main.users} -> {@code "main"."users"}
  • + *
  • {@code users u} -> {@code "users" u}
  • + *
+ * + * @param tableReference The raw table reference as provided by callers + * @return The quoted table reference + */ + String quoteTableReference(String tableReference); + String qualifyIdentifier(String prefix, String column); String driverClassName(); diff --git a/src/main/java/fr/maxlego08/sarah/requests/AlterRequest.java b/src/main/java/fr/maxlego08/sarah/requests/AlterRequest.java index da4506c..6e4812b 100644 --- a/src/main/java/fr/maxlego08/sarah/requests/AlterRequest.java +++ b/src/main/java/fr/maxlego08/sarah/requests/AlterRequest.java @@ -30,7 +30,7 @@ public int execute(DatabaseConnection databaseConnection, DatabaseConfiguration SqlDialect dialect = SqlDialects.from(databaseConfiguration.getDatabaseType()); StringBuilder alterTableSQL = new StringBuilder("ALTER TABLE "); - alterTableSQL.append(dialect.quoteIdentifier(this.schema.getTableName())).append(" "); + alterTableSQL.append(dialect.quoteTableReference(this.schema.getTableName())).append(" "); List columnSQLs = new ArrayList<>(); for (ColumnDefinition column : this.schema.getColumns()) { diff --git a/src/main/java/fr/maxlego08/sarah/requests/CreateIndexRequest.java b/src/main/java/fr/maxlego08/sarah/requests/CreateIndexRequest.java index 4cfc75f..acea56a 100644 --- a/src/main/java/fr/maxlego08/sarah/requests/CreateIndexRequest.java +++ b/src/main/java/fr/maxlego08/sarah/requests/CreateIndexRequest.java @@ -33,7 +33,7 @@ public int execute(DatabaseConnection databaseConnection, DatabaseConfiguration indexTableSQL.append(dialect.quoteIdentifier(indexName)); indexTableSQL.append(" ON "); - indexTableSQL.append(dialect.quoteIdentifier(tableName)); + indexTableSQL.append(dialect.quoteTableReference(tableName)); indexTableSQL.append(" ("); indexTableSQL.append(dialect.quoteIdentifier(column.getName())); indexTableSQL.append(" )"); diff --git a/src/main/java/fr/maxlego08/sarah/requests/CreateRequest.java b/src/main/java/fr/maxlego08/sarah/requests/CreateRequest.java index 23c82e9..68d200e 100644 --- a/src/main/java/fr/maxlego08/sarah/requests/CreateRequest.java +++ b/src/main/java/fr/maxlego08/sarah/requests/CreateRequest.java @@ -30,7 +30,7 @@ public int execute(DatabaseConnection databaseConnection, DatabaseConfiguration SqlDialect dialect = SqlDialects.from(databaseConfiguration.getDatabaseType()); StringBuilder createTableSQL = new StringBuilder("CREATE TABLE IF NOT EXISTS "); - createTableSQL.append(dialect.quoteIdentifier(this.schema.getTableName())).append(" ("); + createTableSQL.append(dialect.quoteTableReference(this.schema.getTableName())).append(" ("); List columnSQLs = new ArrayList<>(); boolean hasInlinePrimaryKey = false; diff --git a/src/main/java/fr/maxlego08/sarah/requests/DeleteRequest.java b/src/main/java/fr/maxlego08/sarah/requests/DeleteRequest.java index 95f9bd1..dec6653 100644 --- a/src/main/java/fr/maxlego08/sarah/requests/DeleteRequest.java +++ b/src/main/java/fr/maxlego08/sarah/requests/DeleteRequest.java @@ -24,7 +24,7 @@ public DeleteRequest(Schema schemaBuilder) { @Override public int execute(DatabaseConnection databaseConnection, DatabaseConfiguration databaseConfiguration, Logger logger) { SqlDialect dialect = SqlDialects.from(databaseConfiguration.getDatabaseType()); - StringBuilder sql = new StringBuilder("DELETE FROM ").append(dialect.quoteIdentifier(schemaBuilder.getTableName())); + StringBuilder sql = new StringBuilder("DELETE FROM ").append(dialect.quoteTableReference(schemaBuilder.getTableName())); schemaBuilder.whereConditions(sql, dialect); String finalQuery = databaseConfiguration.replacePrefix(sql.toString()); diff --git a/src/main/java/fr/maxlego08/sarah/requests/DropTableRequest.java b/src/main/java/fr/maxlego08/sarah/requests/DropTableRequest.java index 2a2050b..87baf37 100644 --- a/src/main/java/fr/maxlego08/sarah/requests/DropTableRequest.java +++ b/src/main/java/fr/maxlego08/sarah/requests/DropTableRequest.java @@ -29,7 +29,7 @@ public int execute(DatabaseConnection databaseConnection, DatabaseConfiguration } SqlDialect dialect = SqlDialects.from(databaseConfiguration.getDatabaseType()); - String finalQuery = databaseConfiguration.replacePrefix("DROP TABLE IF EXISTS " + dialect.quoteIdentifier(tableName)); + String finalQuery = databaseConfiguration.replacePrefix("DROP TABLE IF EXISTS " + dialect.quoteTableReference(tableName)); if (databaseConfiguration.isDebug()) { logger.info("Executing SQL: " + finalQuery); } diff --git a/src/main/java/fr/maxlego08/sarah/requests/InsertAllRequest.java b/src/main/java/fr/maxlego08/sarah/requests/InsertAllRequest.java index 31097f4..ee4f848 100644 --- a/src/main/java/fr/maxlego08/sarah/requests/InsertAllRequest.java +++ b/src/main/java/fr/maxlego08/sarah/requests/InsertAllRequest.java @@ -31,7 +31,7 @@ public int execute(DatabaseConnection databaseConnection, DatabaseConfiguration SqlDialect dialect = SqlDialects.from(databaseConfiguration.getDatabaseType()); StringBuilder insertBuilder = new StringBuilder("INSERT INTO ") - .append(dialect.quoteIdentifier(this.toTableName)) + .append(dialect.quoteTableReference(this.toTableName)) .append(" ("); List quotedColumns = new ArrayList(); @@ -45,7 +45,7 @@ public int execute(DatabaseConnection databaseConnection, DatabaseConfiguration String columnsSql = String.join(", ", quotedColumns); insertBuilder.append(columnsSql).append(") "); insertBuilder.append("SELECT ").append(columnsSql); - insertBuilder.append(" FROM ").append(dialect.quoteIdentifier(this.schema.getTableName())); + insertBuilder.append(" FROM ").append(dialect.quoteTableReference(this.schema.getTableName())); String insertQuery = databaseConfiguration.replacePrefix(insertBuilder.toString()); diff --git a/src/main/java/fr/maxlego08/sarah/requests/InsertBatchRequest.java b/src/main/java/fr/maxlego08/sarah/requests/InsertBatchRequest.java index 20c7f31..6e1ec33 100644 --- a/src/main/java/fr/maxlego08/sarah/requests/InsertBatchRequest.java +++ b/src/main/java/fr/maxlego08/sarah/requests/InsertBatchRequest.java @@ -34,7 +34,7 @@ public int execute(DatabaseConnection databaseConnection, DatabaseConfiguration SqlDialect dialect = SqlDialects.from(databaseConfiguration.getDatabaseType()); Schema firstSchema = schemas.get(0); - StringBuilder insertQuery = new StringBuilder("INSERT INTO " + dialect.quoteIdentifier(firstSchema.getTableName()) + " ("); + StringBuilder insertQuery = new StringBuilder("INSERT INTO " + dialect.quoteTableReference(firstSchema.getTableName()) + " ("); StringBuilder valuesQuery = new StringBuilder("VALUES "); List values = new ArrayList<>(); diff --git a/src/main/java/fr/maxlego08/sarah/requests/InsertRequest.java b/src/main/java/fr/maxlego08/sarah/requests/InsertRequest.java index 5a3f0f4..9343947 100644 --- a/src/main/java/fr/maxlego08/sarah/requests/InsertRequest.java +++ b/src/main/java/fr/maxlego08/sarah/requests/InsertRequest.java @@ -30,7 +30,7 @@ public InsertRequest(Schema schema) { public int execute(DatabaseConnection databaseConnection, DatabaseConfiguration databaseConfiguration, Logger logger) { SqlDialect dialect = SqlDialects.from(databaseConfiguration.getDatabaseType()); - StringBuilder insertQuery = new StringBuilder("INSERT INTO " + dialect.quoteIdentifier(this.schema.getTableName()) + " ("); + StringBuilder insertQuery = new StringBuilder("INSERT INTO " + dialect.quoteTableReference(this.schema.getTableName()) + " ("); StringBuilder valuesQuery = new StringBuilder("VALUES ("); List values = new ArrayList<>(); diff --git a/src/main/java/fr/maxlego08/sarah/requests/RenameExecutor.java b/src/main/java/fr/maxlego08/sarah/requests/RenameExecutor.java index a8d8e61..0f9a7c0 100644 --- a/src/main/java/fr/maxlego08/sarah/requests/RenameExecutor.java +++ b/src/main/java/fr/maxlego08/sarah/requests/RenameExecutor.java @@ -26,9 +26,9 @@ public int execute(DatabaseConnection databaseConnection, DatabaseConfiguration SqlDialect dialect = SqlDialects.from(databaseConfiguration.getDatabaseType()); StringBuilder alterTableSQL = new StringBuilder("ALTER TABLE "); - alterTableSQL.append(dialect.quoteIdentifier(this.schema.getTableName())); + alterTableSQL.append(dialect.quoteTableReference(this.schema.getTableName())); alterTableSQL.append(" RENAME TO "); - alterTableSQL.append(dialect.quoteIdentifier(this.schema.getNewTableName())); + alterTableSQL.append(dialect.quoteTableReference(this.schema.getNewTableName())); String finalQuery = databaseConfiguration.replacePrefix(alterTableSQL.toString()); if (databaseConfiguration.isDebug()) { diff --git a/src/main/java/fr/maxlego08/sarah/requests/UpdateBatchRequest.java b/src/main/java/fr/maxlego08/sarah/requests/UpdateBatchRequest.java index ed039f7..037893f 100644 --- a/src/main/java/fr/maxlego08/sarah/requests/UpdateBatchRequest.java +++ b/src/main/java/fr/maxlego08/sarah/requests/UpdateBatchRequest.java @@ -30,7 +30,7 @@ public int execute(DatabaseConnection databaseConnection, DatabaseConfiguration SqlDialect dialect = SqlDialects.from(databaseConfiguration.getDatabaseType()); Schema firstSchema = schemas.get(0); - StringBuilder updateQuery = new StringBuilder("UPDATE " + dialect.quoteIdentifier(firstSchema.getTableName())); + StringBuilder updateQuery = new StringBuilder("UPDATE " + dialect.quoteTableReference(firstSchema.getTableName())); if (!firstSchema.getJoinConditions().isEmpty()) { for (JoinCondition join : firstSchema.getJoinConditions()) { diff --git a/src/main/java/fr/maxlego08/sarah/requests/UpdateRequest.java b/src/main/java/fr/maxlego08/sarah/requests/UpdateRequest.java index d2fd0f8..58ed50f 100644 --- a/src/main/java/fr/maxlego08/sarah/requests/UpdateRequest.java +++ b/src/main/java/fr/maxlego08/sarah/requests/UpdateRequest.java @@ -29,7 +29,7 @@ public UpdateRequest(Schema schema) { public int execute(DatabaseConnection databaseConnection, DatabaseConfiguration databaseConfiguration, Logger logger) { SqlDialect dialect = SqlDialects.from(databaseConfiguration.getDatabaseType()); - StringBuilder updateQuery = new StringBuilder("UPDATE " + dialect.quoteIdentifier(this.schema.getTableName())); + StringBuilder updateQuery = new StringBuilder("UPDATE " + dialect.quoteTableReference(this.schema.getTableName())); if (!this.schema.getJoinConditions().isEmpty()) { for (JoinCondition join : this.schema.getJoinConditions()) { diff --git a/src/main/java/fr/maxlego08/sarah/requests/UpsertBatchRequest.java b/src/main/java/fr/maxlego08/sarah/requests/UpsertBatchRequest.java index c42f264..3906be9 100644 --- a/src/main/java/fr/maxlego08/sarah/requests/UpsertBatchRequest.java +++ b/src/main/java/fr/maxlego08/sarah/requests/UpsertBatchRequest.java @@ -32,7 +32,7 @@ public int execute(DatabaseConnection databaseConnection, DatabaseConfiguration SqlDialect dialect = SqlDialects.from(databaseConfiguration.getDatabaseType()); Schema firstSchema = schemas.get(0); - StringBuilder insertQuery = new StringBuilder("INSERT INTO " + dialect.quoteIdentifier(firstSchema.getTableName()) + " ("); + StringBuilder insertQuery = new StringBuilder("INSERT INTO " + dialect.quoteTableReference(firstSchema.getTableName()) + " ("); StringBuilder valuesQuery = new StringBuilder("VALUES "); List values = new ArrayList<>(); diff --git a/src/main/java/fr/maxlego08/sarah/requests/UpsertRequest.java b/src/main/java/fr/maxlego08/sarah/requests/UpsertRequest.java index c57677d..af39966 100644 --- a/src/main/java/fr/maxlego08/sarah/requests/UpsertRequest.java +++ b/src/main/java/fr/maxlego08/sarah/requests/UpsertRequest.java @@ -28,7 +28,7 @@ public UpsertRequest(Schema schema) { @Override public int execute(DatabaseConnection databaseConnection, DatabaseConfiguration databaseConfiguration, Logger logger) { SqlDialect dialect = SqlDialects.from(databaseConfiguration.getDatabaseType()); - StringBuilder insertQuery = new StringBuilder("INSERT INTO " + dialect.quoteIdentifier(this.schema.getTableName()) + " ("); + StringBuilder insertQuery = new StringBuilder("INSERT INTO " + dialect.quoteTableReference(this.schema.getTableName()) + " ("); StringBuilder valuesQuery = new StringBuilder("VALUES ("); StringBuilder onUpdateQuery = new StringBuilder(); diff --git a/src/test/java/fr/maxlego08/sarah/PostgreSqlMigrationDialectTest.java b/src/test/java/fr/maxlego08/sarah/PostgreSqlMigrationDialectTest.java index e962d76..343f34e 100644 --- a/src/test/java/fr/maxlego08/sarah/PostgreSqlMigrationDialectTest.java +++ b/src/test/java/fr/maxlego08/sarah/PostgreSqlMigrationDialectTest.java @@ -6,9 +6,17 @@ import fr.maxlego08.sarah.dialect.SqlDialects; import org.junit.jupiter.api.Test; +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; import java.util.Arrays; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; public class PostgreSqlMigrationDialectTest { @@ -17,4 +25,51 @@ public void testPostgreSqlMissingColumnsRequiresConnection() { SqlDialect dialect = SqlDialects.from(DatabaseType.POSTGRESQL); assertThrows(NullPointerException.class, () -> dialect.missingColumns(null, null, "users", Arrays.asList(new ColumnDefinition("name", "VARCHAR")))); } + + @Test + public void testPostgreSqlMissingColumnsUsesSchemaFromQualifiedTableName() throws Exception { + SqlDialect dialect = SqlDialects.from(DatabaseType.POSTGRESQL); + DatabaseConfiguration configuration = DatabaseConfiguration.createPostgreSql("u", "p", 5432, "localhost", "db"); + DatabaseConnection databaseConnection = mock(DatabaseConnection.class); + Connection sqlConnection = mock(Connection.class); + PreparedStatement preparedStatement = mock(PreparedStatement.class); + ResultSet resultSet = mock(ResultSet.class); + + when(databaseConnection.getDatabaseConfiguration()).thenReturn(configuration); + when(databaseConnection.getConnection()).thenReturn(sqlConnection); + when(sqlConnection.prepareStatement(anyString())).thenReturn(preparedStatement); + when(preparedStatement.executeQuery()).thenReturn(resultSet); + when(resultSet.next()).thenReturn(true); + when(resultSet.getLong(1)).thenReturn(1L); + + dialect.missingColumns(databaseConnection, null, "myschema.users", Arrays.asList(new ColumnDefinition("name", "VARCHAR"))); + + verify(preparedStatement).setString(1, "users"); + verify(preparedStatement).setString(2, "myschema"); + verify(preparedStatement).setString(3, "name"); + } + + @Test + public void testPostgreSqlMissingColumnsUsesCurrentSchemaWhenNotQualified() throws Exception { + SqlDialect dialect = SqlDialects.from(DatabaseType.POSTGRESQL); + DatabaseConfiguration configuration = DatabaseConfiguration.createPostgreSql("u", "p", 5432, "localhost", "db"); + DatabaseConnection databaseConnection = mock(DatabaseConnection.class); + Connection sqlConnection = mock(Connection.class); + PreparedStatement preparedStatement = mock(PreparedStatement.class); + ResultSet resultSet = mock(ResultSet.class); + + when(databaseConnection.getDatabaseConfiguration()).thenReturn(configuration); + when(databaseConnection.getConnection()).thenReturn(sqlConnection); + when(sqlConnection.getSchema()).thenReturn("tenant1"); + when(sqlConnection.prepareStatement(anyString())).thenReturn(preparedStatement); + when(preparedStatement.executeQuery()).thenReturn(resultSet); + when(resultSet.next()).thenReturn(true); + when(resultSet.getLong(1)).thenReturn(1L); + + assertTrue(dialect.missingColumns(databaseConnection, null, "users u", Arrays.asList(new ColumnDefinition("name", "VARCHAR"))).isEmpty()); + + verify(preparedStatement).setString(1, "users"); + verify(preparedStatement).setString(2, "tenant1"); + verify(preparedStatement).setString(3, "name"); + } } diff --git a/src/test/java/fr/maxlego08/sarah/PostgreSqlRequestDialectRegressionTest.java b/src/test/java/fr/maxlego08/sarah/PostgreSqlRequestDialectRegressionTest.java index 099dade..e40189f 100644 --- a/src/test/java/fr/maxlego08/sarah/PostgreSqlRequestDialectRegressionTest.java +++ b/src/test/java/fr/maxlego08/sarah/PostgreSqlRequestDialectRegressionTest.java @@ -5,6 +5,7 @@ import fr.maxlego08.sarah.requests.DeleteRequest; import fr.maxlego08.sarah.requests.InsertAllRequest; import fr.maxlego08.sarah.requests.InsertBatchRequest; +import fr.maxlego08.sarah.requests.InsertRequest; import fr.maxlego08.sarah.requests.UpdateBatchRequest; import fr.maxlego08.sarah.requests.UpdateRequest; import org.junit.jupiter.api.Test; @@ -51,6 +52,29 @@ public void testDeleteRequestUsesPostgreSqlWhereQuoting() throws Exception { verify(preparedStatement).setObject(1, "sarah@example.com"); } + @Test + public void testDeleteRequestQuotesAliasWithoutQuotingWhitespace() throws Exception { + DatabaseConfiguration configuration = DatabaseConfiguration.createPostgreSql("u", "p", 5432, "localhost", "db"); + DatabaseConnection databaseConnection = mock(DatabaseConnection.class); + Connection sqlConnection = mock(Connection.class); + PreparedStatement preparedStatement = mock(PreparedStatement.class); + AtomicReference capturedSql = new AtomicReference(); + + when(databaseConnection.getConnection()).thenReturn(sqlConnection); + when(sqlConnection.prepareStatement(anyString())).thenAnswer(invocation -> { + capturedSql.set(invocation.getArgument(0)); + return preparedStatement; + }); + when(preparedStatement.executeUpdate()).thenReturn(1); + + Schema schema = SchemaBuilder.delete("users u"); + schema.where("email", "sarah@example.com"); + + new DeleteRequest(schema).execute(databaseConnection, configuration, logger); + + assertEquals("DELETE FROM \"users\" u WHERE \"email\" = ?", capturedSql.get()); + } + @Test public void testUpdateRequestUsesPostgreSqlWhereQuoting() throws Exception { DatabaseConfiguration configuration = DatabaseConfiguration.createPostgreSql("u", "p", 5432, "localhost", "db"); @@ -104,6 +128,30 @@ public void testInsertBatchRequestUsesPostgreSqlIdentifierQuoting() throws Excep assertEquals("INSERT INTO \"users\" (\"name\") VALUES (?), (?)", capturedSql.get()); } + @Test + public void testInsertRequestQuotesSchemaQualifiedTableName() throws Exception { + DatabaseConfiguration configuration = DatabaseConfiguration.createPostgreSql("u", "p", 5432, "localhost", "db"); + DatabaseConnection databaseConnection = mock(DatabaseConnection.class); + Connection sqlConnection = mock(Connection.class); + PreparedStatement preparedStatement = mock(PreparedStatement.class); + ResultSet generatedKeys = mock(ResultSet.class); + AtomicReference capturedSql = new AtomicReference(); + + when(databaseConnection.getConnection()).thenReturn(sqlConnection); + when(sqlConnection.prepareStatement(anyString(), anyInt())).thenAnswer(invocation -> { + capturedSql.set(invocation.getArgument(0)); + return preparedStatement; + }); + when(preparedStatement.executeUpdate()).thenReturn(1); + when(preparedStatement.getGeneratedKeys()).thenReturn(generatedKeys); + when(generatedKeys.next()).thenReturn(false); + + Schema schema = SchemaBuilder.insert("main.users", builder -> builder.string("name", "Sarah")); + new InsertRequest(schema).execute(databaseConnection, configuration, logger); + + assertEquals("INSERT INTO \"main\".\"users\" (\"name\") VALUES (?)", capturedSql.get()); + } + @Test public void testInsertAllRequestUsesPostgreSqlIdentifierQuoting() throws Exception { DatabaseConfiguration configuration = DatabaseConfiguration.createPostgreSql("u", "p", 5432, "localhost", "db"); diff --git a/src/test/java/fr/maxlego08/sarah/PostgreSqlTypeMappingTest.java b/src/test/java/fr/maxlego08/sarah/PostgreSqlTypeMappingTest.java index 1b9e400..da2fd14 100644 --- a/src/test/java/fr/maxlego08/sarah/PostgreSqlTypeMappingTest.java +++ b/src/test/java/fr/maxlego08/sarah/PostgreSqlTypeMappingTest.java @@ -20,6 +20,13 @@ public void testLongTextMappedToText() { assertEquals("\"data\" TEXT NOT NULL", result); } + @Test + public void testMappingDoesNotMutateOriginalType() { + ColumnDefinition column = new ColumnDefinition("data", "LONGTEXT"); + column.build(config, postgres); + assertEquals("LONGTEXT", column.getType()); + } + @Test public void testMediumTextMappedToText() { ColumnDefinition column = new ColumnDefinition("content", "MEDIUMTEXT"); From 84e4f64c9ed7ab5bd9070f59382f5cf4ed1b5c8f Mon Sep 17 00:00:00 2001 From: hamdan Date: Sat, 30 May 2026 12:28:22 +0800 Subject: [PATCH 5/6] Fix legacy whereConditions dialect --- .../fr/maxlego08/sarah/SchemaBuilder.java | 4 +++- .../PostgreSqlConditionRenderingTest.java | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/main/java/fr/maxlego08/sarah/SchemaBuilder.java b/src/main/java/fr/maxlego08/sarah/SchemaBuilder.java index 3beece0..7e70d19 100644 --- a/src/main/java/fr/maxlego08/sarah/SchemaBuilder.java +++ b/src/main/java/fr/maxlego08/sarah/SchemaBuilder.java @@ -487,7 +487,9 @@ public String getTableName() { @Override public void whereConditions(StringBuilder sql) { - whereConditions(sql, SqlDialects.from(DatabaseType.MYSQL)); + DatabaseConfiguration configuration = MigrationManager.getDatabaseConfiguration(); + SqlDialect dialect = configuration == null ? SqlDialects.from(DatabaseType.MYSQL) : SqlDialects.from(configuration.getDatabaseType()); + whereConditions(sql, dialect); } @Override diff --git a/src/test/java/fr/maxlego08/sarah/PostgreSqlConditionRenderingTest.java b/src/test/java/fr/maxlego08/sarah/PostgreSqlConditionRenderingTest.java index a0f41c9..149c6d5 100644 --- a/src/test/java/fr/maxlego08/sarah/PostgreSqlConditionRenderingTest.java +++ b/src/test/java/fr/maxlego08/sarah/PostgreSqlConditionRenderingTest.java @@ -4,6 +4,7 @@ import fr.maxlego08.sarah.conditions.OrderByCondition; import fr.maxlego08.sarah.conditions.SelectCondition; import fr.maxlego08.sarah.conditions.WhereCondition; +import fr.maxlego08.sarah.database.Schema; import fr.maxlego08.sarah.database.DatabaseType; import fr.maxlego08.sarah.dialect.SqlDialect; import fr.maxlego08.sarah.dialect.SqlDialects; @@ -29,6 +30,24 @@ public void testWhereInUsesDialectQuoting() { assertEquals("u.\"id\" IN (?,?,?)", condition.getCondition(postgres)); } + @Test + public void testSchemaBuilderLegacyWhereConditionsUsesMigrationDialect() { + DatabaseConfiguration previous = MigrationManager.getDatabaseConfiguration(); + try { + MigrationManager.setDatabaseConfiguration(DatabaseConfiguration.createPostgreSql("u", "p", 5432, "localhost", "db")); + + Schema schema = SchemaBuilder.delete("users"); + schema.where("name", "Sarah"); + + StringBuilder sql = new StringBuilder(); + schema.whereConditions(sql); + + assertEquals(" WHERE \"name\" = ?", sql.toString()); + } finally { + MigrationManager.setDatabaseConfiguration(previous); + } + } + @Test public void testSelectConditionUsesDialectQuoting() { SelectCondition select = new SelectCondition("u", "name", "username", false, null); From 410131954313117ebf0307fd87faa3cdd8781565 Mon Sep 17 00:00:00 2001 From: hamdan Date: Sat, 30 May 2026 12:50:23 +0800 Subject: [PATCH 6/6] Fix NPE, join quoting, and rename schema prefix in PostgreSQL dialect --- src/main/java/fr/maxlego08/sarah/MigrationManager.java | 2 +- .../java/fr/maxlego08/sarah/conditions/JoinCondition.java | 8 ++++---- .../java/fr/maxlego08/sarah/requests/RenameExecutor.java | 7 ++++++- .../maxlego08/sarah/PostgreSqlConditionRenderingTest.java | 2 +- 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/main/java/fr/maxlego08/sarah/MigrationManager.java b/src/main/java/fr/maxlego08/sarah/MigrationManager.java index 78fc57f..61b11f4 100644 --- a/src/main/java/fr/maxlego08/sarah/MigrationManager.java +++ b/src/main/java/fr/maxlego08/sarah/MigrationManager.java @@ -105,7 +105,7 @@ public static void execute(DatabaseConnection databaseConnection, Logger logger) List mustBeAdd = new ArrayList<>(); String tableName = schema.getTableName(); - tableName = tableName.replace("%prefix%", databaseConnection.getDatabaseConfiguration().getTablePrefix()); + tableName = databaseConnection.getDatabaseConfiguration().replacePrefix(tableName); SqlDialect dialect = SqlDialects.from(databaseConnection.getDatabaseConfiguration().getDatabaseType()); mustBeAdd.addAll(dialect.missingColumns(databaseConnection, logger, tableName, schema.getColumns())); diff --git a/src/main/java/fr/maxlego08/sarah/conditions/JoinCondition.java b/src/main/java/fr/maxlego08/sarah/conditions/JoinCondition.java index 388ff92..8d7c997 100644 --- a/src/main/java/fr/maxlego08/sarah/conditions/JoinCondition.java +++ b/src/main/java/fr/maxlego08/sarah/conditions/JoinCondition.java @@ -48,9 +48,9 @@ public String getJoinClause() { public String getJoinClause(SqlDialect dialect) { StringBuilder joinClause = new StringBuilder(); joinClause.append(this.joinType.getSql()).append(" ") - .append(this.primaryTable).append(" AS ").append(this.primaryTableAlias) - .append(" ON ").append(dialect.qualifyIdentifier(this.primaryTableAlias, this.primaryColumn)) - .append(" = ").append(dialect.qualifyIdentifier(this.foreignTable, this.foreignColumn)); + .append(dialect.quoteTableReference(this.primaryTable)).append(" AS ").append(dialect.quoteIdentifier(this.primaryTableAlias)) + .append(" ON ").append(dialect.qualifyIdentifier(dialect.quoteIdentifier(this.primaryTableAlias), this.primaryColumn)) + .append(" = ").append(dialect.qualifyIdentifier(dialect.quoteTableReference(this.foreignTable), this.foreignColumn)); if (this.additionalCondition != null) { joinClause.append(" AND ").append(this.additionalCondition.getCondition(dialect)); @@ -63,7 +63,7 @@ private String getCondition() { } private String getCondition(SqlDialect dialect) { - return dialect.qualifyIdentifier(this.primaryTableAlias, this.primaryColumn) + " = '" + this.foreignColumn + "'"; + return dialect.qualifyIdentifier(dialect.quoteIdentifier(this.primaryTableAlias), this.primaryColumn) + " = '" + this.foreignColumn + "'"; } public enum JoinType { diff --git a/src/main/java/fr/maxlego08/sarah/requests/RenameExecutor.java b/src/main/java/fr/maxlego08/sarah/requests/RenameExecutor.java index 0f9a7c0..b80ce90 100644 --- a/src/main/java/fr/maxlego08/sarah/requests/RenameExecutor.java +++ b/src/main/java/fr/maxlego08/sarah/requests/RenameExecutor.java @@ -28,7 +28,12 @@ public int execute(DatabaseConnection databaseConnection, DatabaseConfiguration StringBuilder alterTableSQL = new StringBuilder("ALTER TABLE "); alterTableSQL.append(dialect.quoteTableReference(this.schema.getTableName())); alterTableSQL.append(" RENAME TO "); - alterTableSQL.append(dialect.quoteTableReference(this.schema.getNewTableName())); + String newTable = this.schema.getNewTableName(); + int lastDot = newTable.lastIndexOf('.'); + if (lastDot != -1) { + newTable = newTable.substring(lastDot + 1); + } + alterTableSQL.append(dialect.quoteIdentifier(newTable)); String finalQuery = databaseConfiguration.replacePrefix(alterTableSQL.toString()); if (databaseConfiguration.isDebug()) { diff --git a/src/test/java/fr/maxlego08/sarah/PostgreSqlConditionRenderingTest.java b/src/test/java/fr/maxlego08/sarah/PostgreSqlConditionRenderingTest.java index 149c6d5..00b0f45 100644 --- a/src/test/java/fr/maxlego08/sarah/PostgreSqlConditionRenderingTest.java +++ b/src/test/java/fr/maxlego08/sarah/PostgreSqlConditionRenderingTest.java @@ -71,7 +71,7 @@ public void testJoinConditionUsesDialectQuoting() { "id", null ); - assertEquals("LEFT JOIN orders AS o ON o.\"user_id\" = users.\"id\"", join.getJoinClause(postgres)); + assertEquals("LEFT JOIN \"orders\" AS \"o\" ON \"o\".\"user_id\" = \"users\".\"id\"", join.getJoinClause(postgres)); } @Test