diff --git a/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/function/ProjectionDescriptor.java b/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/function/ProjectionDescriptor.java index 15aeabc53..dff444ae2 100644 --- a/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/function/ProjectionDescriptor.java +++ b/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/function/ProjectionDescriptor.java @@ -38,8 +38,6 @@ public class ProjectionDescriptor extends TransformationDescripto /** * Java implementation of a projection on POJOs via reflection. */ - // TODO: Revise implementation to support multiple field projection, by names - // and indexes. private static class PojoImplementation implements FunctionDescriptor.SerializableFunction { @@ -77,6 +75,50 @@ public Output apply(final Input input) { } } + /** + * Java implementation of a multi-field projection on POJOs via reflection. + * Extracts multiple fields by name and packs them into a {@link Record}. + * {@link Field} objects are resolved lazily on first invocation so the + * instance stays serializable (only {@code String[]} is captured). + */ + private static class MultiFieldPojoImplementation + implements FunctionDescriptor.SerializableFunction { + + private final String[] fieldNames; + + private transient Field[] fields; + + private MultiFieldPojoImplementation(final String[] fieldNames) { + this.fieldNames = fieldNames; + } + + @Override + public Record apply(final Input input) { + if (this.fields == null) { + this.fields = new Field[this.fieldNames.length]; + final Class typeClass = input.getClass(); + for (int i = 0; i < this.fieldNames.length; i++) { + try { + this.fields[i] = typeClass.getField(this.fieldNames[i]); + } catch (final NoSuchFieldException e) { + throw new IllegalStateException( + String.format("Could not find field '%s' on %s.", this.fieldNames[i], typeClass), e); + } + } + } + final Object[] values = new Object[this.fields.length]; + for (int i = 0; i < this.fields.length; i++) { + try { + values[i] = this.fields[i].get(input); + } catch (final IllegalAccessException e) { + throw new RuntimeException( + String.format("Could not access field '%s'.", this.fieldNames[i]), e); + } + } + return new Record(values); + } + } + /** * Java implementation of a projection on {@link Record}s. */ @@ -126,16 +168,71 @@ public static ProjectionDescriptor createForRecords(final Record new RecordType(fieldNames)); } + /** + * Creates a new instance that projects multiple POJO fields by name into a {@link Record}. + * + * @param inputTypeClass the input POJO class + * @param fieldNames names of the public fields to project + * @param the input type + * @return the new instance + */ + public static ProjectionDescriptor createForPojoByNames( + final Class inputTypeClass, final String... fieldNames) { + if (fieldNames.length == 0) { + throw new IllegalArgumentException("At least one field name must be provided."); + } + final FunctionDescriptor.SerializableFunction impl = new MultiFieldPojoImplementation<>(fieldNames); + return new ProjectionDescriptor<>( + impl, + Arrays.asList(fieldNames), + BasicDataUnitType.createBasic(inputTypeClass), + (BasicDataUnitType) BasicDataUnitType.createBasic(Record.class)); + } + + /** + * Creates a new instance that projects POJO fields by their declared index. + * + * @param inputTypeClass the input POJO class + * @param fieldIndexes indexes of the public fields to project (in declaration order) + * @param the input type + * @return the new instance + */ + public static ProjectionDescriptor createForPojoByIndexes( + final Class inputTypeClass, final int... fieldIndexes) { + if (fieldIndexes.length == 0) { + throw new IllegalArgumentException("At least one field index must be provided."); + } + final Field[] allFields = inputTypeClass.getFields(); + final String[] names = new String[fieldIndexes.length]; + for (int i = 0; i < fieldIndexes.length; i++) { + final int idx = fieldIndexes[i]; + if (idx < 0 || idx >= allFields.length) { + throw new IllegalArgumentException( + String.format("Field index %d is out of bounds (0..%d).", idx, allFields.length - 1)); + } + names[i] = allFields[idx].getName(); + } + final FunctionDescriptor.SerializableFunction impl = new MultiFieldPojoImplementation<>(names); + return new ProjectionDescriptor<>( + impl, + Arrays.asList(names), + BasicDataUnitType.createBasic(inputTypeClass), + (BasicDataUnitType) BasicDataUnitType.createBasic(Record.class)); + } + + @SuppressWarnings("unchecked") private static FunctionDescriptor.SerializableFunction createPojoJavaImplementation( final String[] fieldNames, final BasicDataUnitType inputType) { // Get the names of the fields to be projected. - if (fieldNames.length != 1) { - return t -> { - throw new IllegalStateException("The projection descriptor currently supports only a single field."); - }; + if (fieldNames.length == 0) { + throw new IllegalArgumentException("At least one field name must be provided."); + } + if (fieldNames.length == 1) { + return new PojoImplementation<>(fieldNames[0]); } - final String fieldName = fieldNames[0]; - return new PojoImplementation<>(fieldName); + return (FunctionDescriptor.SerializableFunction) + (FunctionDescriptor.SerializableFunction) + new MultiFieldPojoImplementation<>(fieldNames); } private static FunctionDescriptor.SerializableFunction createRecordJavaImplementation( diff --git a/wayang-commons/wayang-basic/src/test/java/org/apache/wayang/basic/function/ProjectionDescriptorTest.java b/wayang-commons/wayang-basic/src/test/java/org/apache/wayang/basic/function/ProjectionDescriptorTest.java index 72f60994d..60fd81f5f 100644 --- a/wayang-commons/wayang-basic/src/test/java/org/apache/wayang/basic/function/ProjectionDescriptorTest.java +++ b/wayang-commons/wayang-basic/src/test/java/org/apache/wayang/basic/function/ProjectionDescriptorTest.java @@ -22,10 +22,12 @@ import org.apache.wayang.basic.types.RecordType; import org.junit.jupiter.api.Test; +import java.util.Arrays; import java.util.function.Function; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; /** * Tests for the {@link ProjectionDescriptor}. @@ -65,6 +67,98 @@ void testRecordImplementation() { ); } + @Test + void testMultiFieldPojoProjectionByName() { + final ProjectionDescriptor descriptor = + ProjectionDescriptor.createForPojoByNames(Pojo.class, "string", "integer"); + final Function implementation = descriptor.getJavaImplementation(); + + assertEquals( + new Record("testValue", 1), + implementation.apply(new Pojo("testValue", 1)) + ); + } + + @Test + void testMultiFieldPojoProjectionReordersFields() { + final ProjectionDescriptor descriptor = + ProjectionDescriptor.createForPojoByNames(Pojo.class, "integer", "string"); + final Function implementation = descriptor.getJavaImplementation(); + + assertEquals( + new Record(42, "hello"), + implementation.apply(new Pojo("hello", 42)) + ); + } + + @Test + void testMultiFieldPojoProjectionWithNulls() { + final ProjectionDescriptor descriptor = + ProjectionDescriptor.createForPojoByNames(Pojo.class, "string", "integer"); + final Function implementation = descriptor.getJavaImplementation(); + + assertEquals( + new Record(null, 0), + implementation.apply(new Pojo(null, 0)) + ); + } + + @Test + void testPojoProjectionByIndex() { + // Pojo declares: string (index 0), integer (index 1) + final ProjectionDescriptor descriptor = + ProjectionDescriptor.createForPojoByIndexes(Pojo.class, 1, 0); + final Function implementation = descriptor.getJavaImplementation(); + + assertEquals( + new Record(99, "abc"), + implementation.apply(new Pojo("abc", 99)) + ); + assertEquals(Arrays.asList("integer", "string"), descriptor.getFieldNames()); + } + + @Test + void testPojoProjectionByIndexSingleField() { + // Pojo declares: string (index 0), integer (index 1) + final ProjectionDescriptor descriptor = + ProjectionDescriptor.createForPojoByIndexes(Pojo.class, 0); + final Function implementation = descriptor.getJavaImplementation(); + + assertEquals( + new Record("hello"), + implementation.apply(new Pojo("hello", 5)) + ); + assertEquals(Arrays.asList("string"), descriptor.getFieldNames()); + } + + @Test + void testPojoProjectionByIndexOutOfBounds() { + assertThrows(IllegalArgumentException.class, () -> + ProjectionDescriptor.createForPojoByIndexes(Pojo.class, 5)); + } + + @Test + void testPojoProjectionByNameNonexistentField() { + final ProjectionDescriptor descriptor = + ProjectionDescriptor.createForPojoByNames(Pojo.class, "string", "nonexistent"); + final Function implementation = descriptor.getJavaImplementation(); + + assertThrows(IllegalStateException.class, () -> + implementation.apply(new Pojo("test", 1))); + } + + @Test + void testPojoProjectionEmptyFieldNames() { + assertThrows(IllegalArgumentException.class, () -> + new ProjectionDescriptor<>(Pojo.class, Record.class)); + } + + @Test + void testPojoProjectionByIndexEmpty() { + assertThrows(IllegalArgumentException.class, () -> + ProjectionDescriptor.createForPojoByIndexes(Pojo.class)); + } + public static class Pojo { public String string;