diff --git a/great_expectations/zep/interfaces.py b/great_expectations/zep/interfaces.py index feb52e7c272a..72f59f7b3ac0 100644 --- a/great_expectations/zep/interfaces.py +++ b/great_expectations/zep/interfaces.py @@ -11,16 +11,12 @@ from great_expectations.zep.metadatasource import MetaDatasource # BatchRequestOptions is a dict that is composed into a BatchRequest that specifies the -# Batches one wants returned. In the simple case the keys represent dimensions one can -# slice the data along and the values are the values. One can also namespace these key/value -# pairs, hence the Dict[str, BatchRequestValue], allowed values. For example: -# options = { -# "month": "3" -# "year_splitter": { -# "year": "2020" -# } -# } -# The month key is in the global namespace while the year key is in the year_splitter namespace. +# Batches one wants returned. The keys represent dimensions one can slice the data along +# and the values are the realized. If a value is None or unspecified, the batch_request +# will capture all data along this dimension. For example, if we have a year and month +# splitter and we want to query all months in the year 2020, the batch request options +# would look like: +# options = { "year": 2020 } BatchRequestOptions: TypeAlias = Dict[str, Any] diff --git a/great_expectations/zep/postgres_datasource.py b/great_expectations/zep/postgres_datasource.py index 168e0f76be14..5ba1c06937f5 100644 --- a/great_expectations/zep/postgres_datasource.py +++ b/great_expectations/zep/postgres_datasource.py @@ -1,9 +1,12 @@ from __future__ import annotations import dataclasses +import itertools +from datetime import datetime from pprint import pformat as pf -from typing import Any, Dict, List, Optional, Type +from typing import Dict, Iterable, List, Optional, Type, cast +import dateutil.tz from typing_extensions import ClassVar from great_expectations.core.batch_spec import SqlAlchemyDatasourceBatchSpec @@ -21,12 +24,28 @@ class PostgresDatasourceError(Exception): pass +class BatchRequestError(Exception): + pass + + +# For our year splitter we default the range to the last 2 year. +_CURRENT_YEAR = datetime.now(dateutil.tz.tzutc()).year +_DEFAULT_YEAR_RANGE = range(_CURRENT_YEAR - 1, _CURRENT_YEAR + 1) +_DEFAULT_MONTH_RANGE = range(1, 13) + + @dataclasses.dataclass(frozen=True) class ColumnSplitter: method_name: str column_name: str - name: str - template_params: List[str] + # param_defaults is a Dict where the keys are the parameters of the splitter and the values are the default + # values are the default values if a batch request using the splitter leaves the parameter unspecified. + # template_params: List[str] + param_defaults: Dict[str, Iterable] + + @property + def param_names(self) -> List[str]: + return list(self.param_defaults.keys()) class TableAsset(DataAsset): @@ -65,46 +84,76 @@ def get_batch_request( Args: options: A dict that can be used to limit the number of batches returned from the asset. The dict structure depends on the asset type. A template of the dict can be obtained by - calling batch_request_template. + calling batch_request_options_template. Returns: A BatchRequest object that can be used to obtain a batch list from a Datasource by calling the get_batch_list_from_batch_request method. """ + if options is not None and not self._valid_batch_request_options(options): + raise BatchRequestError( + "Batch request options should have a subset of keys:\n" + f"{list(self.batch_request_options_template().keys())}\n" + f"but actually has the form:\n{pf(options)}\n" + ) return BatchRequest( datasource_name=self.datasource.name, data_asset_name=self.name, options=options or {}, ) - def batch_request_template( + def _valid_batch_request_options(self, options: BatchRequestOptions) -> bool: + return set(options.keys()).issubset( + set(self.batch_request_options_template().keys()) + ) + + def validate_batch_request(self, batch_request: BatchRequest) -> None: + if not ( + batch_request.datasource_name == self.datasource.name + and batch_request.data_asset_name == self.name + and self._valid_batch_request_options(batch_request.options) + ): + expect_batch_request_form = BatchRequest( + datasource_name=self.datasource.name, + data_asset_name=self.name, + options=self.batch_request_options_template(), + ) + raise BatchRequestError( + "BatchRequest should have form:\n" + f"{pf(dataclasses.asdict(expect_batch_request_form))}\n" + f"but actually has form:\n{pf(dataclasses.asdict(batch_request))}\n" + ) + + def batch_request_options_template( self, ) -> BatchRequestOptions: - """A BatchRequestOptions template that can be used when calling get_batch_request. + """A BatchRequestOptions template for get_batch_request. Returns: A BatchRequestOptions dictionary with the correct shape that get_batch_request - will understand. All the option values will be filled in with the placeholder "value". + will understand. All the option values are defaulted to None. """ + template: BatchRequestOptions = {} if not self.column_splitter: - template: BatchRequestOptions = {} return template - params_dict: BatchRequestOptions - params_dict = {p: "" for p in self.column_splitter.template_params} - if self.column_splitter.name: - params_dict = {self.column_splitter.name: params_dict} - return params_dict + return {p: None for p in self.column_splitter.param_names} # This asset type will support a variety of splitters def add_year_and_month_splitter( - self, column_name: str, name: str = "" + self, + column_name: str, + default_year_range: Iterable[int] = _DEFAULT_YEAR_RANGE, + default_month_range: Iterable[int] = _DEFAULT_MONTH_RANGE, ) -> TableAsset: """Associates a year month splitter with this DataAsset Args: column_name: A column name of the date column where year and month will be parsed out. - name: A name for the splitter that will be used to namespace the batch request options. - Leaving this empty, "", will add the options to the global namespace. + default_year_range: When this splitter is used, say in a BatchRequest, if no value for + year is specified, we query over all years in this range. + will query over all the years in this default range. + default_month_range: When this splitter is used, say in a BatchRequest, if no value for + month is specified, we query over all months in this range. Returns: This TableAsset so we can use this method fluently. @@ -112,17 +161,81 @@ def add_year_and_month_splitter( self.column_splitter = ColumnSplitter( method_name="split_on_year_and_month", column_name=column_name, - name=name, - template_params=["year", "month"], + param_defaults={"year": default_year_range, "month": default_month_range}, ) return self + def fully_specified_batch_requests(self, batch_request) -> List[BatchRequest]: + """Populates a batch requests unspecified params producing a list of batch requests + + This method does NOT validate the batch_request. If necessary call + TableAsset.validate_batch_request before calling this method. + """ + if self.column_splitter is None: + # Currently batch_request.options is complete determined by the presence of a + # column splitter. If column_splitter is None, then there are no specifiable options + # so we return early. + # In the future, if there are options that are not determined by the column splitter + # this check will have to be generalized. + return [batch_request] + + # Make a list of the specified and unspecified params in batch_request + specified_options = [] + unspecified_options = [] + options_template = self.batch_request_options_template() + for option_name in options_template.keys(): + if ( + option_name in batch_request.options + and batch_request.options[option_name] is not None + ): + specified_options.append(option_name) + else: + unspecified_options.append(option_name) + + # Make a list of the all possible batch_request.options by expanding out the unspecified + # options + batch_requests: List[BatchRequest] = [] + + if not unspecified_options: + batch_requests.append(batch_request) + else: + # All options are defined by the splitter, so we look at its default values to fill + # in the option values. + default_option_values = [] + for option in unspecified_options: + default_option_values.append( + self.column_splitter.param_defaults[option] + ) + for option_values in itertools.product(*default_option_values): + # Add options from specified options + options = { + name: batch_request.options[name] for name in specified_options + } + # Add options from unspecified options + for i, option_value in enumerate(option_values): + options[unspecified_options[i]] = option_value + batch_requests.append( + BatchRequest( + datasource_name=batch_request.datasource_name, + data_asset_name=batch_request.data_asset_name, + options=options, + ) + ) + return batch_requests + class PostgresDatasource(Datasource): # class var definitions asset_types: ClassVar[List[Type[DataAsset]]] = [TableAsset] def __init__(self, name: str, connection_str: str) -> None: + """Initializes the PostgresDatasource. + + Args: + name: The name of this datasource. + connection_str: The SQLAlchemy connection string used to connect to the database. + For example: "postgresql+psycopg2://postgres:@localhost/test_database" + """ self.name = name self.execution_engine = SqlAlchemyExecutionEngine( connection_string=connection_str @@ -168,55 +281,37 @@ def get_batch_list_from_batch_request( A list of batches that match the options specified in the batch request. """ # We translate the batch_request into a BatchSpec to hook into GX core. - # NOTE: We only produce 1 batch right now data_asset = self.get_asset(batch_request.data_asset_name) - - # We look at the splitters on the data asset and verify that the passed in batch request provides the - # correct arguments to specify the batch - batch_identifiers: Dict[str, Any] = {} - batch_spec_kwargs: Dict[str, Any] = { - "type": "table", - "data_asset_name": data_asset.name, - "table_name": data_asset.table_name, - "batch_identifiers": batch_identifiers, - } - if data_asset.column_splitter: - column_splitter = data_asset.column_splitter - batch_spec_kwargs["splitter_method"] = column_splitter.method_name - batch_spec_kwargs["splitter_kwargs"] = { - "column_name": column_splitter.column_name + data_asset.validate_batch_request(batch_request) + + batch_list: List[Batch] = [] + column_splitter = data_asset.column_splitter + for request in data_asset.fully_specified_batch_requests(batch_request): + batch_spec_kwargs = { + "type": "table", + "data_asset_name": data_asset.name, + "table_name": data_asset.table_name, + "batch_identifiers": {}, } - try: - param_lookup = ( - batch_request.options[column_splitter.name] - if column_splitter.name - else batch_request.options - ) - except KeyError as e: - raise PostgresDatasourceError( - "One must specify the batch request options in this form: " - f"{pf(data_asset.batch_request_template())}. It was specified like {pf(batch_request.options)}" - ) from e - - column_splitter_kwargs = {} - for param_name in column_splitter.template_params: - column_splitter_kwargs[param_name] = ( - param_lookup[param_name] if param_name in param_lookup else None + if column_splitter: + batch_spec_kwargs["splitter_method"] = column_splitter.method_name + batch_spec_kwargs["splitter_kwargs"] = { + "column_name": column_splitter.column_name + } + # mypy infers that batch_spec_kwargs["batch_identifiers"] is a collection, but + # it is hardcoded to a dict above, so we cast it here. + cast(Dict, batch_spec_kwargs["batch_identifiers"]).update( + {column_splitter.column_name: request.options} ) - batch_spec_kwargs["batch_identifiers"].update( - {column_splitter.column_name: column_splitter_kwargs} + data, _ = self.execution_engine.get_batch_data_and_markers( + batch_spec=SqlAlchemyDatasourceBatchSpec(**batch_spec_kwargs) + ) + batch_list.append( + Batch( + datasource=self, + data_asset=data_asset, + batch_request=batch_request, + data=data, ) - - # Now, that we've verified the arguments, we can create the batch_spec and then the batch. - batch_spec = SqlAlchemyDatasourceBatchSpec(**batch_spec_kwargs) - data, _ = self.execution_engine.get_batch_data_and_markers( - batch_spec=batch_spec - ) - return [ - Batch( - datasource=self, - data_asset=data_asset, - batch_request=batch_request, - data=data, ) - ] + return batch_list diff --git a/tests/zep/test_postgres_datasource.py b/tests/zep/test_postgres_datasource.py index 2660e11cf4d0..c6ab1e598dd2 100644 --- a/tests/zep/test_postgres_datasource.py +++ b/tests/zep/test_postgres_datasource.py @@ -10,7 +10,7 @@ SqlAlchemyDatasourceBatchSpec, ) from great_expectations.execution_engine import SqlAlchemyExecutionEngine -from great_expectations.zep.interfaces import BatchRequestOptions +from great_expectations.zep.interfaces import BatchRequest, BatchRequestOptions @contextmanager @@ -42,8 +42,9 @@ def _source() -> postgres_datasource.PostgresDatasource: ) +@pytest.mark.unit def test_construct_postgres_datasource(): - with sqlachemy_execution_engine_mock(lambda x: None): + with sqlachemy_execution_engine_mock(lambda: None): source = _source() assert source.name == "my_datasource" assert isinstance(source.execution_engine, SqlAlchemyExecutionEngine) @@ -60,7 +61,7 @@ def assert_table_asset( assert asset.name == name assert asset.table_name == table_name assert asset.datasource == source - assert asset.batch_request_template() == batch_request_template + assert asset.batch_request_options_template() == batch_request_template def assert_batch_request( @@ -71,103 +72,75 @@ def assert_batch_request( assert batch_request.options == options -@pytest.mark.parametrize( - "config", - [ - # column_name, splitter_name, batch_request_template, batch_request_options - (None, None, {}, {}), - ( - "my_col", - None, - {"year": "", "month": ""}, - {"year": 2021, "month": 10}, - ), - ( - "my_col", - "", - {"year": "", "month": ""}, - {"year": 2021, "month": 10}, - ), - ( - "my_col", - "mysplitter", - {"mysplitter": {"year": "", "month": ""}}, - {"mysplitter": {"year": 2021, "month": 10}}, - ), - ], -) -def test_add_table_asset(config): - with sqlachemy_execution_engine_mock(lambda x: None): +@pytest.mark.unit +def test_add_table_asset_with_splitter(): + with sqlachemy_execution_engine_mock(lambda: None): source = _source() - ( - splitter_col, - splitter_name, - batch_request_template, - batch_request_options, - ) = config asset = source.add_table_asset(name="my_asset", table_name="my_table") - if batch_request_template: - kwargs = {"column_name": splitter_col} - if splitter_name is not None: - kwargs["name"] = splitter_name - asset.add_year_and_month_splitter(**kwargs) + asset.add_year_and_month_splitter("my_column") assert len(source.assets) == 1 - asset = list(source.assets.values())[0] + assert asset == list(source.assets.values())[0] assert_table_asset( - asset, "my_asset", "my_table", source, batch_request_template + asset=asset, + name="my_asset", + table_name="my_table", + source=source, + batch_request_template={"year": None, "month": None}, ) assert_batch_request( - asset.get_batch_request(batch_request_options), - "my_datasource", - "my_asset", - batch_request_options, + batch_request=asset.get_batch_request({"year": 2021, "month": 10}), + source_name="my_datasource", + asset_name="my_asset", + options={"year": 2021, "month": 10}, ) -def test_construct_table_asset_directly_with_no_splitter(): - with sqlachemy_execution_engine_mock(lambda x: None): +@pytest.mark.unit +def test_add_table_asset_with_no_splitter(): + with sqlachemy_execution_engine_mock(lambda: None): source = _source() - asset = postgres_datasource.TableAsset( - name="my_asset", table_name="my_table", datasource=source - ) - assert_batch_request(asset.get_batch_request(), "my_datasource", "my_asset", {}) - - -def test_construct_table_asset_directly_with_nameless_splitter(): - with sqlachemy_execution_engine_mock(lambda x: None): - source = _source() - splitter = postgres_datasource.ColumnSplitter( - method_name="splitter_method", - column_name="col", - template_params=["a", "b"], - name="", - ) - asset = postgres_datasource.TableAsset( + asset = source.add_table_asset(name="my_asset", table_name="my_table") + assert len(source.assets) == 1 + assert asset == list(source.assets.values())[0] + assert_table_asset( + asset=asset, name="my_asset", table_name="my_table", - datasource=source, - column_splitter=splitter, + source=source, + batch_request_template={}, ) - assert_table_asset( - asset, "my_asset", "my_table", source, {"a": "", "b": ""} + assert_batch_request( + batch_request=asset.get_batch_request(), + source_name="my_datasource", + asset_name="my_asset", + options={}, ) - batch_request_options = {"a": 1, "b": 2} assert_batch_request( - asset.get_batch_request(batch_request_options), - "my_datasource", - "my_asset", - batch_request_options, + batch_request=asset.get_batch_request({}), + source_name="my_datasource", + asset_name="my_asset", + options={}, + ) + + +@pytest.mark.unit +def test_construct_table_asset_directly_with_no_splitter(): + with sqlachemy_execution_engine_mock(lambda: None): + source = _source() + asset = postgres_datasource.TableAsset( + name="my_asset", table_name="my_table", datasource=source ) + assert_batch_request(asset.get_batch_request(), "my_datasource", "my_asset", {}) -def test_construct_table_asset_directly_with_named_splitter(): - with sqlachemy_execution_engine_mock(lambda x: None): +@pytest.mark.unit +def test_construct_table_asset_directly_with_splitter(): + with sqlachemy_execution_engine_mock(lambda: None): source = _source() splitter = postgres_datasource.ColumnSplitter( method_name="splitter_method", column_name="col", - template_params=["a", "b"], - name="splitter", + param_defaults={"a": [1, 2, 3], "b": range(1, 13)}, ) asset = postgres_datasource.TableAsset( name="my_asset", @@ -180,9 +153,9 @@ def test_construct_table_asset_directly_with_named_splitter(): "my_asset", "my_table", source, - {"splitter": {"a": "", "b": ""}}, + {"a": None, "b": None}, ) - batch_request_options = {"splitter": {"a": 1, "b": 2}} + batch_request_options = {"a": 1, "b": 2} assert_batch_request( asset.get_batch_request(batch_request_options), "my_datasource", @@ -191,6 +164,7 @@ def test_construct_table_asset_directly_with_named_splitter(): ) +@pytest.mark.unit def test_datasource_gets_batch_list_no_splitter(): def validate_batch_spec(spec: SqlAlchemyDatasourceBatchSpec) -> None: assert spec == { @@ -206,45 +180,91 @@ def validate_batch_spec(spec: SqlAlchemyDatasourceBatchSpec) -> None: source.get_batch_list_from_batch_request(asset.get_batch_request()) -def test_datasource_gets_batch_list_splitter_no_values(): - def validate_batch_spec(spec: SqlAlchemyDatasourceBatchSpec) -> None: - assert spec == { - "batch_identifiers": {"my_col": {"month": None, "year": None}}, - "data_asset_name": "my_asset", - "splitter_kwargs": {"column_name": "my_col"}, - "splitter_method": "split_on_year_and_month", - "table_name": "my_table", - "type": "table", - } +def assert_batch_specs_correct_with_year_month_splitter_defaults(batch_specs): + # We should have 1 batch_spec per (year, month) pair + expected_batch_spec_num = len(list(postgres_datasource._DEFAULT_YEAR_RANGE)) * len( + list(postgres_datasource._DEFAULT_MONTH_RANGE) + ) + assert len(batch_specs) == expected_batch_spec_num + for year in postgres_datasource._DEFAULT_YEAR_RANGE: + for month in postgres_datasource._DEFAULT_MONTH_RANGE: + spec = { + "type": "table", + "data_asset_name": "my_asset", + "table_name": "my_table", + "batch_identifiers": {"my_col": {"year": year, "month": month}}, + "splitter_method": "split_on_year_and_month", + "splitter_kwargs": {"column_name": "my_col"}, + } + assert spec in batch_specs + + +@pytest.mark.unit +def test_datasource_gets_batch_list_splitter_with_unspecified_batch_request_options(): + batch_specs = [] + + def collect_batch_spec(spec: SqlAlchemyDatasourceBatchSpec) -> None: + batch_specs.append(spec) + + with sqlachemy_execution_engine_mock(collect_batch_spec): + source = _source() + asset = source.add_table_asset(name="my_asset", table_name="my_table") + asset.add_year_and_month_splitter(column_name="my_col") + empty_batch_request = asset.get_batch_request() + assert empty_batch_request.options == {} + source.get_batch_list_from_batch_request(empty_batch_request) + assert_batch_specs_correct_with_year_month_splitter_defaults(batch_specs) - with sqlachemy_execution_engine_mock(validate_batch_spec): + +@pytest.mark.unit +def test_datasource_gets_batch_list_splitter_with_batch_request_options_set_to_none(): + batch_specs = [] + + def collect_batch_spec(spec: SqlAlchemyDatasourceBatchSpec) -> None: + batch_specs.append(spec) + + with sqlachemy_execution_engine_mock(collect_batch_spec): source = _source() asset = source.add_table_asset(name="my_asset", table_name="my_table") asset.add_year_and_month_splitter(column_name="my_col") - source.get_batch_list_from_batch_request(asset.get_batch_request()) + batch_request_with_none = asset.get_batch_request( + asset.batch_request_options_template() + ) + assert batch_request_with_none.options == {"year": None, "month": None} + source.get_batch_list_from_batch_request(batch_request_with_none) + # We should have 1 batch_spec per (year, month) pair + assert_batch_specs_correct_with_year_month_splitter_defaults(batch_specs) -def test_datasource_gets_batch_list_splitter_some_values(): - def validate_batch_spec(spec: SqlAlchemyDatasourceBatchSpec) -> None: - assert spec == { - "batch_identifiers": {"my_col": {"month": None, "year": 2022}}, - "data_asset_name": "my_asset", - "splitter_kwargs": {"column_name": "my_col"}, - "splitter_method": "split_on_year_and_month", - "table_name": "my_table", - "type": "table", - } +@pytest.mark.unit +def test_datasource_gets_batch_list_splitter_with_partially_specified_batch_request_options(): + batch_specs = [] - with sqlachemy_execution_engine_mock(validate_batch_spec): + def collect_batch_spec(spec: SqlAlchemyDatasourceBatchSpec) -> None: + batch_specs.append(spec) + + with sqlachemy_execution_engine_mock(collect_batch_spec): source = _source() asset = source.add_table_asset(name="my_asset", table_name="my_table") asset.add_year_and_month_splitter(column_name="my_col") source.get_batch_list_from_batch_request( asset.get_batch_request({"year": 2022}) ) - - -def test_datasource_gets_batch_list_unnamed_splitter(): + assert len(batch_specs) == 12 + for month in postgres_datasource._DEFAULT_MONTH_RANGE: + spec = { + "type": "table", + "data_asset_name": "my_asset", + "table_name": "my_table", + "batch_identifiers": {"my_col": {"year": 2022, "month": month}}, + "splitter_method": "split_on_year_and_month", + "splitter_kwargs": {"column_name": "my_col"}, + } + assert spec in batch_specs + + +@pytest.mark.unit +def test_datasource_gets_batch_list_with_fully_specified_batch_request_options(): def validate_batch_spec(spec: SqlAlchemyDatasourceBatchSpec) -> None: assert spec == { "batch_identifiers": {"my_col": {"month": 1, "year": 2022}}, @@ -264,49 +284,95 @@ def validate_batch_spec(spec: SqlAlchemyDatasourceBatchSpec) -> None: ) -def test_datasource_gets_batch_list_named_splitter(): - def validate_batch_spec(spec: SqlAlchemyDatasourceBatchSpec) -> None: - assert spec == { - "batch_identifiers": {"my_col": {"month": 1, "year": 2022}}, - "data_asset_name": "my_asset", - "splitter_kwargs": {"column_name": "my_col"}, - "splitter_method": "split_on_year_and_month", - "table_name": "my_table", - "type": "table", - } +@pytest.mark.unit +def test_datasource_gets_nonexistent_asset(): + with sqlachemy_execution_engine_mock(lambda: None): + source = _source() + with pytest.raises(postgres_datasource.PostgresDatasourceError): + source.get_asset("my_asset") - with sqlachemy_execution_engine_mock(validate_batch_spec): + +@pytest.mark.unit +@pytest.mark.parametrize( + "batch_request_args", + [ + ("bad", None, None), + (None, "bad", None), + (None, None, {"bad": None}), + ("bad", "bad", None), + ], +) +def test_bad_batch_request_passed_into_get_batch_list_from_batch_request( + batch_request_args, +): + with sqlachemy_execution_engine_mock(lambda: None): source = _source() asset = source.add_table_asset(name="my_asset", table_name="my_table") - asset.add_year_and_month_splitter(column_name="my_col", name="my_splitter") - source.get_batch_list_from_batch_request( - asset.get_batch_request({"my_splitter": {"month": 1, "year": 2022}}) + asset.add_year_and_month_splitter(column_name="my_col") + + src, ast, op = batch_request_args + batch_request = BatchRequest( + datasource_name=src or source.name, + data_asset_name=ast or asset.name, + options=op or {}, ) + with pytest.raises( + ( + postgres_datasource.BatchRequestError, + postgres_datasource.PostgresDatasourceError, + ) + ): + source.get_batch_list_from_batch_request(batch_request) -def test_datasource_gets_batch_list_using_invalid_splitter_name(): - with sqlachemy_execution_engine_mock(lambda x: None): +@pytest.mark.unit +@pytest.mark.parametrize( + "batch_request_options", + [{}, {"year": 2021}, {"year": 2021, "month": 10}, {"year": None, "month": 10}], +) +def test_validate_good_batch_request(batch_request_options): + with sqlachemy_execution_engine_mock(lambda: None): source = _source() asset = source.add_table_asset(name="my_asset", table_name="my_table") - asset.add_year_and_month_splitter(column_name="my_col", name="splitter_name") - assert_table_asset( - asset, - "my_asset", - "my_table", - source, - {"splitter_name": {"month": "", "year": ""}}, + asset.add_year_and_month_splitter(column_name="my_col") + batch_request = BatchRequest( + datasource_name=source.name, + data_asset_name=asset.name, + options=batch_request_options, ) - with pytest.raises(postgres_datasource.PostgresDatasourceError): - # This raises because we've named the splitter but we didn't specify the name in the batch request options - # The batch_request_options should look like {"splitter": {"month": 1, "year": 2}} but instead looks like - # {"month": 1, "year": 2} - source.get_batch_list_from_batch_request( - asset.get_batch_request({"month": 1, "year": 2}) - ) + # No exception should get thrown + asset.validate_batch_request(batch_request) -def test_datasource_gets_nonexistent_asset(): - with sqlachemy_execution_engine_mock(lambda x: None): +@pytest.mark.unit +@pytest.mark.parametrize( + "batch_request_args", + [ + ("bad", None, None), + (None, "bad", None), + (None, None, {"bad": None}), + ("bad", "bad", None), + ], +) +def test_validate_malformed_batch_request(batch_request_args): + with sqlachemy_execution_engine_mock(lambda: None): source = _source() - with pytest.raises(postgres_datasource.PostgresDatasourceError): - source.get_asset("my_asset") + asset = source.add_table_asset(name="my_asset", table_name="my_table") + asset.add_year_and_month_splitter(column_name="my_col") + src, ast, op = batch_request_args + batch_request = BatchRequest( + datasource_name=src or source.name, + data_asset_name=ast or asset.name, + options=op or {}, + ) + with pytest.raises(postgres_datasource.BatchRequestError): + asset.validate_batch_request(batch_request) + + +def test_get_bad_batch_request(): + with sqlachemy_execution_engine_mock(lambda: None): + source = _source() + asset = source.add_table_asset(name="my_asset", table_name="my_table") + asset.add_year_and_month_splitter(column_name="my_col") + with pytest.raises(postgres_datasource.BatchRequestError): + asset.get_batch_request({"invalid_key": None})