Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ export class SnowflakeQuery extends BaseQuery {
templates.expressions.extract = 'EXTRACT({{ date_part }} FROM {{ expr }})';
templates.expressions.interval = 'INTERVAL \'{{ interval }}\'';
templates.expressions.timestamp_literal = '\'{{ value }}\'::timestamp_tz';
templates.expressions.timestamp_tz_named_timezone_cast = 'CONVERT_TIMEZONE(\'{{ timezone }}\', \'UTC\', \'{{ timestamp }}\'::timestamp_ntz)';
templates.expressions.like = '{{ expr }} {% if negated %}NOT {% endif %}LIKE {{ pattern }}{% if default_escape %} ESCAPE \'\\\\\'{% endif %}';
templates.expressions.ilike = '{{ expr }} {% if negated %}NOT {% endif %}ILIKE {{ pattern }}{% if default_escape %} ESCAPE \'\\\\\'{% endif %}';
templates.operators.is_not_distinct_from = 'IS NOT DISTINCT FROM';
Expand Down
40 changes: 40 additions & 0 deletions rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use crate::{
AliasedColumn, DataSource, LoadRequestMeta, MetaContext, SpanId, SqlGenerator,
SqlTemplates, TransportLoadRequestQuery, TransportService,
},
utils::{parse_named_timezone_timestamp, TIMESTAMP_TZ_NAMED_TIMEZONE_CAST_TEMPLATE},
CubeError,
};
use chrono::{Days, NaiveDate, SecondsFormat, TimeZone, Utc};
Expand Down Expand Up @@ -1816,6 +1817,38 @@ impl WrappedSelectNode {
.map_err(|e| DataFusionError::Internal(format!("Can't generate SQL for cast: {}", e)))
}

fn generate_sql_for_named_timezone_timestamp_cast(
sql_generator: Arc<dyn SqlGenerator>,
expr: &Expr,
data_type: &DataType,
) -> result::Result<Option<String>, DataFusionError> {
// DataFusion erases WITH TIME ZONE for these literals.
if !matches!(data_type, DataType::Timestamp(_, _)) {
return Ok(None);
}

let sql_templates = sql_generator.get_sql_templates();
if !sql_templates.contains_template(TIMESTAMP_TZ_NAMED_TIMEZONE_CAST_TEMPLATE) {
return Ok(None);
}

if let Expr::Literal(ScalarValue::Utf8(Some(value))) = expr {
if let Some((timestamp, timezone)) = parse_named_timezone_timestamp(value) {
return sql_templates
.timestamp_tz_named_timezone_cast_expr(timestamp, timezone, value.clone())
.map(Some)
.map_err(|e| {
DataFusionError::Internal(format!(
"Can't generate SQL for timestamp with named timezone cast: {}",
e
))
});
}
}

Ok(None)
}

