Skip to content

Commit

Permalink
fix: prevent get_removal_interval from returning invalid interval (#3879
Browse files Browse the repository at this point in the history
)
  • Loading branch information
tobymao authored and izeigerman committed Feb 26, 2025
1 parent 4107c0e commit 81181af
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 78 deletions.
69 changes: 38 additions & 31 deletions sqlmesh/core/plan/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import logging
import re
import sys
import typing as t
from collections import defaultdict
from functools import cached_property
Expand All @@ -27,7 +26,13 @@
from sqlmesh.core.snapshot.definition import Interval, SnapshotId
from sqlmesh.utils import columns_to_types_all_known, random_id
from sqlmesh.utils.dag import DAG
from sqlmesh.utils.date import TimeLike, now, to_datetime, yesterday_ds, to_timestamp
from sqlmesh.utils.date import (
TimeLike,
now,
to_datetime,
yesterday_ds,
to_timestamp,
)
from sqlmesh.utils.errors import NoChangesPlanError, PlanError, SQLMeshError

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -322,56 +327,58 @@ def is_restateable_snapshot(snapshot: Snapshot) -> bool:
if not restate_models:
return {}

start = self._start or earliest_interval_start
end = self._end or now()

# Add restate snapshots and their downstream snapshots
dummy_interval = (sys.maxsize, -sys.maxsize)
for model_fqn in restate_models:
snapshot = self._model_fqn_to_snapshot.get(model_fqn)
if not snapshot:
if model_fqn not in self._model_fqn_to_snapshot:
raise PlanError(f"Cannot restate model '{model_fqn}'. Model does not exist.")

# Get restatement intervals for all restated snapshots and make sure that if an incremental snapshot expands it's
# restatement range that it's downstream dependencies all expand their restatement ranges as well.
for s_id in dag:
snapshot = self._context_diff.snapshots[s_id]

if not forward_only_preview_needed:
if self._is_dev and not snapshot.is_paused:
self._console.log_warning(
f"Cannot restate model '{model_fqn}' because the current version is used in production. "
f"Cannot restate model '{snapshot.name}' because the current version is used in production. "
"Run the restatement against the production environment instead to restate this model."
)
continue
elif (not self._is_dev or not snapshot.is_paused) and snapshot.disable_restatement:
self._console.log_warning(
f"Cannot restate model '{model_fqn}'. "
f"Cannot restate model '{snapshot.name}'. "
"Restatement is disabled for this model to prevent possible data loss."
"If you want to restate this model, change the model's `disable_restatement` setting to `false`."
)
continue
elif snapshot.is_symbolic or snapshot.is_seed:
logger.info("Skipping restatement for model '%s'", model_fqn)
logger.info("Skipping restatement for model '%s'", snapshot.name)
continue

restatements[snapshot.snapshot_id] = dummy_interval
for downstream_s_id in dag.downstream(snapshot.snapshot_id):
if is_restateable_snapshot(self._context_diff.snapshots[downstream_s_id]):
restatements[downstream_s_id] = dummy_interval
# Since we are traversing the graph in topological order and the largest interval range is pushed down
# the graph we just have to check our immediate parents in the graph and not the whole upstream graph.
restating_parents = [
self._context_diff.snapshots[s] for s in snapshot.parents if s in restatements
]

