Skip to content

Commit

Permalink
Add sample mode tests for incremental models
Browse files Browse the repository at this point in the history
  • Loading branch information
QMalcolm committed Feb 2, 2025
1 parent 1811754 commit 9678427
Showing 1 changed file with 156 additions and 1 deletion.
157 changes: 156 additions & 1 deletion tests/functional/sample_mode/test_sample_mode.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from datetime import datetime
from typing import Optional

import freezegun
import pytest
Expand All @@ -10,7 +11,7 @@
from dbt.event_time.sample_window import SampleWindow
from dbt.events.types import JinjaLogInfo
from dbt.materializations.incremental.microbatch import MicrobatchBuilder
from dbt.tests.util import read_file, relation_from_name, run_dbt
from dbt.tests.util import read_file, relation_from_name, run_dbt, write_file
from tests.utils import EventCatcher

input_model_sql = """
Expand All @@ -22,6 +23,21 @@
select 3 as id, TIMESTAMP '2025-01-02 12:32:00-0' as event_time
"""

later_input_model_sql = """
{{ config(materialized='table', event_time='event_time') }}
select 1 as id, TIMESTAMP '2020-01-01 01:25:00-0' as event_time
UNION ALL
select 2 as id, TIMESTAMP '2025-01-02 13:47:00-0' as event_time
UNION ALL
select 3 as id, TIMESTAMP '2025-01-03 12:32:00-0' as event_time
UNION ALL
select 4 as id, TIMESTAMP '2025-01-04 14:32:00-0' as event_time
UNION ALL
select 5 as id, TIMESTAMP '2025-01-05 20:32:00-0' as event_time
UNION ALL
select 6 as id, TIMESTAMP '2025-01-06 12:32:00-0' as event_time
"""

sample_mode_model_sql = """
{{ config(materialized='table', event_time='event_time') }}
Expand All @@ -44,6 +60,21 @@
SELECT * FROM {{ ref("input_model") }}
"""

sample_incremental_merge_sql = """
{{ config(materialized='incremental', incremental_strategy='merge', unique_key='id')}}
{% if execute %}
{{ log("is_incremental: " ~ is_incremental(), info=true) }}
{{ log("sample window: " ~ invocation_args_dict.get("sample_window"), info=true) }}
{% endif %}
SELECT * FROM {{ ref("input_model") }}
{% if is_incremental() %}
WHERE event_time >= (SELECT max(event_time) FROM {{ this }})
{% endif %}
"""


class BaseSampleMode:
# TODO This is now used in 3 test files, it might be worth turning into a full test utility method
Expand Down Expand Up @@ -190,3 +221,127 @@ def test_sample_mode(
relation_name="sample_microbatch_model",
expected_row_count=2,
)


class TestIncrementalModelSampleModeRelative(BaseSampleMode):
@pytest.fixture(scope="class")
def models(self):
return {
"input_model.sql": input_model_sql,
"sample_incremental_merge.sql": sample_incremental_merge_sql,
}

@pytest.fixture
def event_catcher(self) -> EventCatcher:
return EventCatcher(event_to_catch=JinjaLogInfo, predicate=lambda event: "is_incremental: True" in event.info.msg) # type: ignore

@pytest.mark.parametrize(
"sample_mode_available,run_sample_mode,sample_window,expected_rows",
[
(True, False, None, 6),
(True, True, "3 days", 6),
(True, True, "2 days", 5),
],
)
@freezegun.freeze_time("2025-01-06T18:03:0Z")
def test_incremental_model_sample(
self,
project,
mocker: MockerFixture,
event_catcher: EventCatcher,
sample_mode_available: bool,
run_sample_mode: bool,
sample_window: Optional[str],
expected_rows: int,
):
write_file(input_model_sql, "models", "input_model.sql")
if sample_mode_available:
mocker.patch.dict(os.environ, {"DBT_EXPERIMENTAL_SAMPLE_MODE": "True"})

_ = run_dbt(["run", "--full-refresh"], callbacks=[event_catcher.catch])

assert len(event_catcher.caught_events) == 0
self.assert_row_count(
project=project,
relation_name="sample_incremental_merge",
expected_row_count=3,
)

write_file(later_input_model_sql, "models", "input_model.sql")

run_args = ["run"]
if run_sample_mode:
run_args.extend(["--sample", f"--sample-window={sample_window}"])

_ = run_dbt(run_args, callbacks=[event_catcher.catch])

assert len(event_catcher.caught_events) == 1
self.assert_row_count(
project=project,
relation_name="sample_incremental_merge",
expected_row_count=expected_rows,
)


class TestIncrementalModelSampleModeSpecific(BaseSampleMode):
# This had to be split out from the "relative" tests because `freezegun.freezetime`
# breaks how timestamps get created.

@pytest.fixture(scope="class")
def models(self):
return {
"input_model.sql": input_model_sql,
"sample_incremental_merge.sql": sample_incremental_merge_sql,
}

@pytest.fixture
def event_catcher(self) -> EventCatcher:
return EventCatcher(event_to_catch=JinjaLogInfo, predicate=lambda event: "is_incremental: True" in event.info.msg) # type: ignore

@pytest.mark.parametrize(
"sample_mode_available,run_sample_mode,sample_window,expected_rows",
[
(True, False, None, 6),
(True, True, "{'start': '2025-01-03', 'end': '2025-01-07'}", 6),
(True, True, "{'start': '2025-01-04', 'end': '2025-01-06'}", 5),
(True, True, "{'start': '2025-01-05', 'end': '2025-01-07'}", 5),
(True, True, "{'start': '2024-12-31', 'end': '2025-01-03'}", 3),
],
)
def test_incremental_model_sample(
self,
project,
mocker: MockerFixture,
event_catcher: EventCatcher,
sample_mode_available: bool,
run_sample_mode: bool,
sample_window: Optional[str],
expected_rows: int,
):
write_file(input_model_sql, "models", "input_model.sql")
if sample_mode_available:
mocker.patch.dict(os.environ, {"DBT_EXPERIMENTAL_SAMPLE_MODE": "True"})

_ = run_dbt(["run", "--full-refresh"], callbacks=[event_catcher.catch])

assert len(event_catcher.caught_events) == 0
self.assert_row_count(
project=project,
relation_name="sample_incremental_merge",
expected_row_count=3,
)

write_file(later_input_model_sql, "models", "input_model.sql")

run_args = ["run"]
if run_sample_mode:
run_args.extend(["--sample", f"--sample-window={sample_window}"])

_ = run_dbt(run_args, callbacks=[event_catcher.catch])

assert len(event_catcher.caught_events) == 1
self.assert_row_count(
project=project,
relation_name="sample_incremental_merge",
expected_row_count=expected_rows,
)

0 comments on commit 9678427

Please sign in to comment.