Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ public class ProjectionDescriptor<Input, Output> extends TransformationDescripto
/**
* Java implementation of a projection on POJOs via reflection.
*/
// TODO: Revise implementation to support multiple field projection, by names
// and indexes.
private static class PojoImplementation<Input, Output>
implements FunctionDescriptor.SerializableFunction<Input, Output> {

Expand Down Expand Up @@ -77,6 +75,50 @@ public Output apply(final Input input) {
}
}

/**
* Java implementation of a multi-field projection on POJOs via reflection.
* Extracts multiple fields by name and packs them into a {@link Record}.
* {@link Field} objects are resolved lazily on first invocation so the
* instance stays serializable (only {@code String[]} is captured).
*/
private static class MultiFieldPojoImplementation<Input>
implements FunctionDescriptor.SerializableFunction<Input, Record> {

private final String[] fieldNames;

private transient Field[] fields;

private MultiFieldPojoImplementation(final String[] fieldNames) {
this.fieldNames = fieldNames;
}

@Override
public Record apply(final Input input) {
if (this.fields == null) {
this.fields = new Field[this.fieldNames.length];
final Class<?> typeClass = input.getClass();
for (int i = 0; i < this.fieldNames.length; i++) {
try {
this.fields[i] = typeClass.getField(this.fieldNames[i]);
} catch (final NoSuchFieldException e) {
throw new IllegalStateException(
String.format("Could not find field '%s' on %s.", this.fieldNames[i], typeClass), e);
}
}
}
final Object[] values = new Object[this.fields.length];
for (int i = 0; i < this.fields.length; i++) {
try {
values[i] = this.fields[i].get(input);
} catch (final IllegalAccessException e) {
throw new RuntimeException(
String.format("Could not access field '%s'.", this.fieldNames[i]), e);
}
}
return new Record(values);
}
}

/**
* Java implementation of a projection on {@link Record}s.
*/
Expand Down Expand Up @@ -126,16 +168,71 @@ public static ProjectionDescriptor<Record, Record> createForRecords(final Record
new RecordType(fieldNames));
}

/**
* Creates a new instance that projects multiple POJO fields by name into a {@link Record}.
*
* @param inputTypeClass the input POJO class
* @param fieldNames names of the public fields to project
* @param <Input> the input type
* @return the new instance
*/
public static <Input> ProjectionDescriptor<Input, Record> createForPojoByNames(
final Class<Input> inputTypeClass, final String... fieldNames) {
if (fieldNames.length == 0) {
throw new IllegalArgumentException("At least one field name must be provided.");
}
final FunctionDescriptor.SerializableFunction<Input, Record> impl = new MultiFieldPojoImplementation<>(fieldNames);
return new ProjectionDescriptor<>(
impl,
Arrays.asList(fieldNames),
BasicDataUnitType.createBasic(inputTypeClass),
(BasicDataUnitType<Record>) BasicDataUnitType.createBasic(Record.class));
}

/**
* Creates a new instance that projects POJO fields by their declared index.
*
* @param inputTypeClass the input POJO class
* @param fieldIndexes indexes of the public fields to project (in declaration order)
* @param <Input> the input type
* @return the new instance
*/
public static <Input> ProjectionDescriptor<Input, Record> createForPojoByIndexes(
final Class<Input> inputTypeClass, final int... fieldIndexes) {
if (fieldIndexes.length == 0) {
throw new IllegalArgumentException("At least one field index must be provided.");
}
final Field[] allFields = inputTypeClass.getFields();
final String[] names = new String[fieldIndexes.length];
for (int i = 0; i < fieldIndexes.length; i++) {
final int idx = fieldIndexes[i];
if (idx < 0 || idx >= allFields.length) {
throw new IllegalArgumentException(
String.format("Field index %d is out of bounds (0..%d).", idx, allFields.length - 1));
}
names[i] = allFields[idx].getName();
}
final FunctionDescriptor.SerializableFunction<Input, Record> impl = new MultiFieldPojoImplementation<>(names);
return new ProjectionDescriptor<>(
impl,
Arrays.asList(names),
BasicDataUnitType.createBasic(inputTypeClass),
(BasicDataUnitType<Record>) BasicDataUnitType.createBasic(Record.class));
}

