diff --git a/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/operators/TableSink.java b/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/operators/TableSink.java new file mode 100644 index 000000000..e5e4b54af --- /dev/null +++ b/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/operators/TableSink.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.basic.operators; + +import org.apache.wayang.basic.data.Record; +import org.apache.wayang.core.plan.wayangplan.UnarySink; +import org.apache.wayang.core.types.DataSetType; + +import java.util.Properties; + +/** + * {@link UnarySink} that writes Records to a database table. + */ + +public class TableSink extends UnarySink { + private final String tableName; + + private String[] columnNames; + + private final Properties props; + + private String mode; + + /** + * Creates a new instance. + * + * @param props database connection properties + * @param mode write mode + * @param tableName name of the table to be written + * @param columnNames names of the columns in the tables + */ + public TableSink(Properties props, String mode, String tableName, String... columnNames) { + this(props, mode, tableName, columnNames, (DataSetType) DataSetType.createDefault(Record.class)); + } + + public TableSink(Properties props, String mode, String tableName, String[] columnNames, DataSetType type) { + super(type); + this.tableName = tableName; + this.columnNames = columnNames == null ? null : java.util.Arrays.copyOf(columnNames, columnNames.length); + this.props = new Properties(); + if (props != null) { + this.props.putAll(props); + } + this.mode = mode; + } + + /** + * Copies an instance (exclusive of broadcasts). + * + * @param that that should be copied + */ + public TableSink(TableSink that) { + super(that); + this.tableName = that.getTableName(); + this.columnNames = that.getColumnNames(); + this.props = that.getProperties(); + this.mode = that.getMode(); + } + + public String getTableName() { + return this.tableName; + } + + protected void setColumnNames(String[] columnNames) { + this.columnNames = columnNames == null ? null : java.util.Arrays.copyOf(columnNames, columnNames.length); + } + + public String[] getColumnNames() { + return this.columnNames == null ? null : java.util.Arrays.copyOf(this.columnNames, this.columnNames.length); + } + + public Properties getProperties() { + Properties copy = new Properties(); + copy.putAll(this.props); + return copy; + } + + public String getMode() { + return mode; + } + + public void setMode(String mode) { + this.mode = mode; + } +} diff --git a/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/util/DatabaseProduct.java b/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/util/DatabaseProduct.java new file mode 100644 index 000000000..d87eeb900 --- /dev/null +++ b/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/util/DatabaseProduct.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.basic.util; + +/** + * Internal representation of database products to avoid external dependencies + * in wayang-basic. + */ +public enum DatabaseProduct { + POSTGRESQL, + MYSQL, + ORACLE, + SQLITE, + H2, + DERBY, + MSSQL, + UNKNOWN +} diff --git a/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/util/SqlTypeUtils.java b/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/util/SqlTypeUtils.java new file mode 100644 index 000000000..b32f84f57 --- /dev/null +++ b/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/util/SqlTypeUtils.java @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.basic.util; + +import org.apache.wayang.basic.data.Record; +import java.sql.Date; +import java.sql.Timestamp; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Utility for mapping Java types to SQL types across different dialects. + */ +public class SqlTypeUtils { + + private static final Map, String>> dialectTypeMaps = new HashMap<>(); + + static { + // Default mappings (Standard SQL) + Map, String> defaultMap = new HashMap<>(); + defaultMap.put(Integer.class, "INT"); + defaultMap.put(int.class, "INT"); + defaultMap.put(Long.class, "BIGINT"); + defaultMap.put(long.class, "BIGINT"); + defaultMap.put(Double.class, "DOUBLE"); + defaultMap.put(double.class, "DOUBLE"); + defaultMap.put(Float.class, "FLOAT"); + defaultMap.put(float.class, "FLOAT"); + defaultMap.put(Boolean.class, "BOOLEAN"); + defaultMap.put(boolean.class, "BOOLEAN"); + defaultMap.put(String.class, "VARCHAR(255)"); + defaultMap.put(Date.class, "DATE"); + defaultMap.put(LocalDate.class, "DATE"); + defaultMap.put(Timestamp.class, "TIMESTAMP"); + defaultMap.put(LocalDateTime.class, "TIMESTAMP"); + + dialectTypeMaps.put(DatabaseProduct.UNKNOWN, defaultMap); + + // PostgreSQL Overrides + Map, String> pgMap = new HashMap<>(defaultMap); + pgMap.put(Double.class, "DOUBLE PRECISION"); + pgMap.put(double.class, "DOUBLE PRECISION"); + dialectTypeMaps.put(DatabaseProduct.POSTGRESQL, pgMap); + + // Add more dialects here as needed (MySQL, Oracle, etc.) + } + + /** + * Detects the database product from a JDBC URL. + * + * @param url JDBC URL + * @return detected DatabaseProduct + */ + public static DatabaseProduct detectProduct(String url) { + if (url == null) + return DatabaseProduct.UNKNOWN; + String lowerUrl = url.toLowerCase(); + if (lowerUrl.contains("postgresql") || lowerUrl.contains("postgres")) + return DatabaseProduct.POSTGRESQL; + if (lowerUrl.contains("mysql")) + return DatabaseProduct.MYSQL; + if (lowerUrl.contains("oracle")) + return DatabaseProduct.ORACLE; + if (lowerUrl.contains("sqlite")) { + return DatabaseProduct.SQLITE; + } + if (lowerUrl.contains("h2")) + return DatabaseProduct.H2; + if (lowerUrl.contains("derby")) + return DatabaseProduct.DERBY; + if (lowerUrl.contains("mssql") || lowerUrl.contains("sqlserver")) + return DatabaseProduct.MSSQL; + return DatabaseProduct.UNKNOWN; + } + + /** + * Returns the SQL type for a given Java class and database product. + * + * @param cls Java class + * @param product database product + * @return SQL type string + */ + public static String getSqlType(Class cls, DatabaseProduct product) { + Map, String> typeMap = dialectTypeMaps.getOrDefault(product, + dialectTypeMaps.get(DatabaseProduct.UNKNOWN)); + return typeMap.getOrDefault(cls, "VARCHAR(255)"); + } + + /** + * Extracts schema information from a POJO class or a Record. + * + * @param cls POJO class + * @param product database product + * @return a list of schema fields + */ + public static List getSchema(Class cls, DatabaseProduct product) { + List schema = new ArrayList<>(); + if (cls == Record.class) { + // For Record.class without an instance, we can't derive names/types easily + // Users should use the instance-based getSchema or provide columnNames + return schema; + } + + for (java.lang.reflect.Method method : cls.getMethods()) { + if (java.lang.reflect.Modifier.isStatic(method.getModifiers()) || + method.getParameterCount() > 0 || + method.getReturnType() == void.class || + method.getName().equals("getClass")) { + continue; + } + + String name = method.getName(); + String propertyName = null; + if (name.startsWith("get") && name.length() > 3) { + propertyName = Character.toLowerCase(name.charAt(3)) + name.substring(4); + } else if (name.startsWith("is") && name.length() > 2 + && (method.getReturnType() == boolean.class || method.getReturnType() == Boolean.class)) { + propertyName = Character.toLowerCase(name.charAt(2)) + name.substring(3); + } + + if (propertyName != null) { + schema.add(new SchemaField(propertyName, method.getReturnType(), + getSqlType(method.getReturnType(), product))); + } + } + schema.sort(java.util.Comparator.comparing(SchemaField::getName)); + return schema; + } + + /** + * Extracts schema information from a Record instance by inspecting its fields. + * + * @param record representative record + * @param product database product + * @param userNames optional user-provided column names + * @return a list of schema fields + */ + public static List getSchema(Record record, DatabaseProduct product, String[] userNames) { + List schema = new ArrayList<>(); + if (record == null) + return schema; + + int size = record.size(); + for (int i = 0; i < size; i++) { + String name = (userNames != null && i < userNames.length) ? userNames[i] : "c_" + i; + Object val = record.getField(i); + Class typeClass = val != null ? val.getClass() : String.class; + String type = getSqlType(typeClass, product); + schema.add(new SchemaField(name, typeClass, type)); + } + return schema; + } + + public static class SchemaField { + private final String name; + private final Class javaClass; + private final String sqlType; + + public SchemaField(String name, Class javaClass, String sqlType) { + this.name = name; + this.javaClass = javaClass; + this.sqlType = sqlType; + } + + public String getName() { + return name; + } + + public Class getJavaClass() { + return javaClass; + } + + public String getSqlType() { + return sqlType; + } + } +} diff --git a/wayang-commons/wayang-basic/src/test/java/org/apache/wayang/basic/util/SqlTypeUtilsTest.java b/wayang-commons/wayang-basic/src/test/java/org/apache/wayang/basic/util/SqlTypeUtilsTest.java new file mode 100644 index 000000000..dfa5a8e5f --- /dev/null +++ b/wayang-commons/wayang-basic/src/test/java/org/apache/wayang/basic/util/SqlTypeUtilsTest.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.basic.util; + +import org.apache.wayang.basic.data.Record; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class SqlTypeUtilsTest { + + @Test + public void testDetectProduct() { + assertEquals(DatabaseProduct.POSTGRESQL, + SqlTypeUtils.detectProduct("jdbc:postgresql://localhost:5432/db")); + assertEquals(DatabaseProduct.MYSQL, + SqlTypeUtils.detectProduct("jdbc:mysql://localhost:3306/db")); + assertEquals(DatabaseProduct.ORACLE, + SqlTypeUtils.detectProduct("jdbc:oracle:thin:@localhost:1521:xe")); + assertEquals(DatabaseProduct.SQLITE, + SqlTypeUtils.detectProduct("jdbc:sqlite:test.db")); + assertEquals(DatabaseProduct.H2, + SqlTypeUtils.detectProduct("jdbc:h2:mem:test")); + assertEquals(DatabaseProduct.DERBY, + SqlTypeUtils.detectProduct("jdbc:derby:memory:test;create=true")); + assertEquals(DatabaseProduct.MSSQL, + SqlTypeUtils.detectProduct("jdbc:sqlserver://localhost:1433;databaseName=db")); + assertEquals(DatabaseProduct.UNKNOWN, SqlTypeUtils.detectProduct("jdbc:unknown:db")); + } + + @Test + public void testGetSqlTypeDefault() { + assertEquals("INT", SqlTypeUtils.getSqlType(Integer.class, DatabaseProduct.UNKNOWN)); + assertEquals("INT", SqlTypeUtils.getSqlType(int.class, DatabaseProduct.UNKNOWN)); + assertEquals("BIGINT", SqlTypeUtils.getSqlType(Long.class, DatabaseProduct.UNKNOWN)); + assertEquals("DOUBLE", SqlTypeUtils.getSqlType(Double.class, DatabaseProduct.UNKNOWN)); + assertEquals("VARCHAR(255)", SqlTypeUtils.getSqlType(String.class, DatabaseProduct.UNKNOWN)); + assertEquals("DATE", SqlTypeUtils.getSqlType(java.sql.Date.class, DatabaseProduct.UNKNOWN)); + assertEquals("TIMESTAMP", SqlTypeUtils.getSqlType(java.sql.Timestamp.class, DatabaseProduct.UNKNOWN)); + } + + @Test + public void testPostgresqlOverrides() { + assertEquals("INT", SqlTypeUtils.getSqlType(Integer.class, DatabaseProduct.POSTGRESQL)); + assertEquals("DOUBLE PRECISION", SqlTypeUtils.getSqlType(Double.class, DatabaseProduct.POSTGRESQL)); + assertEquals("DOUBLE PRECISION", SqlTypeUtils.getSqlType(double.class, DatabaseProduct.POSTGRESQL)); + assertEquals("VARCHAR(255)", SqlTypeUtils.getSqlType(String.class, DatabaseProduct.POSTGRESQL)); + } + + @Test + public void testGetSchema() { + List schema = SqlTypeUtils.getSchema(TestPojo.class, + DatabaseProduct.POSTGRESQL); + // id, name, value, active (from getter/is) + assertEquals(4, schema.size()); + + schema.sort((f1, f2) -> f1.getName().compareTo(f2.getName())); + + assertEquals("active", schema.get(0).getName()); + assertEquals("BOOLEAN", schema.get(0).getSqlType()); + + assertEquals("id", schema.get(1).getName()); + assertEquals("INT", schema.get(1).getSqlType()); + + assertEquals("name", schema.get(2).getName()); + assertEquals("VARCHAR(255)", schema.get(2).getSqlType()); + + assertEquals("value", schema.get(3).getName()); + assertEquals("DOUBLE PRECISION", schema.get(3).getSqlType()); + } + + @Test + public void testGetSchemaRecord() { + Record record = new Record(1, "test", 1.5); + List schema = SqlTypeUtils.getSchema(record, DatabaseProduct.POSTGRESQL, + null); + + assertEquals(3, schema.size()); + assertEquals("c_0", schema.get(0).getName()); + assertEquals("INT", schema.get(0).getSqlType()); + assertEquals(Integer.class, schema.get(0).getJavaClass()); + + assertEquals("c_1", schema.get(1).getName()); + assertEquals("VARCHAR(255)", schema.get(1).getSqlType()); + assertEquals(String.class, schema.get(1).getJavaClass()); + + assertEquals("c_2", schema.get(2).getName()); + assertEquals("DOUBLE PRECISION", schema.get(2).getSqlType()); + assertEquals(Double.class, schema.get(2).getJavaClass()); + } + + @Test + public void testGetSchemaRecordWithNames() { + Record record = new Record(1, "test"); + String[] names = { "id", "description" }; + List schema = SqlTypeUtils.getSchema(record, DatabaseProduct.POSTGRESQL, + names); + + assertEquals(2, schema.size()); + assertEquals("id", schema.get(0).getName()); + assertEquals("description", schema.get(1).getName()); + } + + public static class TestPojo { + private int id; + private String name; + private Double value; + private boolean active; + private String hidden; + + public int getId() { + return id; + } + + public String getName() { + return name; + } + + public Double getValue() { + return value; + } + + public boolean isActive() { + return active; + } + } +} diff --git a/wayang-platforms/wayang-java/pom.xml b/wayang-platforms/wayang-java/pom.xml index 9c58a78fb..48ef59a84 100644 --- a/wayang-platforms/wayang-java/pom.xml +++ b/wayang-platforms/wayang-java/pom.xml @@ -79,6 +79,12 @@ 2.20.0 + + com.h2database + h2 + 2.2.224 + test + org.mockito mockito-core diff --git a/wayang-platforms/wayang-java/src/main/java/org/apache/wayang/java/operators/JavaTableSink.java b/wayang-platforms/wayang-java/src/main/java/org/apache/wayang/java/operators/JavaTableSink.java new file mode 100644 index 000000000..7e12b1403 --- /dev/null +++ b/wayang-platforms/wayang-java/src/main/java/org/apache/wayang/java/operators/JavaTableSink.java @@ -0,0 +1,266 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.java.operators; + +import org.apache.wayang.basic.data.Record; +import org.apache.wayang.basic.operators.TableSink; +import org.apache.wayang.basic.util.DatabaseProduct; +import org.apache.wayang.basic.util.SqlTypeUtils; +import org.apache.wayang.core.optimizer.OptimizationContext; +import org.apache.wayang.core.plan.wayangplan.ExecutionOperator; +import org.apache.wayang.core.platform.ChannelDescriptor; +import org.apache.wayang.core.platform.ChannelInstance; +import org.apache.wayang.core.platform.lineage.ExecutionLineageNode; +import org.apache.wayang.core.types.DataSetType; +import org.apache.wayang.core.util.ReflectionUtils; +import org.apache.wayang.core.util.Tuple; +import org.apache.wayang.java.channels.CollectionChannel; +import org.apache.wayang.java.channels.JavaChannelInstance; +import org.apache.wayang.java.channels.StreamChannel; +import org.apache.wayang.java.execution.JavaExecutor; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.PreparedStatement; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.Arrays; +import java.util.Collection; +import java.util.Iterator; +import java.util.List; +import java.util.Properties; + +public class JavaTableSink extends TableSink implements JavaExecutionOperator { + + private void setRecordValue(PreparedStatement ps, int index, Object value) throws SQLException { + if (value == null) { + ps.setNull(index, java.sql.Types.NULL); + } else if (value instanceof Integer) { + ps.setInt(index, (Integer) value); + } else if (value instanceof Long) { + ps.setLong(index, (Long) value); + } else if (value instanceof Double) { + ps.setDouble(index, (Double) value); + } else if (value instanceof Float) { + ps.setFloat(index, (Float) value); + } else if (value instanceof Boolean) { + ps.setBoolean(index, (Boolean) value); + } else if (value instanceof java.sql.Date) { + ps.setDate(index, (java.sql.Date) value); + } else if (value instanceof java.sql.Timestamp) { + ps.setTimestamp(index, (java.sql.Timestamp) value); + } else { + ps.setString(index, value.toString()); + } + } + + public JavaTableSink(Properties props, String mode, String tableName) { + this(props, mode, tableName, null); + } + + public JavaTableSink(Properties props, String mode, String tableName, String... columnNames) { + super(props, mode, tableName, columnNames); + + } + + public JavaTableSink(Properties props, String mode, String tableName, String[] columnNames, DataSetType type) { + super(props, mode, tableName, columnNames, type); + + } + + public JavaTableSink(TableSink that) { + super(that); + } + + @Override + public Tuple, Collection> evaluate( + ChannelInstance[] inputs, + ChannelInstance[] outputs, + JavaExecutor javaExecutor, + OptimizationContext.OperatorContext operatorContext) { + assert inputs.length == 1; + assert outputs.length == 0; + JavaChannelInstance input = (JavaChannelInstance) inputs[0]; + + Iterator recordIterator = input.provideStream().iterator(); + + if (!recordIterator.hasNext()) { + return ExecutionOperator.modelEagerExecution(inputs, outputs, operatorContext); + } + + T firstElement = recordIterator.next(); + Class typeClass = this.getType().getDataUnitType().getTypeClass(); + + if (typeClass == Record.class && this.getColumnNames() != null) { + Record r = (Record) firstElement; + if (r.size() < this.getColumnNames().length) { + throw new org.apache.wayang.core.api.exception.WayangException( + String.format("Record length (%d) is less than expected column count (%d)", + r.size(), this.getColumnNames().length)); + } + } + + String url = this.getProperties().getProperty("url"); + DatabaseProduct product = SqlTypeUtils.detectProduct(url); + + List schemaFields; + if (typeClass != Record.class) { + schemaFields = SqlTypeUtils.getSchema(typeClass, product); + } else { + schemaFields = SqlTypeUtils.getSchema((Record) firstElement, product, this.getColumnNames()); + } + + String[] currentColumnNames = this.getColumnNames(); + if (currentColumnNames == null || currentColumnNames.length == 0) { + currentColumnNames = new String[schemaFields.size()]; + for (int i = 0; i < schemaFields.size(); i++) { + currentColumnNames[i] = schemaFields.get(i).getName(); + } + this.setColumnNames(currentColumnNames); + } + + String[] sqlTypes = new String[currentColumnNames.length]; + for (int i = 0; i < currentColumnNames.length; i++) { + sqlTypes[i] = "VARCHAR(255)"; + for (SqlTypeUtils.SchemaField field : schemaFields) { + if (field.getName().equals(currentColumnNames[i])) { + sqlTypes[i] = field.getSqlType(); + break; + } + } + } + + final String[] finalColumnNames = currentColumnNames; + final String[] finalSqlTypes = sqlTypes; + + Properties writeProps = this.getProperties(); + writeProps.setProperty("streamingBatchInsert", "True"); + + try { + Class.forName(writeProps.getProperty("driver")); + String quote = (product == DatabaseProduct.MYSQL) ? "`" + : (product == DatabaseProduct.MSSQL) ? "[" : "\""; + String closingQuote = (product == DatabaseProduct.MSSQL) ? "]" : quote; + + try (Connection conn = DriverManager.getConnection(writeProps.getProperty("url"), writeProps)) { + conn.setAutoCommit(false); + try (Statement stmt = conn.createStatement()) { + if (this.getMode().equals("overwrite")) { + stmt.execute("DROP TABLE IF EXISTS " + quote + this.getTableName() + closingQuote); + } + + StringBuilder sb = new StringBuilder(); + sb.append("CREATE TABLE IF NOT EXISTS ").append(quote).append(this.getTableName()) + .append(closingQuote).append(" ("); + String separator = ""; + for (int i = 0; i < finalColumnNames.length; i++) { + sb.append(separator).append(quote).append(finalColumnNames[i]).append(closingQuote).append(" ") + .append(finalSqlTypes[i]); + separator = ", "; + } + sb.append(")"); + stmt.execute(sb.toString()); + + sb = new StringBuilder(); + sb.append("INSERT INTO ").append(quote).append(this.getTableName()).append(closingQuote) + .append(" ("); + separator = ""; + for (int i = 0; i < finalColumnNames.length; i++) { + sb.append(separator).append(quote).append(finalColumnNames[i]).append(closingQuote); + separator = ", "; + } + sb.append(") VALUES ("); + separator = ""; + for (int i = 0; i < finalColumnNames.length; i++) { + sb.append(separator).append("?"); + separator = ", "; + } + sb.append(")"); + + try (PreparedStatement ps = conn.prepareStatement(sb.toString())) { + try { + this.pushToStatement(ps, firstElement, typeClass, finalColumnNames); + ps.addBatch(); + + recordIterator.forEachRemaining( + r -> { + try { + this.pushToStatement(ps, r, typeClass, finalColumnNames); + ps.addBatch(); + } catch (SQLException e) { + throw new RuntimeException("Failed to process record for batch insert", e); + } + }); + + ps.executeBatch(); + conn.commit(); + } catch (Exception e) { + conn.rollback(); + throw e; + } + } + } + } + } catch (ClassNotFoundException e) { + throw new org.apache.wayang.core.api.exception.WayangException("Could not find database driver", e); + } catch (SQLException e) { + throw new org.apache.wayang.core.api.exception.WayangException("Database operation failed", e); + } catch (Exception e) { + throw new org.apache.wayang.core.api.exception.WayangException("Failed to evaluate JavaTableSink", e); + } + + return ExecutionOperator.modelEagerExecution(inputs, outputs, operatorContext); + } + + private void pushToStatement(PreparedStatement ps, T element, Class typeClass, String[] columnNames) + throws SQLException { + if (typeClass == Record.class) { + Record r = (Record) element; + if (r.size() < columnNames.length) { + throw new org.apache.wayang.core.api.exception.WayangException( + String.format("Record length (%d) is less than expected column count (%d)", r.size(), + columnNames.length)); + } + for (int i = 0; i < columnNames.length; i++) { + setRecordValue(ps, i + 1, r.getField(i)); + } + } else { + for (int i = 0; i < columnNames.length; i++) { + Object val = ReflectionUtils.getProperty(element, columnNames[i]); + setRecordValue(ps, i + 1, val); + } + } + } + + @Override + public String getLoadProfileEstimatorConfigurationKey() { + return "wayang.java.tablesink.load"; + } + + @Override + public List getSupportedInputChannels(int index) { + return Arrays.asList(CollectionChannel.DESCRIPTOR, StreamChannel.DESCRIPTOR); + } + + @Override + public List getSupportedOutputChannels(int index) { + throw new UnsupportedOperationException("This operator has no outputs."); + } + +} diff --git a/wayang-platforms/wayang-java/src/test/java/org/apache/wayang/java/operators/JavaTableSinkTest.java b/wayang-platforms/wayang-java/src/test/java/org/apache/wayang/java/operators/JavaTableSinkTest.java new file mode 100644 index 000000000..d56b8d838 --- /dev/null +++ b/wayang-platforms/wayang-java/src/test/java/org/apache/wayang/java/operators/JavaTableSinkTest.java @@ -0,0 +1,328 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.java.operators; + +import org.apache.wayang.basic.data.Record; +import org.apache.wayang.core.api.Configuration; +import org.apache.wayang.core.api.Job; +import org.apache.wayang.core.optimizer.OptimizationContext; +import org.apache.wayang.core.plan.wayangplan.OutputSlot; +import org.apache.wayang.core.platform.ChannelInstance; +import org.apache.wayang.core.types.DataSetType; +import org.apache.wayang.java.channels.StreamChannel; +import org.apache.wayang.java.execution.JavaExecutor; +import org.apache.wayang.java.platform.JavaPlatform; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.Statement; +import java.util.Properties; +import java.util.stream.Stream; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Test suite for {@link JavaTableSink}. + */ +class JavaTableSinkTest extends JavaExecutionOperatorTestBase { + + private static final String JDBC_URL = "jdbc:h2:mem:testdb;DB_CLOSE_DELAY=-1"; + private static final String DRIVER = "org.h2.Driver"; + private static final String TABLE_NAME = "test_table"; + + private Connection connection; + + @BeforeEach + void setupTest() throws Exception { + Class.forName(DRIVER); + connection = DriverManager.getConnection(JDBC_URL, "sa", ""); + } + + @AfterEach + void teardownTest() throws Exception { + if (connection != null && !connection.isClosed()) { + try (Statement stmt = connection.createStatement()) { + stmt.execute("DROP TABLE IF EXISTS " + TABLE_NAME); + } + connection.close(); + } + } + + @Test + void testWritingRecordToH2() throws Exception { + Configuration configuration = new Configuration(); + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + JavaTableSink sink = new JavaTableSink<>(dbProps, "overwrite", TABLE_NAME, + new String[] { "id", "name", "value" }, + DataSetType.createDefault(Record.class)); + + Job job = mock(Job.class); + when(job.getConfiguration()).thenReturn(configuration); + final JavaExecutor javaExecutor = (JavaExecutor) JavaPlatform.getInstance().createExecutor(job); + + StreamChannel.Instance inputChannelInstance = (StreamChannel.Instance) StreamChannel.DESCRIPTOR + .createChannel(mock(OutputSlot.class), configuration) + .createInstance(javaExecutor, mock(OptimizationContext.OperatorContext.class), 0); + + Record record1 = new Record(1, "Alice", 100.5); + Record record2 = new Record(2, "Bob", 200.75); + + inputChannelInstance.accept(Stream.of(record1, record2)); + + evaluate(sink, new ChannelInstance[] { inputChannelInstance }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM \"" + TABLE_NAME + "\"")) { + rs.next(); + assertEquals(2, rs.getInt(1)); + } + } + + @Test + void testWritingPojoToH2() throws Exception { + Configuration configuration = new Configuration(); + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + JavaTableSink sink = new JavaTableSink<>(dbProps, "overwrite", TABLE_NAME, + null, + DataSetType.createDefault(TestPojo.class)); + + Job job = mock(Job.class); + when(job.getConfiguration()).thenReturn(configuration); + final JavaExecutor javaExecutor = (JavaExecutor) JavaPlatform.getInstance().createExecutor(job); + + StreamChannel.Instance inputChannelInstance = (StreamChannel.Instance) StreamChannel.DESCRIPTOR + .createChannel(mock(OutputSlot.class), configuration) + .createInstance(javaExecutor, mock(OptimizationContext.OperatorContext.class), 0); + + TestPojo p1 = new TestPojo(1, "Alice"); + TestPojo p2 = new TestPojo(2, "Bob"); + + inputChannelInstance.accept(Stream.of(p1, p2)); + + evaluate(sink, new ChannelInstance[] { inputChannelInstance }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT * FROM \"" + TABLE_NAME + "\" ORDER BY \"id\"")) { + rs.next(); + assertEquals(1, rs.getInt("id")); + assertEquals("Alice", rs.getString("name")); + rs.next(); + assertEquals(2, rs.getInt("id")); + assertEquals("Bob", rs.getString("name")); + } + } + + @Test + void testAppendMode() throws Exception { + Configuration configuration = new Configuration(); + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + // 1. Initial write + JavaTableSink sink1 = new JavaTableSink<>(dbProps, "overwrite", TABLE_NAME, + new String[] { "id", "name" }, + DataSetType.createDefault(Record.class)); + + Job job = mock(Job.class); + when(job.getConfiguration()).thenReturn(configuration); + final JavaExecutor javaExecutor = (JavaExecutor) JavaPlatform.getInstance().createExecutor(job); + + StreamChannel.Instance input1 = (StreamChannel.Instance) StreamChannel.DESCRIPTOR + .createChannel(mock(OutputSlot.class), configuration) + .createInstance(javaExecutor, mock(OptimizationContext.OperatorContext.class), 0); + input1.accept(Stream.of(new Record(1, "Alice"))); + evaluate(sink1, new ChannelInstance[] { input1 }, new ChannelInstance[0]); + + // 2. Append write + JavaTableSink sink2 = new JavaTableSink<>(dbProps, "append", TABLE_NAME, + new String[] { "id", "name" }, + DataSetType.createDefault(Record.class)); + + Job job2 = mock(Job.class); + when(job2.getConfiguration()).thenReturn(configuration); + final JavaExecutor javaExecutor2 = (JavaExecutor) JavaPlatform.getInstance().createExecutor(job2); + + StreamChannel.Instance input2 = (StreamChannel.Instance) StreamChannel.DESCRIPTOR + .createChannel(mock(OutputSlot.class), configuration) + .createInstance(javaExecutor2, mock(OptimizationContext.OperatorContext.class), 0); + input2.accept(Stream.of(new Record(2, "Bob"))); + evaluate(sink2, new ChannelInstance[] { input2 }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM \"" + TABLE_NAME + "\"")) { + rs.next(); + assertEquals(2, rs.getInt(1)); + } + } + + @Test + void testOverwriteWithSchemaMismatch() throws Exception { + Configuration configuration = new Configuration(); + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + // 1. Create table with old schema + try (Statement stmt = connection.createStatement()) { + stmt.execute("CREATE TABLE " + TABLE_NAME + " (id INT, name VARCHAR(255))"); + stmt.execute("INSERT INTO " + TABLE_NAME + " VALUES (1, 'Old')"); + } + + // 2. Overwrite with new schema + JavaTableSink sink = new JavaTableSink<>(dbProps, "overwrite", TABLE_NAME, + new String[] { "id", "age", "city" }, + DataSetType.createDefault(Record.class)); + + Job job = mock(Job.class); + when(job.getConfiguration()).thenReturn(configuration); + final JavaExecutor javaExecutor = (JavaExecutor) JavaPlatform.getInstance().createExecutor(job); + + StreamChannel.Instance input = (StreamChannel.Instance) StreamChannel.DESCRIPTOR + .createChannel(mock(OutputSlot.class), configuration) + .createInstance(javaExecutor, mock(OptimizationContext.OperatorContext.class), 0); + input.accept(Stream.of(new Record(2, 30, "Berlin"))); + evaluate(sink, new ChannelInstance[] { input }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT * FROM \"" + TABLE_NAME + "\"")) { + rs.next(); + assertEquals(2, rs.getInt("id")); + assertEquals(30, rs.getInt("age")); + assertEquals("Berlin", rs.getString("city")); + + // Verify 'name' column is gone + boolean hasName = false; + for (int i = 1; i <= rs.getMetaData().getColumnCount(); i++) { + if ("name".equalsIgnoreCase(rs.getMetaData().getColumnName(i))) { + hasName = true; + } + } + assertFalse(hasName, "Column 'name' should have been dropped"); + } + } + + @Test + void testNullValues() throws Exception { + Configuration configuration = new Configuration(); + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + JavaTableSink sink = new JavaTableSink<>(dbProps, "overwrite", TABLE_NAME, + new String[] { "id", "name" }, + DataSetType.createDefault(Record.class)); + + Job job = mock(Job.class); + when(job.getConfiguration()).thenReturn(configuration); + final JavaExecutor javaExecutor = (JavaExecutor) JavaPlatform.getInstance().createExecutor(job); + + StreamChannel.Instance input = (StreamChannel.Instance) StreamChannel.DESCRIPTOR + .createChannel(mock(OutputSlot.class), configuration) + .createInstance(javaExecutor, mock(OptimizationContext.OperatorContext.class), 0); + + input.accept(Stream.of(new Record(1, null))); + evaluate(sink, new ChannelInstance[] { input }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT \"name\" FROM \"" + TABLE_NAME + "\" WHERE \"id\" = 1")) { + rs.next(); + assertEquals(null, rs.getString(1)); + assertTrue(rs.wasNull()); + } + } + + @Test + void testSupportedTypes() throws Exception { + Configuration configuration = new Configuration(); + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + JavaTableSink sink = new JavaTableSink<>(dbProps, "overwrite", TABLE_NAME, + new String[] { "id", "is_active", "salary", "score" }, + DataSetType.createDefault(Record.class)); + + Job job = mock(Job.class); + when(job.getConfiguration()).thenReturn(configuration); + final JavaExecutor javaExecutor = (JavaExecutor) JavaPlatform.getInstance().createExecutor(job); + + StreamChannel.Instance input = (StreamChannel.Instance) StreamChannel.DESCRIPTOR + .createChannel(mock(OutputSlot.class), configuration) + .createInstance(javaExecutor, mock(OptimizationContext.OperatorContext.class), 0); + + input.accept(Stream.of(new Record(1, true, 5000.50, 95.5f))); + evaluate(sink, new ChannelInstance[] { input }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT * FROM \"" + TABLE_NAME + "\" WHERE \"id\" = 1")) { + rs.next(); + assertTrue(rs.getBoolean("is_active")); + assertEquals(5000.50, rs.getDouble("salary"), 0.001); + assertEquals(95.5f, rs.getFloat("score"), 0.001f); + } + } + + public static class TestPojo { + private int id; + private String name; + + public TestPojo() { + } + + public TestPojo(int id, String name) { + this.id = id; + this.name = name; + } + + public int getId() { + return id; + } + + public String getName() { + return name; + } + } +} \ No newline at end of file diff --git a/wayang-platforms/wayang-java/src/test/java/org/apache/wayang/java/operators/JavaTextFileSourceTest.java b/wayang-platforms/wayang-java/src/test/java/org/apache/wayang/java/operators/JavaTextFileSourceTest.java index f354819fa..d461927e8 100644 --- a/wayang-platforms/wayang-java/src/test/java/org/apache/wayang/java/operators/JavaTextFileSourceTest.java +++ b/wayang-platforms/wayang-java/src/test/java/org/apache/wayang/java/operators/JavaTextFileSourceTest.java @@ -30,6 +30,7 @@ import java.net.URL; import java.util.List; import java.util.Locale; +import java.util.stream.Collectors; import static org.junit.jupiter.api.Assertions.assertEquals; /** @@ -71,7 +72,7 @@ void testReadLocalFile() { evaluate(source, inputs, outputs); // Verify the outcome. - final List result = outputs[0].provideStream().toList(); + final List result = outputs[0].provideStream().collect(Collectors.toList()); assertEquals(63, result.size()); } @@ -100,7 +101,7 @@ void testReadRemoteFileHTTP() throws Exception { evaluate(source, inputs, outputs); // Verify the outcome. - final List result = outputs[0].provideStream().toList(); + final List result = outputs[0].provideStream().collect(Collectors.toList()); assertEquals(225, result.size()); } finally { if (javaExecutor != null) @@ -123,7 +124,7 @@ void testReadRemoteFileHTTPS() throws Exception { evaluate(source, inputs, outputs); // Verify the outcome. - final List result = outputs[0].provideStream().toList(); + final List result = outputs[0].provideStream().collect(Collectors.toList()); assertEquals(64, result.size()); } catch (final Exception e) { Assumptions.assumeTrue(false, "Skipping test due to possible network error: " + e.getMessage()); diff --git a/wayang-platforms/wayang-spark/pom.xml b/wayang-platforms/wayang-spark/pom.xml index 1e89fd15e..02055479d 100644 --- a/wayang-platforms/wayang-spark/pom.xml +++ b/wayang-platforms/wayang-spark/pom.xml @@ -121,5 +121,11 @@ 4.8 + + com.h2database + h2 + 2.2.224 + test + diff --git a/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkTableSink.java b/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkTableSink.java new file mode 100644 index 000000000..c2bcc143a --- /dev/null +++ b/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkTableSink.java @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.wayang.spark.operators; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.wayang.basic.data.Record; +import org.apache.wayang.basic.operators.TableSink; +import org.apache.wayang.basic.util.DatabaseProduct; +import org.apache.wayang.basic.util.SqlTypeUtils; +import org.apache.wayang.core.api.exception.WayangException; +import org.apache.wayang.core.optimizer.OptimizationContext; +import org.apache.wayang.core.plan.wayangplan.ExecutionOperator; +import org.apache.wayang.core.platform.ChannelDescriptor; +import org.apache.wayang.core.platform.ChannelInstance; +import org.apache.wayang.core.platform.lineage.ExecutionLineageNode; +import org.apache.wayang.core.types.DataSetType; +import org.apache.wayang.core.util.Tuple; +import org.apache.wayang.spark.channels.RddChannel; +import org.apache.wayang.spark.execution.SparkExecutor; + +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Properties; + +public class SparkTableSink extends TableSink implements SparkExecutionOperator { + + private SaveMode mode; + + public SparkTableSink(Properties props, String mode, String tableName, String... columnNames) { + super(props, mode, tableName, columnNames); + this.setMode(mode); + } + + public SparkTableSink(Properties props, String mode, String tableName, String[] columnNames, DataSetType type) { + super(props, mode, tableName, columnNames, type); + this.setMode(mode); + } + + public SparkTableSink(TableSink that) { + super(that); + this.setMode(that.getMode()); + } + + @Override + public Tuple, Collection> evaluate( + ChannelInstance[] inputs, + ChannelInstance[] outputs, + SparkExecutor sparkExecutor, + OptimizationContext.OperatorContext operatorContext) { + assert inputs.length == 1; + assert outputs.length == 0; + + JavaRDD recordRDD = ((RddChannel.Instance) inputs[0]).provideRdd(); + Class typeClass = (Class) this.getType().getDataUnitType().getTypeClass(); + SparkSession sparkSession = SparkSession.builder().sparkContext(sparkExecutor.sc.sc()).getOrCreate(); + SQLContext sqlContext = sparkSession.sqlContext(); + + Dataset df; + if (typeClass == Record.class) { + List sample = recordRDD.take(1); + if (sample.isEmpty()) { + return ExecutionOperator.modelEagerExecution(inputs, outputs, operatorContext); + } + Record first = (Record) sample.get(0); + + List schemaFields = SqlTypeUtils.getSchema(first, + SqlTypeUtils.detectProduct(this.getProperties().getProperty("url")), + this.getColumnNames()); + + JavaRDD rowRDD = recordRDD.map(rec -> RowFactory.create(((Record) rec).getValues())); + + StructField[] fields = new StructField[schemaFields.size()]; + for (int i = 0; i < schemaFields.size(); i++) { + SqlTypeUtils.SchemaField sf = schemaFields.get(i); + org.apache.spark.sql.types.DataType sparkType = getSparkDataType(sf.getJavaClass()); + fields[i] = new StructField(sf.getName(), sparkType, true, Metadata.empty()); + } + + // We skip updating column names in the operator to avoid mutating shared state. + // Inferred names are used locally for df creation. + + df = sqlContext.createDataFrame(rowRDD, new StructType(fields)); + } else { + df = sqlContext.createDataFrame(recordRDD, typeClass); + // For POJOs, we currently do not support custom columnNames to avoid + // ambiguous or misleading mappings. Fail fast if they are provided. + String[] columnNames = this.getColumnNames(); + if (columnNames != null && columnNames.length > 0) { + throw new WayangException( + "columnNames are not supported for POJO inputs in SparkTableSink. " + + "Either omit columnNames or use Record inputs if you need custom column mapping."); + } + } + + Properties writeProps = new Properties(); + writeProps.putAll(this.getProperties()); + if (!writeProps.containsKey("batchSize")) { + writeProps.setProperty("batchSize", "250000"); + } + + String targetTable = this.getTableName(); + if (targetTable != null && !targetTable.startsWith("(") && !targetTable.startsWith("\"") + && !targetTable.startsWith("`")) { + DatabaseProduct product = SqlTypeUtils.detectProduct(this.getProperties().getProperty("url")); + String quote = (product == DatabaseProduct.MYSQL) ? "`" + : (product == DatabaseProduct.MSSQL) ? "[" : "\""; + String closingQuote = (product == DatabaseProduct.MSSQL) ? "]" : quote; + targetTable = quote + targetTable + closingQuote; + } + + df.write() + .mode(this.mode) + .jdbc(this.getProperties().getProperty("url"), targetTable, writeProps); + + return ExecutionOperator.modelEagerExecution(inputs, outputs, operatorContext); + } + + private org.apache.spark.sql.types.DataType getSparkDataType(Class cls) { + if (cls == Integer.class || cls == int.class) + return DataTypes.IntegerType; + if (cls == Long.class || cls == long.class) + return DataTypes.LongType; + if (cls == Double.class || cls == double.class) + return DataTypes.DoubleType; + if (cls == Float.class || cls == float.class) + return DataTypes.FloatType; + if (cls == Boolean.class || cls == boolean.class) + return DataTypes.BooleanType; + if (cls == java.sql.Date.class || cls == java.time.LocalDate.class) + return DataTypes.DateType; + if (cls == java.sql.Timestamp.class || cls == java.time.LocalDateTime.class) + return DataTypes.TimestampType; + return DataTypes.StringType; + } + + public void setMode(String mode) { + if (mode == null) { + throw new WayangException("Unspecified write mode for SparkTableSink."); + } else if (mode.equals("append")) { + this.mode = SaveMode.Append; + } else if (mode.equals("overwrite")) { + this.mode = SaveMode.Overwrite; + } else if (mode.equals("errorIfExists")) { + this.mode = SaveMode.ErrorIfExists; + } else if (mode.equals("ignore")) { + this.mode = SaveMode.Ignore; + } else { + throw new WayangException( + String.format("Specified write mode for SparkTableSink does not exist: %s", mode)); + } + } + + @Override + public List getSupportedInputChannels(int index) { + return Arrays.asList(RddChannel.UNCACHED_DESCRIPTOR, RddChannel.CACHED_DESCRIPTOR); + } + + @Override + public List getSupportedOutputChannels(int index) { + throw new UnsupportedOperationException("This operator has no outputs."); + } + + @Override + public boolean containsAction() { + return true; + } + + @Override + public String getLoadProfileEstimatorConfigurationKey() { + return "wayang.spark.tablesink.load"; + } +} diff --git a/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkTableSinkTest.java b/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkTableSinkTest.java new file mode 100644 index 000000000..bd9ca40de --- /dev/null +++ b/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkTableSinkTest.java @@ -0,0 +1,262 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.wayang.spark.operators; + +import org.apache.wayang.basic.data.Record; +import org.apache.wayang.core.platform.ChannelInstance; +import org.apache.wayang.core.types.DataSetType; +import org.apache.wayang.spark.channels.RddChannel; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.Statement; +import java.util.Arrays; +import java.util.Properties; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Test suite for {@link SparkTableSink}. + */ +class SparkTableSinkTest extends SparkOperatorTestBase { + + private static final String JDBC_URL = "jdbc:h2:mem:sparktestdb;DB_CLOSE_DELAY=-1"; + private static final String DRIVER = "org.h2.Driver"; + private static final String TABLE_NAME = "spark_test_table"; + + private Connection connection; + + @BeforeEach + void setupTest() throws Exception { + Class.forName(DRIVER); + connection = DriverManager.getConnection(JDBC_URL, "sa", ""); + } + + @AfterEach + void teardownTest() throws Exception { + if (connection != null && !connection.isClosed()) { + try (Statement stmt = connection.createStatement()) { + stmt.execute("DROP TABLE IF EXISTS \"" + TABLE_NAME + "\""); + } + connection.close(); + } + } + + @Test + void testWritingRecordToH2() throws Exception { + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + SparkTableSink sink = new SparkTableSink<>(dbProps, "overwrite", TABLE_NAME, + new String[] { "id", "name", "value" }, + DataSetType.createDefault(Record.class)); + + Record record1 = new Record(1, "Alice", 100.5); + Record record2 = new Record(2, "Bob", 200.75); + + RddChannel.Instance inputChannelInstance = this.createRddChannelInstance( + Arrays.asList(record1, record2)); + + evaluate(sink, new ChannelInstance[] { inputChannelInstance }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM \"" + TABLE_NAME + "\"")) { + rs.next(); + assertEquals(2, rs.getInt(1)); + } + } + + @Test + void testWritingPojoToH2() throws Exception { + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + SparkTableSink sink = new SparkTableSink<>(dbProps, "overwrite", TABLE_NAME, + null, + DataSetType.createDefault(TestPojo.class)); + + TestPojo p1 = new TestPojo(1, "Alice"); + TestPojo p2 = new TestPojo(2, "Bob"); + + RddChannel.Instance inputChannelInstance = this.createRddChannelInstance( + Arrays.asList(p1, p2)); + + evaluate(sink, new ChannelInstance[] { inputChannelInstance }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT * FROM \"" + TABLE_NAME + "\" ORDER BY \"id\"")) { + rs.next(); + assertEquals(1, rs.getInt("id")); + assertEquals("Alice", rs.getString("name")); + rs.next(); + assertEquals(2, rs.getInt("id")); + assertEquals("Bob", rs.getString("name")); + } + } + + @Test + void testAppendMode() throws Exception { + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + // 1. Initial write + SparkTableSink sink1 = new SparkTableSink<>(dbProps, "overwrite", TABLE_NAME, + new String[] { "id", "name" }, + DataSetType.createDefault(Record.class)); + + RddChannel.Instance input1 = this.createRddChannelInstance(Arrays.asList(new Record(1, "Alice"))); + evaluate(sink1, new ChannelInstance[] { input1 }, new ChannelInstance[0]); + + // 2. Append write + SparkTableSink sink2 = new SparkTableSink<>(dbProps, "append", TABLE_NAME, + new String[] { "id", "name" }, + DataSetType.createDefault(Record.class)); + + RddChannel.Instance input2 = this.createRddChannelInstance(Arrays.asList(new Record(2, "Bob"))); + evaluate(sink2, new ChannelInstance[] { input2 }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM \"" + TABLE_NAME + "\"")) { + rs.next(); + assertEquals(2, rs.getInt(1)); + } + } + + @Test + void testOverwriteWithSchemaMismatch() throws Exception { + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + // 1. Create table with old schema + try (Statement stmt = connection.createStatement()) { + stmt.execute("CREATE TABLE " + TABLE_NAME + " (\"id\" INT, \"name\" VARCHAR(255))"); + stmt.execute("INSERT INTO " + TABLE_NAME + " VALUES (1, 'Old')"); + } + + // 2. Overwrite with new schema + SparkTableSink sink = new SparkTableSink<>(dbProps, "overwrite", TABLE_NAME, + new String[] { "id", "age", "city" }, + DataSetType.createDefault(Record.class)); + + RddChannel.Instance input = this.createRddChannelInstance(Arrays.asList(new Record(2, 30, "Berlin"))); + evaluate(sink, new ChannelInstance[] { input }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT * FROM \"" + TABLE_NAME + "\"")) { + rs.next(); + assertEquals(2, rs.getInt("id")); + assertEquals(30, rs.getInt("age")); + assertEquals("Berlin", rs.getString("city")); + + // Verify 'name' column is gone + boolean hasName = false; + for (int i = 1; i <= rs.getMetaData().getColumnCount(); i++) { + if ("name".equalsIgnoreCase(rs.getMetaData().getColumnName(i))) { + hasName = true; + } + } + assertFalse(hasName, "Column 'name' should have been dropped"); + } + } + + @Test + void testNullValues() throws Exception { + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + SparkTableSink sink = new SparkTableSink<>(dbProps, "overwrite", TABLE_NAME, + new String[] { "id", "name" }, + DataSetType.createDefault(Record.class)); + + RddChannel.Instance input = this.createRddChannelInstance(Arrays.asList(new Record(1, null))); + evaluate(sink, new ChannelInstance[] { input }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT \"name\" FROM \"" + TABLE_NAME + "\" WHERE \"id\" = 1")) { + rs.next(); + assertEquals(null, rs.getString(1)); + assertTrue(rs.wasNull()); + } + } + + @Test + void testSupportedTypes() throws Exception { + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + SparkTableSink sink = new SparkTableSink<>(dbProps, "overwrite", TABLE_NAME, + new String[] { "id", "is_active", "salary", "score" }, + DataSetType.createDefault(Record.class)); + + RddChannel.Instance input = this.createRddChannelInstance(Arrays.asList(new Record(1, true, 5000.50, 95.5f))); + evaluate(sink, new ChannelInstance[] { input }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT * FROM \"" + TABLE_NAME + "\" WHERE \"id\" = 1")) { + rs.next(); + assertTrue(rs.getBoolean("is_active")); + assertEquals(5000.50, rs.getDouble("salary"), 0.001); + assertEquals(95.5f, rs.getFloat("score"), 0.001f); + } + } + + public static class TestPojo implements java.io.Serializable { + private int id; + private String name; + + public TestPojo() { + } + + public TestPojo(int id, String name) { + this.id = id; + this.name = name; + } + + public int getId() { + return id; + } + + public String getName() { + return name; + } + } +} \ No newline at end of file