diff --git a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java index 0f58da6af..d820031cd 100644 --- a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java +++ b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java @@ -12,6 +12,7 @@ import io.substrait.expression.Expression.SwitchClause; import io.substrait.expression.FieldReference; import io.substrait.expression.FunctionArg; +import io.substrait.expression.FunctionOption; import io.substrait.expression.WindowBound; import io.substrait.extension.DefaultExtensionCatalog; import io.substrait.extension.SimpleExtension; @@ -39,6 +40,7 @@ import io.substrait.type.NamedStruct; import io.substrait.type.Type; import io.substrait.type.TypeCreator; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.LinkedList; @@ -73,6 +75,17 @@ public class SubstraitBuilder { private final SimpleExtension.ExtensionCollection extensions; + /** + * Constructs a new SubstraitBuilder with the default extension collection. + * + *

The builder is initialized with {@link DefaultExtensionCatalog#DEFAULT_COLLECTION}, which + * includes standard Substrait functions for strings, arithmetic, comparison, datetime, and other + * operations. + */ + public SubstraitBuilder() { + this(DefaultExtensionCatalog.DEFAULT_COLLECTION); + } + /** * Constructs a new SubstraitBuilder with the specified extension collection. * @@ -83,6 +96,18 @@ public SubstraitBuilder(SimpleExtension.ExtensionCollection extensions) { this.extensions = extensions; } + /** + * Gets the default extension collection used by this builder. + * + *

This collection includes standard Substrait functions for strings, arithmetic, comparison, + * datetime, and other operations from {@link DefaultExtensionCatalog#DEFAULT_COLLECTION}. + * + * @return the ExtensionCollection containing standard Substrait functions + */ + public SimpleExtension.ExtensionCollection getExtensions() { + return extensions; + } + // Relations /** @@ -142,13 +167,32 @@ public Aggregate aggregate( return aggregate(groupingsFn, measuresFn, Optional.of(remap), input); } - private Aggregate aggregate( - Function> groupingsFn, - Function> measuresFn, - Optional remap, - Rel input) { - List groupings = groupingsFn.apply(input); - List measures = measuresFn.apply(input); + /** + * Creates an aggregate relation that groups and aggregates data from an input relation. + * + *

This method constructs a Substrait aggregate operation by applying grouping and measure + * functions to the input relation. The grouping function defines how rows are grouped together, + * while the measure function defines the aggregate computations (e.g., SUM, COUNT, AVG) to + * perform on each group. + * + *

The optional remap parameter allows reordering or filtering of output columns, which is + * useful for controlling the final schema of the aggregate result. + * + * @param groupingsFn a function that takes the input relation and returns a list of grouping + * expressions defining how to partition the data + * @param measuresFn a function that takes the input relation and returns a list of aggregate + * measures to compute for each group + * @param remap an optional remapping specification to reorder or filter output columns + * @param input the input relation to aggregate + * @return an Aggregate relation representing the grouping and aggregation operation + */ + public Aggregate aggregate( + final Function> groupingsFn, + final Function> measuresFn, + final Optional remap, + final Rel input) { + final List groupings = groupingsFn.apply(input); + final List measures = measuresFn.apply(input); return Aggregate.builder() .groupings(groupings) .measures(measures) @@ -853,6 +897,26 @@ public Expression.BoolLiteral bool(boolean v) { return Expression.BoolLiteral.builder().value(v).build(); } + /** + * Create i16 literal. + * + * @param value value to create + * @return i16 instance + */ + public Expression.I8Literal i8(final int value) { + return Expression.I8Literal.builder().value(value).build(); + } + + /** + * Create i16 literal. + * + * @param value value to create + * @return i16 instance + */ + public Expression.I16Literal i16(final int value) { + return Expression.I16Literal.builder().value(value).build(); + } + /** * Creates a 32-bit integer literal expression. * @@ -863,6 +927,26 @@ public Expression.I32Literal i32(int v) { return Expression.I32Literal.builder().value(v).build(); } + /** + * Creates a 64-bit integer literal expression. + * + * @param value value to create + * @return i64 instance + */ + public Expression.I64Literal i64(final long value) { + return Expression.I64Literal.builder().value(value).build(); + } + + /** + * Creates a 32-bit floating point literal expression. + * + * @param value the float value + * @return a new {@link Expression.FP32Literal} + */ + public Expression.FP32Literal fp32(final float value) { + return Expression.FP32Literal.builder().value(value).build(); + } + /** * Creates a 64-bit floating point literal expression. * @@ -1439,6 +1523,79 @@ public Expression.ScalarFunctionInvocation or(Expression... args) { return scalarFn(DefaultExtensionCatalog.FUNCTIONS_BOOLEAN, "or:bool", outputType, args); } + /** + * Creates a logical NOT expression that negates a boolean expression. + * + *

This is a convenience method that wraps the boolean NOT function from the Substrait standard + * library. The result is nullable to handle NULL input values according to three-valued logic. + * + * @param expression the boolean expression to negate + * @return a scalar function invocation representing the logical NOT of the input expression + */ + public Expression not(final Expression expression) { + return this.scalarFn( + DefaultExtensionCatalog.FUNCTIONS_BOOLEAN, + "not:bool", + TypeCreator.NULLABLE.BOOLEAN, + expression); + } + + /** + * Creates a null-check expression that tests whether an expression is null. + * + *

This is a convenience method that wraps the is_null function from the Substrait comparison + * function library. The function evaluates the input expression and returns true if it is null, + * false otherwise. This is commonly used in conditional logic and filtering operations. + * + *

The return type is always a required (non-nullable) boolean, as the null check itself always + * produces a definite true/false result. + * + * @param expression the expression to test for null + * @return a scalar function invocation that returns true if the expression is null, false + * otherwise + */ + public Expression isNull(final Expression expression) { + + final List args = new ArrayList<>(); + args.add(expression); + + return this.scalarFn( + DefaultExtensionCatalog.FUNCTIONS_COMPARISON, + "is_null:any", + TypeCreator.REQUIRED.BOOLEAN, + args, + new ArrayList()); + } + + /** + * Creates a scalar function invocation with function options. + * + *

This method extends the base builder's functionality by supporting function options, which + * control function behavior (e.g., rounding modes, overflow handling). + * + * @param urn the extension URI (e.g., {@link DefaultExtensionCatalog#FUNCTIONS_STRING}) + * @param key the function signature (e.g., "substring:str_i32_i32") + * @param returnType the return type of the function + * @param args the function arguments + * @param optionsList the function options controlling behavior + * @return a scalar function invocation expression + */ + public Expression scalarFn( + final String urn, + final String key, + final Type returnType, + final List args, + final List optionsList) { + final SimpleExtension.ScalarFunctionVariant declaration = + extensions.getScalarFunction(SimpleExtension.FunctionAnchor.of(urn, key)); + return Expression.ScalarFunctionInvocation.builder() + .declaration(declaration) + .options(optionsList) + .outputType(returnType) + .arguments(args) + .build(); + } + /** * Creates a scalar function invocation with specified arguments. * @@ -1459,6 +1616,29 @@ public Expression.ScalarFunctionInvocation scalarFn( .build(); } + /** + * Creates a scalar function invocation with function options. + * + * @param urn the extension URI (e.g., {@link DefaultExtensionCatalog#FUNCTIONS_STRING}) + * @param key the function signature (e.g., "substring:str_i32_i32") + * @param returnType the return type of the function + * @param args the function arguments + * @return a scalar function invocation expression + */ + public Expression scalarFn( + final String urn, + final String key, + final Type returnType, + final List args) { + final SimpleExtension.ScalarFunctionVariant declaration = + extensions.getScalarFunction(SimpleExtension.FunctionAnchor.of(urn, key)); + return Expression.ScalarFunctionInvocation.builder() + .declaration(declaration) + .outputType(returnType) + .arguments(args) + .build(); + } + /** * Creates a window function invocation with specified arguments and window bounds. * @@ -1532,6 +1712,22 @@ public Plan plan(Plan.Root root) { return Plan.builder().addRoots(root).build(); } + /** + * Creates a Plan.Root, which is the top-level container for a Substrait query plan. + * + *

The {@link Plan} wraps a relational expression tree and associates output column names with + * the plan. This is the final step in building a complete Substrait plan that can be serialized + * and executed by a Substrait consumer. + * + * @param input the root relational expression of the query plan + * @param names the ordered list of output column names corresponding to the input relation's + * output schema + * @return a new {@link Plan} + */ + public Plan.Root root(final Rel input, final List names) { + return Plan.Root.builder().input(input).names(names).build(); + } + /** * Creates a field remapping specification from field indexes. * diff --git a/core/src/test/java/io/substrait/dsl/SubstraitBuilderTest.java b/core/src/test/java/io/substrait/dsl/SubstraitBuilderTest.java new file mode 100644 index 000000000..2a199de03 --- /dev/null +++ b/core/src/test/java/io/substrait/dsl/SubstraitBuilderTest.java @@ -0,0 +1,265 @@ +package io.substrait.dsl; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.substrait.TestBase; +import io.substrait.expression.AggregateFunctionInvocation; +import io.substrait.expression.Expression; +import io.substrait.expression.FieldReference; +import io.substrait.extension.SimpleExtension; +import io.substrait.plan.Plan; +import io.substrait.relation.AbstractWriteRel; +import io.substrait.relation.Aggregate; +import io.substrait.relation.Cross; +import io.substrait.relation.Fetch; +import io.substrait.relation.Filter; +import io.substrait.relation.Join; +import io.substrait.relation.NamedScan; +import io.substrait.relation.NamedWrite; +import io.substrait.relation.Project; +import io.substrait.relation.Rel.Remap; +import io.substrait.relation.Sort; +import io.substrait.type.Type; +import java.util.Arrays; +import java.util.List; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; + +class SubstraitBuilderTest extends TestBase { + + private SubstraitBuilder builder; + + @BeforeEach + void newbuilder() { + this.builder = new SubstraitBuilder(); + } + + @Test + void basicCreation() { + assertNotNull(this.builder); + assertInstanceOf(SimpleExtension.ExtensionCollection.class, this.builder.getExtensions()); + } + + @Nested + @DisplayName("Literal and Expression Tests") + class ExpressionTests { + + @Test + void testLiterals() { + assertEquals(true, builder.bool(true).value()); + assertEquals(10, builder.i8(10).value()); + assertEquals(100, builder.i16(100).value()); + assertEquals(1000, builder.i32(1000).value()); + assertEquals(10_000L, builder.i64(10_000L).value()); + assertEquals(1.5f, builder.fp32(1.5f).value()); + assertEquals(2.5, builder.fp64(2.5).value()); + assertEquals("foo", builder.str("foo").value()); + } + + @Test + void testFieldReferences() { + final NamedScan scan = createSimpleScan(); + final FieldReference ref = builder.fieldReference(scan, 0); + assertNotNull(ref); + + final List refs = builder.fieldReferences(scan, 0, 1); + assertEquals(2, refs.size()); + } + + @Test + void testCast() { + final Expression.I32Literal input = builder.i32(1); + final Expression cast = builder.cast(input, R.I64); + assertNotNull(cast); + assertTrue(cast instanceof Expression.Cast); + } + } + + @Nested + @DisplayName("Relation Building Tests") + class RelationTests { + + @Test + void testNamedScan() { + final NamedScan scan = + builder.namedScan(List.of("table"), List.of("c1", "c2"), List.of(R.I32, R.STRING)); + assertNotNull(scan); + assertEquals(List.of("table"), scan.getNames()); + } + + @Test + void testProject() { + final NamedScan scan = createSimpleScan(); + final Project project = + builder.project(rel -> List.of(builder.i32(1), builder.fieldReference(rel, 0)), scan); + assertNotNull(project); + assertEquals(scan, project.getInput()); + } + + @Test + void testFilter() { + final NamedScan scan = createSimpleScan(); + final Filter filter = + builder.filter( + rel -> builder.equal(builder.fieldReference(rel, 0), builder.i32(10)), scan); + assertNotNull(filter); + assertNotNull(filter.getCondition()); + } + + @Test + void testInnerJoin() { + final NamedScan left = createSimpleScan(); + final NamedScan right = createSimpleScan(); + final Join join = + builder.innerJoin( + inputs -> + builder.equal( + builder.fieldReference(inputs.left(), 0), + builder.fieldReference(inputs.right(), 0)), + left, + right); + assertNotNull(join); + assertEquals(Join.JoinType.INNER, join.getJoinType()); + } + + @Test + void testFetchAndLimit() { + final NamedScan scan = createSimpleScan(); + Fetch limit = builder.limit(10, scan); + assertEquals(10, limit.getCount().getAsLong()); + assertEquals(0, limit.getOffset()); + limit = builder.limit(10, Remap.of(Arrays.asList(new Integer[] {0, 1})), scan); + assertNotNull(limit); + + Fetch offset = builder.offset(5, scan); + assertEquals(5, offset.getOffset()); + + offset = builder.offset(5, Remap.of(Arrays.asList(new Integer[] {0, 1})), scan); + assertEquals(5, offset.getOffset()); + } + + @Test + void testSort() { + final NamedScan scan = createSimpleScan(); + Sort sort = builder.sort(rel -> builder.sortFields(rel, 0), scan); + assertNotNull(sort); + sort = + builder.sort( + rel -> builder.sortFields(rel, 0), + Remap.of(Arrays.asList(new Integer[] {0, 1})), + scan); + assertNotNull(sort); + } + + @Test + void testCross() { + final NamedScan left = createSimpleScan(); + final NamedScan right = createSimpleScan(); + + Cross cross = builder.cross(left, right); + assertNotNull(cross); + + cross = builder.cross(left, right, Remap.of(Arrays.asList(new Integer[] {0, 1}))); + assertNotNull(cross); + } + + @Test + void testFetch() { + final NamedScan left = createSimpleScan(); + Fetch fetch = builder.fetch(0, 1, left); + assertNotNull(fetch); + + fetch = builder.fetch(0, 1, Remap.of(Arrays.asList(new Integer[] {0, 1})), left); + assertNotNull(fetch); + } + } + + @Nested + @DisplayName("Aggregate and Scalar Function Tests") + class FunctionTests { + + @Test + void testAggregateMeasures() { + final NamedScan scan = createSimpleScan(); + final Aggregate.Measure count = builder.count(scan, 0); + final Aggregate.Measure sum = builder.sum(builder.fieldReference(scan, 0)); + + assertNotNull(count); + assertNotNull(sum); + + final Aggregate.Grouping grouping = builder.grouping(scan, 1); + assertNotNull(grouping); + assertEquals(1, grouping.getExpressions().size()); + + final AggregateFunctionInvocation afi1 = + builder.aggregateFn( + "extension:io.substrait:functions_aggregate_generic", + "count:any", + Type.I32.builder().nullable(false).build(), + builder.fieldReference(scan, 0)); + assertNotNull(afi1); + } + + @Test + void testArithmeticFunctions() { + final Expression left = builder.i32(10); + final Expression right = builder.i32(20); + + assertNotNull(builder.add(left, right)); + assertNotNull(builder.subtract(left, right)); + assertNotNull(builder.multiply(left, right)); + assertNotNull(builder.divide(left, right)); + assertNotNull(builder.negate(left)); + } + + @Test + void testBooleanLogic() { + final Expression b1 = builder.bool(true); + final Expression b2 = builder.bool(false); + + assertNotNull(builder.and(b1, b2)); + assertNotNull(builder.or(b1, b2)); + assertNotNull(builder.not(b1)); + assertNotNull(builder.isNull(b1)); + } + } + + @Nested + @DisplayName("Write and Plan Tests") + class PlanTests { + + @Test + void testNamedWrite() { + final NamedScan scan = createSimpleScan(); + final NamedWrite write = + builder.namedWrite( + List.of("target_table"), + List.of("c1", "c2"), + AbstractWriteRel.WriteOp.INSERT, + AbstractWriteRel.CreateMode.APPEND_IF_EXISTS, + AbstractWriteRel.OutputMode.MODIFIED_RECORDS, + scan); + assertNotNull(write); + } + + @Test + void testPlanCreation() { + final NamedScan scan = createSimpleScan(); + final Plan.Root root = builder.root(scan, List.of("out1", "out2")); + final Plan plan = builder.plan(root); + + assertNotNull(plan); + assertEquals(1, plan.getRoots().size()); + } + } + + // Helper method to create a base relation for testing + private NamedScan createSimpleScan() { + return builder.namedScan(List.of("t"), List.of("a", "b"), List.of(R.I32, R.STRING)); + } +}