Skip to content

Commit

Permalink
fix(backends): ensure that analytic functions do not receive a window…
Browse files Browse the repository at this point in the history
… frame
  • Loading branch information
cpcloud committed Jan 28, 2025
1 parent 2651fbd commit a19c3b5
Show file tree
Hide file tree
Showing 32 changed files with 126 additions and 106 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
FIRST_VALUE(`t0`.`double_col`) OVER (ORDER BY `t0`.`id` ASC) AS `First(double_col, ())`
FIRST_VALUE(`t0`.`double_col`) OVER (ORDER BY `t0`.`id` ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS `First(double_col, ())`
FROM `functional_alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
LAST_VALUE(`t0`.`double_col`) OVER (ORDER BY `t0`.`id` ASC) AS `Last(double_col, ())`
LAST_VALUE(`t0`.`double_col`) OVER (ORDER BY `t0`.`id` ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS `Last(double_col, ())`
FROM `functional_alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1 +1 @@
SELECT * FROM (SELECT `t1`.`col`, COUNT(*) OVER (ORDER BY NULL ASC) AS `analytic` FROM (SELECT `t0`.`col`, NULL AS `filter` FROM `x` AS `t0` WHERE NULL IS NULL) AS `t1`) AS `t2`
SELECT * FROM (SELECT `t1`.`col`, COUNT(*) OVER (ORDER BY NULL ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS `analytic` FROM (SELECT `t0`.`col`, NULL AS `filter` FROM `x` AS `t0` WHERE NULL IS NULL) AS `t1`) AS `t2`
Original file line number Diff line number Diff line change
@@ -1 +1 @@
SELECT `t0`.`one`, `t0`.`two`, `t0`.`three`, SUM(`t0`.`two`) OVER (PARTITION BY `t0`.`three` ORDER BY `t0`.`one` ASC) AS `four` FROM `my_data` AS `t0`
SELECT `t0`.`one`, `t0`.`two`, `t0`.`three`, SUM(`t0`.`two`) OVER (PARTITION BY `t0`.`three` ORDER BY `t0`.`one` ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `four` FROM `my_data` AS `t0`
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ SELECT
`t0`.`k`,
LAG(`t0`.`f`) OVER (PARTITION BY `t0`.`g` ORDER BY NULL ASC) AS `lag`,
LEAD(`t0`.`f`) OVER (PARTITION BY `t0`.`g` ORDER BY NULL ASC) - `t0`.`f` AS `fwd_diff`,
FIRST_VALUE(`t0`.`f`) OVER (PARTITION BY `t0`.`g` ORDER BY NULL ASC) AS `first`,
LAST_VALUE(`t0`.`f`) OVER (PARTITION BY `t0`.`g` ORDER BY NULL ASC) AS `last`,
FIRST_VALUE(`t0`.`f`) OVER (PARTITION BY `t0`.`g` ORDER BY NULL ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS `first`,
LAST_VALUE(`t0`.`f`) OVER (PARTITION BY `t0`.`g` ORDER BY NULL ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS `last`,
LAG(`t0`.`f`) OVER (PARTITION BY `t0`.`g` ORDER BY `t0`.`d` ASC) AS `lag2`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ SELECT
`t0`.`i`,
`t0`.`j`,
`t0`.`k`,
`t0`.`f` / SUM(`t0`.`f`) OVER (ORDER BY NULL ASC) AS `normed_f`
`t0`.`f` / SUM(`t0`.`f`) OVER (ORDER BY NULL ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS `normed_f`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
MAX(`t0`.`f`) OVER (ORDER BY `t0`.`d` ASC) AS `foo`
MAX(`t0`.`f`) OVER (ORDER BY `t0`.`d` ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `foo`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
MAX(`t0`.`f`) OVER (ORDER BY `t0`.`d` ASC) AS `foo`
MAX(`t0`.`f`) OVER (ORDER BY `t0`.`d` ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `foo`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
AVG(`t0`.`f`) OVER (ORDER BY `t0`.`d` ASC) AS `foo`
AVG(`t0`.`f`) OVER (ORDER BY `t0`.`d` ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `foo`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
AVG(`t0`.`f`) OVER (ORDER BY `t0`.`d` ASC) AS `foo`
AVG(`t0`.`f`) OVER (ORDER BY `t0`.`d` ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `foo`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
MIN(`t0`.`f`) OVER (ORDER BY `t0`.`d` ASC) AS `foo`
MIN(`t0`.`f`) OVER (ORDER BY `t0`.`d` ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `foo`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
MIN(`t0`.`f`) OVER (ORDER BY `t0`.`d` ASC) AS `foo`
MIN(`t0`.`f`) OVER (ORDER BY `t0`.`d` ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `foo`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
SUM(`t0`.`f`) OVER (ORDER BY `t0`.`d` ASC) AS `foo`
SUM(`t0`.`f`) OVER (ORDER BY `t0`.`d` ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `foo`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
SUM(`t0`.`f`) OVER (ORDER BY `t0`.`d` ASC) AS `foo`
SUM(`t0`.`f`) OVER (ORDER BY `t0`.`d` ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `foo`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
SELECT
`t0`.`g`,
SUM(`t0`.`f`) OVER (PARTITION BY `t0`.`g` ORDER BY NULL ASC) - SUM(`t0`.`f`) OVER (ORDER BY NULL ASC) AS `result`
SUM(`t0`.`f`) OVER (PARTITION BY `t0`.`g` ORDER BY NULL ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) - SUM(`t0`.`f`) OVER (ORDER BY NULL ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS `result`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
SELECT
LAG(`t0`.`d`) OVER (PARTITION BY `t0`.`g` ORDER BY `t0`.`f` DESC NULLS LAST) AS `foo`,
MAX(`t0`.`a`) OVER (PARTITION BY `t0`.`g` ORDER BY `t0`.`f` DESC NULLS LAST) AS `Max(a)`
MAX(`t0`.`a`) OVER (PARTITION BY `t0`.`g` ORDER BY `t0`.`f` DESC NULLS LAST ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS `Max(a)`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
SUM(`t0`.`d`) OVER (ORDER BY `t0`.`f` ASC) AS `foo`
SUM(`t0`.`d`) OVER (ORDER BY `t0`.`f` ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `foo`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
SUM(`t0`.`d`) OVER (ORDER BY `t0`.`f` ASC) AS `foo`
SUM(`t0`.`d`) OVER (ORDER BY `t0`.`f` ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `foo`
FROM `alltypes` AS `t0`
4 changes: 2 additions & 2 deletions ibis/backends/sql/compilers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1185,12 +1185,12 @@ def visit_WindowFunction(self, op, *, how, func, start, end, group_by, order_by)
)
order = sge.Order(expressions=order_by) if order_by else None

spec = self._minimize_spec(op.start, op.end, spec)
spec = self._minimize_spec(op, spec)

return sge.Window(this=func, partition_by=group_by, order=order, spec=spec)

@staticmethod
def _minimize_spec(start, end, spec):
def _minimize_spec(op, spec):
return spec

def visit_LagLead(self, op, *, arg, offset, default):
Expand Down
12 changes: 6 additions & 6 deletions ibis/backends/sql/compilers/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from ibis.backends.sql.compilers.bigquery.udf.core import PythonToJavaScriptTranslator
from ibis.backends.sql.datatypes import BigQueryType, BigQueryUDFType
from ibis.backends.sql.rewrites import (
FirstValue,
LastValue,
exclude_unsupported_window_frame_from_ops,
exclude_unsupported_window_frame_from_rank,
exclude_unsupported_window_frame_from_row_number,
Expand Down Expand Up @@ -323,12 +325,10 @@ def _compile_python_udf(self, udf_node: ops.ScalarUDF) -> sge.Create:
return func

@staticmethod
def _minimize_spec(start, end, spec):
if (
start is None
and isinstance(getattr(end, "value", None), ops.Literal)
and end.value.value == 0
and end.following
def _minimize_spec(op, spec):
# bigquery doesn't allow certain window functions to specify a window frame
if isinstance(func := op.func, ops.Analytic) and not isinstance(
func, (ops.First, ops.Last, FirstValue, LastValue, ops.NthValue)
):
return None
return spec
Expand Down
9 changes: 2 additions & 7 deletions ibis/backends/sql/compilers/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,8 @@ class ClickHouseCompiler(SQLGlotCompiler):
}

@staticmethod
def _minimize_spec(start, end, spec):
if (
start is None
and isinstance(getattr(end, "value", None), ops.Literal)
and end.value.value == 0
and end.following
):
def _minimize_spec(op, spec):
if isinstance(op.func, ops.NTile):
return None
return spec

Expand Down
11 changes: 5 additions & 6 deletions ibis/backends/sql/compilers/exasol.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from ibis.backends.sql.datatypes import ExasolType
from ibis.backends.sql.dialects import Exasol
from ibis.backends.sql.rewrites import (
FirstValue,
LastValue,
exclude_unsupported_window_frame_from_ops,
exclude_unsupported_window_frame_from_rank,
exclude_unsupported_window_frame_from_row_number,
Expand Down Expand Up @@ -85,12 +87,9 @@ class ExasolCompiler(SQLGlotCompiler):
}

@staticmethod
def _minimize_spec(start, end, spec):
if (
start is None
and isinstance(getattr(end, "value", None), ops.Literal)
and end.value.value == 0
and end.following
def _minimize_spec(op, spec):
if isinstance(func := op.func, ops.Analytic) and not isinstance(
func, (ops.First, ops.Last, FirstValue, LastValue, ops.NthValue)
):
return None
return spec
Expand Down
8 changes: 4 additions & 4 deletions ibis/backends/sql/compilers/flink.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,16 +128,16 @@ def _generate_groups(groups):
return groups

@staticmethod
def _minimize_spec(start, end, spec):
def _minimize_spec(op, spec):
if (
start is None
and isinstance(getattr(end, "value", None), ops.Literal)
op.start is None
and isinstance(getattr(end := op.end, "value", None), ops.Literal)
and end.value.value == 0
and end.following
):
return None
elif (
isinstance(getattr(end, "value", None), ops.Cast)
isinstance(getattr(end := op.end, "value", None), ops.Cast)
and end.value.arg.value == 0
and end.following
):
Expand Down
34 changes: 11 additions & 23 deletions ibis/backends/sql/compilers/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
from ibis.backends.sql.compilers.base import NULL, STAR, SQLGlotCompiler
from ibis.backends.sql.datatypes import ImpalaType
from ibis.backends.sql.dialects import Impala
from ibis.backends.sql.rewrites import lower_sample, rewrite_empty_order_by_window
from ibis.backends.sql.rewrites import (
FirstValue,
LastValue,
lower_sample,
rewrite_empty_order_by_window,
)


class ImpalaCompiler(SQLGlotCompiler):
Expand Down Expand Up @@ -73,28 +78,11 @@ class ImpalaCompiler(SQLGlotCompiler):
}

@staticmethod
def _minimize_spec(start, end, spec):
# start is None means unbounded preceding
if start is None:
# end is None: unbounded following
# end == 0 => current row
# these are treated the same because for the functions where these
# are not allowed they end up behaving the same
#
# I think we're not covering some cases here:
# These will be treated the same, even though they're not
# - window(order_by=x, rows=(None, None)) # should be equivalent to `over ()`
# - window(order_by=x, rows=(None, 0)) # equivalent to a cumulative aggregation
#
# TODO(cpcloud): we need to clean up the semantics of unbounded
# following vs current row at the API level.
#
if end is None or (
isinstance(getattr(end, "value", None), ops.Literal)
and end.value.value == 0
and end.following
):
return None
def _minimize_spec(op, spec):
if isinstance(func := op.func, ops.Analytic) and not isinstance(
func, (ops.First, ops.Last, FirstValue, LastValue, ops.NthValue)
):
return None
return spec

def visit_Log2(self, op, *, arg):
Expand Down
11 changes: 5 additions & 6 deletions ibis/backends/sql/compilers/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from ibis.backends.sql.datatypes import MSSQLType
from ibis.backends.sql.dialects import MSSQL
from ibis.backends.sql.rewrites import (
FirstValue,
LastValue,
exclude_unsupported_window_frame_from_ops,
exclude_unsupported_window_frame_from_rank,
exclude_unsupported_window_frame_from_row_number,
Expand Down Expand Up @@ -147,12 +149,9 @@ def _generate_groups(groups):
return groups

@staticmethod
def _minimize_spec(start, end, spec):
if (
start is None
and isinstance(getattr(end, "value", None), ops.Literal)
and end.value.value == 0
and end.following
def _minimize_spec(op, spec):
if isinstance(func := op.func, ops.Analytic) and not isinstance(
func, (ops.First, ops.Last, FirstValue, LastValue, ops.NthValue)
):
return None
return spec
Expand Down
9 changes: 3 additions & 6 deletions ibis/backends/sql/compilers/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,9 @@ def POS_INF(self):
}

@staticmethod
def _minimize_spec(start, end, spec):
if (
start is None
and isinstance(getattr(end, "value", None), ops.Literal)
and end.value.value == 0
and end.following
def _minimize_spec(op, spec):
if isinstance(
op.func, (ops.RankBase, ops.CumeDist, ops.NTile, ops.PercentRank)
):
return None
return spec
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/sql/compilers/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def visit_WindowFunction(self, op, *, how, func, start, end, group_by, order_by)

order = sge.Order(expressions=order_by) if order_by else None

spec = self._minimize_spec(op.start, op.end, spec)
spec = self._minimize_spec(op, spec)

return sge.Window(this=func, partition_by=group_by, order=order, spec=spec)

Expand Down
23 changes: 11 additions & 12 deletions ibis/backends/sql/compilers/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from ibis.backends.sql.datatypes import SnowflakeType
from ibis.backends.sql.dialects import Snowflake
from ibis.backends.sql.rewrites import (
FirstValue,
LastValue,
exclude_unsupported_window_frame_from_ops,
exclude_unsupported_window_frame_from_row_number,
lower_log2,
Expand Down Expand Up @@ -237,17 +239,6 @@ def _compile_udf(self, udf_node: ops.ScalarUDF):
_compile_pandas_udf = _compile_udf
_compile_python_udf = _compile_udf

@staticmethod
def _minimize_spec(start, end, spec):
if (
start is None
and isinstance(getattr(end, "value", None), ops.Literal)
and end.value.value == 0
and end.following
):
return None
return spec

def visit_Literal(self, op, *, value, dtype):
if value is None:
return super().visit_Literal(op, value=value, dtype=dtype)
Expand Down Expand Up @@ -668,6 +659,14 @@ def visit_Xor(self, op, *, left, right):
# boolxor accepts numerics ... and returns a boolean? wtf?
return self.f.boolxor(self.cast(left, dt.int8), self.cast(right, dt.int8))

@staticmethod
def _minimize_spec(op, spec):
if isinstance(func := op.func, ops.Analytic) and not isinstance(
func, (ops.First, ops.Last, FirstValue, LastValue, ops.NthValue)
):
return None
return spec

def visit_WindowFunction(self, op, *, how, func, start, end, group_by, order_by):
if start is None:
start = {}
Expand Down Expand Up @@ -698,7 +697,7 @@ def visit_WindowFunction(self, op, *, how, func, start, end, group_by, order_by)
order = sge.Order(expressions=order_by) if order_by else None

orig_spec = spec
spec = self._minimize_spec(op.start, op.end, orig_spec)
spec = self._minimize_spec(op, orig_spec)

# due to https://docs.snowflake.com/en/sql-reference/functions-analytic#window-frame-usage-notes
# we need to make the default window rows (since range isn't supported)
Expand Down
11 changes: 5 additions & 6 deletions ibis/backends/sql/compilers/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from ibis.backends.sql.datatypes import TrinoType
from ibis.backends.sql.dialects import Trino
from ibis.backends.sql.rewrites import (
FirstValue,
LastValue,
exclude_unsupported_window_frame_from_ops,
exclude_unsupported_window_frame_from_rank,
exclude_unsupported_window_frame_from_row_number,
Expand Down Expand Up @@ -104,12 +106,9 @@ class TrinoCompiler(SQLGlotCompiler):
}

@staticmethod
def _minimize_spec(start, end, spec):
if (
start is None
and isinstance(getattr(end, "value", None), ops.Literal)
and end.value.value == 0
and end.following
def _minimize_spec(op, spec):
if isinstance(func := op.func, ops.Analytic) and not isinstance(
func, (ops.First, ops.Last, FirstValue, LastValue, ops.NthValue)
):
return None
return spec
Expand Down
Loading

0 comments on commit a19c3b5

Please sign in to comment.