@SuppressWarnings("unchecked")
private static <Input, Output> FunctionDescriptor.SerializableFunction<Input, Output> createPojoJavaImplementation(
final String[] fieldNames, final BasicDataUnitType<Input> inputType) {
// Get the names of the fields to be projected.
if (fieldNames.length != 1) {
return t -> {
throw new IllegalStateException("The projection descriptor currently supports only a single field.");
};
if (fieldNames.length == 0) {
throw new IllegalArgumentException("At least one field name must be provided.");
}
if (fieldNames.length == 1) {
return new PojoImplementation<>(fieldNames[0]);
}
final String fieldName = fieldNames[0];
return new PojoImplementation<>(fieldName);
return (FunctionDescriptor.SerializableFunction<Input, Output>)
(FunctionDescriptor.SerializableFunction<Input, ?>)
new MultiFieldPojoImplementation<>(fieldNames);
}

private static FunctionDescriptor.SerializableFunction<Record, Record> createRecordJavaImplementation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@
import org.apache.wayang.basic.types.RecordType;
import org.junit.jupiter.api.Test;

import java.util.Arrays;
import java.util.function.Function;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;

/**
* Tests for the {@link ProjectionDescriptor}.
Expand Down Expand Up @@ -65,6 +67,98 @@ void testRecordImplementation() {
);
}

@Test
void testMultiFieldPojoProjectionByName() {
final ProjectionDescriptor<Pojo, Record> descriptor =
ProjectionDescriptor.createForPojoByNames(Pojo.class, "string", "integer");
final Function<Pojo, Record> implementation = descriptor.getJavaImplementation();

assertEquals(
new Record("testValue", 1),
implementation.apply(new Pojo("testValue", 1))
);
}

@Test
void testMultiFieldPojoProjectionReordersFields() {
final ProjectionDescriptor<Pojo, Record> descriptor =
ProjectionDescriptor.createForPojoByNames(Pojo.class, "integer", "string");
final Function<Pojo, Record> implementation = descriptor.getJavaImplementation();

assertEquals(
new Record(42, "hello"),
implementation.apply(new Pojo("hello", 42))
);
}

@Test
void testMultiFieldPojoProjectionWithNulls() {
final ProjectionDescriptor<Pojo, Record> descriptor =
ProjectionDescriptor.createForPojoByNames(Pojo.class, "string", "integer");
final Function<Pojo, Record> implementation = descriptor.getJavaImplementation();

assertEquals(
new Record(null, 0),
implementation.apply(new Pojo(null, 0))
);
}

@Test
void testPojoProjectionByIndex() {
// Pojo declares: string (index 0), integer (index 1)
final ProjectionDescriptor<Pojo, Record> descriptor =
ProjectionDescriptor.createForPojoByIndexes(Pojo.class, 1, 0);
final Function<Pojo, Record> implementation = descriptor.getJavaImplementation();

assertEquals(
new Record(99, "abc"),
implementation.apply(new Pojo("abc", 99))
);
assertEquals(Arrays.asList("integer", "string"), descriptor.getFieldNames());
}

@Test
void testPojoProjectionByIndexSingleField() {
// Pojo declares: string (index 0), integer (index 1)
final ProjectionDescriptor<Pojo, Record> descriptor =
ProjectionDescriptor.createForPojoByIndexes(Pojo.class, 0);
final Function<Pojo, Record> implementation = descriptor.getJavaImplementation();

assertEquals(
new Record("hello"),
implementation.apply(new Pojo("hello", 5))
);
assertEquals(Arrays.asList("string"), descriptor.getFieldNames());
}

@Test
void testPojoProjectionByIndexOutOfBounds() {
assertThrows(IllegalArgumentException.class, () ->
ProjectionDescriptor.createForPojoByIndexes(Pojo.class, 5));
}

@Test
void testPojoProjectionByNameNonexistentField() {
final ProjectionDescriptor<Pojo, Record> descriptor =
ProjectionDescriptor.createForPojoByNames(Pojo.class, "string", "nonexistent");
final Function<Pojo, Record> implementation = descriptor.getJavaImplementation();

assertThrows(IllegalStateException.class, () ->
implementation.apply(new Pojo("test", 1)));
}

@Test
void testPojoProjectionEmptyFieldNames() {
assertThrows(IllegalArgumentException.class, () ->
new ProjectionDescriptor<>(Pojo.class, Record.class));
}

@Test
void testPojoProjectionByIndexEmpty() {
assertThrows(IllegalArgumentException.class, () ->
ProjectionDescriptor.createForPojoByIndexes(Pojo.class));
}

public static class Pojo {

public String string;
Expand Down
Loading