Skip to content

Commit d2a1f10

Browse files
committed
fix(Predicate): 避免对引用类型发生double quote
1 parent a8e7364 commit d2a1f10

File tree

4 files changed

+160
-8
lines changed

4 files changed

+160
-8
lines changed

spark-connector/common/pom.xml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,20 @@
1717
<groupId>org.apache.arrow</groupId>
1818
<artifactId>arrow-vector</artifactId>
1919
<version>${arrow.version}</version>
20+
<exclusions>
21+
<exclusion>
22+
<groupId>com.fasterxml.jackson.core</groupId>
23+
<artifactId>jackson-core</artifactId>
24+
</exclusion>
25+
<exclusion>
26+
<groupId>com.fasterxml.jackson.core</groupId>
27+
<artifactId>jackson-databind</artifactId>
28+
</exclusion>
29+
<exclusion>
30+
<groupId>com.fasterxml.jackson.core</groupId>
31+
<artifactId>jackson-annotations</artifactId>
32+
</exclusion>
33+
</exclusions>
2034
</dependency>
2135
<dependency>
2236
<groupId>org.apache.arrow</groupId>
@@ -109,6 +123,18 @@
109123
<groupId>commons-lang</groupId>
110124
<artifactId>commons-lang</artifactId>
111125
</exclusion>
126+
<exclusion>
127+
<groupId>com.fasterxml.jackson.core</groupId>
128+
<artifactId>jackson-core</artifactId>
129+
</exclusion>
130+
<exclusion>
131+
<groupId>com.fasterxml.jackson.core</groupId>
132+
<artifactId>jackson-databind</artifactId>
133+
</exclusion>
134+
<exclusion>
135+
<groupId>com.fasterxml.jackson.core</groupId>
136+
<artifactId>jackson-annotations</artifactId>
137+
</exclusion>
112138
</exclusions>
113139
</dependency>
114140
<dependency>

spark-connector/common/src/main/scala/org/apache/spark/sql/odps/ExecutionUtils.scala

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -124,18 +124,31 @@ object ExecutionUtils {
124124
}
125125

126126
private def convertToOdpsPredicate(filter: Filter): Predicate = filter match {
127-
case EqualTo(attribute, value) => BinaryPredicate.equals(Attribute.of(attribute), Constant.of(value))
128-
case GreaterThan(attribute, value) => BinaryPredicate.greaterThan(Attribute.of(attribute), Constant.of(value))
129-
case GreaterThanOrEqual(attribute, value) => BinaryPredicate.greaterThanOrEqual(Attribute.of(attribute), Constant.of(value))
130-
case LessThan(attribute, value) => BinaryPredicate.lessThan(Attribute.of(attribute), Constant.of(value))
131-
case LessThanOrEqual(attribute, value) => BinaryPredicate.lessThanOrEqual(Attribute.of(attribute), Constant.of(value))
132-
case In(attribute, values) => InPredicate.in(Attribute.of(attribute), values.map(Constant.of).toList.asJava.asInstanceOf[java.util.List[java.io.Serializable]])
133-
case IsNull(attribute) => UnaryPredicate.isNull(Attribute.of(attribute))
134-
case IsNotNull(attribute) => UnaryPredicate.notNull(Attribute.of(attribute))
127+
case EqualTo(attribute, value) => BinaryPredicate.equals(quoteAttribute(attribute), Constant.of(value))
128+
case GreaterThan(attribute, value) => BinaryPredicate.greaterThan(quoteAttribute(attribute), Constant.of(value))
129+
case GreaterThanOrEqual(attribute, value) => BinaryPredicate.greaterThanOrEqual(quoteAttribute(attribute), Constant.of(value))
130+
case LessThan(attribute, value) => BinaryPredicate.lessThan(quoteAttribute(attribute), Constant.of(value))
131+
case LessThanOrEqual(attribute, value) => BinaryPredicate.lessThanOrEqual(quoteAttribute(attribute), Constant.of(value))
132+
case In(attribute, values) => InPredicate.in(quoteAttribute(attribute), values.map(Constant.of).toList.asJava.asInstanceOf[java.util.List[java.io.Serializable]])
133+
case IsNull(attribute) => UnaryPredicate.isNull(quoteAttribute(attribute))
134+
case IsNotNull(attribute) => UnaryPredicate.notNull(quoteAttribute(attribute))
135135
case And(left, right) => CompoundPredicate.and(convertToOdpsPredicate(left), convertToOdpsPredicate(right))
136136
case Or(left, right) => CompoundPredicate.or(convertToOdpsPredicate(left), convertToOdpsPredicate(right))
137137
case Not(child) => CompoundPredicate.not(convertToOdpsPredicate(child))
138138
case _ =>
139139
throw new UnsupportedOperationException(s"Unsupported filter: $filter")
140140
}
141+
142+
143+
/**
144+
* all Attribute will quote by [[org.apache.spark.sql.catalyst.util#quoteIfNeeded]],
145+
* so here we need to determine whether the attribute has already been quoted
146+
*/
147+
def quoteAttribute(value: String): String = {
148+
if (value.startsWith("`") && value.endsWith("`") && value.length() > 1) {
149+
value
150+
} else {
151+
s"`${value.replace("`", "``")}`"
152+
}
153+
}
141154
}

