diff --git a/CHANGELOG.md b/CHANGELOG.md index ddac57e28..fef5e1d4b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added +- Conflict handler option to the `dict_deep_merge` function in `utils.py` +- Ability to add module-specific pytest fixtures +- Added fixtures specifically for testing status functionality +- Added tests for reading and writing status files, and status conflict handling +- Added tests for the `dict_deep_merge` function +- Pytest-mock as a dependency for the test suite (necessary for using mocks and fixtures in the same test) ### Changed diff --git a/merlin/common/tasks.py b/merlin/common/tasks.py index 33afb3316..980493ae8 100644 --- a/merlin/common/tasks.py +++ b/merlin/common/tasks.py @@ -49,7 +49,7 @@ from merlin.exceptions import HardFailException, InvalidChainException, RestartException, RetryException from merlin.router import stop_workers from merlin.spec.expansion import parameter_substitutions_for_cmd, parameter_substitutions_for_sample -from merlin.study.status import read_status +from merlin.study.status import read_status, status_conflict_handler from merlin.utils import dict_deep_merge @@ -484,7 +484,7 @@ def gather_statuses( # Make sure the status for this sample workspace is in a finished state (not initialized or running) if status[step_name][f"{condensed_workspace}/{path}"]["status"] not in ("INITIALIZED", "RUNNING"): # Add the status data to the statuses we'll write to the condensed file and remove this status file - dict_deep_merge(condensed_statuses, status) + dict_deep_merge(condensed_statuses, status, conflict_handler=status_conflict_handler) files_to_remove.append(status_filepath) files_to_remove.append(lock_filepath) # Remove the lock files as well as the status files except KeyError: @@ -556,7 +556,7 @@ def condense_status_files(self, *args: Any, **kwargs: Any) -> ReturnCode: # pyl existing_condensed_statuses = json.load(condensed_status_file) # Merging the statuses we're condensing into the already existing statuses # because it's faster at scale than vice versa - dict_deep_merge(existing_condensed_statuses, condensed_statuses) + dict_deep_merge(existing_condensed_statuses, condensed_statuses, conflict_handler=status_conflict_handler) condensed_statuses = existing_condensed_statuses # Write the condensed statuses to the condensed status file diff --git a/merlin/exceptions/__init__.py b/merlin/exceptions/__init__.py index 89fe89a13..71ef8ba93 100644 --- a/merlin/exceptions/__init__.py +++ b/merlin/exceptions/__init__.py @@ -42,7 +42,6 @@ "HardFailException", "InvalidChainException", "RestartException", - "DeepMergeException", "NoWorkersException", ) @@ -96,16 +95,6 @@ def __init__(self): super().__init__() -class DeepMergeException(Exception): - """ - Exception to signal that there's a conflict when trying - to merge two dicts together - """ - - def __init__(self, message): - super().__init__(message) - - class NoWorkersException(Exception): """ Exception to signal that no workers were started diff --git a/merlin/study/status.py b/merlin/study/status.py index d11a403e2..507e810c2 100644 --- a/merlin/study/status.py +++ b/merlin/study/status.py @@ -37,10 +37,11 @@ from datetime import datetime from glob import glob from traceback import print_exception -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np from filelock import FileLock, Timeout +from maestrowf.utils import get_duration from tabulate import tabulate from merlin.common.dumper import dump_handler @@ -363,7 +364,7 @@ def get_step_statuses(self, step_workspace: str, started_step_name: str) -> Dict statuses_read = read_status(status_filepath, f"{root}/status.lock") # Merge the statuses we read with the dict tracking all statuses for this step - dict_deep_merge(step_statuses, statuses_read) + dict_deep_merge(step_statuses, statuses_read, conflict_handler=status_conflict_handler) # Add full step name to the tracker and count number of statuses we just read in for full_step_name, status_info in statuses_read.items(): @@ -391,7 +392,7 @@ def load_requested_statuses(self): for sstep in self.step_tracker["started_steps"]: step_workspace = f"{self.workspace}/{sstep}" step_statuses = self.get_step_statuses(step_workspace, sstep) - dict_deep_merge(self.requested_statuses, step_statuses) + dict_deep_merge(self.requested_statuses, step_statuses, conflict_handler=status_conflict_handler) # Calculate run time average and standard deviation for this step self.get_runtime_avg_std_dev(step_statuses, sstep) @@ -531,8 +532,9 @@ def format_status_for_csv(self) -> Dict: # Loop through information for each step for step_info_key, step_info_value in overall_step_info.items(): - # Skip the workers entry at the top level; this will be added in the else statement below on a task-by-task basis - if step_info_key == "workers": + # Skip the workers entry at the top level; + # this will be added in the else statement below on a task-by-task basis + if step_info_key in ("workers", "worker_name"): continue # Format task queue entry if step_info_key == "task_queue": @@ -833,13 +835,15 @@ def apply_filters(self): filtered_statuses = {} for step_name, overall_step_info in self.requested_statuses.items(): filtered_statuses[step_name] = {} - # Add the non-workspace keys to the filtered_status dict so we don't accidentally miss any of this information while filtering + # Add the non-workspace keys to the filtered_status dict so we + # don't accidentally miss any of this information while filtering for non_ws_key in NON_WORKSPACE_KEYS: try: filtered_statuses[step_name][non_ws_key] = overall_step_info[non_ws_key] except KeyError: LOG.debug( - f"Tried to add {non_ws_key} to filtered_statuses dict but it was not found in requested_statuses[{step_name}]" + f"Tried to add {non_ws_key} to filtered_statuses dict " + f"but it was not found in requested_statuses[{step_name}]" ) # Go through the actual statuses and filter them as necessary @@ -916,7 +920,7 @@ def apply_max_tasks_limit(self): self.args.max_tasks -= len(sub_step_workspaces) # Merge in the task statuses that we're allowing - dict_deep_merge(new_status_dict[step_name], overall_step_info) + dict_deep_merge(new_status_dict[step_name], overall_step_info, conflict_handler=status_conflict_handler) LOG.info(f"Limited the number of tasks to display to {max_tasks} tasks.") @@ -1099,6 +1103,95 @@ def display(self, test_mode: Optional[bool] = False): LOG.warning("No statuses to display.") +# Pylint complains that args is unused but we can ignore that +def status_conflict_handler(*args, **kwargs) -> Any: # pylint: disable=W0613 + """ + The conflict handler function to apply to any status entries that have conflicting + values while merging two status files together. + + kwargs should include: + - dict_a_val: The conflicting value from the dictionary that we're merging into + - dict_b_val: The conflicting value from the dictionary that we're pulling from + - key: The key into each dictionary that has a conflict + - path: The path down the dictionary tree that `dict_deep_merge` is currently at + + When we're reading in status files, we're merging all of the statuses into one dictionary. + This function defines the merge rules in case there is a merge conflict. We ignore the list + and dictionary entries since `dict_deep_merge` from `utils.py` handles these scenarios already. + + There are currently 4 rules: + - string-concatenate: take the two conflicting values and concatenate them in a string + - use-initial-and-log-warning: use the value from dict_a and log a warning message + - use-longest-time: use the longest time between the two conflicting values + - use-max: use the larger integer between the two conflicting values + + :returns: The value to merge into dict_a at `key` + """ + # Grab the arguments passed into this function + dict_a_val = kwargs.get("dict_a_val", None) + dict_b_val = kwargs.get("dict_b_val", None) + key = kwargs.get("key", None) + path = kwargs.get("path", None) + + merge_rules = { + "task_queue": "string-concatenate", + "worker_name": "string-concatenate", + "status": "use-initial-and-log-warning", + "return_code": "use-initial-and-log-warning", + "elapsed_time": "use-longest-time", + "run_time": "use-longest-time", + "restarts": "use-max", + } + + # TODO + # - make status tracking more modular (see https://lc.llnl.gov/gitlab/weave/merlin/-/issues/58) + # - once it's more modular, move the below code and the above merge_rules dict to a property in + # one of the new status classes (the one that has condensing maybe? or upstream from that?) + + # params = self.spec.get_parameters() + # for token in params.parameters: + # merge_rules[token] = "use-initial-and-log-warning" + + # Set parameter token key rules (commented for loop would be better but it's + # only possible if this conflict handler is contained within Status object; however, + # since this function needs to be imported outside of this file we can't do that) + if path is not None and "parameters" in path: + merge_rules[key] = "use-initial-and-log-warning" + + try: + merge_rule = merge_rules[key] + except KeyError: + LOG.warning(f"The key '{key}' does not have a merge rule defined. Setting this merge to None.") + return None + + merge_val = None + + if merge_rule == "string-concatenate": + merge_val = f"{dict_a_val}, {dict_b_val}" + elif merge_rule == "use-initial-and-log-warning": + LOG.warning( + f"Conflict at key '{key}' while merging status files. Defaulting to initial value. " + "This could lead to incorrect status information, you may want to re-run in debug mode and " + "check the files in the output directory for this task." + ) + merge_val = dict_a_val + elif merge_rule == "use-longest-time": + if dict_a_val == "--:--:--": + merge_val = dict_b_val + elif dict_b_val == "--:--:--": + merge_val = dict_a_val + else: + dict_a_time = convert_to_timedelta(dict_a_val) + dict_b_time = convert_to_timedelta(dict_b_val) + merge_val = get_duration(max(dict_a_time, dict_b_time)) + elif merge_rule == "use-max": + merge_val = max(dict_a_val, dict_b_val) + else: + LOG.warning(f"The merge_rule '{merge_rule}' was provided but it has no implementation.") + + return merge_val + + def read_status( status_filepath: str, lock_file: str, display_fnf_message: bool = True, raise_errors: bool = False, timeout: int = 10 ) -> Dict: @@ -1112,6 +1205,8 @@ def read_status( :param timeout: An integer representing how long to hold a lock for before timing out. :returns: A dict of the contents in the status file """ + statuses_read = {} + # Pylint complains that we're instantiating an abstract class but this is correct usage lock = FileLock(lock_file) # pylint: disable=abstract-class-instantiated try: @@ -1122,25 +1217,24 @@ def read_status( # Handle timeouts except Timeout as to_exc: LOG.warning(f"Timed out when trying to read status from '{status_filepath}'") - statuses_read = {} if raise_errors: - raise Timeout from to_exc + raise to_exc # Handle FNF errors except FileNotFoundError as fnf_exc: if display_fnf_message: LOG.warning(f"Could not find '{status_filepath}'") - statuses_read = {} if raise_errors: - raise FileNotFoundError from fnf_exc + raise fnf_exc # Handle JSONDecode errors (this is likely due to an empty status file) except json.decoder.JSONDecodeError as json_exc: LOG.warning(f"JSONDecodeError raised when trying to read status from '{status_filepath}'") if raise_errors: - raise json.decoder.JSONDecodeError from json_exc + raise json_exc # Catch all exceptions so that we don't crash the workers except Exception as exc: # pylint: disable=broad-except LOG.warning( - f"An exception was raised while trying to read status from '{status_filepath}'!\n{print_exception(type(exc), exc, exc.__traceback__)}" + f"An exception was raised while trying to read status from '{status_filepath}'!\n" + f"{print_exception(type(exc), exc, exc.__traceback__)}" ) if raise_errors: raise exc @@ -1167,5 +1261,6 @@ def write_status(status_to_write: Dict, status_filepath: str, lock_file: str, ti # Catch all exceptions so that we don't crash the workers except Exception as exc: # pylint: disable=broad-except LOG.warning( - f"An exception was raised while trying to write status to '{status_filepath}'!\n{print_exception(type(exc), exc, exc.__traceback__)}" + f"An exception was raised while trying to write status to '{status_filepath}'!\n" + f"{print_exception(type(exc), exc, exc.__traceback__)}" ) diff --git a/merlin/utils.py b/merlin/utils.py index 070638c38..91b631263 100644 --- a/merlin/utils.py +++ b/merlin/utils.py @@ -41,14 +41,12 @@ from copy import deepcopy from datetime import datetime, timedelta from types import SimpleNamespace -from typing import List, Optional, Union +from typing import Callable, List, Optional, Union import numpy as np import psutil import yaml -from merlin.exceptions import DeepMergeException - try: import cPickle as pickle @@ -559,33 +557,51 @@ def needs_merlin_expansion( return False -def dict_deep_merge(dict_a, dict_b, path=None): +def dict_deep_merge(dict_a: dict, dict_b: dict, path: str = None, conflict_handler: Callable = None): """ This function recursively merges dict_b into dict_a. The built-in merge of dictionaries in python (dict(dict_a) | dict(dict_b)) does not do a deep merge so this function is necessary. This will only merge in new keys, - it will NOT update existing ones. + it will NOT update existing ones, unless you specify a conflict handler function. Credit to this stack overflow post: https://stackoverflow.com/a/7205107. :param `dict_a`: A dict that we'll merge dict_b into :param `dict_b`: A dict that we want to merge into dict_a :param `path`: The path down the dictionary tree that we're currently at + :param `conflict_handler`: An optional function to handle conflicts between values at the same key. + The function should return the value to be used in the merged dictionary. + The default behavior without this argument is to log a warning. """ + + # Check to make sure we have valid dict_a and dict_b input + msgs = [ + f"{name} '{actual_dict}' is not a dict" + for name, actual_dict in [("dict_a", dict_a), ("dict_b", dict_b)] + if not isinstance(actual_dict, dict) + ] + if len(msgs) > 0: + LOG.warning(f"Problem with dict_deep_merge: {', '.join(msgs)}. Ignoring this merge call.") + return + if path is None: path = [] for key in dict_b: if key in dict_a: if isinstance(dict_a[key], dict) and isinstance(dict_b[key], dict): - dict_deep_merge(dict_a[key], dict_b[key], path + [str(key)]) - elif key == "workers": # specifically for status merging - all_workers = [dict_a[key], dict_b[key]] - dict_a[key] = list(set().union(*all_workers)) + dict_deep_merge(dict_a[key], dict_b[key], path=path + [str(key)], conflict_handler=conflict_handler) elif isinstance(dict_a[key], list) and isinstance(dict_a[key], list): dict_a[key] += dict_b[key] elif dict_a[key] == dict_b[key]: pass # same leaf value else: - raise DeepMergeException(f"Conflict at {'.'.join(path + [str(key)])}") + if conflict_handler is not None: + merged_val = conflict_handler( + dict_a_val=dict_a[key], dict_b_val=dict_b[key], key=key, path=path + [str(key)] + ) + dict_a[key] = merged_val + else: + # Want to just output a warning instead of raising an exception so that the workflow doesn't crash + LOG.warning(f"Conflict at {'.'.join(path + [str(key)])}. Ignoring the update to key '{key}'.") else: dict_a[key] = dict_b[key] @@ -619,6 +635,11 @@ def convert_to_timedelta(timestr: Union[str, int]) -> timedelta: """ # make sure it's a string in case we get an int timestr = str(timestr) + + # remove time unit characters (if any exist) + time_unit_chars = r"[dhms]" + timestr = re.sub(time_unit_chars, "", timestr) + nfields = len(timestr.split(":")) if nfields > 4: raise ValueError(f"Cannot convert {timestr} to a timedelta. Valid format: days:hours:minutes:seconds.") diff --git a/requirements/dev.txt b/requirements/dev.txt index 6e8722b4b..3695c6164 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -12,3 +12,4 @@ alabaster johnnydep deepdiff pytest-order +pytest-mock diff --git a/tests/README.md b/tests/README.md index a6bf7005a..22efc5470 100644 --- a/tests/README.md +++ b/tests/README.md @@ -58,17 +58,47 @@ not connected> quit ## The Fixture Process Explained -Pytest fixtures play a fundamental role in establishing a consistent foundation for test execution, -thus ensuring reliable and predictable test outcomes. This section will delve into essential aspects -of these fixtures, including how to integrate fixtures into tests, the utilization of fixtures within other fixtures, -their scope, and the yielding of fixture results. +In the world of pytest testing, fixtures are like the building blocks that create a sturdy foundation for your tests. +They ensure that every test starts from the same fresh ground, leading to reliable and consistent results. This section +will dive into the nitty-gritty of these fixtures, showing you how they're architected in this test suite, how to use +them in your tests here, how to combine them for more complex scenarios, how long they stick around during testing, and +what it means to yield a fixture. + +### Fixture Architecture + +Fixtures can be defined in two locations: + +1. `tests/conftest.py`: This file located at the root of the test suite houses common fixtures that are utilized +across various test modules +2. `tests/fixtures/`: This directory contains specific test module fixtures. Each fixture file is named according +to the module(s) that the fixtures defined within are for. + +Credit for this setup must be given to [this Medium article](https://medium.com/@nicolaikozel/modularizing-pytest-fixtures-fd40315c5a93). + +#### Fixture Naming Conventions + +For fixtures defined within the `tests/fixtures/` directory, the fixture name should be prefixed by the name of the +fixture file they are defined in. + +#### Importing Fixtures as Plugins + +Fixtures located in the `tests/fixtures/` directory are technically plugins. Therefore, to use them we must +register them as plugins within the `conftest.py` file (see the top of said file for the implementation). +This allows them to be discovered and used by test modules throughout the suite. + +**You do not have to register the fixtures you define as plugins in `conftest.py` since the registration there +uses `glob` to grab everything from the `tests/fixtures/` directory automatically.** ### How to Integrate Fixtures Into Tests Probably the most important part of fixtures is understanding how to use them. Luckily, this process is very -simple and can be dumbed down to 2 steps: +simple and can be dumbed down to just a couple steps: + +0. **[Module-specific fixtures only]** If you're creating a module-specific fixture (i.e. a fixture that won't be used throughout the entire test +suite), then create a file in the `tests/fixtures/` directory. -1. Create a fixture in the `conftest.py` file by using the `@pytest.fixture` decorator. For example: +1. Create a fixture in either the `conftest.py` file or the file you created in the `tests/fixtures/` directory +by using the `@pytest.fixture` decorator. For example: ``` @pytest.fixture @@ -131,10 +161,10 @@ scopes come to save the day. ### Fixture Scopes -There are several different scopes that you can set for fixtures. The majority of our fixtures use a `session` -scope so that we only have to create the fixtures one time (as some of them can take a few seconds to set up). -The goal is to create fixtures with the most general use-case in mind so that we can re-use them for larger -scopes, which helps with efficiency. +There are several different scopes that you can set for fixtures. The majority of our fixtures in `conftest.py` +use a `session` scope so that we only have to create the fixtures one time (as some of them can take a few seconds +to set up). The goal is to create fixtures with the most general use-case in mind so that we can re-use them for +larger scopes, which helps with efficiency. For more info on scopes, see [Pytest's Fixture Scope documentation](https://docs.pytest.org/en/6.2.x/fixture.html#scope-sharing-fixtures-across-classes-modules-packages-or-session). diff --git a/tests/conftest.py b/tests/conftest.py index e180ad910..e795e1836 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -33,6 +33,7 @@ """ import os import subprocess +from glob import glob from time import sleep from typing import Dict @@ -45,6 +46,14 @@ from tests.celery_test_workers import CeleryTestWorkersManager +####################################### +# Loading in Module Specific Fixtures # +####################################### +pytest_plugins = [ + fixture_file.replace("/", ".").replace(".py", "") for fixture_file in glob("tests/fixtures/[!__]*.py", recursive=True) +] + + class RedisServerError(Exception): """ Exception to signal that the server wasn't pinged properly. diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py new file mode 100644 index 000000000..ab3e56590 --- /dev/null +++ b/tests/fixtures/__init__.py @@ -0,0 +1,16 @@ +""" +This directory is for help modularizing fixture definitions so that we don't have to +store every single fixture in the `conftest.py` file. + +Fixtures must start with the same name as the file they're defined in. For instance, +if our fixture file was named `example.py` then our fixtures in this file would have +to start with "example_": + +```title="example.py" +import pytest + +@pytest.fixture +def example_test_data(): + return {"key": "val"} +``` +""" diff --git a/tests/fixtures/status.py b/tests/fixtures/status.py new file mode 100644 index 000000000..ab0de5d1e --- /dev/null +++ b/tests/fixtures/status.py @@ -0,0 +1,37 @@ +""" +Fixtures specifically for help testing the functionality related to +status/detailed-status. +""" + +import os +from pathlib import Path + +import pytest + + +@pytest.fixture(scope="class") +def status_testing_dir(temp_output_dir: str) -> str: + """ + A pytest fixture to set up a temporary directory to write files to for testing status. + + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + """ + testing_dir = f"{temp_output_dir}/status_testing/" + if not os.path.exists(testing_dir): + os.mkdir(testing_dir) + + return testing_dir + + +@pytest.fixture(scope="class") +def status_empty_file(status_testing_dir: str) -> str: # pylint: disable=W0621 + """ + A pytest fixture to create an empty status file. + + :param status_testing_dir: A pytest fixture that defines a path to the the output directory we'll write to + """ + empty_file = Path(f"{status_testing_dir}/empty_status.json") + if not empty_file.exists(): + empty_file.touch() + + return empty_file diff --git a/tests/unit/study/test_detailed_status.py b/tests/unit/study/test_detailed_status.py index 8c7f0f600..4db32a7af 100644 --- a/tests/unit/study/test_detailed_status.py +++ b/tests/unit/study/test_detailed_status.py @@ -212,16 +212,31 @@ def test_json_dump_with_filters(self): dump functionalities. The file needs to exist already for an append so it's better to keep these tests together. """ - # Set filters for failed and cancelled tasks, and then reload the requested_statuses - self.detailed_status_obj.args.task_status = ["FAILED", "CANCELLED"] - self.detailed_status_obj.load_requested_statuses() - - # Set the dump file - json_dump_file = f"{status_test_variables.PATH_TO_TEST_FILES}/detailed_dump_test.json" - self.detailed_status_obj.args.dump = json_dump_file + # Need to create a new DetailedStatus object so that filters are loaded from the beginning + args = Namespace( + subparsers="detailed-status", + level="INFO", + detailed=True, + output_path=None, + task_server="celery", + dump=f"{status_test_variables.PATH_TO_TEST_FILES}/detailed_dump_test.json", # Set the dump file + no_prompts=True, + max_tasks=None, + return_code=None, + steps=["all"], + task_queues=None, + task_status=["FAILED", "CANCELLED"], # Set filters for failed and cancelled tasks + workers=None, + disable_pager=True, + disable_theme=False, + layout="default", + ) + detailed_status_obj = DetailedStatus( + args=args, spec_display=False, file_or_ws=status_test_variables.VALID_WORKSPACE_PATH + ) # Run the json dump test (we should only get failed and cancelled statuses) - shared_tests.run_json_dump_test(self.detailed_status_obj, status_test_variables.REQUESTED_STATUSES_FAIL_AND_CANCEL) + shared_tests.run_json_dump_test(detailed_status_obj, status_test_variables.REQUESTED_STATUSES_FAIL_AND_CANCEL) def test_csv_dump_with_filters(self): """ @@ -229,17 +244,32 @@ def test_csv_dump_with_filters(self): dump functionalities. The file needs to exist already for an append so it's better to keep these tests together. """ - # Set filters for failed and cancelled tasks, and then reload the requested_statuses - self.detailed_status_obj.args.task_status = ["FAILED", "CANCELLED"] - self.detailed_status_obj.load_requested_statuses() - - # Set the dump file - csv_dump_file = f"{status_test_variables.PATH_TO_TEST_FILES}/detailed_dump_test.csv" - self.detailed_status_obj.args.dump = csv_dump_file + # Need to create a new DetailedStatus object so that filters are loaded from the beginning + args = Namespace( + subparsers="detailed-status", + level="INFO", + detailed=True, + output_path=None, + task_server="celery", + dump=f"{status_test_variables.PATH_TO_TEST_FILES}/detailed_dump_test.csv", # Set the dump file + no_prompts=True, + max_tasks=None, + return_code=None, + steps=["all"], + task_queues=None, + task_status=["FAILED", "CANCELLED"], # Set filters for failed and cancelled tasks + workers=None, + disable_pager=True, + disable_theme=False, + layout="default", + ) + detailed_status_obj = DetailedStatus( + args=args, spec_display=False, file_or_ws=status_test_variables.VALID_WORKSPACE_PATH + ) # Run the csv dump test (we should only get failed and cancelled statuses) expected_output = shared_tests.build_row_list(status_test_variables.FORMATTED_STATUSES_FAIL_AND_CANCEL) - shared_tests.run_csv_dump_test(self.detailed_status_obj, expected_output) + shared_tests.run_csv_dump_test(detailed_status_obj, expected_output) class TestPromptFunctionality(TestBaseDetailedStatus): diff --git a/tests/unit/study/test_status.py b/tests/unit/study/test_status.py index 695af17f3..ef34eec02 100644 --- a/tests/unit/study/test_status.py +++ b/tests/unit/study/test_status.py @@ -30,19 +30,469 @@ """ Tests for the Status class in the status.py module """ +import json +import os import unittest from argparse import Namespace from copy import deepcopy from datetime import datetime +from json.decoder import JSONDecodeError +import pytest import yaml from deepdiff import DeepDiff +from filelock import Timeout from merlin.spec.expansion import get_spec_with_expansion -from merlin.study.status import Status +from merlin.study.status import Status, read_status, status_conflict_handler, write_status from tests.unit.study.status_test_files import shared_tests, status_test_variables +class TestStatusReading: + """Test the logic for reading in status files""" + + cancel_step_dir = f"{status_test_variables.VALID_WORKSPACE_PATH}/cancel_step" + status_file = f"{cancel_step_dir}/MERLIN_STATUS.json" + lock_file = f"{cancel_step_dir}/status.lock" + + def test_basic_read(self): + """ + Test the basic reading functionality of `read_status`. There should + be no errors thrown and the correct status dict should be returned. + """ + actual_statuses = read_status(self.status_file, self.lock_file) + read_status_diff = DeepDiff( + actual_statuses, status_test_variables.REQUESTED_STATUSES_JUST_CANCELLED_STEP, ignore_order=True + ) + assert read_status_diff == {} + + def test_timeout_raise_errors_disabled(self, mocker: "Fixture", caplog: "Fixture"): # noqa: F821 + """ + Test the timeout functionality of the `read_status` function with + `raise_errors` set to False. This should log a warning message and + return an empty dict. + This test will create a mock of the FileLock object in order to + force a timeout to be raised. + + :param mocker: A built-in fixture from the pytest-mock library to create a Mock object + :param caplog: A built-in fixture from the pytest library to capture logs + """ + + # Set the mock to raise a timeout + mock_filelock = mocker.patch("merlin.study.status.FileLock") + mock_lock = mocker.MagicMock() + mock_lock.acquire.side_effect = Timeout(self.lock_file) + mock_filelock.return_value = mock_lock + + # Check that the return is as we expect + actual_status = read_status(self.status_file, self.lock_file) + assert actual_status == {} + + # Check that a warning is logged + expected_log = f"Timed out when trying to read status from '{self.status_file}'" + assert expected_log in caplog.text, "Missing expected log message" + + def test_timeout_raise_errors_enabled(self, mocker: "Fixture", caplog: "Fixture"): # noqa: F821 + """ + Test the timeout functionality of the `read_status` function with + `raise_errors` set to True. This should log a warning message and + raise a Timeout exception. + This test will create a mock of the FileLock object in order to + force a timeout to be raised. + + :param mocker: A built-in fixture from the pytest-mock library to create a Mock object + :param caplog: A built-in fixture from the pytest library to capture logs + """ + + # Set the mock to raise a timeout + mock_filelock = mocker.patch("merlin.study.status.FileLock") + mock_lock = mocker.MagicMock() + mock_lock.acquire.side_effect = Timeout(self.lock_file) + mock_filelock.return_value = mock_lock + + # Check that a Timeout exception is raised + with pytest.raises(Timeout): + read_status(self.status_file, self.lock_file, raise_errors=True) + + # Check that a warning is logged + expected_log = f"Timed out when trying to read status from '{self.status_file}'" + assert expected_log in caplog.text, "Missing expected log message" + + def test_file_not_found_no_fnf_no_errors(self, caplog: "Fixture"): # noqa: F821 + """ + Test the file not found functionality with the `display_fnf_message` + and `raise_errors` options both set to False. This should just return + an empty dict and not log anything. + + :param caplog: A built-in fixture from the pytest library to capture logs + """ + dummy_file = "i_dont_exist.json" + actual_status = read_status(dummy_file, self.lock_file, display_fnf_message=False, raise_errors=False) + assert actual_status == {} + assert caplog.text == "" + + def test_file_not_found_with_fnf_no_errors(self, caplog: "Fixture"): # noqa: F821 + """ + Test the file not found functionality with the `display_fnf_message` + set to True and the `raise_errors` option set to False. This should + return an empty dict and log a warning. + + :param caplog: A built-in fixture from the pytest library to capture logs + """ + dummy_file = "i_dont_exist.json" + actual_status = read_status(dummy_file, self.lock_file, display_fnf_message=True, raise_errors=False) + assert actual_status == {} + assert f"Could not find '{dummy_file}'" in caplog.text + + def test_file_not_found_no_fnf_with_errors(self, caplog: "Fixture"): # noqa: F821 + """ + Test the file not found functionality with the `display_fnf_message` + set to False and the `raise_errors` option set to True. This should + raise a FileNotFound error and not log anything. + + :param caplog: A built-in fixture from the pytest library to capture logs + """ + dummy_file = "i_dont_exist.json" + with pytest.raises(FileNotFoundError): + read_status(dummy_file, self.lock_file, display_fnf_message=False, raise_errors=True) + assert caplog.text == "" + + def test_file_not_found_with_fnf_and_errors(self, caplog: "Fixture"): # noqa: F821 + """ + Test the file not found functionality with the `display_fnf_message` + and `raise_errors` options both set to True. This should raise a FileNotFound + error and log a warning. + + :param caplog: A built-in fixture from the pytest library to capture logs + """ + dummy_file = "i_dont_exist.json" + with pytest.raises(FileNotFoundError): + read_status(dummy_file, self.lock_file, display_fnf_message=True, raise_errors=True) + assert f"Could not find '{dummy_file}'" in caplog.text + + def test_json_decode_raise_errors_disabled(self, caplog: "Fixture", status_empty_file: str): # noqa: F821 + """ + Test the json decode error functionality with `raise_errors` disabled. + This should log a warning and return an empty dict. + + :param caplog: A built-in fixture from the pytest library to capture logs + :param status_empty_file: A pytest fixture to give us an empty status file + """ + actual_status = read_status(status_empty_file, self.lock_file, raise_errors=False) + assert actual_status == {} + assert f"JSONDecodeError raised when trying to read status from '{status_empty_file}'" in caplog.text + + def test_json_decode_raise_errors_enabled(self, caplog: "Fixture", status_empty_file: str): # noqa: F821 + """ + Test the json decode error functionality with `raise_errors` enabled. + This should log a warning and raise a JSONDecodeError. + + :param caplog: A built-in fixture from the pytest library to capture logs + :param status_empty_file: A pytest fixture to give us an empty status file + """ + with pytest.raises(JSONDecodeError): + read_status(status_empty_file, self.lock_file, raise_errors=True) + assert f"JSONDecodeError raised when trying to read status from '{status_empty_file}'" in caplog.text + + @pytest.mark.parametrize("exception", [TypeError, ValueError, NotImplementedError, IOError, UnicodeError, OSError]) + def test_broad_exception_handler_raise_errors_disabled( + self, mocker: "Fixture", caplog: "Fixture", exception: Exception # noqa: F821 + ): + """ + Test the broad exception handler with `raise_errors` disabled. This should + log a warning and return an empty dict. + + :param mocker: A built-in fixture from the pytest-mock library to create a Mock object + :param caplog: A built-in fixture from the pytest library to capture logs + :param exception: An exception to force `read_status` to raise. + Values for this are obtained from parametrized list above. + """ + + # Set the mock to raise an exception + mock_filelock = mocker.patch("merlin.study.status.FileLock") + mock_lock = mocker.MagicMock() + mock_lock.acquire.side_effect = exception() + mock_filelock.return_value = mock_lock + + actual_status = read_status(self.status_file, self.lock_file, raise_errors=False) + assert actual_status == {} + assert f"An exception was raised while trying to read status from '{self.status_file}'!" in caplog.text + + @pytest.mark.parametrize("exception", [TypeError, ValueError, NotImplementedError, IOError, UnicodeError, OSError]) + def test_broad_exception_handler_raise_errors_enabled( + self, mocker: "Fixture", caplog: "Fixture", exception: Exception # noqa: F821 + ): + """ + Test the broad exception handler with `raise_errors` enabled. This should + log a warning and raise whichever exception is passed in (see list of + parametrized exceptions in the decorator above). + + :param mocker: A built-in fixture from the pytest-mock library to create a Mock object + :param caplog: A built-in fixture from the pytest library to capture logs + :param exception: An exception to force `read_status` to raise. + Values for this are obtained from parametrized list above. + """ + + # Set the mock to raise an exception + mock_filelock = mocker.patch("merlin.study.status.FileLock") + mock_lock = mocker.MagicMock() + mock_lock.acquire.side_effect = exception() + mock_filelock.return_value = mock_lock + + with pytest.raises(exception): + read_status(self.status_file, self.lock_file, raise_errors=True) + assert f"An exception was raised while trying to read status from '{self.status_file}'!" in caplog.text + + +class TestStatusWriting: + """Test the logic for writing to status files""" + + status_to_write = {"status": "TESTING"} + + def test_basic_write(self, status_testing_dir: str): + """ + Test the basic functionality of the `write_status` function. This + should write status to a file. + + :param status_testing_dir: A pytest fixture defined in `tests/fixtures/status.py` + that defines a path to the the output directory we'll write to + """ + + # Test variables + status_filepath = f"{status_testing_dir}/basic_write.json" + lock_file = f"{status_testing_dir}/basic_write.lock" + + # Run the test + write_status(self.status_to_write, status_filepath, lock_file) + + # Check that the path exists and that it contains the dummy status content + assert os.path.exists(status_filepath) + with open(status_filepath, "r") as sfp: + dummy_status = json.load(sfp) + assert dummy_status == self.status_to_write + + @pytest.mark.parametrize("exception", [TypeError, ValueError, NotImplementedError, IOError, UnicodeError, OSError]) + def test_exception_raised( + self, mocker: "Fixture", caplog: "Fixture", status_testing_dir: str, exception: Exception # noqa: F821 + ): + """ + Test the exception handler using several different exceptions defined in the + parametrized list in the decorator above. This should log a warning and not + create the status file that we provide. + + :param mocker: A built-in fixture from the pytest-mock library to create a Mock object + :param caplog: A built-in fixture from the pytest library to capture logs + :param status_testing_dir: A pytest fixture defined in `tests/fixtures/status.py` + that defines a path to the the output directory we'll write to + :param exception: An exception to force `read_status` to raise. + Values for this are obtained from parametrized list above. + """ + + # Set the mock to raise an exception + mock_filelock = mocker.patch("merlin.study.status.FileLock") + mock_lock = mocker.MagicMock() + mock_lock.acquire.side_effect = exception() + mock_filelock.return_value = mock_lock + + # Test variables + status_filepath = f"{status_testing_dir}/exception_{exception.__name__}.json" + lock_file = f"{status_testing_dir}/exception_{exception.__name__}.lock" + + write_status(self.status_to_write, status_filepath, lock_file) + assert f"An exception was raised while trying to write status to '{status_filepath}'!" in caplog.text + assert not os.path.exists(status_filepath) + + +class TestStatusConflictHandler: + """Test the functionality of the `status_conflict_handler` function.""" + + def test_parameter_conflict(self, caplog: "Fixture"): # noqa: F821 + """ + Test that conflicting parameters are handled properly. This is a special + case of the use-initial-and-log-warning rule since parameter tokens vary + and have to be added to the `merge_rules` dict on the fly. + + :param caplog: A built-in fixture from the pytest library to capture logs + """ + + # Create two dicts with conflicting parameter values + key = "TOKEN" + dict_a = {"parameters": {"cmd": {key: "value"}, "restart": None}} + dict_b = {"parameters": {"cmd": {key: "new_value"}, "restart": None}} + path = ["parameters", "cmd"] + + # Run the test + merged_val = status_conflict_handler( + dict_a_val=dict_a["parameters"]["cmd"][key], dict_b_val=dict_b["parameters"]["cmd"][key], key=key, path=path + ) + + # Check that everything ran properly + expected_log = ( + f"Conflict at key '{key}' while merging status files. Defaulting to initial value. " + "This could lead to incorrect status information, you may want to re-run in debug mode and " + "check the files in the output directory for this task." + ) + assert merged_val == "value" + assert expected_log in caplog.text + + def test_non_existent_key(self, caplog: "Fixture"): # noqa: F821 + """ + Test providing `status_conflict_handler` a key that doesn't exist in + the `merge_rule` dict. This should log a warning and return None. + + :param caplog: A built-in fixture from the pytest library to capture logs + """ + key = "i_dont_exist" + merged_val = status_conflict_handler(key=key) + assert merged_val is None + assert f"The key '{key}' does not have a merge rule defined. Setting this merge to None." in caplog.text + + def test_rule_string_concatenate(self): + """ + Test the string-concatenate merge rule. This should combine + the strings provided in `dict_a_val` and `dict_b_val` into one + comma-delimited string. + """ + + # Create two dicts with conflicting task-queue values + key = "task_queue" + val1 = "existing_task_queue" + val2 = "new_task_queue" + dict_a = {key: val1} + dict_b = {key: val2} + + # Run the test and make sure the values are being concatenated + merged_val = status_conflict_handler( + dict_a_val=dict_a[key], + dict_b_val=dict_b[key], + key=key, + ) + assert merged_val == f"{val1}, {val2}" + + def test_rule_use_initial_and_log_warning(self, caplog: "Fixture"): # noqa: F821 + """ + Test the use-initial-and-log-warning merge rule. This should + return the value passed in to `dict_a_val` and log a warning + message. + + :param caplog: A built-in fixture from the pytest library to capture logs + """ + + # Create two dicts with conflicting status values + key = "status" + dict_a = {key: "SUCCESS"} + dict_b = {key: "FAILED"} + + # Run the test + merged_val = status_conflict_handler( + dict_a_val=dict_a[key], + dict_b_val=dict_b[key], + key=key, + ) + + # Check that everything ran properly + expected_log = ( + f"Conflict at key '{key}' while merging status files. Defaulting to initial value. " + "This could lead to incorrect status information, you may want to re-run in debug mode and " + "check the files in the output directory for this task." + ) + assert merged_val == "SUCCESS" + assert expected_log in caplog.text + + def test_rule_use_longest_time_no_dict_a_time(self): + """ + Test the use-longest-time merge rule with no time set for `dict_a_val`. + This should default to using the time in `dict_b_val`. + """ + key = "elapsed_time" + expected_time = "12h:34m:56s" + dict_a = {key: "--:--:--"} + dict_b = {key: expected_time} + + merged_val = status_conflict_handler( + dict_a_val=dict_a[key], + dict_b_val=dict_b[key], + key=key, + ) + assert merged_val == expected_time + + def test_rule_use_longest_time_no_dict_b_time(self): + """ + Test the use-longest-time merge rule with no time set for `dict_b_val`. + This should default to using the time in `dict_a_val`. + """ + key = "run_time" + expected_time = "12h:34m:56s" + dict_a = {key: expected_time} + dict_b = {key: "--:--:--"} + + merged_val = status_conflict_handler( + dict_a_val=dict_a[key], + dict_b_val=dict_b[key], + key=key, + ) + assert merged_val == expected_time + + def test_rule_use_longest_time(self): + """ + Test the use-longest-time merge rule with times set for both `dict_a_val` + and `dict_b_val`. This should use whichever time is longer. + """ + + # Set up test variables + key = "run_time" + short_time = "01h:04m:33s" + long_time = "12h:34m:56s" + + # Run test with dict b having the longer time + dict_a = {key: short_time} + dict_b = {key: long_time} + merged_val = status_conflict_handler( + dict_a_val=dict_a[key], + dict_b_val=dict_b[key], + key=key, + ) + assert merged_val == "0d:" + long_time # Time manipulation in status_conflict_handler will prepend '0d:' + + # Run test with dict a having the longer time + dict_a_2 = {key: long_time} + dict_b_2 = {key: short_time} + merged_val_2 = status_conflict_handler( + dict_a_val=dict_a_2[key], + dict_b_val=dict_b_2[key], + key=key, + ) + assert merged_val_2 == "0d:" + long_time + + @pytest.mark.parametrize( + "dict_a_val, dict_b_val, expected", + [ + (0, 0, 0), + (0, 1, 1), + (1, 0, 1), + (-1, 0, 0), + (0, -1, 0), + (23, 20, 23), + (17, 21, 21), + ], + ) + def test_rule_use_max(self, dict_a_val: int, dict_b_val: int, expected: int): + """ + Test the use-max merge rule. This should take the maximum of 2 values. + + :param dict_a_val: The value to pass in for dict_a_val + :param dict_b_val: The value to pass in for dict_b_val + :param expected: The expected value from this test + """ + key = "restarts" + merged_val = status_conflict_handler( + dict_a_val=dict_a_val, + dict_b_val=dict_b_val, + key=key, + ) + assert merged_val == expected + + class TestMerlinStatus(unittest.TestCase): """Test the logic for methods in the Status class.""" diff --git a/tests/unit/utils/test_dict_deep_merge.py b/tests/unit/utils/test_dict_deep_merge.py new file mode 100644 index 000000000..133897f36 --- /dev/null +++ b/tests/unit/utils/test_dict_deep_merge.py @@ -0,0 +1,276 @@ +""" +Tests for the `dict_deep_merge` function defined in the `utils.py` module. +""" + +from typing import Any, Dict, List + +import pytest + +from merlin.utils import dict_deep_merge + + +def run_invalid_check(dict_a: Any, dict_b: Any, expected_log: str, caplog: "Fixture"): # noqa: F821 + """ + Helper function to run invalid input tests on the `dict_deep_merge` function. + + :param dict_a: The value of dict_a that we're testing against + :param dict_b: The value of dict_b that we're testing against + :param expected_log: The log that we're expecting `dict_deep_merge` to write + :param caplog: A built-in fixture from the pytest library to capture logs + """ + + # Store initial value of dict_a + if isinstance(dict_a, list): + dict_a_initial = dict_a.copy() + else: + dict_a_initial = dict_a + + # Check that dict_deep_merge returns None and that dict_a wasn't modified + assert dict_deep_merge(dict_a, dict_b) is None + assert dict_a_initial == dict_a + + # Check that dict_deep_merge logs a warning + print(f"caplog.text: {caplog.text}") + assert expected_log in caplog.text, "Missing expected log message" + + +@pytest.mark.parametrize( + "dict_a, dict_b", + [ + (None, None), + (None, ["no lists allowed!"]), + (["no lists allowed!"], None), + (["no lists allowed!"], ["no lists allowed!"]), + ("no strings allowed!", None), + (None, "no strings allowed!"), + ("no strings allowed!", "no strings allowed!"), + (10, None), + (None, 10), + (10, 10), + (10.5, None), + (None, 10.5), + (10.5, 10.5), + (("no", "tuples"), None), + (None, ("no", "tuples")), + (("no", "tuples"), ("no", "tuples")), + (True, None), + (None, True), + (True, True), + ], +) +def test_dict_deep_merge_both_dicts_invalid(dict_a: Any, dict_b: Any, caplog: "Fixture"): # noqa: F821 + """ + Test the `dict_deep_merge` function with both `dict_a` and `dict_b` + parameters being an invalid type. This should log a message and do + nothing. + + :param dict_a: The value of dict_a that we're testing against + :param dict_b: The value of dict_b that we're testing against + :param caplog: A built-in fixture from the pytest library to capture logs + """ + + # The expected log that's output by dict_deep_merge + expected_log = f"Problem with dict_deep_merge: dict_a '{dict_a}' is not a dict, dict_b '{dict_b}' is not a dict. Ignoring this merge call." + + # Run the actual test + run_invalid_check(dict_a, dict_b, expected_log, caplog) + + +@pytest.mark.parametrize( + "dict_a, dict_b", + [ + (None, {"test_key": "test_val"}), + (["no lists allowed!"], {"test_key": "test_val"}), + ("no strings allowed!", {"test_key": "test_val"}), + (10, {"test_key": "test_val"}), + (10.5, {"test_key": "test_val"}), + (("no", "tuples"), {"test_key": "test_val"}), + (True, {"test_key": "test_val"}), + ], +) +def test_dict_deep_merge_dict_a_invalid(dict_a: Any, dict_b: Dict[str, str], caplog: "Fixture"): # noqa: F821 + """ + Test the `dict_deep_merge` function with the `dict_a` parameter + being an invalid type. This should log a message and do nothing. + + :param dict_a: The value of dict_a that we're testing against + :param dict_b: The value of dict_b that we're testing against + :param caplog: A built-in fixture from the pytest library to capture logs + """ + + # The expected log that's output by dict_deep_merge + expected_log = f"Problem with dict_deep_merge: dict_a '{dict_a}' is not a dict. Ignoring this merge call." + + # Run the actual test + run_invalid_check(dict_a, dict_b, expected_log, caplog) + + +@pytest.mark.parametrize( + "dict_a, dict_b", + [ + ({"test_key": "test_val"}, None), + ({"test_key": "test_val"}, ["no lists allowed!"]), + ({"test_key": "test_val"}, "no strings allowed!"), + ({"test_key": "test_val"}, 10), + ({"test_key": "test_val"}, 10.5), + ({"test_key": "test_val"}, ("no", "tuples")), + ({"test_key": "test_val"}, True), + ], +) +def test_dict_deep_merge_dict_b_invalid(dict_a: Dict[str, str], dict_b: Any, caplog: "Fixture"): # noqa: F821 + """ + Test the `dict_deep_merge` function with the `dict_b` parameter + being an invalid type. This should log a message and do nothing. + + :param dict_a: The value of dict_a that we're testing against + :param dict_b: The value of dict_b that we're testing against + :param caplog: A built-in fixture from the pytest library to capture logs + """ + + # The expected log that's output by dict_deep_merge + expected_log = f"Problem with dict_deep_merge: dict_b '{dict_b}' is not a dict. Ignoring this merge call." + + # Run the actual test + run_invalid_check(dict_a, dict_b, expected_log, caplog) + + +@pytest.mark.parametrize( + "dict_a, dict_b, expected", + [ + ({"test_key": {}}, {"test_key": {}}, {}), # Testing merge of two empty dicts + ({"test_key": {}}, {"test_key": {"new_key": "new_val"}}, {"new_key": "new_val"}), # Testing dict_a empty dict merge + ( + {"test_key": {"existing_key": "existing_val"}}, + {"test_key": {}}, + {"existing_key": "existing_val"}, + ), # Testing dict_b empty dict merge + ( + {"test_key": {"existing_key": "existing_val"}}, + {"test_key": {"new_key": "new_val"}}, + {"existing_key": "existing_val", "new_key": "new_val"}, + ), # Testing merge of dicts with content + ], +) +def test_dict_deep_merge_dict_merge( + dict_a: Dict[str, Dict[Any, Any]], dict_b: Dict[str, Dict[Any, Any]], expected: Dict[Any, Any] +): + """ + Test the `dict_deep_merge` function with dicts that need to be merged. + NOTE we're keeping the test values of this function simple since the other tests + related to `dict_deep_merge` should be hitting the other possible scenarios. + + :param dict_a: The value of dict_a that we're testing against + :param dict_b: The value of dict_b that we're testing against + :param expected: The dict that we're expecting to now be in dict_a at 'test_key' + """ + dict_deep_merge(dict_a, dict_b) + assert dict_a["test_key"] == expected + + +@pytest.mark.parametrize( + "dict_a, dict_b, expected", + [ + ({"test_key": []}, {"test_key": []}, []), # Testing merge of two empty lists + ({"test_key": []}, {"test_key": ["new_val"]}, ["new_val"]), # Testing dict_a empty list merge + ({"test_key": ["existing_val"]}, {"test_key": []}, ["existing_val"]), # Testing dict_b empty list merge + ( + {"test_key": ["existing_val"]}, + {"test_key": ["new_val"]}, + ["existing_val", "new_val"], + ), # Testing merge of list of strings + ({"test_key": [None]}, {"test_key": [None]}, [None, None]), # Testing merge of list of None + ({"test_key": [0]}, {"test_key": [1]}, [0, 1]), # Testing merge of list of integers + ({"test_key": [True]}, {"test_key": [False]}, [True, False]), # Testing merge of list of bools + ({"test_key": [0.0]}, {"test_key": [1.0]}, [0.0, 1.0]), # Testing merge of list of floats + ( + {"test_key": [(True, False)]}, + {"test_key": [(False, True)]}, + [(True, False), (False, True)], + ), # Testing merge of list of tuples + ( + {"test_key": [{"existing_key": "existing_val"}]}, + {"test_key": [{"new_key": "new_val"}]}, + [{"existing_key": "existing_val"}, {"new_key": "new_val"}], + ), # Testing merge of list of dicts + ( + {"test_key": ["existing_val", 0]}, + {"test_key": [True, 1.0, None]}, + ["existing_val", 0, True, 1.0, None], + ), # Testing merge of list of multiple types + ], +) +def test_dict_deep_merge_list_merge(dict_a: Dict[str, List[Any]], dict_b: Dict[str, List[Any]], expected: List[Any]): + """ + Test the `dict_deep_merge` function with lists that need to be merged. + + :param dict_a: The value of dict_a that we're testing against + :param dict_b: The value of dict_b that we're testing against + :param expected: The list that we're expecting to now be in dict_a at 'test_key' + """ + dict_deep_merge(dict_a, dict_b) + assert dict_a["test_key"] == expected + + +@pytest.mark.parametrize( + "dict_a, dict_b, expected", + [ + ({"test_key": None}, {"test_key": None}, None), # Testing merge of None + ({"test_key": "test_val"}, {"test_key": "test_val"}, "test_val"), # Testing merge of string + ({"test_key": 1}, {"test_key": 1}, 1), # Testing merge of int + ({"test_key": 1.0}, {"test_key": 1.0}, 1.0), # Testing merge of float + ({"test_key": False}, {"test_key": False}, False), # Testing merge of bool + ], +) +def test_dict_deep_merge_same_leaf(dict_a: Dict[str, Any], dict_b: Dict[str, Any], expected: Any): + """ + Test the `dict_deep_merge` function with equivalent values in dict_a and dict_b. + Nothing should happen here so dict_a["test_key"] should be the exact same. + + :param dict_a: The value of dict_a that we're testing against + :param dict_b: The value of dict_b that we're testing against + :param expected: The value that we're expecting to now be in dict_a at 'test_key' + """ + dict_deep_merge(dict_a, dict_b) + assert dict_a["test_key"] == expected + + +def test_dict_deep_merge_conflict_no_conflict_handler(caplog: "Fixture"): # noqa: F821 + """ + Test the `dict_deep_merge` function with a conflicting value in dict_b + and no conflict handler. Since there's no conflict handler this should + log a warning and ignore any merge for the key that has the conflict. + + :param caplog: A built-in fixture from the pytest library to capture logs + """ + dict_a = {"test_key": "existing_value"} + dict_b = {"test_key": "new_value"} + + # Call deep merge and make sure "test_key" in dict_a wasn't updated + dict_deep_merge(dict_a, dict_b) + assert dict_a["test_key"] == "existing_value" + + # Check that dict_deep_merge logs a warning + assert "Conflict at test_key. Ignoring the update to key 'test_key'." in caplog.text, "Missing expected log message" + + +def test_dict_deep_merge_conflict_with_conflict_handler(): + """ + Test the `dict_deep_merge` function with a conflicting value in dict_b + and a conflict handler. Our conflict handler will just concatenate the + conflicting strings. + """ + dict_a = {"test_key": "existing_value"} + dict_b = {"test_key": "new_value"} + + def conflict_handler(*args, **kwargs): + """ + The conflict handler that we'll be passing in to `dict_deep_merge`. + This will concatenate the conflicting strings. + """ + dict_a_val = kwargs.get("dict_a_val", None) + dict_b_val = kwargs.get("dict_b_val", None) + return ", ".join([dict_a_val, dict_b_val]) + + # Call deep merge and make sure "test_key" in dict_a wasn't updated + dict_deep_merge(dict_a, dict_b, conflict_handler=conflict_handler) + assert dict_a["test_key"] == "existing_value, new_value"