diff --git a/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/function/ProjectionDescriptor.java b/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/function/ProjectionDescriptor.java
index 15aeabc53..dff444ae2 100644
--- a/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/function/ProjectionDescriptor.java
+++ b/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/function/ProjectionDescriptor.java
@@ -38,8 +38,6 @@ public class ProjectionDescriptor extends TransformationDescripto
/**
* Java implementation of a projection on POJOs via reflection.
*/
- // TODO: Revise implementation to support multiple field projection, by names
- // and indexes.
private static class PojoImplementation
implements FunctionDescriptor.SerializableFunction {
@@ -77,6 +75,50 @@ public Output apply(final Input input) {
}
}
+ /**
+ * Java implementation of a multi-field projection on POJOs via reflection.
+ * Extracts multiple fields by name and packs them into a {@link Record}.
+ * {@link Field} objects are resolved lazily on first invocation so the
+ * instance stays serializable (only {@code String[]} is captured).
+ */
+ private static class MultiFieldPojoImplementation
+ implements FunctionDescriptor.SerializableFunction {
+
+ private final String[] fieldNames;
+
+ private transient Field[] fields;
+
+ private MultiFieldPojoImplementation(final String[] fieldNames) {
+ this.fieldNames = fieldNames;
+ }
+
+ @Override
+ public Record apply(final Input input) {
+ if (this.fields == null) {
+ this.fields = new Field[this.fieldNames.length];
+ final Class> typeClass = input.getClass();
+ for (int i = 0; i < this.fieldNames.length; i++) {
+ try {
+ this.fields[i] = typeClass.getField(this.fieldNames[i]);
+ } catch (final NoSuchFieldException e) {
+ throw new IllegalStateException(
+ String.format("Could not find field '%s' on %s.", this.fieldNames[i], typeClass), e);
+ }
+ }
+ }
+ final Object[] values = new Object[this.fields.length];
+ for (int i = 0; i < this.fields.length; i++) {
+ try {
+ values[i] = this.fields[i].get(input);
+ } catch (final IllegalAccessException e) {
+ throw new RuntimeException(
+ String.format("Could not access field '%s'.", this.fieldNames[i]), e);
+ }
+ }
+ return new Record(values);
+ }
+ }
+
/**
* Java implementation of a projection on {@link Record}s.
*/
@@ -126,16 +168,71 @@ public static ProjectionDescriptor createForRecords(final Record
new RecordType(fieldNames));
}
+ /**
+ * Creates a new instance that projects multiple POJO fields by name into a {@link Record}.
+ *
+ * @param inputTypeClass the input POJO class
+ * @param fieldNames names of the public fields to project
+ * @param the input type
+ * @return the new instance
+ */
+ public static ProjectionDescriptor createForPojoByNames(
+ final Class inputTypeClass, final String... fieldNames) {
+ if (fieldNames.length == 0) {
+ throw new IllegalArgumentException("At least one field name must be provided.");
+ }
+ final FunctionDescriptor.SerializableFunction impl = new MultiFieldPojoImplementation<>(fieldNames);
+ return new ProjectionDescriptor<>(
+ impl,
+ Arrays.asList(fieldNames),
+ BasicDataUnitType.createBasic(inputTypeClass),
+ (BasicDataUnitType) BasicDataUnitType.createBasic(Record.class));
+ }
+
+ /**
+ * Creates a new instance that projects POJO fields by their declared index.
+ *
+ * @param inputTypeClass the input POJO class
+ * @param fieldIndexes indexes of the public fields to project (in declaration order)
+ * @param the input type
+ * @return the new instance
+ */
+ public static ProjectionDescriptor createForPojoByIndexes(
+ final Class inputTypeClass, final int... fieldIndexes) {
+ if (fieldIndexes.length == 0) {
+ throw new IllegalArgumentException("At least one field index must be provided.");
+ }
+ final Field[] allFields = inputTypeClass.getFields();
+ final String[] names = new String[fieldIndexes.length];
+ for (int i = 0; i < fieldIndexes.length; i++) {
+ final int idx = fieldIndexes[i];
+ if (idx < 0 || idx >= allFields.length) {
+ throw new IllegalArgumentException(
+ String.format("Field index %d is out of bounds (0..%d).", idx, allFields.length - 1));
+ }
+ names[i] = allFields[idx].getName();
+ }
+ final FunctionDescriptor.SerializableFunction impl = new MultiFieldPojoImplementation<>(names);
+ return new ProjectionDescriptor<>(
+ impl,
+ Arrays.asList(names),
+ BasicDataUnitType.createBasic(inputTypeClass),
+ (BasicDataUnitType) BasicDataUnitType.createBasic(Record.class));
+ }
+
+ @SuppressWarnings("unchecked")
private static FunctionDescriptor.SerializableFunction createPojoJavaImplementation(
final String[] fieldNames, final BasicDataUnitType inputType) {
// Get the names of the fields to be projected.
- if (fieldNames.length != 1) {
- return t -> {
- throw new IllegalStateException("The projection descriptor currently supports only a single field.");
- };
+ if (fieldNames.length == 0) {
+ throw new IllegalArgumentException("At least one field name must be provided.");
+ }
+ if (fieldNames.length == 1) {
+ return new PojoImplementation<>(fieldNames[0]);
}
- final String fieldName = fieldNames[0];
- return new PojoImplementation<>(fieldName);
+ return (FunctionDescriptor.SerializableFunction)
+ (FunctionDescriptor.SerializableFunction)
+ new MultiFieldPojoImplementation<>(fieldNames);
}
private static FunctionDescriptor.SerializableFunction createRecordJavaImplementation(
diff --git a/wayang-commons/wayang-basic/src/test/java/org/apache/wayang/basic/function/ProjectionDescriptorTest.java b/wayang-commons/wayang-basic/src/test/java/org/apache/wayang/basic/function/ProjectionDescriptorTest.java
index 72f60994d..60fd81f5f 100644
--- a/wayang-commons/wayang-basic/src/test/java/org/apache/wayang/basic/function/ProjectionDescriptorTest.java
+++ b/wayang-commons/wayang-basic/src/test/java/org/apache/wayang/basic/function/ProjectionDescriptorTest.java
@@ -22,10 +22,12 @@
import org.apache.wayang.basic.types.RecordType;
import org.junit.jupiter.api.Test;
+import java.util.Arrays;
import java.util.function.Function;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
/**
* Tests for the {@link ProjectionDescriptor}.
@@ -65,6 +67,98 @@ void testRecordImplementation() {
);
}
+ @Test
+ void testMultiFieldPojoProjectionByName() {
+ final ProjectionDescriptor descriptor =
+ ProjectionDescriptor.createForPojoByNames(Pojo.class, "string", "integer");
+ final Function implementation = descriptor.getJavaImplementation();
+
+ assertEquals(
+ new Record("testValue", 1),
+ implementation.apply(new Pojo("testValue", 1))
+ );
+ }
+
+ @Test
+ void testMultiFieldPojoProjectionReordersFields() {
+ final ProjectionDescriptor descriptor =
+ ProjectionDescriptor.createForPojoByNames(Pojo.class, "integer", "string");
+ final Function implementation = descriptor.getJavaImplementation();
+
+ assertEquals(
+ new Record(42, "hello"),
+ implementation.apply(new Pojo("hello", 42))
+ );
+ }
+
+ @Test
+ void testMultiFieldPojoProjectionWithNulls() {
+ final ProjectionDescriptor descriptor =
+ ProjectionDescriptor.createForPojoByNames(Pojo.class, "string", "integer");
+ final Function implementation = descriptor.getJavaImplementation();
+
+ assertEquals(
+ new Record(null, 0),
+ implementation.apply(new Pojo(null, 0))
+ );
+ }
+
+ @Test
+ void testPojoProjectionByIndex() {
+ // Pojo declares: string (index 0), integer (index 1)
+ final ProjectionDescriptor descriptor =
+ ProjectionDescriptor.createForPojoByIndexes(Pojo.class, 1, 0);
+ final Function implementation = descriptor.getJavaImplementation();
+
+ assertEquals(
+ new Record(99, "abc"),
+ implementation.apply(new Pojo("abc", 99))
+ );
+ assertEquals(Arrays.asList("integer", "string"), descriptor.getFieldNames());
+ }
+
+ @Test
+ void testPojoProjectionByIndexSingleField() {
+ // Pojo declares: string (index 0), integer (index 1)
+ final ProjectionDescriptor descriptor =
+ ProjectionDescriptor.createForPojoByIndexes(Pojo.class, 0);
+ final Function implementation = descriptor.getJavaImplementation();
+
+ assertEquals(
+ new Record("hello"),
+ implementation.apply(new Pojo("hello", 5))
+ );
+ assertEquals(Arrays.asList("string"), descriptor.getFieldNames());
+ }
+
+ @Test
+ void testPojoProjectionByIndexOutOfBounds() {
+ assertThrows(IllegalArgumentException.class, () ->
+ ProjectionDescriptor.createForPojoByIndexes(Pojo.class, 5));
+ }
+
+ @Test
+ void testPojoProjectionByNameNonexistentField() {
+ final ProjectionDescriptor descriptor =
+ ProjectionDescriptor.createForPojoByNames(Pojo.class, "string", "nonexistent");
+ final Function implementation = descriptor.getJavaImplementation();
+
+ assertThrows(IllegalStateException.class, () ->
+ implementation.apply(new Pojo("test", 1)));
+ }
+
+ @Test
+ void testPojoProjectionEmptyFieldNames() {
+ assertThrows(IllegalArgumentException.class, () ->
+ new ProjectionDescriptor<>(Pojo.class, Record.class));
+ }
+
+ @Test
+ void testPojoProjectionByIndexEmpty() {
+ assertThrows(IllegalArgumentException.class, () ->
+ ProjectionDescriptor.createForPojoByIndexes(Pojo.class));
+ }
+
public static class Pojo {
public String string;