spark-connector/common/src/test/scala/org/apache/spark/sql/odps/ExecutionUtilsSuite.scala

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package org.apache.spark.sql.odps
22

3+
import org.apache.spark.sql.catalyst.util.quoteIfNeeded
34
import org.apache.spark.sql.sources.{And, StringStartsWith, _}
45
import org.scalatest.funsuite.AnyFunSuite
56

@@ -136,5 +137,39 @@ class ExecutionUtilsSuite extends AnyFunSuite {
136137

137138
assert(result.toString === "(`column3` > 10 or `column4` is not null)")
138139
}
140+
141+
test("quoteAttribute with quoteIfNeeded") {
142+
assert(ExecutionUtils.quoteAttribute(quoteIfNeeded("")) === "``")
143+
assert(ExecutionUtils.quoteAttribute(quoteIfNeeded("``")) === "``````")
144+
assert(ExecutionUtils.quoteAttribute(quoteIfNeeded("`")) === "````")
145+
assert(ExecutionUtils.quoteAttribute(quoteIfNeeded("ab")) === "`ab`")
146+
assert(ExecutionUtils.quoteAttribute(quoteIfNeeded("a b")) === "`a b`")
147+
assert(ExecutionUtils.quoteAttribute(quoteIfNeeded("a*b")) === "`a*b`")
148+
assert(ExecutionUtils.quoteAttribute(quoteIfNeeded("123")) === "`123`")
149+
assert(ExecutionUtils.quoteAttribute(quoteIfNeeded("1a")) === "`1a`")
150+
assert(ExecutionUtils.quoteAttribute(quoteIfNeeded("`1a`")) === "```1a```")
151+
assert(ExecutionUtils.quoteAttribute(quoteIfNeeded("`_")) === "```_`")
152+
assert(ExecutionUtils.quoteAttribute(quoteIfNeeded("`_`")) === "```_```")
153+
assert(ExecutionUtils.quoteAttribute(quoteIfNeeded("你好")) === "`你好`")
154+
assert(ExecutionUtils.quoteAttribute(quoteIfNeeded("你`好")) === "`你``好`")
155+
assert(ExecutionUtils.quoteAttribute(quoteIfNeeded("你``好")) === "`你````好`")
156+
assert(ExecutionUtils.quoteAttribute(quoteIfNeeded("`你好")) === "```你好`")
157+
assert(ExecutionUtils.quoteAttribute(quoteIfNeeded("你好`")) === "`你好```")
158+
assert(ExecutionUtils.quoteAttribute(quoteIfNeeded("`你好`")) === "```你好```")
159+
}
160+
161+
test("QuoteChines") {
162+
val greaterThan = GreaterThan("`你好`", 2)
163+
val result = ExecutionUtils.convertToOdpsPredicate(Seq(greaterThan))
164+
println(result.toString)
165+
assert(result.toString == "`你好` > 2")
166+
}
167+
168+
test("QuoteSpecialCharacter") {
169+
val greaterThan = GreaterThan("你`好", 2)
170+
val result = ExecutionUtils.convertToOdpsPredicate(Seq(greaterThan))
171+
println(result.toString)
172+
assert(result.toString == "`你``好` > 2")
173+
}
139174
}
140175

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
package org.apache.spark.sql.execution.datasources.v2.odps
2+
3+
import com.aliyun.odps.{Column, Odps, OdpsType, TableSchema}
4+
import com.aliyun.odps.account.AliyunAccount
5+
import org.apache.spark.internal.Logging
6+
import org.apache.spark.sql.SparkSession
7+
import org.apache.spark.sql.connector.catalog.Identifier
8+
import org.scalatest.funsuite.AnyFunSuite
9+
10+
import java.util
11+
12+
class SQLQuerySuite extends AnyFunSuite with Logging {
13+
14+
private val project: String = ""
15+
private val accessId: String = ""
16+
private val accessKey: String = ""
17+
private val endPoint: String = ""
18+
19+
private val table: Identifier = Identifier.of(Array(project), "testTable")
20+
21+
private def sparkSession: SparkSession = SparkSession.builder()
22+
.master("local[2]")
23+
.config("spark.hadoop.odps.access.id", accessId)
24+
.config("spark.hadoop.odps.access.key", accessKey)
25+
.config("spark.hadoop.odps.end.point", endPoint)
26+
.config("spark.hadoop.odps.project.name", project)
27+
.config("spark.sql.catalog.odps", "org.apache.spark.sql.execution.datasources.v2.odps.OdpsTableCatalog")
28+
.config("spark.sql.extensions", "org.apache.spark.sql.execution.datasources.v2.odps.extension.OdpsExtensions")
29+
.config("spark.sql.defaultCatalog", "odps")
30+
.config( "spark.sql.sources.partitionOverwriteMode", "dynamic")
31+
.getOrCreate()
32+
33+
private def odps: Odps = {
34+
val odps = new Odps(new AliyunAccount(accessId, accessKey))
35+
odps.setDefaultProject(project)
36+
odps.setEndpoint(endPoint)
37+
odps
38+
}
39+
40+
test("filterPushDownColumnNames") {
41+
sparkSession.conf.set("spark.sql.catalog.odps.enableFilterPushDown", true)
42+
43+
val tableSchema = new TableSchema
44+
val columns = new util.ArrayList[Column]
45+
columns.add(new Column("c0", OdpsType.BIGINT))
46+
columns.add(new Column("c1", OdpsType.BIGINT))
47+
columns.add(new Column("列2", OdpsType.BIGINT))
48+
columns.add(new Column("列3", OdpsType.BIGINT))
49+
columns.add(new Column("44", OdpsType.BIGINT))
50+
columns.add(new Column("5列", OdpsType.BIGINT))
51+
columns.add(new Column("列六", OdpsType.BIGINT))
52+
columns.add(new Column("'列七'", OdpsType.BIGINT))
53+
54+
tableSchema.setColumns(columns)
55+
createTable(table.name(), tableSchema)
56+
57+
sparkSession.sql(s"insert overwrite table ${table.name()} values (0,1,2,3,4,5,6,7), (1,2,3,4,5,6,7,8)")
58+
val result = sparkSession.sql("select * from testTable where c0 = 0 " +
59+
"and `列2` = 2 and `44` = 4 and `列六` = 6 and `'列七'` = 7").collect()
60+
61+
assert(result.length == 1)
62+
assert(result(0).get(0) == 0)
63+
assert(result(0).get(1) == 1)
64+
assert(result(0).get(2) == 2)
65+
assert(result(0).get(3) == 3)
66+
assert(result(0).get(4) == 4)
67+
assert(result(0).get(5) == 5)
68+
assert(result(0).get(6) == 6)
69+
assert(result(0).get(7) == 7)
70+
}
71+
72+
private def createTable(taleName: String, tableSchema: TableSchema): Unit = {
73+
if (odps.tables().exists(taleName)) {
74+
odps.tables().delete(taleName)
75+
}
76+
odps.tables().create(taleName, tableSchema)
77+
}
78+
}

0 commit comments

Comments
 (0)