Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,13 @@ public class IoTDBArithmeticIT {
"CREATE TIMESERIES root.sg.d1.s6 WITH DATATYPE=TEXT, ENCODING=PLAIN",
"CREATE TIMESERIES root.sg.d1.s7 WITH DATATYPE=INT32, ENCODING=PLAIN",
"CREATE TIMESERIES root.sg.d1.s8 WITH DATATYPE=INT32, ENCODING=PLAIN",
"insert into root.sg.d1(time, s1, s2, s3, s4, s5, s6, s7) values (1, 1, 1, 1, 1, false, '1', 1)",
"insert into root.sg.d1(time, s1, s2, s3, s4, s5, s6, s8) values (2, 2, 2, 2, 2, false, '2', 2)",
"insert into root.sg.d1(time, s1, s2, s3, s4, s5, s6, s7) values (3, 3, 3, 3, 3, true, '3', 3)",
"insert into root.sg.d1(time, s1, s2, s3, s4, s5, s6, s8) values (4, 4, 4, 4, 4, true, '4', 4)",
"insert into root.sg.d1(time, s1, s2, s3, s4, s5, s6, s7, s8) values (5, 5, 5, 5, 5, true, '5', 5, 5)",
"CREATE TIMESERIES root.sg.d1.s9 WITH DATATYPE=DATE, ENCODING=PLAIN",
"CREATE TIMESERIES root.sg.d1.s10 WITH DATATYPE=TIMESTAMP, ENCODING=PLAIN",
"insert into root.sg.d1(time, s1, s2, s3, s4, s5, s6, s7, s9, s10) values (1, 1, 1, 1, 1, false, '1', 1, '2024-01-01', 10)",
"insert into root.sg.d1(time, s1, s2, s3, s4, s5, s6, s8, s9, s10) values (2, 2, 2, 2, 2, false, '2', 2, '2024-02-01', 20)",
"insert into root.sg.d1(time, s1, s2, s3, s4, s5, s6, s7, s9, s10) values (3, 3, 3, 3, 3, true, '3', 3, '2024-03-01', 30)",
"insert into root.sg.d1(time, s1, s2, s3, s4, s5, s6, s8, s9, s10) values (4, 4, 4, 4, 4, true, '4', 4, '2024-04-01', 40)",
"insert into root.sg.d1(time, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10) values (5, 5, 5, 5, 5, true, '5', 5, 5, '2024-05-01', 50)",
};

@BeforeClass
Expand Down Expand Up @@ -151,6 +153,44 @@ public void testArithmeticUnary() {
}
}

@Test
public void testTimestampNegation() {
try (Connection connection = EnvFactory.getEnv().getConnection();
Statement statement = connection.createStatement()) {
ResultSet resultSet = statement.executeQuery("select -s10 from root.sg.d1");

String[] expected = {
"1969-12-31T23:59:59.990Z",
"1969-12-31T23:59:59.980Z",
"1969-12-31T23:59:59.970Z",
"1969-12-31T23:59:59.960Z",
"1969-12-31T23:59:59.950Z"
};

for (String expectedValue : expected) {
resultSet.next();
assertEquals(expectedValue, resultSet.getString(2));
}
} catch (SQLException throwable) {
fail(throwable.getMessage());
}
}

@Test
public void testUnaryWrongType() {
try (Connection connection = EnvFactory.getEnv().getConnection();
Statement statement = connection.createStatement()) {
tsAssertTestFail(
statement, "select -s5 from root.sg.d1", "Invalid input expression data type");
tsAssertTestFail(
statement, "select -s6 from root.sg.d1", "Invalid input expression data type");
tsAssertTestFail(
statement, "select -s9 from root.sg.d1", "Invalid input expression data type");
} catch (SQLException throwable) {
fail(throwable.getMessage());
}
}

