From b2a59c4e3495cad2d007c6ff8c74dfa4f6652dab Mon Sep 17 00:00:00 2001 From: harry Date: Sun, 1 Mar 2026 11:59:33 +0100 Subject: [PATCH 1/2] introduce abstract TableSink operator with Java and Spark platform implementation and tests --- wayang-commons/wayang-basic/pom.xml | 5 + .../wayang/basic/operators/TableSink.java | 96 +++++ .../wayang/basic/util/SqlTypeUtils.java | 187 ++++++++++ .../wayang/basic/util/SqlTypeUtilsTest.java | 122 +++++++ wayang-platforms/wayang-java/pom.xml | 12 + .../wayang/java/operators/JavaTableSink.java | 246 +++++++++++++ .../java/operators/JavaTableSinkTest.java | 328 ++++++++++++++++++ wayang-platforms/wayang-spark/pom.xml | 13 + .../spark/operators/SparkTableSink.java | 185 ++++++++++ .../spark/operators/SparkTableSinkTest.java | 281 +++++++++++++++ 10 files changed, 1475 insertions(+) create mode 100644 wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/operators/TableSink.java create mode 100644 wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/util/SqlTypeUtils.java create mode 100644 wayang-commons/wayang-basic/src/test/java/org/apache/wayang/basic/util/SqlTypeUtilsTest.java create mode 100644 wayang-platforms/wayang-java/src/main/java/org/apache/wayang/java/operators/JavaTableSink.java create mode 100644 wayang-platforms/wayang-java/src/test/java/org/apache/wayang/java/operators/JavaTableSinkTest.java create mode 100644 wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkTableSink.java create mode 100644 wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkTableSinkTest.java diff --git a/wayang-commons/wayang-basic/pom.xml b/wayang-commons/wayang-basic/pom.xml index 1d1b460ae..f8ce0fe0e 100644 --- a/wayang-commons/wayang-basic/pom.xml +++ b/wayang-commons/wayang-basic/pom.xml @@ -120,6 +120,11 @@ 20231013 + + org.apache.calcite + calcite-core + ${calcite.version} + com.azure azure-storage-blob diff --git a/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/operators/TableSink.java b/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/operators/TableSink.java new file mode 100644 index 000000000..0b556519f --- /dev/null +++ b/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/operators/TableSink.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.basic.operators; + +import org.apache.wayang.basic.data.Record; +import org.apache.wayang.core.plan.wayangplan.UnarySink; +import org.apache.wayang.core.types.DataSetType; + +import java.util.Properties; + +/** + * {@link UnarySink} that writes Records to a database table. + */ + +public class TableSink extends UnarySink { + private final String tableName; + + private String[] columnNames; + + private final Properties props; + + private String mode; + + /** + * Creates a new instance. + * + * @param props database connection properties + * @param mode write mode + * @param tableName name of the table to be written + * @param columnNames names of the columns in the tables + */ + public TableSink(Properties props, String mode, String tableName, String... columnNames) { + this(props, mode, tableName, columnNames, (DataSetType) DataSetType.createDefault(Record.class)); + } + + public TableSink(Properties props, String mode, String tableName, String[] columnNames, DataSetType type) { + super(type); + this.tableName = tableName; + this.columnNames = columnNames; + this.props = props; + this.mode = mode; + } + + /** + * Copies an instance (exclusive of broadcasts). + * + * @param that that should be copied + */ + public TableSink(TableSink that) { + super(that); + this.tableName = that.getTableName(); + this.columnNames = that.getColumnNames(); + this.props = that.getProperties(); + this.mode = that.getMode(); + } + + public String getTableName() { + return this.tableName; + } + + protected void setColumnNames(String[] columnNames) { + this.columnNames = columnNames; + } + + public String[] getColumnNames() { + return this.columnNames; + } + + public Properties getProperties() { + return this.props; + } + + public String getMode() { + return mode; + } + + public void setMode(String mode) { + this.mode = mode; + } +} diff --git a/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/util/SqlTypeUtils.java b/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/util/SqlTypeUtils.java new file mode 100644 index 000000000..541600b71 --- /dev/null +++ b/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/util/SqlTypeUtils.java @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.basic.util; + +import org.apache.calcite.sql.SqlDialect; +import org.apache.wayang.basic.data.Record; + +import java.lang.reflect.Field; +import java.sql.Date; +import java.sql.Timestamp; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Utility for mapping Java types to SQL types across different dialects. + */ +public class SqlTypeUtils { + + private static final Map, String>> dialectTypeMaps = new HashMap<>(); + + static { + // Default mappings (Standard SQL) + Map, String> defaultMap = new HashMap<>(); + defaultMap.put(Integer.class, "INT"); + defaultMap.put(int.class, "INT"); + defaultMap.put(Long.class, "BIGINT"); + defaultMap.put(long.class, "BIGINT"); + defaultMap.put(Double.class, "DOUBLE"); + defaultMap.put(double.class, "DOUBLE"); + defaultMap.put(Float.class, "FLOAT"); + defaultMap.put(float.class, "FLOAT"); + defaultMap.put(Boolean.class, "BOOLEAN"); + defaultMap.put(boolean.class, "BOOLEAN"); + defaultMap.put(String.class, "VARCHAR(255)"); + defaultMap.put(Date.class, "DATE"); + defaultMap.put(LocalDate.class, "DATE"); + defaultMap.put(Timestamp.class, "TIMESTAMP"); + defaultMap.put(LocalDateTime.class, "TIMESTAMP"); + + dialectTypeMaps.put(SqlDialect.DatabaseProduct.UNKNOWN, defaultMap); + + // PostgreSQL Overrides + Map, String> pgMap = new HashMap<>(defaultMap); + pgMap.put(Double.class, "DOUBLE PRECISION"); + pgMap.put(double.class, "DOUBLE PRECISION"); + dialectTypeMaps.put(SqlDialect.DatabaseProduct.POSTGRESQL, pgMap); + + // Add more dialects here as needed (MySQL, Oracle, etc.) + } + + /** + * Detects the database product from a JDBC URL. + * + * @param url JDBC URL + * @return detected DatabaseProduct + */ + public static SqlDialect.DatabaseProduct detectProduct(String url) { + if (url == null) + return SqlDialect.DatabaseProduct.UNKNOWN; + String lowerUrl = url.toLowerCase(); + if (lowerUrl.contains("postgresql") || lowerUrl.contains("postgres")) + return SqlDialect.DatabaseProduct.POSTGRESQL; + if (lowerUrl.contains("mysql")) + return SqlDialect.DatabaseProduct.MYSQL; + if (lowerUrl.contains("oracle")) + return SqlDialect.DatabaseProduct.ORACLE; + if (lowerUrl.contains("sqlite")) { + try { + return SqlDialect.DatabaseProduct.valueOf("SQLITE"); + } catch (Exception e) { + return SqlDialect.DatabaseProduct.UNKNOWN; + } + } + if (lowerUrl.contains("h2")) + return SqlDialect.DatabaseProduct.H2; + if (lowerUrl.contains("derby")) + return SqlDialect.DatabaseProduct.DERBY; + if (lowerUrl.contains("mssql") || lowerUrl.contains("sqlserver")) + return SqlDialect.DatabaseProduct.MSSQL; + return SqlDialect.DatabaseProduct.UNKNOWN; + } + + /** + * Returns the SQL type for a given Java class and database product. + * + * @param cls Java class + * @param product database product + * @return SQL type string + */ + public static String getSqlType(Class cls, SqlDialect.DatabaseProduct product) { + Map, String> typeMap = dialectTypeMaps.getOrDefault(product, + dialectTypeMaps.get(SqlDialect.DatabaseProduct.UNKNOWN)); + return typeMap.getOrDefault(cls, "VARCHAR(255)"); + } + + /** + * Extracts schema information from a POJO class or a Record. + * + * @param cls POJO class + * @param product database product + * @return a list of schema fields + */ + public static List getSchema(Class cls, SqlDialect.DatabaseProduct product) { + List schema = new ArrayList<>(); + if (cls == Record.class) { + // For Record.class without an instance, we can't derive names/types easily + // Users should use the instance-based getSchema or provide columnNames + return schema; + } + + for (Field field : cls.getDeclaredFields()) { + if (java.lang.reflect.Modifier.isStatic(field.getModifiers())) { + continue; + } + schema.add(new SchemaField(field.getName(), field.getType(), getSqlType(field.getType(), product))); + } + return schema; + } + + /** + * Extracts schema information from a Record instance by inspecting its fields. + * + * @param record representative record + * @param product database product + * @param userNames optional user-provided column names + * @return a list of schema fields + */ + public static List getSchema(Record record, SqlDialect.DatabaseProduct product, String[] userNames) { + List schema = new ArrayList<>(); + if (record == null) + return schema; + + int size = record.size(); + for (int i = 0; i < size; i++) { + String name = (userNames != null && i < userNames.length) ? userNames[i] : "c_" + i; + Object val = record.getField(i); + Class typeClass = val != null ? val.getClass() : String.class; + String type = getSqlType(typeClass, product); + schema.add(new SchemaField(name, typeClass, type)); + } + return schema; + } + + public static class SchemaField { + private final String name; + private final Class javaClass; + private final String sqlType; + + public SchemaField(String name, Class javaClass, String sqlType) { + this.name = name; + this.javaClass = javaClass; + this.sqlType = sqlType; + } + + public String getName() { + return name; + } + + public Class getJavaClass() { + return javaClass; + } + + public String getSqlType() { + return sqlType; + } + } +} diff --git a/wayang-commons/wayang-basic/src/test/java/org/apache/wayang/basic/util/SqlTypeUtilsTest.java b/wayang-commons/wayang-basic/src/test/java/org/apache/wayang/basic/util/SqlTypeUtilsTest.java new file mode 100644 index 000000000..28e043e12 --- /dev/null +++ b/wayang-commons/wayang-basic/src/test/java/org/apache/wayang/basic/util/SqlTypeUtilsTest.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.basic.util; + +import org.apache.calcite.sql.SqlDialect; +import org.apache.wayang.basic.data.Record; +import org.junit.jupiter.api.Test; + +import java.sql.Date; +import java.sql.Timestamp; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class SqlTypeUtilsTest { + + @Test + public void testDetectProduct() { + assertEquals(SqlDialect.DatabaseProduct.POSTGRESQL, + SqlTypeUtils.detectProduct("jdbc:postgresql://localhost:5432/db")); + assertEquals(SqlDialect.DatabaseProduct.MYSQL, SqlTypeUtils.detectProduct("jdbc:mysql://localhost:3306/db")); + assertEquals(SqlDialect.DatabaseProduct.ORACLE, + SqlTypeUtils.detectProduct("jdbc:oracle:thin:@localhost:1521:xe")); + assertEquals(SqlDialect.DatabaseProduct.H2, SqlTypeUtils.detectProduct("jdbc:h2:mem:test")); + assertEquals(SqlDialect.DatabaseProduct.DERBY, + SqlTypeUtils.detectProduct("jdbc:derby:memory:test;create=true")); + assertEquals(SqlDialect.DatabaseProduct.MSSQL, + SqlTypeUtils.detectProduct("jdbc:sqlserver://localhost:1433;databaseName=db")); + assertEquals(SqlDialect.DatabaseProduct.UNKNOWN, SqlTypeUtils.detectProduct("jdbc:unknown:db")); + } + + @Test + public void testGetSqlTypeDefault() { + SqlDialect.DatabaseProduct product = SqlDialect.DatabaseProduct.UNKNOWN; + assertEquals("INT", SqlTypeUtils.getSqlType(Integer.class, product)); + assertEquals("INT", SqlTypeUtils.getSqlType(int.class, product)); + assertEquals("BIGINT", SqlTypeUtils.getSqlType(Long.class, product)); + assertEquals("DOUBLE", SqlTypeUtils.getSqlType(Double.class, product)); + assertEquals("VARCHAR(255)", SqlTypeUtils.getSqlType(String.class, product)); + assertEquals("DATE", SqlTypeUtils.getSqlType(Date.class, product)); + assertEquals("TIMESTAMP", SqlTypeUtils.getSqlType(Timestamp.class, product)); + } + + @Test + public void testGetSqlTypePostgres() { + SqlDialect.DatabaseProduct product = SqlDialect.DatabaseProduct.POSTGRESQL; + assertEquals("INT", SqlTypeUtils.getSqlType(Integer.class, product)); + assertEquals("DOUBLE PRECISION", SqlTypeUtils.getSqlType(Double.class, product)); + assertEquals("DOUBLE PRECISION", SqlTypeUtils.getSqlType(double.class, product)); + assertEquals("VARCHAR(255)", SqlTypeUtils.getSqlType(String.class, product)); + } + + @Test + public void testGetSchema() { + List schema = SqlTypeUtils.getSchema(TestPojo.class, + SqlDialect.DatabaseProduct.POSTGRESQL); + assertEquals(3, schema.size()); + + assertEquals("id", schema.get(0).getName()); + assertEquals("INT", schema.get(0).getSqlType()); + + assertEquals("name", schema.get(1).getName()); + assertEquals("VARCHAR(255)", schema.get(1).getSqlType()); + + assertEquals("value", schema.get(2).getName()); + assertEquals("DOUBLE PRECISION", schema.get(2).getSqlType()); + } + + @Test + public void testGetSchemaRecord() { + Record record = new Record(1, "test", 1.5); + List schema = SqlTypeUtils.getSchema(record, SqlDialect.DatabaseProduct.POSTGRESQL, + null); + + assertEquals(3, schema.size()); + assertEquals("c_0", schema.get(0).getName()); + assertEquals("INT", schema.get(0).getSqlType()); + assertEquals(Integer.class, schema.get(0).getJavaClass()); + + assertEquals("c_1", schema.get(1).getName()); + assertEquals("VARCHAR(255)", schema.get(1).getSqlType()); + assertEquals(String.class, schema.get(1).getJavaClass()); + + assertEquals("c_2", schema.get(2).getName()); + assertEquals("DOUBLE PRECISION", schema.get(2).getSqlType()); + assertEquals(Double.class, schema.get(2).getJavaClass()); + } + + @Test + public void testGetSchemaRecordWithNames() { + Record record = new Record(1, "test"); + String[] names = { "id", "description" }; + List schema = SqlTypeUtils.getSchema(record, SqlDialect.DatabaseProduct.POSTGRESQL, + names); + + assertEquals(2, schema.size()); + assertEquals("id", schema.get(0).getName()); + assertEquals("description", schema.get(1).getName()); + } + + public static class TestPojo { + public int id; + public String name; + public Double value; + } +} diff --git a/wayang-platforms/wayang-java/pom.xml b/wayang-platforms/wayang-java/pom.xml index 9c58a78fb..70966b92d 100644 --- a/wayang-platforms/wayang-java/pom.xml +++ b/wayang-platforms/wayang-java/pom.xml @@ -78,7 +78,19 @@ log4j-slf4j-impl 2.20.0 + + org.postgresql + postgresql + 42.7.2 + test + + + com.h2database + h2 + 2.2.224 + test + org.mockito mockito-core diff --git a/wayang-platforms/wayang-java/src/main/java/org/apache/wayang/java/operators/JavaTableSink.java b/wayang-platforms/wayang-java/src/main/java/org/apache/wayang/java/operators/JavaTableSink.java new file mode 100644 index 000000000..8c2564551 --- /dev/null +++ b/wayang-platforms/wayang-java/src/main/java/org/apache/wayang/java/operators/JavaTableSink.java @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.java.operators; + +import org.apache.wayang.basic.data.Record; +import org.apache.wayang.basic.operators.TableSink; +import org.apache.wayang.basic.util.SqlTypeUtils; +import org.apache.wayang.core.optimizer.OptimizationContext; +import org.apache.wayang.core.plan.wayangplan.ExecutionOperator; +import org.apache.wayang.core.platform.ChannelDescriptor; +import org.apache.wayang.core.platform.ChannelInstance; +import org.apache.wayang.core.platform.lineage.ExecutionLineageNode; +import org.apache.wayang.core.types.DataSetType; +import org.apache.wayang.core.util.ReflectionUtils; +import org.apache.wayang.core.util.Tuple; +import org.apache.wayang.java.channels.CollectionChannel; +import org.apache.wayang.java.channels.JavaChannelInstance; +import org.apache.wayang.java.channels.StreamChannel; +import org.apache.wayang.java.execution.JavaExecutor; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.PreparedStatement; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.Arrays; +import java.util.Collection; +import java.util.Iterator; +import java.util.List; +import java.util.Properties; + +public class JavaTableSink extends TableSink implements JavaExecutionOperator { + + private void setRecordValue(PreparedStatement ps, int index, Object value) throws SQLException { + if (value == null) { + ps.setNull(index, java.sql.Types.NULL); + } else if (value instanceof Integer) { + ps.setInt(index, (Integer) value); + } else if (value instanceof Long) { + ps.setLong(index, (Long) value); + } else if (value instanceof Double) { + ps.setDouble(index, (Double) value); + } else if (value instanceof Float) { + ps.setFloat(index, (Float) value); + } else if (value instanceof Boolean) { + ps.setBoolean(index, (Boolean) value); + } else if (value instanceof java.sql.Date) { + ps.setDate(index, (java.sql.Date) value); + } else if (value instanceof java.sql.Timestamp) { + ps.setTimestamp(index, (java.sql.Timestamp) value); + } else { + ps.setString(index, value.toString()); + } + } + + public JavaTableSink(Properties props, String mode, String tableName) { + this(props, mode, tableName, null); + } + + public JavaTableSink(Properties props, String mode, String tableName, String... columnNames) { + super(props, mode, tableName, columnNames); + + } + + public JavaTableSink(Properties props, String mode, String tableName, String[] columnNames, DataSetType type) { + super(props, mode, tableName, columnNames, type); + + } + + public JavaTableSink(TableSink that) { + super(that); + } + + @Override + public Tuple, Collection> evaluate( + ChannelInstance[] inputs, + ChannelInstance[] outputs, + JavaExecutor javaExecutor, + OptimizationContext.OperatorContext operatorContext) { + assert inputs.length == 1; + assert outputs.length == 0; + JavaChannelInstance input = (JavaChannelInstance) inputs[0]; + + // The stream is converted to an Iterator so that we can read the first element + // w/o consuming the entire stream. + Iterator recordIterator = input.provideStream().iterator(); + + if (!recordIterator.hasNext()) { + return ExecutionOperator.modelEagerExecution(inputs, outputs, operatorContext); + } + + // We read the first element to derive the Record schema. + T firstElement = recordIterator.next(); + Class typeClass = this.getType().getDataUnitType().getTypeClass(); + + String url = this.getProperties().getProperty("url"); + org.apache.calcite.sql.SqlDialect.DatabaseProduct product = SqlTypeUtils.detectProduct(url); + + List schemaFields; + if (typeClass != Record.class) { + schemaFields = SqlTypeUtils.getSchema(typeClass, product); + } else { + schemaFields = SqlTypeUtils.getSchema((Record) firstElement, product, this.getColumnNames()); + } + + String[] currentColumnNames = this.getColumnNames(); + if (currentColumnNames == null || currentColumnNames.length == 0) { + currentColumnNames = new String[schemaFields.size()]; + for (int i = 0; i < schemaFields.size(); i++) { + currentColumnNames[i] = schemaFields.get(i).getName(); + } + this.setColumnNames(currentColumnNames); + } + + String[] sqlTypes = new String[currentColumnNames.length]; + for (int i = 0; i < currentColumnNames.length; i++) { + sqlTypes[i] = "VARCHAR(255)"; // Default + for (SqlTypeUtils.SchemaField field : schemaFields) { + if (field.getName().equals(currentColumnNames[i])) { + sqlTypes[i] = field.getSqlType(); + break; + } + } + } + + final String[] finalColumnNames = currentColumnNames; + final String[] finalSqlTypes = sqlTypes; + + this.getProperties().setProperty("streamingBatchInsert", "True"); + + Connection conn; + try { + Class.forName(this.getProperties().getProperty("driver")); + conn = DriverManager.getConnection(this.getProperties().getProperty("url"), this.getProperties()); + conn.setAutoCommit(false); + + Statement stmt = conn.createStatement(); + + // Drop existing table if the mode is 'overwrite'. + if (this.getMode().equals("overwrite")) { + stmt.execute("DROP TABLE IF EXISTS " + this.getTableName()); + } + + // Create a new table if the specified table name does not exist yet. + StringBuilder sb = new StringBuilder(); + sb.append("CREATE TABLE IF NOT EXISTS ").append(this.getTableName()).append(" ("); + String separator = ""; + for (int i = 0; i < finalColumnNames.length; i++) { + sb.append(separator).append("\"").append(finalColumnNames[i]).append("\" ").append(finalSqlTypes[i]); + separator = ", "; + } + sb.append(")"); + stmt.execute(sb.toString()); + + // Create a prepared statement to insert value from the recordIterator. + sb = new StringBuilder(); + sb.append("INSERT INTO ").append(this.getTableName()).append(" ("); + separator = ""; + for (int i = 0; i < finalColumnNames.length; i++) { + sb.append(separator).append("\"").append(finalColumnNames[i]).append("\""); + separator = ", "; + } + sb.append(") VALUES ("); + separator = ""; + for (int i = 0; i < finalColumnNames.length; i++) { + sb.append(separator).append("?"); + separator = ", "; + } + sb.append(")"); + PreparedStatement ps = conn.prepareStatement(sb.toString()); + + // The schema Record has to be pushed to the database too. + this.pushToStatement(ps, firstElement, typeClass, finalColumnNames); + ps.addBatch(); + + // Iterate through all remaining records and add them to the prepared statement + recordIterator.forEachRemaining( + r -> { + try { + this.pushToStatement(ps, r, typeClass, finalColumnNames); + ps.addBatch(); + } catch (SQLException e) { + e.printStackTrace(); + } + }); + + ps.executeBatch(); + conn.commit(); + conn.close(); + } catch (ClassNotFoundException e) { + System.out.println("Please specify a correct database driver."); + e.printStackTrace(); + } catch (SQLException e) { + e.printStackTrace(); + } + + return ExecutionOperator.modelEagerExecution(inputs, outputs, operatorContext); + } + + private void pushToStatement(PreparedStatement ps, T element, Class typeClass, String[] columnNames) + throws SQLException { + if (typeClass == Record.class) { + Record r = (Record) element; + for (int i = 0; i < columnNames.length; i++) { + setRecordValue(ps, i + 1, r.getField(i)); + } + } else { + for (int i = 0; i < columnNames.length; i++) { + Object val = ReflectionUtils.getProperty(element, columnNames[i]); + setRecordValue(ps, i + 1, val); + } + } + } + + @Override + public String getLoadProfileEstimatorConfigurationKey() { + return "wayang.java.tablesink.load"; + } + + @Override + public List getSupportedInputChannels(int index) { + return Arrays.asList(CollectionChannel.DESCRIPTOR, StreamChannel.DESCRIPTOR); + } + + @Override + public List getSupportedOutputChannels(int index) { + throw new UnsupportedOperationException("This operator has no outputs."); + } + +} diff --git a/wayang-platforms/wayang-java/src/test/java/org/apache/wayang/java/operators/JavaTableSinkTest.java b/wayang-platforms/wayang-java/src/test/java/org/apache/wayang/java/operators/JavaTableSinkTest.java new file mode 100644 index 000000000..02b719e0f --- /dev/null +++ b/wayang-platforms/wayang-java/src/test/java/org/apache/wayang/java/operators/JavaTableSinkTest.java @@ -0,0 +1,328 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.java.operators; + +import org.apache.wayang.basic.data.Record; +import org.apache.wayang.core.api.Configuration; +import org.apache.wayang.core.api.Job; +import org.apache.wayang.core.optimizer.OptimizationContext; +import org.apache.wayang.core.plan.wayangplan.OutputSlot; +import org.apache.wayang.core.platform.ChannelInstance; +import org.apache.wayang.core.types.DataSetType; +import org.apache.wayang.java.channels.StreamChannel; +import org.apache.wayang.java.execution.JavaExecutor; +import org.apache.wayang.java.platform.JavaPlatform; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.Statement; +import java.util.Properties; +import java.util.stream.Stream; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Test suite for {@link JavaTableSink}. + */ +class JavaTableSinkTest extends JavaExecutionOperatorTestBase { + + private static final String JDBC_URL = "jdbc:h2:mem:testdb;DB_CLOSE_DELAY=-1"; + private static final String DRIVER = "org.h2.Driver"; + private static final String TABLE_NAME = "test_table"; + + private Connection connection; + + @BeforeEach + void setupTest() throws Exception { + Class.forName(DRIVER); + connection = DriverManager.getConnection(JDBC_URL, "sa", ""); + } + + @AfterEach + void teardownTest() throws Exception { + if (connection != null && !connection.isClosed()) { + try (Statement stmt = connection.createStatement()) { + stmt.execute("DROP TABLE IF EXISTS " + TABLE_NAME); + } + connection.close(); + } + } + + @Test + void testWritingRecordToH2() throws Exception { + Configuration configuration = new Configuration(); + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + JavaTableSink sink = new JavaTableSink<>(dbProps, "overwrite", TABLE_NAME, + new String[] { "id", "name", "value" }, + DataSetType.createDefault(Record.class)); + + Job job = mock(Job.class); + when(job.getConfiguration()).thenReturn(configuration); + final JavaExecutor javaExecutor = (JavaExecutor) JavaPlatform.getInstance().createExecutor(job); + + StreamChannel.Instance inputChannelInstance = (StreamChannel.Instance) StreamChannel.DESCRIPTOR + .createChannel(mock(OutputSlot.class), configuration) + .createInstance(javaExecutor, mock(OptimizationContext.OperatorContext.class), 0); + + Record record1 = new Record(1, "Alice", 100.5); + Record record2 = new Record(2, "Bob", 200.75); + + inputChannelInstance.accept(Stream.of(record1, record2)); + + evaluate(sink, new ChannelInstance[] { inputChannelInstance }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM " + TABLE_NAME)) { + rs.next(); + assertEquals(2, rs.getInt(1)); + } + } + + @Test + void testWritingPojoToH2() throws Exception { + Configuration configuration = new Configuration(); + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + JavaTableSink sink = new JavaTableSink<>(dbProps, "overwrite", TABLE_NAME, + null, // schema detected via reflection + DataSetType.createDefault(TestPojo.class)); + + Job job = mock(Job.class); + when(job.getConfiguration()).thenReturn(configuration); + final JavaExecutor javaExecutor = (JavaExecutor) JavaPlatform.getInstance().createExecutor(job); + + StreamChannel.Instance inputChannelInstance = (StreamChannel.Instance) StreamChannel.DESCRIPTOR + .createChannel(mock(OutputSlot.class), configuration) + .createInstance(javaExecutor, mock(OptimizationContext.OperatorContext.class), 0); + + TestPojo p1 = new TestPojo(1, "Alice"); + TestPojo p2 = new TestPojo(2, "Bob"); + + inputChannelInstance.accept(Stream.of(p1, p2)); + + evaluate(sink, new ChannelInstance[] { inputChannelInstance }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT * FROM " + TABLE_NAME + " ORDER BY \"id\"")) { + rs.next(); + assertEquals(1, rs.getInt("id")); + assertEquals("Alice", rs.getString("name")); + rs.next(); + assertEquals(2, rs.getInt("id")); + assertEquals("Bob", rs.getString("name")); + } + } + + @Test + void testAppendMode() throws Exception { + Configuration configuration = new Configuration(); + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + // 1. Initial write (overwrite) + JavaTableSink sink1 = new JavaTableSink<>(dbProps, "overwrite", TABLE_NAME, + new String[] { "id", "name" }, + DataSetType.createDefault(Record.class)); + + Job job = mock(Job.class); + when(job.getConfiguration()).thenReturn(configuration); + final JavaExecutor javaExecutor = (JavaExecutor) JavaPlatform.getInstance().createExecutor(job); + + StreamChannel.Instance input1 = (StreamChannel.Instance) StreamChannel.DESCRIPTOR + .createChannel(mock(OutputSlot.class), configuration) + .createInstance(javaExecutor, mock(OptimizationContext.OperatorContext.class), 0); + input1.accept(Stream.of(new Record(1, "Alice"))); + evaluate(sink1, new ChannelInstance[] { input1 }, new ChannelInstance[0]); + + // 2. Append write + JavaTableSink sink2 = new JavaTableSink<>(dbProps, "append", TABLE_NAME, + new String[] { "id", "name" }, + DataSetType.createDefault(Record.class)); + + Job job2 = mock(Job.class); + when(job2.getConfiguration()).thenReturn(configuration); + final JavaExecutor javaExecutor2 = (JavaExecutor) JavaPlatform.getInstance().createExecutor(job2); + + StreamChannel.Instance input2 = (StreamChannel.Instance) StreamChannel.DESCRIPTOR + .createChannel(mock(OutputSlot.class), configuration) + .createInstance(javaExecutor2, mock(OptimizationContext.OperatorContext.class), 0); + input2.accept(Stream.of(new Record(2, "Bob"))); + evaluate(sink2, new ChannelInstance[] { input2 }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM " + TABLE_NAME)) { + rs.next(); + assertEquals(2, rs.getInt(1)); + } + } + + @Test + void testOverwriteWithSchemaMismatch() throws Exception { + Configuration configuration = new Configuration(); + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + // 1. Create table with old schema (id, name) + try (Statement stmt = connection.createStatement()) { + stmt.execute("CREATE TABLE " + TABLE_NAME + " (id INT, name VARCHAR(255))"); + stmt.execute("INSERT INTO " + TABLE_NAME + " VALUES (1, 'Old')"); + } + + // 2. Overwrite with new schema (id, age, city) + JavaTableSink sink = new JavaTableSink<>(dbProps, "overwrite", TABLE_NAME, + new String[] { "id", "age", "city" }, + DataSetType.createDefault(Record.class)); + + Job job = mock(Job.class); + when(job.getConfiguration()).thenReturn(configuration); + final JavaExecutor javaExecutor = (JavaExecutor) JavaPlatform.getInstance().createExecutor(job); + + StreamChannel.Instance input = (StreamChannel.Instance) StreamChannel.DESCRIPTOR + .createChannel(mock(OutputSlot.class), configuration) + .createInstance(javaExecutor, mock(OptimizationContext.OperatorContext.class), 0); + input.accept(Stream.of(new Record(2, 30, "Berlin"))); + evaluate(sink, new ChannelInstance[] { input }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT * FROM " + TABLE_NAME)) { + rs.next(); + assertEquals(2, rs.getInt("id")); + assertEquals(30, rs.getInt("age")); + assertEquals("Berlin", rs.getString("city")); + + // Verify 'name' column is gone + boolean hasName = false; + for (int i = 1; i <= rs.getMetaData().getColumnCount(); i++) { + if ("name".equalsIgnoreCase(rs.getMetaData().getColumnName(i))) { + hasName = true; + } + } + assertFalse(hasName, "Column 'name' should have been dropped"); + } + } + + @Test + void testNullValues() throws Exception { + Configuration configuration = new Configuration(); + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + JavaTableSink sink = new JavaTableSink<>(dbProps, "overwrite", TABLE_NAME, + new String[] { "id", "name" }, + DataSetType.createDefault(Record.class)); + + Job job = mock(Job.class); + when(job.getConfiguration()).thenReturn(configuration); + final JavaExecutor javaExecutor = (JavaExecutor) JavaPlatform.getInstance().createExecutor(job); + + StreamChannel.Instance input = (StreamChannel.Instance) StreamChannel.DESCRIPTOR + .createChannel(mock(OutputSlot.class), configuration) + .createInstance(javaExecutor, mock(OptimizationContext.OperatorContext.class), 0); + + input.accept(Stream.of(new Record(1, null))); + evaluate(sink, new ChannelInstance[] { input }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT \"name\" FROM " + TABLE_NAME + " WHERE \"id\" = 1")) { + rs.next(); + assertEquals(null, rs.getString(1)); + assertTrue(rs.wasNull()); + } + } + + @Test + void testSupportedTypes() throws Exception { + Configuration configuration = new Configuration(); + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + JavaTableSink sink = new JavaTableSink<>(dbProps, "overwrite", TABLE_NAME, + new String[] { "id", "is_active", "salary", "score" }, + DataSetType.createDefault(Record.class)); + + Job job = mock(Job.class); + when(job.getConfiguration()).thenReturn(configuration); + final JavaExecutor javaExecutor = (JavaExecutor) JavaPlatform.getInstance().createExecutor(job); + + StreamChannel.Instance input = (StreamChannel.Instance) StreamChannel.DESCRIPTOR + .createChannel(mock(OutputSlot.class), configuration) + .createInstance(javaExecutor, mock(OptimizationContext.OperatorContext.class), 0); + + input.accept(Stream.of(new Record(1, true, 5000.50, 95.5f))); + evaluate(sink, new ChannelInstance[] { input }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT * FROM " + TABLE_NAME + " WHERE \"id\" = 1")) { + rs.next(); + assertTrue(rs.getBoolean("is_active")); + assertEquals(5000.50, rs.getDouble("salary"), 0.001); + assertEquals(95.5f, rs.getFloat("score"), 0.001f); + } + } + + public static class TestPojo { + private int id; + private String name; + + public TestPojo() { + } + + public TestPojo(int id, String name) { + this.id = id; + this.name = name; + } + + public int getId() { + return id; + } + + public String getName() { + return name; + } + } +} \ No newline at end of file diff --git a/wayang-platforms/wayang-spark/pom.xml b/wayang-platforms/wayang-spark/pom.xml index 1e89fd15e..abdd225d6 100644 --- a/wayang-platforms/wayang-spark/pom.xml +++ b/wayang-platforms/wayang-spark/pom.xml @@ -121,5 +121,18 @@ 4.8 + + org.postgresql + postgresql + 42.7.2 + compile + + + + com.h2database + h2 + 2.2.224 + test + diff --git a/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkTableSink.java b/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkTableSink.java new file mode 100644 index 000000000..9d5edf62e --- /dev/null +++ b/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkTableSink.java @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.wayang.spark.operators; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.wayang.basic.data.Record; +import org.apache.wayang.basic.operators.TableSink; +import org.apache.wayang.basic.util.SqlTypeUtils; +import org.apache.wayang.core.api.exception.WayangException; +import org.apache.wayang.core.optimizer.OptimizationContext; +import org.apache.wayang.core.plan.wayangplan.ExecutionOperator; +import org.apache.wayang.core.platform.ChannelDescriptor; +import org.apache.wayang.core.platform.ChannelInstance; +import org.apache.wayang.core.platform.lineage.ExecutionLineageNode; +import org.apache.wayang.core.types.DataSetType; +import org.apache.wayang.core.util.Tuple; +import org.apache.wayang.spark.channels.RddChannel; +import org.apache.wayang.spark.execution.SparkExecutor; + +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Properties; + +public class SparkTableSink extends TableSink implements SparkExecutionOperator { + + private SaveMode mode; + + public SparkTableSink(Properties props, String mode, String tableName, String... columnNames) { + super(props, mode, tableName, columnNames); + this.setMode(mode); + } + + public SparkTableSink(Properties props, String mode, String tableName, String[] columnNames, DataSetType type) { + super(props, mode, tableName, columnNames, type); + this.setMode(mode); + } + + public SparkTableSink(TableSink that) { + super(that); + this.setMode(that.getMode()); + } + + @Override + public Tuple, Collection> evaluate( + ChannelInstance[] inputs, + ChannelInstance[] outputs, + SparkExecutor sparkExecutor, + OptimizationContext.OperatorContext operatorContext) { + assert inputs.length == 1; + assert outputs.length == 0; + + JavaRDD recordRDD = ((RddChannel.Instance) inputs[0]).provideRdd(); + Class typeClass = (Class) this.getType().getDataUnitType().getTypeClass(); + SparkSession sparkSession = SparkSession.builder().sparkContext(sparkExecutor.sc.sc()).getOrCreate(); + SQLContext sqlContext = sparkSession.sqlContext(); + + Dataset df; + if (typeClass == Record.class) { + // Records need manual schema handling + if (recordRDD.isEmpty()) { + return ExecutionOperator.modelEagerExecution(inputs, outputs, operatorContext); + } + Record first = (Record) recordRDD.first(); + + // Centralized Schema Derivation + List schemaFields = SqlTypeUtils.getSchema(first, + SqlTypeUtils.detectProduct(this.getProperties().getProperty("url")), + this.getColumnNames()); + + // Map Record to Row + JavaRDD rowRDD = recordRDD.map(rec -> RowFactory.create(((Record) rec).getValues())); + + // Build Spark Schema + StructField[] fields = new StructField[schemaFields.size()]; + for (int i = 0; i < schemaFields.size(); i++) { + SqlTypeUtils.SchemaField sf = schemaFields.get(i); + org.apache.spark.sql.types.DataType sparkType = getSparkDataType(sf.getJavaClass()); + fields[i] = new StructField(sf.getName(), sparkType, true, Metadata.empty()); + } + + // Update column names in the operator if they were generated + String[] newColNames = schemaFields.stream().map(SqlTypeUtils.SchemaField::getName).toArray(String[]::new); + this.setColumnNames(newColNames); + + df = sqlContext.createDataFrame(rowRDD, new StructType(fields)); + } else { + // POJO Case: Let Spark handle it natively + df = sqlContext.createDataFrame(recordRDD, typeClass); + // If columnNames are provided, we should probably select/rename them, + // but usually createDataFrame(rdd, beanClass) maps fields to columns. + if (this.getColumnNames() != null && this.getColumnNames().length > 0) { + // Optionally filter or reorder columns to match this.getColumnNames() + // For now, Spark's native mapping is preferred. + } + } + + this.getProperties().setProperty("batchSize", "250000"); + df.write() + .mode(this.mode) + .jdbc(this.getProperties().getProperty("url"), this.getTableName(), this.getProperties()); + + return ExecutionOperator.modelEagerExecution(inputs, outputs, operatorContext); + } + + private org.apache.spark.sql.types.DataType getSparkDataType(Class cls) { + if (cls == Integer.class || cls == int.class) + return DataTypes.IntegerType; + if (cls == Long.class || cls == long.class) + return DataTypes.LongType; + if (cls == Double.class || cls == double.class) + return DataTypes.DoubleType; + if (cls == Float.class || cls == float.class) + return DataTypes.FloatType; + if (cls == Boolean.class || cls == boolean.class) + return DataTypes.BooleanType; + if (cls == java.sql.Date.class || cls == java.time.LocalDate.class) + return DataTypes.DateType; + if (cls == java.sql.Timestamp.class || cls == java.time.LocalDateTime.class) + return DataTypes.TimestampType; + return DataTypes.StringType; + } + + public void setMode(String mode) { + if (mode == null) { + throw new WayangException("Unspecified write mode for SparkTableSink."); + } else if (mode.equals("append")) { + this.mode = SaveMode.Append; + } else if (mode.equals("overwrite")) { + this.mode = SaveMode.Overwrite; + } else if (mode.equals("errorIfExists")) { + this.mode = SaveMode.ErrorIfExists; + } else if (mode.equals("ignore")) { + this.mode = SaveMode.Ignore; + } else { + throw new WayangException( + String.format("Specified write mode for SparkTableSink does not exist: %s", mode)); + } + } + + @Override + public List getSupportedInputChannels(int index) { + return Arrays.asList(RddChannel.UNCACHED_DESCRIPTOR, RddChannel.CACHED_DESCRIPTOR); + } + + @Override + public List getSupportedOutputChannels(int index) { + throw new UnsupportedOperationException("This operator has no outputs."); + } + + @Override + public boolean containsAction() { + return true; + } + + @Override + public String getLoadProfileEstimatorConfigurationKey() { + return "wayang.spark.tablesink.load"; + } +} diff --git a/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkTableSinkTest.java b/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkTableSinkTest.java new file mode 100644 index 000000000..0197c3749 --- /dev/null +++ b/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkTableSinkTest.java @@ -0,0 +1,281 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.wayang.spark.operators; + +import org.apache.wayang.basic.data.Record; +import org.apache.wayang.core.api.Configuration; +import org.apache.wayang.core.api.Job; +import org.apache.wayang.core.optimizer.OptimizationContext; +import org.apache.wayang.core.plan.wayangplan.OutputSlot; +import org.apache.wayang.core.platform.ChannelInstance; +import org.apache.wayang.core.types.DataSetType; +import org.apache.wayang.spark.channels.RddChannel; +import org.apache.wayang.spark.platform.SparkPlatform; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.Statement; +import java.util.Arrays; +import java.util.Properties; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Test suite for {@link SparkTableSink}. + */ +class SparkTableSinkTest extends SparkOperatorTestBase { + + private static final String JDBC_URL = "jdbc:h2:mem:sparktestdb;DB_CLOSE_DELAY=-1"; + private static final String DRIVER = "org.h2.Driver"; + private static final String TABLE_NAME = "spark_test_table"; + + private Connection connection; + + @BeforeEach + void setupTest() throws Exception { + Class.forName(DRIVER); + connection = DriverManager.getConnection(JDBC_URL, "sa", ""); + } + + @AfterEach + void teardownTest() throws Exception { + if (connection != null && !connection.isClosed()) { + try (Statement stmt = connection.createStatement()) { + stmt.execute("DROP TABLE IF EXISTS " + TABLE_NAME); + } + connection.close(); + } + } + + @Test + void testWritingRecordToH2() throws Exception { + Configuration configuration = new Configuration(); + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + SparkTableSink sink = new SparkTableSink<>(dbProps, "overwrite", TABLE_NAME, + new String[] { "id", "name", "value" }, + DataSetType.createDefault(Record.class)); + + Job job = mock(Job.class); + when(job.getConfiguration()).thenReturn(configuration); + + Record record1 = new Record(1, "Alice", 100.5); + Record record2 = new Record(2, "Bob", 200.75); + + RddChannel.Instance inputChannelInstance = this.createRddChannelInstance( + Arrays.asList(record1, record2)); + + evaluate(sink, new ChannelInstance[] { inputChannelInstance }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM " + TABLE_NAME)) { + rs.next(); + assertEquals(2, rs.getInt(1)); + } + } + + @Test + void testWritingPojoToH2() throws Exception { + Configuration configuration = new Configuration(); + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + SparkTableSink sink = new SparkTableSink<>(dbProps, "overwrite", TABLE_NAME, + null, // schema detected via reflection + DataSetType.createDefault(TestPojo.class)); + + Job job = mock(Job.class); + when(job.getConfiguration()).thenReturn(configuration); + + TestPojo p1 = new TestPojo(1, "Alice"); + TestPojo p2 = new TestPojo(2, "Bob"); + + RddChannel.Instance inputChannelInstance = this.createRddChannelInstance( + Arrays.asList(p1, p2)); + + evaluate(sink, new ChannelInstance[] { inputChannelInstance }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT * FROM " + TABLE_NAME + " ORDER BY \"id\"")) { + rs.next(); + assertEquals(1, rs.getInt("id")); + assertEquals("Alice", rs.getString("name")); + rs.next(); + assertEquals(2, rs.getInt("id")); + assertEquals("Bob", rs.getString("name")); + } + } + + @Test + void testAppendMode() throws Exception { + Configuration configuration = new Configuration(); + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + // 1. Initial write (overwrite) + SparkTableSink sink1 = new SparkTableSink<>(dbProps, "overwrite", TABLE_NAME, + new String[] { "id", "name" }, + DataSetType.createDefault(Record.class)); + + RddChannel.Instance input1 = this.createRddChannelInstance(Arrays.asList(new Record(1, "Alice"))); + evaluate(sink1, new ChannelInstance[] { input1 }, new ChannelInstance[0]); + + // 2. Append write + SparkTableSink sink2 = new SparkTableSink<>(dbProps, "append", TABLE_NAME, + new String[] { "id", "name" }, + DataSetType.createDefault(Record.class)); + + RddChannel.Instance input2 = this.createRddChannelInstance(Arrays.asList(new Record(2, "Bob"))); + evaluate(sink2, new ChannelInstance[] { input2 }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM " + TABLE_NAME)) { + rs.next(); + assertEquals(2, rs.getInt(1)); + } + } + + @Test + void testOverwriteWithSchemaMismatch() throws Exception { + Configuration configuration = new Configuration(); + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + // 1. Create table with old schema (id, name) + try (Statement stmt = connection.createStatement()) { + stmt.execute("CREATE TABLE " + TABLE_NAME + " (\"id\" INT, \"name\" VARCHAR(255))"); + stmt.execute("INSERT INTO " + TABLE_NAME + " VALUES (1, 'Old')"); + } + + // 2. Overwrite with new schema (id, age, city) + SparkTableSink sink = new SparkTableSink<>(dbProps, "overwrite", TABLE_NAME, + new String[] { "id", "age", "city" }, + DataSetType.createDefault(Record.class)); + + RddChannel.Instance input = this.createRddChannelInstance(Arrays.asList(new Record(2, 30, "Berlin"))); + evaluate(sink, new ChannelInstance[] { input }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT * FROM " + TABLE_NAME)) { + rs.next(); + assertEquals(2, rs.getInt("id")); + assertEquals(30, rs.getInt("age")); + assertEquals("Berlin", rs.getString("city")); + + // Verify 'name' column is gone + boolean hasName = false; + for (int i = 1; i <= rs.getMetaData().getColumnCount(); i++) { + if ("name".equalsIgnoreCase(rs.getMetaData().getColumnName(i))) { + hasName = true; + } + } + assertFalse(hasName, "Column 'name' should have been dropped"); + } + } + + @Test + void testNullValues() throws Exception { + Configuration configuration = new Configuration(); + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + SparkTableSink sink = new SparkTableSink<>(dbProps, "overwrite", TABLE_NAME, + new String[] { "id", "name" }, + DataSetType.createDefault(Record.class)); + + RddChannel.Instance input = this.createRddChannelInstance(Arrays.asList(new Record(1, null))); + evaluate(sink, new ChannelInstance[] { input }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT \"name\" FROM " + TABLE_NAME + " WHERE \"id\" = 1")) { + rs.next(); + assertEquals(null, rs.getString(1)); + assertTrue(rs.wasNull()); + } + } + + @Test + void testSupportedTypes() throws Exception { + Configuration configuration = new Configuration(); + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + SparkTableSink sink = new SparkTableSink<>(dbProps, "overwrite", TABLE_NAME, + new String[] { "id", "is_active", "salary", "score" }, + DataSetType.createDefault(Record.class)); + + RddChannel.Instance input = this.createRddChannelInstance(Arrays.asList(new Record(1, true, 5000.50, 95.5f))); + evaluate(sink, new ChannelInstance[] { input }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT * FROM " + TABLE_NAME + " WHERE \"id\" = 1")) { + rs.next(); + assertTrue(rs.getBoolean("is_active")); + assertEquals(5000.50, rs.getDouble("salary"), 0.001); + assertEquals(95.5f, rs.getFloat("score"), 0.001f); + } + } + + public static class TestPojo implements java.io.Serializable { + private int id; + private String name; + + public TestPojo() { + } + + public TestPojo(int id, String name) { + this.id = id; + this.name = name; + } + + public int getId() { + return id; + } + + public String getName() { + return name; + } + } +} \ No newline at end of file From 2d453ed8a142624d3baf2d5eef3bb024a73bab44 Mon Sep 17 00:00:00 2001 From: harry Date: Sat, 7 Mar 2026 15:10:53 +0100 Subject: [PATCH 2/2] hardening table sink abstraction, implementation and testing --- wayang-commons/wayang-basic/pom.xml | 5 - .../wayang/basic/operators/TableSink.java | 15 +- .../wayang/basic/util/DatabaseProduct.java | 34 ++++ .../wayang/basic/util/SqlTypeUtils.java | 64 ++++---- .../wayang/basic/util/SqlTypeUtilsTest.java | 97 +++++++----- wayang-platforms/wayang-java/pom.xml | 6 - .../wayang/java/operators/JavaTableSink.java | 148 ++++++++++-------- .../java/operators/JavaTableSinkTest.java | 20 +-- .../operators/JavaTextFileSourceTest.java | 7 +- wayang-platforms/wayang-spark/pom.xml | 7 - .../spark/operators/SparkTableSink.java | 47 ++++-- .../spark/operators/SparkTableSinkTest.java | 41 ++--- 12 files changed, 280 insertions(+), 211 deletions(-) create mode 100644 wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/util/DatabaseProduct.java diff --git a/wayang-commons/wayang-basic/pom.xml b/wayang-commons/wayang-basic/pom.xml index f8ce0fe0e..1d1b460ae 100644 --- a/wayang-commons/wayang-basic/pom.xml +++ b/wayang-commons/wayang-basic/pom.xml @@ -120,11 +120,6 @@ 20231013 - - org.apache.calcite - calcite-core - ${calcite.version} - com.azure azure-storage-blob diff --git a/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/operators/TableSink.java b/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/operators/TableSink.java index 0b556519f..e5e4b54af 100644 --- a/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/operators/TableSink.java +++ b/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/operators/TableSink.java @@ -52,8 +52,11 @@ public TableSink(Properties props, String mode, String tableName, String... colu public TableSink(Properties props, String mode, String tableName, String[] columnNames, DataSetType type) { super(type); this.tableName = tableName; - this.columnNames = columnNames; - this.props = props; + this.columnNames = columnNames == null ? null : java.util.Arrays.copyOf(columnNames, columnNames.length); + this.props = new Properties(); + if (props != null) { + this.props.putAll(props); + } this.mode = mode; } @@ -75,15 +78,17 @@ public String getTableName() { } protected void setColumnNames(String[] columnNames) { - this.columnNames = columnNames; + this.columnNames = columnNames == null ? null : java.util.Arrays.copyOf(columnNames, columnNames.length); } public String[] getColumnNames() { - return this.columnNames; + return this.columnNames == null ? null : java.util.Arrays.copyOf(this.columnNames, this.columnNames.length); } public Properties getProperties() { - return this.props; + Properties copy = new Properties(); + copy.putAll(this.props); + return copy; } public String getMode() { diff --git a/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/util/DatabaseProduct.java b/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/util/DatabaseProduct.java new file mode 100644 index 000000000..d87eeb900 --- /dev/null +++ b/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/util/DatabaseProduct.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.basic.util; + +/** + * Internal representation of database products to avoid external dependencies + * in wayang-basic. + */ +public enum DatabaseProduct { + POSTGRESQL, + MYSQL, + ORACLE, + SQLITE, + H2, + DERBY, + MSSQL, + UNKNOWN +} diff --git a/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/util/SqlTypeUtils.java b/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/util/SqlTypeUtils.java index 541600b71..b32f84f57 100644 --- a/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/util/SqlTypeUtils.java +++ b/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/util/SqlTypeUtils.java @@ -18,10 +18,7 @@ package org.apache.wayang.basic.util; -import org.apache.calcite.sql.SqlDialect; import org.apache.wayang.basic.data.Record; - -import java.lang.reflect.Field; import java.sql.Date; import java.sql.Timestamp; import java.time.LocalDate; @@ -36,7 +33,7 @@ */ public class SqlTypeUtils { - private static final Map, String>> dialectTypeMaps = new HashMap<>(); + private static final Map, String>> dialectTypeMaps = new HashMap<>(); static { // Default mappings (Standard SQL) @@ -57,13 +54,13 @@ public class SqlTypeUtils { defaultMap.put(Timestamp.class, "TIMESTAMP"); defaultMap.put(LocalDateTime.class, "TIMESTAMP"); - dialectTypeMaps.put(SqlDialect.DatabaseProduct.UNKNOWN, defaultMap); + dialectTypeMaps.put(DatabaseProduct.UNKNOWN, defaultMap); // PostgreSQL Overrides Map, String> pgMap = new HashMap<>(defaultMap); pgMap.put(Double.class, "DOUBLE PRECISION"); pgMap.put(double.class, "DOUBLE PRECISION"); - dialectTypeMaps.put(SqlDialect.DatabaseProduct.POSTGRESQL, pgMap); + dialectTypeMaps.put(DatabaseProduct.POSTGRESQL, pgMap); // Add more dialects here as needed (MySQL, Oracle, etc.) } @@ -74,30 +71,26 @@ public class SqlTypeUtils { * @param url JDBC URL * @return detected DatabaseProduct */ - public static SqlDialect.DatabaseProduct detectProduct(String url) { + public static DatabaseProduct detectProduct(String url) { if (url == null) - return SqlDialect.DatabaseProduct.UNKNOWN; + return DatabaseProduct.UNKNOWN; String lowerUrl = url.toLowerCase(); if (lowerUrl.contains("postgresql") || lowerUrl.contains("postgres")) - return SqlDialect.DatabaseProduct.POSTGRESQL; + return DatabaseProduct.POSTGRESQL; if (lowerUrl.contains("mysql")) - return SqlDialect.DatabaseProduct.MYSQL; + return DatabaseProduct.MYSQL; if (lowerUrl.contains("oracle")) - return SqlDialect.DatabaseProduct.ORACLE; + return DatabaseProduct.ORACLE; if (lowerUrl.contains("sqlite")) { - try { - return SqlDialect.DatabaseProduct.valueOf("SQLITE"); - } catch (Exception e) { - return SqlDialect.DatabaseProduct.UNKNOWN; - } + return DatabaseProduct.SQLITE; } if (lowerUrl.contains("h2")) - return SqlDialect.DatabaseProduct.H2; + return DatabaseProduct.H2; if (lowerUrl.contains("derby")) - return SqlDialect.DatabaseProduct.DERBY; + return DatabaseProduct.DERBY; if (lowerUrl.contains("mssql") || lowerUrl.contains("sqlserver")) - return SqlDialect.DatabaseProduct.MSSQL; - return SqlDialect.DatabaseProduct.UNKNOWN; + return DatabaseProduct.MSSQL; + return DatabaseProduct.UNKNOWN; } /** @@ -107,9 +100,9 @@ public static SqlDialect.DatabaseProduct detectProduct(String url) { * @param product database product * @return SQL type string */ - public static String getSqlType(Class cls, SqlDialect.DatabaseProduct product) { + public static String getSqlType(Class cls, DatabaseProduct product) { Map, String> typeMap = dialectTypeMaps.getOrDefault(product, - dialectTypeMaps.get(SqlDialect.DatabaseProduct.UNKNOWN)); + dialectTypeMaps.get(DatabaseProduct.UNKNOWN)); return typeMap.getOrDefault(cls, "VARCHAR(255)"); } @@ -120,7 +113,7 @@ public static String getSqlType(Class cls, SqlDialect.DatabaseProduct product * @param product database product * @return a list of schema fields */ - public static List getSchema(Class cls, SqlDialect.DatabaseProduct product) { + public static List getSchema(Class cls, DatabaseProduct product) { List schema = new ArrayList<>(); if (cls == Record.class) { // For Record.class without an instance, we can't derive names/types easily @@ -128,12 +121,29 @@ public static List getSchema(Class cls, SqlDialect.DatabaseProdu return schema; } - for (Field field : cls.getDeclaredFields()) { - if (java.lang.reflect.Modifier.isStatic(field.getModifiers())) { + for (java.lang.reflect.Method method : cls.getMethods()) { + if (java.lang.reflect.Modifier.isStatic(method.getModifiers()) || + method.getParameterCount() > 0 || + method.getReturnType() == void.class || + method.getName().equals("getClass")) { continue; } - schema.add(new SchemaField(field.getName(), field.getType(), getSqlType(field.getType(), product))); + + String name = method.getName(); + String propertyName = null; + if (name.startsWith("get") && name.length() > 3) { + propertyName = Character.toLowerCase(name.charAt(3)) + name.substring(4); + } else if (name.startsWith("is") && name.length() > 2 + && (method.getReturnType() == boolean.class || method.getReturnType() == Boolean.class)) { + propertyName = Character.toLowerCase(name.charAt(2)) + name.substring(3); + } + + if (propertyName != null) { + schema.add(new SchemaField(propertyName, method.getReturnType(), + getSqlType(method.getReturnType(), product))); + } } + schema.sort(java.util.Comparator.comparing(SchemaField::getName)); return schema; } @@ -145,7 +155,7 @@ public static List getSchema(Class cls, SqlDialect.DatabaseProdu * @param userNames optional user-provided column names * @return a list of schema fields */ - public static List getSchema(Record record, SqlDialect.DatabaseProduct product, String[] userNames) { + public static List getSchema(Record record, DatabaseProduct product, String[] userNames) { List schema = new ArrayList<>(); if (record == null) return schema; diff --git a/wayang-commons/wayang-basic/src/test/java/org/apache/wayang/basic/util/SqlTypeUtilsTest.java b/wayang-commons/wayang-basic/src/test/java/org/apache/wayang/basic/util/SqlTypeUtilsTest.java index 28e043e12..dfa5a8e5f 100644 --- a/wayang-commons/wayang-basic/src/test/java/org/apache/wayang/basic/util/SqlTypeUtilsTest.java +++ b/wayang-commons/wayang-basic/src/test/java/org/apache/wayang/basic/util/SqlTypeUtilsTest.java @@ -18,12 +18,9 @@ package org.apache.wayang.basic.util; -import org.apache.calcite.sql.SqlDialect; import org.apache.wayang.basic.data.Record; import org.junit.jupiter.api.Test; -import java.sql.Date; -import java.sql.Timestamp; import java.util.List; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -32,60 +29,68 @@ public class SqlTypeUtilsTest { @Test public void testDetectProduct() { - assertEquals(SqlDialect.DatabaseProduct.POSTGRESQL, + assertEquals(DatabaseProduct.POSTGRESQL, SqlTypeUtils.detectProduct("jdbc:postgresql://localhost:5432/db")); - assertEquals(SqlDialect.DatabaseProduct.MYSQL, SqlTypeUtils.detectProduct("jdbc:mysql://localhost:3306/db")); - assertEquals(SqlDialect.DatabaseProduct.ORACLE, + assertEquals(DatabaseProduct.MYSQL, + SqlTypeUtils.detectProduct("jdbc:mysql://localhost:3306/db")); + assertEquals(DatabaseProduct.ORACLE, SqlTypeUtils.detectProduct("jdbc:oracle:thin:@localhost:1521:xe")); - assertEquals(SqlDialect.DatabaseProduct.H2, SqlTypeUtils.detectProduct("jdbc:h2:mem:test")); - assertEquals(SqlDialect.DatabaseProduct.DERBY, + assertEquals(DatabaseProduct.SQLITE, + SqlTypeUtils.detectProduct("jdbc:sqlite:test.db")); + assertEquals(DatabaseProduct.H2, + SqlTypeUtils.detectProduct("jdbc:h2:mem:test")); + assertEquals(DatabaseProduct.DERBY, SqlTypeUtils.detectProduct("jdbc:derby:memory:test;create=true")); - assertEquals(SqlDialect.DatabaseProduct.MSSQL, + assertEquals(DatabaseProduct.MSSQL, SqlTypeUtils.detectProduct("jdbc:sqlserver://localhost:1433;databaseName=db")); - assertEquals(SqlDialect.DatabaseProduct.UNKNOWN, SqlTypeUtils.detectProduct("jdbc:unknown:db")); + assertEquals(DatabaseProduct.UNKNOWN, SqlTypeUtils.detectProduct("jdbc:unknown:db")); } @Test public void testGetSqlTypeDefault() { - SqlDialect.DatabaseProduct product = SqlDialect.DatabaseProduct.UNKNOWN; - assertEquals("INT", SqlTypeUtils.getSqlType(Integer.class, product)); - assertEquals("INT", SqlTypeUtils.getSqlType(int.class, product)); - assertEquals("BIGINT", SqlTypeUtils.getSqlType(Long.class, product)); - assertEquals("DOUBLE", SqlTypeUtils.getSqlType(Double.class, product)); - assertEquals("VARCHAR(255)", SqlTypeUtils.getSqlType(String.class, product)); - assertEquals("DATE", SqlTypeUtils.getSqlType(Date.class, product)); - assertEquals("TIMESTAMP", SqlTypeUtils.getSqlType(Timestamp.class, product)); + assertEquals("INT", SqlTypeUtils.getSqlType(Integer.class, DatabaseProduct.UNKNOWN)); + assertEquals("INT", SqlTypeUtils.getSqlType(int.class, DatabaseProduct.UNKNOWN)); + assertEquals("BIGINT", SqlTypeUtils.getSqlType(Long.class, DatabaseProduct.UNKNOWN)); + assertEquals("DOUBLE", SqlTypeUtils.getSqlType(Double.class, DatabaseProduct.UNKNOWN)); + assertEquals("VARCHAR(255)", SqlTypeUtils.getSqlType(String.class, DatabaseProduct.UNKNOWN)); + assertEquals("DATE", SqlTypeUtils.getSqlType(java.sql.Date.class, DatabaseProduct.UNKNOWN)); + assertEquals("TIMESTAMP", SqlTypeUtils.getSqlType(java.sql.Timestamp.class, DatabaseProduct.UNKNOWN)); } @Test - public void testGetSqlTypePostgres() { - SqlDialect.DatabaseProduct product = SqlDialect.DatabaseProduct.POSTGRESQL; - assertEquals("INT", SqlTypeUtils.getSqlType(Integer.class, product)); - assertEquals("DOUBLE PRECISION", SqlTypeUtils.getSqlType(Double.class, product)); - assertEquals("DOUBLE PRECISION", SqlTypeUtils.getSqlType(double.class, product)); - assertEquals("VARCHAR(255)", SqlTypeUtils.getSqlType(String.class, product)); + public void testPostgresqlOverrides() { + assertEquals("INT", SqlTypeUtils.getSqlType(Integer.class, DatabaseProduct.POSTGRESQL)); + assertEquals("DOUBLE PRECISION", SqlTypeUtils.getSqlType(Double.class, DatabaseProduct.POSTGRESQL)); + assertEquals("DOUBLE PRECISION", SqlTypeUtils.getSqlType(double.class, DatabaseProduct.POSTGRESQL)); + assertEquals("VARCHAR(255)", SqlTypeUtils.getSqlType(String.class, DatabaseProduct.POSTGRESQL)); } @Test public void testGetSchema() { List schema = SqlTypeUtils.getSchema(TestPojo.class, - SqlDialect.DatabaseProduct.POSTGRESQL); - assertEquals(3, schema.size()); + DatabaseProduct.POSTGRESQL); + // id, name, value, active (from getter/is) + assertEquals(4, schema.size()); - assertEquals("id", schema.get(0).getName()); - assertEquals("INT", schema.get(0).getSqlType()); + schema.sort((f1, f2) -> f1.getName().compareTo(f2.getName())); - assertEquals("name", schema.get(1).getName()); - assertEquals("VARCHAR(255)", schema.get(1).getSqlType()); + assertEquals("active", schema.get(0).getName()); + assertEquals("BOOLEAN", schema.get(0).getSqlType()); - assertEquals("value", schema.get(2).getName()); - assertEquals("DOUBLE PRECISION", schema.get(2).getSqlType()); + assertEquals("id", schema.get(1).getName()); + assertEquals("INT", schema.get(1).getSqlType()); + + assertEquals("name", schema.get(2).getName()); + assertEquals("VARCHAR(255)", schema.get(2).getSqlType()); + + assertEquals("value", schema.get(3).getName()); + assertEquals("DOUBLE PRECISION", schema.get(3).getSqlType()); } @Test public void testGetSchemaRecord() { Record record = new Record(1, "test", 1.5); - List schema = SqlTypeUtils.getSchema(record, SqlDialect.DatabaseProduct.POSTGRESQL, + List schema = SqlTypeUtils.getSchema(record, DatabaseProduct.POSTGRESQL, null); assertEquals(3, schema.size()); @@ -106,7 +111,7 @@ public void testGetSchemaRecord() { public void testGetSchemaRecordWithNames() { Record record = new Record(1, "test"); String[] names = { "id", "description" }; - List schema = SqlTypeUtils.getSchema(record, SqlDialect.DatabaseProduct.POSTGRESQL, + List schema = SqlTypeUtils.getSchema(record, DatabaseProduct.POSTGRESQL, names); assertEquals(2, schema.size()); @@ -115,8 +120,26 @@ public void testGetSchemaRecordWithNames() { } public static class TestPojo { - public int id; - public String name; - public Double value; + private int id; + private String name; + private Double value; + private boolean active; + private String hidden; + + public int getId() { + return id; + } + + public String getName() { + return name; + } + + public Double getValue() { + return value; + } + + public boolean isActive() { + return active; + } } } diff --git a/wayang-platforms/wayang-java/pom.xml b/wayang-platforms/wayang-java/pom.xml index 70966b92d..48ef59a84 100644 --- a/wayang-platforms/wayang-java/pom.xml +++ b/wayang-platforms/wayang-java/pom.xml @@ -78,12 +78,6 @@ log4j-slf4j-impl 2.20.0 - - org.postgresql - postgresql - 42.7.2 - test - com.h2database diff --git a/wayang-platforms/wayang-java/src/main/java/org/apache/wayang/java/operators/JavaTableSink.java b/wayang-platforms/wayang-java/src/main/java/org/apache/wayang/java/operators/JavaTableSink.java index 8c2564551..7e12b1403 100644 --- a/wayang-platforms/wayang-java/src/main/java/org/apache/wayang/java/operators/JavaTableSink.java +++ b/wayang-platforms/wayang-java/src/main/java/org/apache/wayang/java/operators/JavaTableSink.java @@ -20,6 +20,7 @@ import org.apache.wayang.basic.data.Record; import org.apache.wayang.basic.operators.TableSink; +import org.apache.wayang.basic.util.DatabaseProduct; import org.apache.wayang.basic.util.SqlTypeUtils; import org.apache.wayang.core.optimizer.OptimizationContext; import org.apache.wayang.core.plan.wayangplan.ExecutionOperator; @@ -97,20 +98,26 @@ public Tuple, Collection> eval assert outputs.length == 0; JavaChannelInstance input = (JavaChannelInstance) inputs[0]; - // The stream is converted to an Iterator so that we can read the first element - // w/o consuming the entire stream. Iterator recordIterator = input.provideStream().iterator(); if (!recordIterator.hasNext()) { return ExecutionOperator.modelEagerExecution(inputs, outputs, operatorContext); } - // We read the first element to derive the Record schema. T firstElement = recordIterator.next(); Class typeClass = this.getType().getDataUnitType().getTypeClass(); + if (typeClass == Record.class && this.getColumnNames() != null) { + Record r = (Record) firstElement; + if (r.size() < this.getColumnNames().length) { + throw new org.apache.wayang.core.api.exception.WayangException( + String.format("Record length (%d) is less than expected column count (%d)", + r.size(), this.getColumnNames().length)); + } + } + String url = this.getProperties().getProperty("url"); - org.apache.calcite.sql.SqlDialect.DatabaseProduct product = SqlTypeUtils.detectProduct(url); + DatabaseProduct product = SqlTypeUtils.detectProduct(url); List schemaFields; if (typeClass != Record.class) { @@ -130,7 +137,7 @@ public Tuple, Collection> eval String[] sqlTypes = new String[currentColumnNames.length]; for (int i = 0; i < currentColumnNames.length; i++) { - sqlTypes[i] = "VARCHAR(255)"; // Default + sqlTypes[i] = "VARCHAR(255)"; for (SqlTypeUtils.SchemaField field : schemaFields) { if (field.getName().equals(currentColumnNames[i])) { sqlTypes[i] = field.getSqlType(); @@ -142,72 +149,80 @@ public Tuple, Collection> eval final String[] finalColumnNames = currentColumnNames; final String[] finalSqlTypes = sqlTypes; - this.getProperties().setProperty("streamingBatchInsert", "True"); + Properties writeProps = this.getProperties(); + writeProps.setProperty("streamingBatchInsert", "True"); - Connection conn; try { - Class.forName(this.getProperties().getProperty("driver")); - conn = DriverManager.getConnection(this.getProperties().getProperty("url"), this.getProperties()); - conn.setAutoCommit(false); - - Statement stmt = conn.createStatement(); - - // Drop existing table if the mode is 'overwrite'. - if (this.getMode().equals("overwrite")) { - stmt.execute("DROP TABLE IF EXISTS " + this.getTableName()); - } - - // Create a new table if the specified table name does not exist yet. - StringBuilder sb = new StringBuilder(); - sb.append("CREATE TABLE IF NOT EXISTS ").append(this.getTableName()).append(" ("); - String separator = ""; - for (int i = 0; i < finalColumnNames.length; i++) { - sb.append(separator).append("\"").append(finalColumnNames[i]).append("\" ").append(finalSqlTypes[i]); - separator = ", "; - } - sb.append(")"); - stmt.execute(sb.toString()); - - // Create a prepared statement to insert value from the recordIterator. - sb = new StringBuilder(); - sb.append("INSERT INTO ").append(this.getTableName()).append(" ("); - separator = ""; - for (int i = 0; i < finalColumnNames.length; i++) { - sb.append(separator).append("\"").append(finalColumnNames[i]).append("\""); - separator = ", "; - } - sb.append(") VALUES ("); - separator = ""; - for (int i = 0; i < finalColumnNames.length; i++) { - sb.append(separator).append("?"); - separator = ", "; - } - sb.append(")"); - PreparedStatement ps = conn.prepareStatement(sb.toString()); - - // The schema Record has to be pushed to the database too. - this.pushToStatement(ps, firstElement, typeClass, finalColumnNames); - ps.addBatch(); - - // Iterate through all remaining records and add them to the prepared statement - recordIterator.forEachRemaining( - r -> { + Class.forName(writeProps.getProperty("driver")); + String quote = (product == DatabaseProduct.MYSQL) ? "`" + : (product == DatabaseProduct.MSSQL) ? "[" : "\""; + String closingQuote = (product == DatabaseProduct.MSSQL) ? "]" : quote; + + try (Connection conn = DriverManager.getConnection(writeProps.getProperty("url"), writeProps)) { + conn.setAutoCommit(false); + try (Statement stmt = conn.createStatement()) { + if (this.getMode().equals("overwrite")) { + stmt.execute("DROP TABLE IF EXISTS " + quote + this.getTableName() + closingQuote); + } + + StringBuilder sb = new StringBuilder(); + sb.append("CREATE TABLE IF NOT EXISTS ").append(quote).append(this.getTableName()) + .append(closingQuote).append(" ("); + String separator = ""; + for (int i = 0; i < finalColumnNames.length; i++) { + sb.append(separator).append(quote).append(finalColumnNames[i]).append(closingQuote).append(" ") + .append(finalSqlTypes[i]); + separator = ", "; + } + sb.append(")"); + stmt.execute(sb.toString()); + + sb = new StringBuilder(); + sb.append("INSERT INTO ").append(quote).append(this.getTableName()).append(closingQuote) + .append(" ("); + separator = ""; + for (int i = 0; i < finalColumnNames.length; i++) { + sb.append(separator).append(quote).append(finalColumnNames[i]).append(closingQuote); + separator = ", "; + } + sb.append(") VALUES ("); + separator = ""; + for (int i = 0; i < finalColumnNames.length; i++) { + sb.append(separator).append("?"); + separator = ", "; + } + sb.append(")"); + + try (PreparedStatement ps = conn.prepareStatement(sb.toString())) { try { - this.pushToStatement(ps, r, typeClass, finalColumnNames); + this.pushToStatement(ps, firstElement, typeClass, finalColumnNames); ps.addBatch(); - } catch (SQLException e) { - e.printStackTrace(); - } - }); - ps.executeBatch(); - conn.commit(); - conn.close(); + recordIterator.forEachRemaining( + r -> { + try { + this.pushToStatement(ps, r, typeClass, finalColumnNames); + ps.addBatch(); + } catch (SQLException e) { + throw new RuntimeException("Failed to process record for batch insert", e); + } + }); + + ps.executeBatch(); + conn.commit(); + } catch (Exception e) { + conn.rollback(); + throw e; + } + } + } + } } catch (ClassNotFoundException e) { - System.out.println("Please specify a correct database driver."); - e.printStackTrace(); + throw new org.apache.wayang.core.api.exception.WayangException("Could not find database driver", e); } catch (SQLException e) { - e.printStackTrace(); + throw new org.apache.wayang.core.api.exception.WayangException("Database operation failed", e); + } catch (Exception e) { + throw new org.apache.wayang.core.api.exception.WayangException("Failed to evaluate JavaTableSink", e); } return ExecutionOperator.modelEagerExecution(inputs, outputs, operatorContext); @@ -217,6 +232,11 @@ private void pushToStatement(PreparedStatement ps, T element, Class typeClass throws SQLException { if (typeClass == Record.class) { Record r = (Record) element; + if (r.size() < columnNames.length) { + throw new org.apache.wayang.core.api.exception.WayangException( + String.format("Record length (%d) is less than expected column count (%d)", r.size(), + columnNames.length)); + } for (int i = 0; i < columnNames.length; i++) { setRecordValue(ps, i + 1, r.getField(i)); } diff --git a/wayang-platforms/wayang-java/src/test/java/org/apache/wayang/java/operators/JavaTableSinkTest.java b/wayang-platforms/wayang-java/src/test/java/org/apache/wayang/java/operators/JavaTableSinkTest.java index 02b719e0f..d56b8d838 100644 --- a/wayang-platforms/wayang-java/src/test/java/org/apache/wayang/java/operators/JavaTableSinkTest.java +++ b/wayang-platforms/wayang-java/src/test/java/org/apache/wayang/java/operators/JavaTableSinkTest.java @@ -101,7 +101,7 @@ void testWritingRecordToH2() throws Exception { evaluate(sink, new ChannelInstance[] { inputChannelInstance }, new ChannelInstance[0]); try (Statement stmt = connection.createStatement(); - ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM " + TABLE_NAME)) { + ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM \"" + TABLE_NAME + "\"")) { rs.next(); assertEquals(2, rs.getInt(1)); } @@ -117,7 +117,7 @@ void testWritingPojoToH2() throws Exception { dbProps.setProperty("driver", DRIVER); JavaTableSink sink = new JavaTableSink<>(dbProps, "overwrite", TABLE_NAME, - null, // schema detected via reflection + null, DataSetType.createDefault(TestPojo.class)); Job job = mock(Job.class); @@ -136,7 +136,7 @@ void testWritingPojoToH2() throws Exception { evaluate(sink, new ChannelInstance[] { inputChannelInstance }, new ChannelInstance[0]); try (Statement stmt = connection.createStatement(); - ResultSet rs = stmt.executeQuery("SELECT * FROM " + TABLE_NAME + " ORDER BY \"id\"")) { + ResultSet rs = stmt.executeQuery("SELECT * FROM \"" + TABLE_NAME + "\" ORDER BY \"id\"")) { rs.next(); assertEquals(1, rs.getInt("id")); assertEquals("Alice", rs.getString("name")); @@ -155,7 +155,7 @@ void testAppendMode() throws Exception { dbProps.setProperty("password", ""); dbProps.setProperty("driver", DRIVER); - // 1. Initial write (overwrite) + // 1. Initial write JavaTableSink sink1 = new JavaTableSink<>(dbProps, "overwrite", TABLE_NAME, new String[] { "id", "name" }, DataSetType.createDefault(Record.class)); @@ -186,7 +186,7 @@ void testAppendMode() throws Exception { evaluate(sink2, new ChannelInstance[] { input2 }, new ChannelInstance[0]); try (Statement stmt = connection.createStatement(); - ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM " + TABLE_NAME)) { + ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM \"" + TABLE_NAME + "\"")) { rs.next(); assertEquals(2, rs.getInt(1)); } @@ -201,13 +201,13 @@ void testOverwriteWithSchemaMismatch() throws Exception { dbProps.setProperty("password", ""); dbProps.setProperty("driver", DRIVER); - // 1. Create table with old schema (id, name) + // 1. Create table with old schema try (Statement stmt = connection.createStatement()) { stmt.execute("CREATE TABLE " + TABLE_NAME + " (id INT, name VARCHAR(255))"); stmt.execute("INSERT INTO " + TABLE_NAME + " VALUES (1, 'Old')"); } - // 2. Overwrite with new schema (id, age, city) + // 2. Overwrite with new schema JavaTableSink sink = new JavaTableSink<>(dbProps, "overwrite", TABLE_NAME, new String[] { "id", "age", "city" }, DataSetType.createDefault(Record.class)); @@ -223,7 +223,7 @@ void testOverwriteWithSchemaMismatch() throws Exception { evaluate(sink, new ChannelInstance[] { input }, new ChannelInstance[0]); try (Statement stmt = connection.createStatement(); - ResultSet rs = stmt.executeQuery("SELECT * FROM " + TABLE_NAME)) { + ResultSet rs = stmt.executeQuery("SELECT * FROM \"" + TABLE_NAME + "\"")) { rs.next(); assertEquals(2, rs.getInt("id")); assertEquals(30, rs.getInt("age")); @@ -265,7 +265,7 @@ void testNullValues() throws Exception { evaluate(sink, new ChannelInstance[] { input }, new ChannelInstance[0]); try (Statement stmt = connection.createStatement(); - ResultSet rs = stmt.executeQuery("SELECT \"name\" FROM " + TABLE_NAME + " WHERE \"id\" = 1")) { + ResultSet rs = stmt.executeQuery("SELECT \"name\" FROM \"" + TABLE_NAME + "\" WHERE \"id\" = 1")) { rs.next(); assertEquals(null, rs.getString(1)); assertTrue(rs.wasNull()); @@ -297,7 +297,7 @@ void testSupportedTypes() throws Exception { evaluate(sink, new ChannelInstance[] { input }, new ChannelInstance[0]); try (Statement stmt = connection.createStatement(); - ResultSet rs = stmt.executeQuery("SELECT * FROM " + TABLE_NAME + " WHERE \"id\" = 1")) { + ResultSet rs = stmt.executeQuery("SELECT * FROM \"" + TABLE_NAME + "\" WHERE \"id\" = 1")) { rs.next(); assertTrue(rs.getBoolean("is_active")); assertEquals(5000.50, rs.getDouble("salary"), 0.001); diff --git a/wayang-platforms/wayang-java/src/test/java/org/apache/wayang/java/operators/JavaTextFileSourceTest.java b/wayang-platforms/wayang-java/src/test/java/org/apache/wayang/java/operators/JavaTextFileSourceTest.java index f354819fa..d461927e8 100644 --- a/wayang-platforms/wayang-java/src/test/java/org/apache/wayang/java/operators/JavaTextFileSourceTest.java +++ b/wayang-platforms/wayang-java/src/test/java/org/apache/wayang/java/operators/JavaTextFileSourceTest.java @@ -30,6 +30,7 @@ import java.net.URL; import java.util.List; import java.util.Locale; +import java.util.stream.Collectors; import static org.junit.jupiter.api.Assertions.assertEquals; /** @@ -71,7 +72,7 @@ void testReadLocalFile() { evaluate(source, inputs, outputs); // Verify the outcome. - final List result = outputs[0].provideStream().toList(); + final List result = outputs[0].provideStream().collect(Collectors.toList()); assertEquals(63, result.size()); } @@ -100,7 +101,7 @@ void testReadRemoteFileHTTP() throws Exception { evaluate(source, inputs, outputs); // Verify the outcome. - final List result = outputs[0].provideStream().toList(); + final List result = outputs[0].provideStream().collect(Collectors.toList()); assertEquals(225, result.size()); } finally { if (javaExecutor != null) @@ -123,7 +124,7 @@ void testReadRemoteFileHTTPS() throws Exception { evaluate(source, inputs, outputs); // Verify the outcome. - final List result = outputs[0].provideStream().toList(); + final List result = outputs[0].provideStream().collect(Collectors.toList()); assertEquals(64, result.size()); } catch (final Exception e) { Assumptions.assumeTrue(false, "Skipping test due to possible network error: " + e.getMessage()); diff --git a/wayang-platforms/wayang-spark/pom.xml b/wayang-platforms/wayang-spark/pom.xml index abdd225d6..02055479d 100644 --- a/wayang-platforms/wayang-spark/pom.xml +++ b/wayang-platforms/wayang-spark/pom.xml @@ -121,13 +121,6 @@ 4.8 - - org.postgresql - postgresql - 42.7.2 - compile - - com.h2database h2 diff --git a/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkTableSink.java b/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkTableSink.java index 9d5edf62e..c2bcc143a 100644 --- a/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkTableSink.java +++ b/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkTableSink.java @@ -30,6 +30,7 @@ import org.apache.spark.sql.types.StructType; import org.apache.wayang.basic.data.Record; import org.apache.wayang.basic.operators.TableSink; +import org.apache.wayang.basic.util.DatabaseProduct; import org.apache.wayang.basic.util.SqlTypeUtils; import org.apache.wayang.core.api.exception.WayangException; import org.apache.wayang.core.optimizer.OptimizationContext; @@ -82,21 +83,18 @@ public Tuple, Collection> eval Dataset df; if (typeClass == Record.class) { - // Records need manual schema handling - if (recordRDD.isEmpty()) { + List sample = recordRDD.take(1); + if (sample.isEmpty()) { return ExecutionOperator.modelEagerExecution(inputs, outputs, operatorContext); } - Record first = (Record) recordRDD.first(); + Record first = (Record) sample.get(0); - // Centralized Schema Derivation List schemaFields = SqlTypeUtils.getSchema(first, SqlTypeUtils.detectProduct(this.getProperties().getProperty("url")), this.getColumnNames()); - // Map Record to Row JavaRDD rowRDD = recordRDD.map(rec -> RowFactory.create(((Record) rec).getValues())); - // Build Spark Schema StructField[] fields = new StructField[schemaFields.size()]; for (int i = 0; i < schemaFields.size(); i++) { SqlTypeUtils.SchemaField sf = schemaFields.get(i); @@ -104,26 +102,41 @@ public Tuple, Collection> eval fields[i] = new StructField(sf.getName(), sparkType, true, Metadata.empty()); } - // Update column names in the operator if they were generated - String[] newColNames = schemaFields.stream().map(SqlTypeUtils.SchemaField::getName).toArray(String[]::new); - this.setColumnNames(newColNames); + // We skip updating column names in the operator to avoid mutating shared state. + // Inferred names are used locally for df creation. df = sqlContext.createDataFrame(rowRDD, new StructType(fields)); } else { - // POJO Case: Let Spark handle it natively df = sqlContext.createDataFrame(recordRDD, typeClass); - // If columnNames are provided, we should probably select/rename them, - // but usually createDataFrame(rdd, beanClass) maps fields to columns. - if (this.getColumnNames() != null && this.getColumnNames().length > 0) { - // Optionally filter or reorder columns to match this.getColumnNames() - // For now, Spark's native mapping is preferred. + // For POJOs, we currently do not support custom columnNames to avoid + // ambiguous or misleading mappings. Fail fast if they are provided. + String[] columnNames = this.getColumnNames(); + if (columnNames != null && columnNames.length > 0) { + throw new WayangException( + "columnNames are not supported for POJO inputs in SparkTableSink. " + + "Either omit columnNames or use Record inputs if you need custom column mapping."); } } - this.getProperties().setProperty("batchSize", "250000"); + Properties writeProps = new Properties(); + writeProps.putAll(this.getProperties()); + if (!writeProps.containsKey("batchSize")) { + writeProps.setProperty("batchSize", "250000"); + } + + String targetTable = this.getTableName(); + if (targetTable != null && !targetTable.startsWith("(") && !targetTable.startsWith("\"") + && !targetTable.startsWith("`")) { + DatabaseProduct product = SqlTypeUtils.detectProduct(this.getProperties().getProperty("url")); + String quote = (product == DatabaseProduct.MYSQL) ? "`" + : (product == DatabaseProduct.MSSQL) ? "[" : "\""; + String closingQuote = (product == DatabaseProduct.MSSQL) ? "]" : quote; + targetTable = quote + targetTable + closingQuote; + } + df.write() .mode(this.mode) - .jdbc(this.getProperties().getProperty("url"), this.getTableName(), this.getProperties()); + .jdbc(this.getProperties().getProperty("url"), targetTable, writeProps); return ExecutionOperator.modelEagerExecution(inputs, outputs, operatorContext); } diff --git a/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkTableSinkTest.java b/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkTableSinkTest.java index 0197c3749..bd9ca40de 100644 --- a/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkTableSinkTest.java +++ b/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkTableSinkTest.java @@ -18,14 +18,9 @@ package org.apache.wayang.spark.operators; import org.apache.wayang.basic.data.Record; -import org.apache.wayang.core.api.Configuration; -import org.apache.wayang.core.api.Job; -import org.apache.wayang.core.optimizer.OptimizationContext; -import org.apache.wayang.core.plan.wayangplan.OutputSlot; import org.apache.wayang.core.platform.ChannelInstance; import org.apache.wayang.core.types.DataSetType; import org.apache.wayang.spark.channels.RddChannel; -import org.apache.wayang.spark.platform.SparkPlatform; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -40,8 +35,6 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; /** * Test suite for {@link SparkTableSink}. @@ -64,7 +57,7 @@ void setupTest() throws Exception { void teardownTest() throws Exception { if (connection != null && !connection.isClosed()) { try (Statement stmt = connection.createStatement()) { - stmt.execute("DROP TABLE IF EXISTS " + TABLE_NAME); + stmt.execute("DROP TABLE IF EXISTS \"" + TABLE_NAME + "\""); } connection.close(); } @@ -72,7 +65,6 @@ void teardownTest() throws Exception { @Test void testWritingRecordToH2() throws Exception { - Configuration configuration = new Configuration(); Properties dbProps = new Properties(); dbProps.setProperty("url", JDBC_URL); dbProps.setProperty("user", "sa"); @@ -83,9 +75,6 @@ void testWritingRecordToH2() throws Exception { new String[] { "id", "name", "value" }, DataSetType.createDefault(Record.class)); - Job job = mock(Job.class); - when(job.getConfiguration()).thenReturn(configuration); - Record record1 = new Record(1, "Alice", 100.5); Record record2 = new Record(2, "Bob", 200.75); @@ -95,7 +84,7 @@ void testWritingRecordToH2() throws Exception { evaluate(sink, new ChannelInstance[] { inputChannelInstance }, new ChannelInstance[0]); try (Statement stmt = connection.createStatement(); - ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM " + TABLE_NAME)) { + ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM \"" + TABLE_NAME + "\"")) { rs.next(); assertEquals(2, rs.getInt(1)); } @@ -103,7 +92,6 @@ void testWritingRecordToH2() throws Exception { @Test void testWritingPojoToH2() throws Exception { - Configuration configuration = new Configuration(); Properties dbProps = new Properties(); dbProps.setProperty("url", JDBC_URL); dbProps.setProperty("user", "sa"); @@ -111,12 +99,9 @@ void testWritingPojoToH2() throws Exception { dbProps.setProperty("driver", DRIVER); SparkTableSink sink = new SparkTableSink<>(dbProps, "overwrite", TABLE_NAME, - null, // schema detected via reflection + null, DataSetType.createDefault(TestPojo.class)); - Job job = mock(Job.class); - when(job.getConfiguration()).thenReturn(configuration); - TestPojo p1 = new TestPojo(1, "Alice"); TestPojo p2 = new TestPojo(2, "Bob"); @@ -126,7 +111,7 @@ void testWritingPojoToH2() throws Exception { evaluate(sink, new ChannelInstance[] { inputChannelInstance }, new ChannelInstance[0]); try (Statement stmt = connection.createStatement(); - ResultSet rs = stmt.executeQuery("SELECT * FROM " + TABLE_NAME + " ORDER BY \"id\"")) { + ResultSet rs = stmt.executeQuery("SELECT * FROM \"" + TABLE_NAME + "\" ORDER BY \"id\"")) { rs.next(); assertEquals(1, rs.getInt("id")); assertEquals("Alice", rs.getString("name")); @@ -138,14 +123,13 @@ void testWritingPojoToH2() throws Exception { @Test void testAppendMode() throws Exception { - Configuration configuration = new Configuration(); Properties dbProps = new Properties(); dbProps.setProperty("url", JDBC_URL); dbProps.setProperty("user", "sa"); dbProps.setProperty("password", ""); dbProps.setProperty("driver", DRIVER); - // 1. Initial write (overwrite) + // 1. Initial write SparkTableSink sink1 = new SparkTableSink<>(dbProps, "overwrite", TABLE_NAME, new String[] { "id", "name" }, DataSetType.createDefault(Record.class)); @@ -162,7 +146,7 @@ void testAppendMode() throws Exception { evaluate(sink2, new ChannelInstance[] { input2 }, new ChannelInstance[0]); try (Statement stmt = connection.createStatement(); - ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM " + TABLE_NAME)) { + ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM \"" + TABLE_NAME + "\"")) { rs.next(); assertEquals(2, rs.getInt(1)); } @@ -170,20 +154,19 @@ void testAppendMode() throws Exception { @Test void testOverwriteWithSchemaMismatch() throws Exception { - Configuration configuration = new Configuration(); Properties dbProps = new Properties(); dbProps.setProperty("url", JDBC_URL); dbProps.setProperty("user", "sa"); dbProps.setProperty("password", ""); dbProps.setProperty("driver", DRIVER); - // 1. Create table with old schema (id, name) + // 1. Create table with old schema try (Statement stmt = connection.createStatement()) { stmt.execute("CREATE TABLE " + TABLE_NAME + " (\"id\" INT, \"name\" VARCHAR(255))"); stmt.execute("INSERT INTO " + TABLE_NAME + " VALUES (1, 'Old')"); } - // 2. Overwrite with new schema (id, age, city) + // 2. Overwrite with new schema SparkTableSink sink = new SparkTableSink<>(dbProps, "overwrite", TABLE_NAME, new String[] { "id", "age", "city" }, DataSetType.createDefault(Record.class)); @@ -192,7 +175,7 @@ void testOverwriteWithSchemaMismatch() throws Exception { evaluate(sink, new ChannelInstance[] { input }, new ChannelInstance[0]); try (Statement stmt = connection.createStatement(); - ResultSet rs = stmt.executeQuery("SELECT * FROM " + TABLE_NAME)) { + ResultSet rs = stmt.executeQuery("SELECT * FROM \"" + TABLE_NAME + "\"")) { rs.next(); assertEquals(2, rs.getInt("id")); assertEquals(30, rs.getInt("age")); @@ -211,7 +194,6 @@ void testOverwriteWithSchemaMismatch() throws Exception { @Test void testNullValues() throws Exception { - Configuration configuration = new Configuration(); Properties dbProps = new Properties(); dbProps.setProperty("url", JDBC_URL); dbProps.setProperty("user", "sa"); @@ -226,7 +208,7 @@ void testNullValues() throws Exception { evaluate(sink, new ChannelInstance[] { input }, new ChannelInstance[0]); try (Statement stmt = connection.createStatement(); - ResultSet rs = stmt.executeQuery("SELECT \"name\" FROM " + TABLE_NAME + " WHERE \"id\" = 1")) { + ResultSet rs = stmt.executeQuery("SELECT \"name\" FROM \"" + TABLE_NAME + "\" WHERE \"id\" = 1")) { rs.next(); assertEquals(null, rs.getString(1)); assertTrue(rs.wasNull()); @@ -235,7 +217,6 @@ void testNullValues() throws Exception { @Test void testSupportedTypes() throws Exception { - Configuration configuration = new Configuration(); Properties dbProps = new Properties(); dbProps.setProperty("url", JDBC_URL); dbProps.setProperty("user", "sa"); @@ -250,7 +231,7 @@ void testSupportedTypes() throws Exception { evaluate(sink, new ChannelInstance[] { input }, new ChannelInstance[0]); try (Statement stmt = connection.createStatement(); - ResultSet rs = stmt.executeQuery("SELECT * FROM " + TABLE_NAME + " WHERE \"id\" = 1")) { + ResultSet rs = stmt.executeQuery("SELECT * FROM \"" + TABLE_NAME + "\" WHERE \"id\" = 1")) { rs.next(); assertTrue(rs.getBoolean("is_active")); assertEquals(5000.50, rs.getDouble("salary"), 0.001);