Skip to content

Commit 5d5e071

Browse files
committed
use get partition spec for partition filter
1 parent f2f99ab commit 5d5e071

File tree

4 files changed

+326
-9
lines changed

4 files changed

+326
-9
lines changed

spark-connector/common/pom.xml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -142,12 +142,6 @@
142142
<version>1.15.3</version>
143143
<scope>test</scope>
144144
</dependency>
145-
<dependency>
146-
<groupId>org.scalatest</groupId>
147-
<artifactId>scalatest-funsuite_${scala.binary.version}</artifactId>
148-
<version>3.2.19</version>
149-
<scope>test</scope>
150-
</dependency>
151145
</dependencies>
152146

153147
</project>

spark-connector/datasource/src/main/scala/org/apache/spark/sql/execution/datasources/v2/odps/OdpsTableCatalog.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ class OdpsTableCatalog extends TableCatalog with SupportsNamespaces with SQLConf
386386
val sdkTable = metaClient.getSdkTable(project, odpsSchema, table)
387387
val partitionSchema = getPartitionSchema(sdkTable)
388388
val partitionSpecs =
389-
sdkTable.getPartitions.asScala.map(p => convertToTablePartitionSpec(p.getPartitionSpec))
389+
sdkTable.getPartitionSpecs.asScala.map(p => convertToTablePartitionSpec(p))
390390