@Test
public void testHybridQuery() {
try (Connection connection = EnvFactory.getEnv().getConnection();
Expand Down Expand Up @@ -241,4 +281,194 @@ public void testNot() {
fail();
}
}

@Test
public void testDateAndTimestampArithmetic() {
try (Connection connection = EnvFactory.getEnv().getConnection();
Statement statement = connection.createStatement()) {

String sql =
"select s1+s9,s1+s10,s2+s9,s2+s10,s9+s1,s9+s2,s10+s1,s10+s2,s9-s1,s9-s2,s10-s1,s10-s2 from root.sg.d1";
ResultSet resultSet = statement.executeQuery(sql);

int[][] expectedResults = {
{20240101, 11, 20240101, 11, 20240101, 20240101, 11, 11, 20231231, 20231231, 9, 9},
{20240201, 22, 20240201, 22, 20240201, 20240201, 22, 22, 20240131, 20240131, 18, 18},
{20240301, 33, 20240301, 33, 20240301, 20240301, 33, 33, 20240229, 20240229, 27, 27},
{20240401, 44, 20240401, 44, 20240401, 20240401, 44, 44, 20240331, 20240331, 36, 36},
{20240501, 55, 20240501, 55, 20240501, 20240501, 55, 55, 20240430, 20240430, 45, 45}
};

for (int[] expectedResult : expectedResults) {
resultSet.next();
for (int i = 0; i < expectedResult.length; i++) {
assertEquals(expectedResult[i], resultSet.getInt(i + 2));
}
}
} catch (SQLException throwable) {
fail(throwable.getMessage());
}
}

@Test
public void testDivisionByZero() {
try (Connection connection = EnvFactory.getEnv().getConnection();
Statement statement = connection.createStatement()) {
tsAssertTestFail(statement, "select s1/0 from root.sg.d1", "Division by zero");
tsAssertTestFail(statement, "select s2/0 from root.sg.d1", "Division by zero");
tsAssertTestFail(statement, "select s1%0 from root.sg.d1", "Division by zero");
tsAssertTestFail(statement, "select s2%0 from root.sg.d1", "Division by zero");
} catch (SQLException throwable) {
fail(throwable.getMessage());
}
}

@Test
public void testFloatDivisionByZero() {
try (Connection connection = EnvFactory.getEnv().getConnection();
Statement statement = connection.createStatement()) {
ResultSet resultSet =
statement.executeQuery("select s3/0.0,0.0/s3,0.0/-s3,-s3/0.0 from root.sg.d1");

String[] expected = {"Infinity", "0.0", "-0.0", "-Infinity"};

for (int i = 0; i < 5; i++) {
resultSet.next();
for (int j = 0; j < expected.length; j++) {
assertEquals(expected[j], resultSet.getString(j + 2));
}
}
} catch (SQLException throwable) {
fail(throwable.getMessage());
}
}

@Test
public void testDoubleModuloByZero() {
try (Connection connection = EnvFactory.getEnv().getConnection();
Statement statement = connection.createStatement()) {
ResultSet resultSet =
statement.executeQuery("select s4%0.0,0.0%s4,0.0%-s4,-s4%0.0 from root.sg.d1");

String[] expected = {"NaN", "0.0", "0.0", "NaN"};

for (int i = 0; i < 5; i++) {
resultSet.next();
for (int j = 0; j < expected.length; j++) {
assertEquals(expected[j], resultSet.getString(j + 2));
}
}
} catch (SQLException throwable) {
fail(throwable.getMessage());
}
}

@Test
public void testBinaryWrongType() {
try (Connection connection = EnvFactory.getEnv().getConnection();
Statement statement = connection.createStatement()) {
tsAssertTestFail(statement, "select s9 * s1 from root.sg.d1", "Invalid");
tsAssertTestFail(statement, "select s9 / s1 from root.sg.d1", "Invalid");
tsAssertTestFail(statement, "select s9 % s1 from root.sg.d1", "Invalid");
tsAssertTestFail(statement, "select s10 * s1 from root.sg.d1", "Invalid");
tsAssertTestFail(statement, "select s10 / s1 from root.sg.d1", "Invalid");
tsAssertTestFail(statement, "select s10 % s1 from root.sg.d1", "Invalid");
} catch (SQLException throwable) {
fail(throwable.getMessage());
}
}

@Test
public void testOverflow() {
try (Connection connection = EnvFactory.getEnv().getConnection();
Statement statement = connection.createStatement()) {

String[][] timeseries = {
{"s1", "INT32"},
{"s2", "INT64"},
{"s3", "INT64"},
{"s7", "INT32"},
{"s8", "INT32"},
{"s10", "TIMESTAMP"}
};
for (String[] ts : timeseries) {
statement.execute(
String.format(
"CREATE TIMESERIES root.sg.d2.%s WITH DATATYPE=%s, ENCODING=PLAIN", ts[0], ts[1]));
}

statement.execute(
String.format(
"insert into root.sg.d2(time, s1, s2, s3, s7, s8, s10) values (1, %d, %d, %d, %d, %d, %d)",
Integer.MAX_VALUE / 2 + 1,
Long.MAX_VALUE / 2 + 1,
Long.MIN_VALUE / 2 - 1,
Integer.MAX_VALUE / 2 + 1,
Integer.MIN_VALUE / 2 - 1,
Long.MAX_VALUE / 2 + 1));
statement.execute(
String.format(
"insert into root.sg.d2(time, s1, s2, s7, s8) values (2, %d, %d, %d, %d)",
Integer.MIN_VALUE / 2 - 1,
Long.MIN_VALUE / 2 - 1,
Integer.MIN_VALUE / 2 - 1,
Integer.MAX_VALUE / 2 + 1));

tsAssertTestFail(
statement, "select s1+s7 from root.sg.d2 where time=1", "int Addition overflow");
tsAssertTestFail(
statement, "select s1-s8 from root.sg.d2 where time=2", "int Subtraction overflow");
tsAssertTestFail(
statement, "select s1*s7 from root.sg.d2 where time=1", "int Multiplication overflow");

tsAssertTestFail(
statement, "select s2+s2 from root.sg.d2 where time=1", "long Addition overflow");
tsAssertTestFail(
statement, "select s3-s2 from root.sg.d2 where time=1", "long Subtraction overflow");

tsAssertTestFail(
statement,
String.format("select s10+%d from root.sg.d2 where time=1", Long.MAX_VALUE),
"long Addition overflow");
tsAssertTestFail(
statement,
String.format("select s10-(%d) from root.sg.d2 where time=1", Long.MIN_VALUE),
"long Subtraction overflow");

} catch (SQLException throwable) {
fail(throwable.getMessage());
}
}

@Test
public void testDateOutOfRange() {
try (Connection connection = EnvFactory.getEnv().getConnection();
Statement statement = connection.createStatement()) {
statement.execute("CREATE TIMESERIES root.sg.d3.date WITH DATATYPE=DATE, ENCODING=PLAIN");
statement.execute("insert into root.sg.d3(time, date) values (1, '9999-12-31')");
statement.execute("insert into root.sg.d3(time, date) values (2, '1000-01-01')");

tsAssertTestFail(
statement,
"select date + 86400000 from root.sg.d3 where time = 1",
"Year must be between 1000 and 9999");
tsAssertTestFail(
statement,
"select date - 86400000 from root.sg.d3 where time = 2",
"Year must be between 1000 and 9999");
} catch (SQLException throwable) {
fail(throwable.getMessage());
}
}

private void tsAssertTestFail(Statement statement, String sql, String expectedErrorMsg) {
try {
statement.executeQuery(sql);
fail("Expected exception with message: " + expectedErrorMsg);
} catch (SQLException e) {
assertTrue(
"Expected error message '" + expectedErrorMsg + "' but got: " + e.getMessage(),
e.getMessage().contains(expectedErrorMsg));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@
import org.apache.iotdb.db.queryengine.plan.expression.unary.NegationExpression;
import org.apache.iotdb.db.queryengine.plan.expression.unary.RegularExpression;
import org.apache.iotdb.db.queryengine.plan.expression.visitor.ExpressionVisitor;
import org.apache.iotdb.db.queryengine.plan.relational.function.arithmetic.AdditionResolver;
import org.apache.iotdb.db.queryengine.plan.relational.function.arithmetic.DivisionResolver;
import org.apache.iotdb.db.queryengine.plan.relational.function.arithmetic.ModulusResolver;
import org.apache.iotdb.db.queryengine.plan.relational.function.arithmetic.MultiplicationResolver;
import org.apache.iotdb.db.queryengine.plan.relational.function.arithmetic.SubtractionResolver;
import org.apache.iotdb.db.queryengine.transformation.dag.udf.UDAFInformationInferrer;
import org.apache.iotdb.db.queryengine.transformation.dag.udf.UDTFInformationInferrer;
import org.apache.iotdb.db.utils.TypeInferenceUtils;
Expand All @@ -54,6 +59,7 @@
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -217,46 +223,61 @@ public TSDataType visitLogicNotExpression(
@Override
public TSDataType visitNegationExpression(
NegationExpression negationExpression, Function<String, TSDataType> context) {
TSDataType inputExpressionType = process(negationExpression.getExpression(), context);
checkInputExpressionDataType(
negationExpression.getExpression().getExpressionString(),
inputExpressionType,
TSDataType.INT32,
TSDataType.INT64,
TSDataType.FLOAT,
TSDataType.DOUBLE);
return setExpressionType(negationExpression, inputExpressionType);
TSDataType operandType = process(negationExpression.getExpression(), context);

if (operandType != TSDataType.INT32
&& operandType != TSDataType.INT64
&& operandType != TSDataType.FLOAT
&& operandType != TSDataType.DOUBLE
&& operandType != TSDataType.TIMESTAMP) {
throw new SemanticException(
String.format(
"Invalid input expression data type. Do not support %s operation for %s.",
ExpressionType.NEGATION, operandType));
}

return setExpressionType(negationExpression, operandType);
}

@Override
public TSDataType visitArithmeticBinaryExpression(
ArithmeticBinaryExpression arithmeticBinaryExpression,
Function<String, TSDataType> context) {
checkInputExpressionDataType(
arithmeticBinaryExpression.getLeftExpression().getExpressionString(),
process(arithmeticBinaryExpression.getLeftExpression(), context),
TSDataType.INT32,
TSDataType.INT64,
TSDataType.FLOAT,
TSDataType.DOUBLE);
checkInputExpressionDataType(
arithmeticBinaryExpression.getRightExpression().getExpressionString(),
process(arithmeticBinaryExpression.getRightExpression(), context),
TSDataType.INT32,
TSDataType.INT64,
TSDataType.FLOAT,
TSDataType.DOUBLE);
if ((arithmeticBinaryExpression.getExpressionType() == ExpressionType.DIVISION
|| arithmeticBinaryExpression.getExpressionType() == ExpressionType.MODULO)
&& isExpressionDataTypeSatisfy(
arithmeticBinaryExpression.getLeftExpression(), TSDataType.INT64, TSDataType.INT32)
&& isExpressionDataTypeSatisfy(
arithmeticBinaryExpression.getRightExpression(),
TSDataType.INT64,
TSDataType.INT32)) {
return setExpressionType(arithmeticBinaryExpression, TSDataType.INT64);
TSDataType leftType = process(arithmeticBinaryExpression.getLeftExpression(), context);
TSDataType rightType = process(arithmeticBinaryExpression.getRightExpression(), context);

ExpressionType operatorType = arithmeticBinaryExpression.getExpressionType();

Optional<TSDataType> resultTypeOpt;
switch (operatorType) {
case ADDITION:
resultTypeOpt = AdditionResolver.inferType(leftType, rightType);
break;
case SUBTRACTION:
resultTypeOpt = SubtractionResolver.inferType(leftType, rightType);
break;
case MULTIPLICATION:
resultTypeOpt = MultiplicationResolver.inferType(leftType, rightType);
break;
case DIVISION:
resultTypeOpt = DivisionResolver.inferType(leftType, rightType);
break;
case MODULO:
resultTypeOpt = ModulusResolver.inferType(leftType, rightType);
break;
default:
resultTypeOpt = Optional.empty();
break;
}
return setExpressionType(arithmeticBinaryExpression, TSDataType.DOUBLE);

if (!resultTypeOpt.isPresent()) {
throw new SemanticException(
String.format(
"Invalid input expression data type. Do not support %s operation for %s and %s.",
operatorType, leftType, rightType));
}

return setExpressionType(arithmeticBinaryExpression, resultTypeOpt.get());
}

@Override
Expand Down
Loading