# Get restatement intervals for all restated snapshots and make sure that if an incremental snapshot expands it's
# restatement range that it's downstream dependencies all expand their restatement ranges as well.
for s_id in dag:
if s_id not in restatements:
if not restating_parents and snapshot.name not in restate_models:
continue
snapshot = self._context_diff.snapshots[s_id]
interval = snapshot.get_removal_interval(
self._start or earliest_interval_start,
self._end or now(),
self._execution_time,
strict=False,
is_preview=is_preview,

possible_intervals = {
restatements[p.snapshot_id] for p in restating_parents if p.is_incremental
}
possible_intervals.add(
snapshot.get_removal_interval(
start,
end,
self._execution_time,
strict=False,
is_preview=is_preview,
)
)
# Since we are traversing the graph in topological order and the largest interval range is pushed down
# the graph we just have to check our immediate parents in the graph and not the whole upstream graph.
snapshot_dependencies = snapshot.parents
possible_intervals = [
restatements.get(s, dummy_interval)
for s in snapshot_dependencies
if self._context_diff.snapshots[s].is_incremental
] + [interval]
snapshot_start = min(i[0] for i in possible_intervals)
snapshot_end = max(i[1] for i in possible_intervals)

Expand Down
61 changes: 33 additions & 28 deletions sqlmesh/core/snapshot/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,8 @@ def add_interval(self, start: TimeLike, end: TimeLike, is_dev: bool = False) ->
f"Attempted to add an Invalid interval ({start}, {end}) to snapshot {self.snapshot_id}"
)

start_ts, end_ts = self.inclusive_exclusive(start, end, strict=False)
start_ts, end_ts = self.inclusive_exclusive(start, end, strict=False, expand=False)