fn generate_sql_type(
sql_generator: Arc<dyn SqlGenerator>,
data_type: DataType,
Expand Down Expand Up @@ -2127,6 +2160,13 @@ impl WrappedSelectNode {
subqueries,
),
Expr::Cast { expr, data_type } => {
if let Some(resulting_sql) = Self::generate_sql_for_named_timezone_timestamp_cast(
sql_generator.clone(),
expr.as_ref(),
&data_type,
)? {
return Ok((resulting_sql, sql_query));
}
let (expr, sql_query) = Self::generate_sql_for_expr(
sql_query,
sql_generator.clone(),
Expand Down
98 changes: 92 additions & 6 deletions rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/cast.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,50 @@
use crate::compile::rewrite::{
cast_expr, rewrite, rewriter::CubeRewrite, rules::wrapper::WrapperRules,
wrapper_pullup_replacer, wrapper_pushdown_replacer,
use crate::{
compile::rewrite::{
cast_expr, rewrite, rewriter::CubeEGraph, rewriter::CubeRewrite,
rules::wrapper::WrapperRules, transforming_rewrite, wrapper_pullup_replacer,
wrapper_pushdown_replacer, wrapper_replacer_context, CastExprDataType, LiteralExprValue,
LogicalPlanLanguage,
},
transport::DataSource,
utils::{parse_named_timezone_timestamp, TIMESTAMP_TZ_NAMED_TIMEZONE_CAST_TEMPLATE},
};
use crate::{var, var_iter};
use datafusion::{arrow::datatypes::DataType, scalar::ScalarValue};
use egg::{Id, Subst};

impl WrapperRules {
pub fn cast_rules(&self, rules: &mut Vec<CubeRewrite>) {
rules.extend(vec![
rewrite(
transforming_rewrite(
"wrapper-push-down-cast",
wrapper_pushdown_replacer(cast_expr("?expr", "?data_type"), "?context"),
cast_expr(wrapper_pushdown_replacer("?expr", "?context"), "?data_type"),
wrapper_pushdown_replacer(
cast_expr("?expr", "?data_type"),
wrapper_replacer_context(
"?alias_to_cube",
"?push_to_cube",
"?in_projection",
"?cube_members",
"?grouped_subqueries",
"?ungrouped_scan",
"?input_data_source",
),
),
cast_expr(
wrapper_pushdown_replacer(
"?expr",
wrapper_replacer_context(
"?alias_to_cube",
"?push_to_cube",
"?in_projection",
"?cube_members",
"?grouped_subqueries",
"?ungrouped_scan",
"?input_data_source",
),
),
"?data_type",
),
self.transform_cast_pushdown("?input_data_source", "?expr", "?data_type"),
),
rewrite(
"wrapper-pull-up-cast",
Expand All @@ -18,4 +53,55 @@ impl WrapperRules {
),
]);
}

fn transform_cast_pushdown(
&self,
input_data_source_var: &str,
expr_var: &str,
data_type_var: &str,
) -> impl Fn(&mut CubeEGraph, &mut Subst) -> bool {
let input_data_source_var = var!(input_data_source_var);
let expr_var = var!(expr_var);
let data_type_var = var!(data_type_var);
let meta = self.meta_context.clone();
move |egraph, subst| {
let is_named_timezone_timestamp_cast =
var_iter!(egraph[subst[data_type_var]], CastExprDataType).any(|data_type| {
// DataFusion erases WITH TIME ZONE for these literals.
matches!(data_type, DataType::Timestamp(_, _))
&& Self::expr_is_named_timezone_string_literal(egraph, subst[expr_var])
});

if !is_named_timezone_timestamp_cast {
return true;
}

let Ok(data_source) = Self::get_data_source(egraph, subst, input_data_source_var)
else {
return false;
};

match data_source {
DataSource::Specific(_) => Self::can_rewrite_template(
&data_source,
&meta,
TIMESTAMP_TZ_NAMED_TIMEZONE_CAST_TEMPLATE,
),
// This template is target-specific.
DataSource::Unrestricted => false,
}
}
}

fn expr_is_named_timezone_string_literal(egraph: &CubeEGraph, expr_id: Id) -> bool {
egraph[expr_id].nodes.iter().any(|node| {
if let LogicalPlanLanguage::LiteralExpr([value_id]) = node {
var_iter!(egraph[*value_id], LiteralExprValue).any(|literal| {
matches!(literal, ScalarValue::Utf8(Some(value)) if parse_named_timezone_timestamp(value).is_some())
})
} else {
false
}
})
}
}
7 changes: 4 additions & 3 deletions rust/cubesql/cubesql/src/compile/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ use crate::{
auth_service::SqlAuthServiceAuthenticateRequest,
dataframe,
statement::{
ApproximateCountDistinctVisitor, CastReplacer, RedshiftDatePartReplacer,
SensitiveDataSanitizer, SqlParser062Normalizer, ToTimestampReplacer,
UdfWildcardArgReplacer,
ApproximateCountDistinctVisitor, CastReplacer, PlainTimestampTimezoneSuffixReplacer,
RedshiftDatePartReplacer, SensitiveDataSanitizer, SqlParser062Normalizer,
ToTimestampReplacer, UdfWildcardArgReplacer,
},
ColumnFlags, ColumnType, Session, SessionManager, SessionState,
},
Expand Down Expand Up @@ -744,6 +744,7 @@ impl QueryRouter {

pub fn rewrite_statement(stmt: ast::Statement) -> ast::Statement {
let stmt = SqlParser062Normalizer::new().replace(stmt);
let stmt = PlainTimestampTimezoneSuffixReplacer::new().replace(stmt);
let stmt = CastReplacer::new().replace(stmt);
let stmt = ToTimestampReplacer::new().replace(stmt);
let stmt = UdfWildcardArgReplacer::new().replace(stmt);
Expand Down
Loading
Loading