Skip to content

Commit 41cfa9e

Browse files
fix(optimizer)!: query schema directly when type annotation fails for processing UNNEST source
1 parent f7458a4 commit 41cfa9e

File tree

2 files changed

+116
-3
lines changed

2 files changed

+116
-3
lines changed

sqlglot/optimizer/resolver.py

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,24 @@ def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequenc
144144
# in bigquery, unnest structs are automatically scoped as tables, so you can
145145
# directly select a struct field in a query.
146146
# this handles the case where the unnest is statically defined.
147-
if self.dialect.UNNEST_COLUMN_ONLY:
148-
if source.expression.is_type(exp.DataType.Type.STRUCT):
149-
for k in source.expression.type.expressions: # type: ignore
147+
if self.dialect.UNNEST_COLUMN_ONLY and isinstance(source.expression, exp.Unnest):
148+
unnest = source.expression
149+
150+
# if type is not annotated yet, try to get it from the schema
151+
if not unnest.type or unnest.type.is_type(exp.DataType.Type.UNKNOWN):
152+
unnest_expr = seq_get(unnest.expressions, 0)
153+
if isinstance(unnest_expr, exp.Column) and self.scope.parent:
154+
col_type = self._get_unnest_column_type(unnest_expr)
155+
# extract element type if it's an ARRAY
156+
if col_type and col_type.is_type(exp.DataType.Type.ARRAY):
157+
element_types = col_type.expressions
158+
if element_types:
159+
unnest.type = element_types[0]
160+
else:
161+
unnest.type = col_type
162+
# check if the result type is a STRUCT - extract struct field names
163+
if unnest.is_type(exp.DataType.Type.STRUCT):
164+
for k in unnest.type.expressions: # type: ignore
150165
columns.append(k.name)
151166
elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation):
152167
columns = self.get_source_columns_from_set_op(source.expression)
@@ -299,3 +314,59 @@ def _get_unambiguous_columns(
299314
unambiguous_columns[column] = table
300315

301316
return unambiguous_columns
317+
318+
def _get_unnest_column_type(self, column: exp.Column) -> t.Optional[exp.DataType]:
319+
"""
320+
Get the type of a column being unnested, tracing through CTEs/subqueries to find the base table.
321+
322+
Args:
323+
column: The column expression being unnested.
324+
325+
Returns:
326+
The DataType of the column, or None if not found.
327+
"""
328+
scope = self.scope.parent
329+
330+
# if column is qualified, use that table, otherwise disambiguate using the resolver
331+
if column.table:
332+
table_name = column.table
333+
else:
334+
# use the parent scope's resolver to disambiguate the column
335+
parent_resolver = Resolver(scope, self.schema, self._infer_schema)
336+
table_identifier = parent_resolver.get_table(column)
337+
if not table_identifier:
338+
return None
339+
table_name = table_identifier.name
340+
341+
source = scope.sources[table_name]
342+
return self._get_column_type_from_scope(source, column.name)
343+
344+
def _get_column_type_from_scope(
345+
self, source: t.Union[Scope, exp.Table], col_name: str
346+
) -> t.Optional[exp.DataType]:
347+
"""
348+
Get a column's type by tracing through scopes/tables to find the base table.
349+
350+
Args:
351+
source: The source to search - can be a Scope (to iterate its sources) or a Table.
352+
col_name: The column name to find.
353+
354+
Returns:
355+
The DataType of the column, or None if not found.
356+
"""
357+
if isinstance(source, exp.Table):
358+
# base table - get the column type from schema
359+
col_type: t.Optional[exp.DataType] = self.schema.get_column_type(
360+
source, exp.Column(this=exp.to_identifier(col_name))
361+
)
362+
if col_type and not col_type.is_type(exp.DataType.Type.UNKNOWN):
363+
return col_type
364+
elif isinstance(source, Scope):
365+
# iterate over all sources in the scope
366+
for source_name in source.sources:
367+
nested_source = source.sources[source_name]
368+
col_type = self._get_column_type_from_scope(nested_source, col_name)
369+
if col_type and not col_type.is_type(exp.DataType.Type.UNKNOWN):
370+
return col_type
371+
372+
return None

tests/test_optimizer.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,48 @@ def test_qualify_columns(self, logger):
516516
"SELECT a.b_id AS b_id FROM a AS a JOIN b AS b ON a.b_id = b.b_id JOIN c AS c ON b.b_id = c.b_id JOIN d AS d ON b.d_id = d.d_id",
517517
)
518518

519+
self.assertEqual(
520+
optimizer.qualify.qualify(
521+
parse_one(
522+
"""
523+
SELECT
524+
(SELECT SUM(c.amount)
525+
FROM UNNEST(credits) AS c
526+
WHERE type != 'promotion') as total
527+
FROM billing
528+
""",
529+
read="bigquery",
530+
),
531+
schema={"billing": {"credits": "ARRAY<STRUCT<amount FLOAT64, type STRING>>"}},
532+
dialect="bigquery",
533+
).sql(dialect="bigquery"),
534+
"SELECT (SELECT SUM(`c`.`amount`) AS `_col_0` FROM UNNEST(`billing`.`credits`) AS `c` WHERE `type` <> 'promotion') AS `total` FROM `billing` AS `billing`",
535+
)
536+
537+
self.assertEqual(
538+
optimizer.qualify.qualify(
539+
parse_one(
540+
"""
541+
WITH cte AS (SELECT * FROM base_table)
542+
SELECT
543+
(SELECT SUM(item.price)
544+
FROM UNNEST(items) AS item
545+
WHERE category = 'electronics') as electronics_total
546+
FROM cte
547+
""",
548+
read="bigquery",
549+
),
550+
schema={
551+
"base_table": {
552+
"id": "INT64",
553+
"items": "ARRAY<STRUCT<price FLOAT64, category STRING>>",
554+
}
555+
},
556+
dialect="bigquery",
557+
).sql(dialect="bigquery"),
558+
"WITH `cte` AS (SELECT `base_table`.`id` AS `id`, `base_table`.`items` AS `items` FROM `base_table` AS `base_table`) SELECT (SELECT SUM(`item`.`price`) AS `_col_0` FROM UNNEST(`cte`.`items`) AS `item` WHERE `category` = 'electronics') AS `electronics_total` FROM `cte` AS `cte`",
559+
)
560+
519561
self.check_file(
520562
"qualify_columns",
521563
qualify_columns,

0 commit comments

Comments
 (0)