if start_ts >= end_ts:
# Skipping partial interval.
return
Expand Down Expand Up @@ -744,12 +745,17 @@ def get_removal_interval(

return removal_interval

@property
def allow_partials(self) -> bool:
return self.is_model and self.model.allow_partials

def inclusive_exclusive(
self,
start: TimeLike,
end: TimeLike,
strict: bool = True,
allow_partial: t.Optional[bool] = None,
expand: bool = True,
) -> Interval:
"""Transform the inclusive start and end into a [start, end) pair.
Expand All @@ -758,19 +764,18 @@ def inclusive_exclusive(
end: The end date/time of the interval (inclusive)
strict: Whether to fail when the inclusive start is the same as the exclusive end.
allow_partial: Whether the interval can be partial or not.
expand: Whether or not partial intervals are expanded outwards.
Returns:
A [start, end) pair.
"""
if allow_partial is None:
allow_partial = self.is_model and self.model.allow_partials
return inclusive_exclusive(
start,
end,
self.node.interval_unit,
model_allow_partials=self.is_model and self.model.allow_partials,
strict=strict,
allow_partial=allow_partial,
allow_partial=self.allow_partials if allow_partial is None else allow_partial,
expand=expand,
)

def merge_intervals(self, other: t.Union[Snapshot, SnapshotIntervals]) -> None:
Expand Down Expand Up @@ -847,9 +852,10 @@ def missing_intervals(
# If the amount of time being checked is less than the size of a single interval then we
# know that there can't being missing intervals within that range and return
validate_date_range(start, end)

if (
not is_date(end)
and not (self.is_model and self.model.allow_partials)
and not self.allow_partials
and to_timestamp(end) - to_timestamp(start) < self.node.interval_unit.milliseconds
):
return []
Expand All @@ -862,16 +868,7 @@ def missing_intervals(
if not self.evaluatable or (self.is_seed and intervals):
return []

allow_partials = self.is_model and self.model.allow_partials
start_ts, end_ts = (
to_timestamp(ts)
for ts in self.inclusive_exclusive(
start,
end,
strict=False,
allow_partial=allow_partials,
)
)
start_ts, end_ts = (to_timestamp(ts) for ts in self.inclusive_exclusive(start, end))

interval_unit = self.node.interval_unit
execution_time_ts = to_timestamp(execution_time) if execution_time else now_timestamp()
Expand All @@ -882,7 +879,7 @@ def missing_intervals(
)
if end_bounded:
upper_bound_ts = min(upper_bound_ts, end_ts)
if not allow_partials:
if not self.allow_partials:
upper_bound_ts = to_timestamp(interval_unit.cron_floor(upper_bound_ts))

end_ts = min(end_ts, upper_bound_ts)
Expand Down Expand Up @@ -1865,36 +1862,44 @@ def inclusive_exclusive(
start: TimeLike,
end: TimeLike,
interval_unit: IntervalUnit,
model_allow_partials: bool,
strict: bool = True,
allow_partial: bool = False,
expand: bool = True,
) -> Interval:
"""Transform the inclusive start and end into a [start, end) pair.
Args:
start: The start date/time of the interval (inclusive)
end: The end date/time of the interval (inclusive)
interval_unit: The interval unit.
model_allow_partials: Whether or not the model allows partials.
strict: Whether to fail when the inclusive start is the same as the exclusive end.
allow_partial: Whether the interval can be partial or not.
expand: Whether or not partial intervals are expanded outwards.
Returns:
A [start, end) pair.
"""
start_ts = to_timestamp(interval_unit.cron_floor(start))
if start_ts < to_timestamp(start) and not model_allow_partials:
start_ts = to_timestamp(interval_unit.cron_next(start_ts))
start_dt = interval_unit.cron_floor(start)

if not expand and not allow_partial and start_dt < to_datetime(start):
start_dt = interval_unit.cron_next(start_dt)

start_ts = to_timestamp(start_dt)

if is_date(end):
end = to_datetime(end) + timedelta(days=1)
end_ts = to_timestamp(interval_unit.cron_floor(end) if not allow_partial else end)
if end_ts < start_ts and to_timestamp(end) > to_timestamp(start) and not strict:
# This can happen when the interval unit is coarser than the size of the input interval.
# For example, if the interval unit is monthly, but the input interval is only 1 hour long.
return (start_ts, end_ts)

if (strict and start_ts >= end_ts) or (start_ts > end_ts):
if allow_partial:
end_dt = end
else:
end_dt = interval_unit.cron_floor(end)

if expand and end_dt != to_datetime(end):
end_dt = interval_unit.cron_next(end_dt)

end_ts = to_timestamp(end_dt)

if strict and start_ts >= end_ts:
raise ValueError(
f"`end` ({to_datetime(end_ts)}) must be greater than `start` ({to_datetime(start_ts)})"
)
Expand Down
2 changes: 1 addition & 1 deletion sqlmesh/core/state_sync/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ def add_interval(
end: The end of the interval to add.
is_dev: Indicates whether the given interval is being added while in development mode
"""
start_ts, end_ts = snapshot.inclusive_exclusive(start, end, strict=False)
start_ts, end_ts = snapshot.inclusive_exclusive(start, end, strict=False, expand=False)
if not snapshot.version:
raise SQLMeshError("Snapshot version must be set to add an interval.")
intervals = [(start_ts, end_ts)]
Expand Down
8 changes: 3 additions & 5 deletions tests/core/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2153,7 +2153,7 @@ def test_restatement_plan_ignores_changes(init_and_plan_context: t.Callable):
assert not plan.new_snapshots
assert plan.requires_backfill
assert plan.restatements == {
restated_snapshot.snapshot_id: (to_timestamp("2023-01-01"), to_timestamp("2023-01-08"))
restated_snapshot.snapshot_id: (to_timestamp("2023-01-01"), to_timestamp("2023-01-09"))
}
assert plan.missing_intervals == [
SnapshotIntervals(
Expand Down Expand Up @@ -4562,16 +4562,14 @@ def test_restatement_of_full_model_with_start(init_and_plan_context: t.Callable)
no_prompts=True,
)

restatement_end = to_timestamp("2023-01-08")

sushi_customer_interval = restatement_plan.restatements[
context.get_snapshot("sushi.customers").snapshot_id
]
assert sushi_customer_interval == (to_timestamp("2023-01-01"), restatement_end)
assert sushi_customer_interval == (to_timestamp("2023-01-01"), to_timestamp("2023-01-09"))
waiter_by_day_interval = restatement_plan.restatements[
context.get_snapshot("sushi.waiter_as_customer_by_day").snapshot_id
]
assert waiter_by_day_interval == (to_timestamp("2023-01-07"), restatement_end)
assert waiter_by_day_interval == (to_timestamp("2023-01-07"), to_timestamp("2023-01-08"))


def initial_add(context: Context, environment: str):
Expand Down
Loading

0 comments on commit 81181af

Please sign in to comment.