391391
val prunedPartitions = if (filters.nonEmpty) {
392392
val predicate = new PartitionFilters(filters, partitionSchema).toPredicate
Lines changed: 308 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,308 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.spark.sql.execution.datasources.v2.odps
20+
21+
import com.aliyun.odps.`type`.TypeInfoFactory
22+
import com.aliyun.odps.account.AliyunAccount
23+
import com.aliyun.odps.task.SQLTask
24+
import com.aliyun.odps.{Column, Odps, OdpsType, TableSchema}
25+
import org.apache.spark.internal.Logging
26+
import org.apache.spark.sql.{SparkSession, sources}
27+
import org.apache.spark.sql.catalyst.InternalRow
28+
import org.apache.spark.sql.connector.catalog.{Identifier, SupportsPartitionManagement}
29+
import org.apache.spark.sql.connector.expressions.{LogicalExpressions, NamedReference, Transform}
30+
import org.apache.spark.sql.sources.Filter
31+
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
32+
import org.apache.spark.sql.util.CaseInsensitiveStringMap
33+
import org.apache.spark.unsafe.types.UTF8String
34+
import org.scalatest.funsuite.AnyFunSuite
35+
36+
import java.util
37+
38+
class OdpsTableSuite extends AnyFunSuite with Logging {
39+
40+
private val project: String = ""
41+
private val accessId: String = ""
42+
private val accessKey: String = ""
43+
private val endPoint: String = ""
44+
45+
private val partTable: Identifier = Identifier.of(Array(project), "partTable")
46+
private val multiPartTable: Identifier = Identifier.of(Array(project), "multiPartTable")
47+
private val allTypePartTable: Identifier = Identifier.of(Array(project), "allTypePartTable")
48+
49+
private def sparkSession: SparkSession = SparkSession.builder()
50+
.master("local[2]")
51+
.config("spark.hadoop.odps.access.id", accessId)
52+
.config("spark.hadoop.odps.access.key", accessKey)
53+
.config("spark.hadoop.odps.end.point", endPoint)
54+
.config("spark.hadoop.odps.project.name", project)
55+
.getOrCreate()
56+
57+
private def odps: Odps = {
58+
val odps = new Odps(new AliyunAccount(accessId, accessKey))
59+
odps.setDefaultProject(project)
60+
odps.setEndpoint(endPoint)
61+
odps
62+
}
63+
64+
private val catalog: OdpsTableCatalog = {
65+
SparkSession.setDefaultSession(sparkSession)
66+
67+
val newCatalog = new OdpsTableCatalog
68+
newCatalog.initialize("odps", CaseInsensitiveStringMap.empty())
69+
70+
newCatalog
71+
}
72+
73+
test("listPartitionIdentifiers") {
74+
val partTable = createPartTable()
75+
76+
assert(!hasPartitions(partTable))
77+
78+
val partIdent = InternalRow.apply(UTF8String.fromString("3"))
79+
partTable.createPartition(partIdent, new util.HashMap[String, String]())
80+
assert(partTable.listPartitionIdentifiers(Array.empty, InternalRow.empty).length == 1)
81+
82+
val partIdent1 = InternalRow.apply(UTF8String.fromString("4"))
83+
partTable.createPartition(partIdent1, new util.HashMap[String, String]())
84+
assert(partTable.listPartitionIdentifiers(Array.empty, InternalRow.empty).length == 2)
85+
assert(partTable.listPartitionIdentifiers(Array("dt"), partIdent1).length == 1)
86+
87+
partTable.dropPartition(partIdent)
88+
assert(partTable.listPartitionIdentifiers(Array.empty, InternalRow.empty).length == 1)
89+
partTable.dropPartition(partIdent1)
90+
assert(!hasPartitions(partTable))
91+
}
92+
93+
test("listPartitionsByFilter") {
94+
createMultiPartTable()
95+
96+
Seq(
97+
Array[Filter]() -> Set(
98+
Map("part0" -> "0", "part1" -> "abc"),
99+
Map("part0" -> "0", "part1" -> "def"),
100+
Map("part0" -> "1", "part1" -> "abc")),
101+
102+
Array[Filter](sources.EqualTo("part0", 0)) -> Set(
103+
Map("part0" -> "0", "part1" -> "abc"),
104+
Map("part0" -> "0", "part1" -> "def")),
105+
106+
Array[Filter](sources.EqualTo("part1", "abc")) -> Set(
107+
Map("part0" -> "0", "part1" -> "abc"),
108+
Map("part0" -> "1", "part1" -> "abc")),
109+
110+
Array[Filter](sources.And(
111+
sources.LessThanOrEqual("part0", 1),
112+
sources.GreaterThanOrEqual("part0", 0))) -> Set(
113+
Map("part0" -> "0", "part1" -> "abc"),
114+
Map("part0" -> "0", "part1" -> "def"),
115+
Map("part0" -> "1", "part1" -> "abc")),
116+
117+
Array[Filter](sources.GreaterThanOrEqual("part0", 1)) -> Set(
118+
Map("part0" -> "1", "part1" -> "abc")),
119+
120+
Array[Filter](sources.GreaterThanOrEqual("part0", 2)) -> Set(),
121+
122+
).foreach { case (filters, expected) =>
123+
assert(catalog.listPartitionsByFilter(multiPartTable, filters).toSet === expected)
124+
}
125+
126+
createPartTableAllType()
127+
Seq(
128+
Array[Filter]() -> Set(
129+
Map("p1" -> "11", "p2" -> "22", "p3" -> "3", "p4" -> "4", "p5" -> "5", "p6" -> "6"),
130+
Map("p1" -> "11 ", "p2" -> "22 ", "p3" -> "3", "p4" -> "4", "p5" -> "5", "p6" -> "6"),
131+
Map("p1" -> "111 ", "p2" -> "222 ","p3" -> "3", "p4" -> "4", "p5" -> "5", "p6" -> "6")),
132+
133+
Array[Filter](sources.StringStartsWith("p1", "11")) -> Set(
134+
Map("p1" -> "11", "p2" -> "22", "p3" -> "3", "p4" -> "4", "p5" -> "5", "p6" -> "6"),
135+
Map("p1" -> "11 ", "p2" -> "22 ", "p3" -> "3", "p4" -> "4", "p5" -> "5", "p6" -> "6"),
136+
Map("p1" -> "111 ", "p2" -> "222 ","p3" -> "3", "p4" -> "4", "p5" -> "5", "p6" -> "6")),
137+
138+
Array[Filter](sources.StringEndsWith("p1", " ")) -> Set(
139+
Map("p1" -> "11 ", "p2" -> "22 ", "p3" -> "3", "p4" -> "4", "p5" -> "5", "p6" -> "6"),
140+
Map("p1" -> "111 ", "p2" -> "222 ","p3" -> "3", "p4" -> "4", "p5" -> "5", "p6" -> "6")),
141+
).foreach { case (filters, expected) =>
142+
assert(catalog.listPartitionsByFilter(allTypePartTable, filters).toSet === expected)
143+
}
144+
}
145+
146+
test("partitionExists") {
147+
val partTable = createMultiPartTable()
148+
149+
assert(partTable.partitionExists(InternalRow(0, UTF8String.fromString("def"))))
150+
assert(!partTable.partitionExists(InternalRow(-1, UTF8String.fromString("def"))))
151+
152+
val errMsg = intercept[ClassCastException] {
153+
partTable.partitionExists(InternalRow(UTF8String.fromString("abc"), UTF8String.fromString("def")))
154+
}.getMessage
155+
assert(errMsg.contains("org.apache.spark.unsafe.types.UTF8String cannot be cast to java.lang.Integer"))
156+
157+
val errMsg2 = intercept[IllegalArgumentException] {
158+
partTable.partitionExists(InternalRow(0))
159+
}.getMessage
160+
assert(errMsg2.contains("The identifier might not refer to one partition"))
161+
}
162+
163+
test("listPartitionByNames") {
164+
val partTable = createMultiPartTable()
165+
166+
Seq(
167+
(Array("part0", "part1"), InternalRow(0, UTF8String.fromString("abc"))) -> Set(InternalRow(0, UTF8String.fromString("abc"))),
168+
169+
(Array("part0"), InternalRow(0)) -> Set(InternalRow(0, UTF8String.fromString("abc")),
170+
InternalRow(0, UTF8String.fromString("def"))),
171+
172+
(Array("part1"), InternalRow(UTF8String.fromString("abc"))) -> Set(InternalRow(0, UTF8String.fromString("abc")),
173+
InternalRow(1, UTF8String.fromString("abc"))),
174+
175+
(Array.empty[String], InternalRow.empty) -> Set(InternalRow(0, UTF8String.fromString("abc")),
176+
InternalRow(0, UTF8String.fromString("def")), InternalRow(1, UTF8String.fromString("abc"))),
177+
178+
(Array("part0", "part1"), InternalRow(3, UTF8String.fromString("xyz"))) -> Set(),
179+
180+
(Array("part1"), InternalRow(UTF8String.fromString("xyz"))) -> Set()
181+
).foreach { case ((names, idents), expected) =>
182+
assert(partTable.listPartitionIdentifiers(names, idents).toSet === expected)
183+
}
184+
185+
// Check invalid parameters
186+
Seq(
187+
(Array("part0", "part1"), InternalRow(0))
188+
).foreach { case (names, idents) =>
189+
intercept[ArrayIndexOutOfBoundsException](partTable.listPartitionIdentifiers(names, idents))
190+
}
191+
192+
Seq(
193+
(Array("col0", "part1"), InternalRow(0, 1)),
194+
(Array("wrong"), InternalRow("invalid"))
195+
).foreach { case (names, idents) =>
196+
intercept[IllegalArgumentException](partTable.listPartitionIdentifiers(names, idents))
197+
}
198+
}
199+
200+
test("listPartitionForAllType") {
201+
val partTable = createPartTableAllType()
202+
203+
Seq(
204+
(Array.empty[String], InternalRow.empty) -> Set(
205+
InternalRow(UTF8String.fromString("11 "), UTF8String.fromString("22 "), 3,4,5,6),
206+
InternalRow(UTF8String.fromString("111 "), UTF8String.fromString("222 "), 3,4,5,6),
207+
InternalRow(UTF8String.fromString("11"), UTF8String.fromString("22"), 3,4,5,6)),
208+
209+
(Array("p1"), InternalRow(UTF8String.fromString("11"))) -> Set(
210+
InternalRow(UTF8String.fromString("11"), UTF8String.fromString("22"), 3,4,5,6)),
211+
212+
(Array("p1"), InternalRow(UTF8String.fromString("11 "))) -> Set(),
213+
214+
(Array("p1"), InternalRow(UTF8String.fromString("111 "))) -> Set(
215+
InternalRow(UTF8String.fromString("111 "), UTF8String.fromString("222 "), 3,4,5,6)
216+
),
217+
218+
(Array("p2"), InternalRow(UTF8String.fromString("22 "))) -> Set(
219+
InternalRow(UTF8String.fromString("11 "), UTF8String.fromString("22 "), 3,4,5,6),
220+
),
221+
222+
).foreach { case ((names, idents), expected) =>
223+
assert(partTable.listPartitionIdentifiers(names, idents).toSet === expected)
224+
}
225+
}
226+
227+
private def createPartTable(): OdpsTable = {
228+
if (catalog.tableExists(partTable)) {
229+
catalog.dropTable(partTable)
230+
}
231+
232+
val odpsTable = catalog.createTable(
233+
partTable,
234+
new StructType()
235+
.add("id", IntegerType)
236+
.add("data", StringType)
237+
.add("dt", StringType),
238+
Array[Transform](LogicalExpressions.identity(ref("dt"))),
239+
util.Collections.emptyMap[String, String]).asInstanceOf[OdpsTable]
240+
odpsTable
241+
}
242+
243+
private def createMultiPartTable(): OdpsTable = {
244+
if (catalog.tableExists(multiPartTable)) {
245+
catalog.dropTable(multiPartTable)
246+
}
247+
248+
val odpsTable = catalog.createTable(
249+
multiPartTable,
250+
new StructType()
251+
.add("col0", IntegerType)
252+
.add("part0", IntegerType)
253+
.add("part1", StringType),
254+
Array(LogicalExpressions.identity(ref("part0")), LogicalExpressions.identity(ref("part1"))),
255+
util.Collections.emptyMap[String, String]).asInstanceOf[OdpsTable]
256+
257+
Seq(
258+
InternalRow(0, UTF8String.fromString("abc")),
259+
InternalRow(0, UTF8String.fromString("def")),
260+
InternalRow(1, UTF8String.fromString("abc"))).foreach { partIdent =>
261+
odpsTable.createPartition(partIdent, new util.HashMap[String, String]())
262+
}
263+
odpsTable
264+
}
265+
266+
private def createPartTableAllType(): OdpsTable = {
267+
val tableSchema = new TableSchema
268+
val columns = new util.ArrayList[Column]
269+
columns.add(new Column("c0", OdpsType.STRING))
270+
columns.add(new Column("c1", OdpsType.BIGINT))
271+
tableSchema.setColumns(columns)
272+
273+
val partitionColumns = new util.ArrayList[Column]
274+
partitionColumns.add(new Column("p1", TypeInfoFactory.getCharTypeInfo(5)))
275+
partitionColumns.add(new Column("p2", TypeInfoFactory.getVarcharTypeInfo(5)))
276+
partitionColumns.add(new Column("p3", OdpsType.SMALLINT))
277+
partitionColumns.add(new Column("p4", OdpsType.TINYINT))
278+
partitionColumns.add(new Column("p5", OdpsType.BIGINT))
279+
partitionColumns.add(new Column("p6", OdpsType.INT))
280+
tableSchema.setPartitionColumns(partitionColumns)
281+
282+
if (odps.tables().exists(allTypePartTable.name)) {
283+
odps.tables().delete(allTypePartTable.name)
284+
}
285+
odps.tables().create(allTypePartTable.name, tableSchema)
286+
287+
createPartitionUseOdpsSQL(allTypePartTable.name, "p1='11',p2='22',p3=3,p4=4,p5=5,p6=6")
288+
createPartitionUseOdpsSQL(allTypePartTable.name, "p1='11 ',p2='22 ',p3=3,p4=4,p5=5,p6=6")
289+
createPartitionUseOdpsSQL(allTypePartTable.name, "p1='111 ',p2='222 ',p3=3,p4=4,p5=5,p6=6")
290+
291+
catalog.loadTable(allTypePartTable).asInstanceOf[OdpsTable]
292+
}
293+
294+
private def ref(name: String): NamedReference = LogicalExpressions.parseReference(name)
295+
296+
private def hasPartitions(table: SupportsPartitionManagement): Boolean = {
297+
!table.listPartitionIdentifiers(Array.empty, InternalRow.empty).isEmpty
298+
}
299+
300+
private def createPartitionUseOdpsSQL(tableName: String, partition: String): Unit = {
301+
val i = SQLTask.run(odps, odps.getDefaultProject,
302+
s"ALTER TABLE $project.$tableName ADD PARTITION ($partition);",
303+
"SQLAddPartitionTask",
304+
new util.HashMap[String, String](),
305+
null)
306+
i.waitForSuccess()
307+
}
308+
}

spark-connector/pom.xml

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
<packaging>pom</packaging>
2222
<properties>
2323
<arrow.version>4.0.0</arrow.version>
24-
<odps.sdk.version>0.50.1-public</odps.sdk.version>
25-
<odps.sdk.table.version>0.50.1-public</odps.sdk.table.version>
24+
<odps.sdk.version>0.50.4-public</odps.sdk.version>
25+
<odps.sdk.table.version>0.50.4-public</odps.sdk.table.version>
2626
<spark.version>3.3.1</spark.version>
2727
<scala.version>2.12.10</scala.version>
2828
<scala.binary.version>2.12</scala.binary.version>
@@ -46,6 +46,21 @@
4646
</dependencies>
4747
</dependencyManagement>
4848

49+
<dependencies>
50+
<dependency>
51+
<groupId>org.scalatest</groupId>
52+
<artifactId>scalatest-funsuite_${scala.binary.version}</artifactId>
53+
<version>3.2.19</version>
54+
<scope>test</scope>
55+
</dependency>
56+
<dependency>
57+
<groupId>io.netty</groupId>
58+
<artifactId>netty-all</artifactId>
59+
<version>4.1.74.Final</version>
60+
<scope>test</scope>
61+
</dependency>
62+
</dependencies>
63+
4964
<build>
5065
<plugins>
5166
<plugin>

0 commit comments

Comments
 (0)