Skip to content

Commit

Permalink
remove DeepMergeException and add conflict_handler to dict_deep_merge
Browse files Browse the repository at this point in the history
  • Loading branch information
bgunnar5 committed May 7, 2024
1 parent 3a3a2ac commit 9050ce2
Show file tree
Hide file tree
Showing 3 changed files with 288 additions and 22 deletions.
11 changes: 0 additions & 11 deletions merlin/exceptions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
"HardFailException",
"InvalidChainException",
"RestartException",
"DeepMergeException",
"NoWorkersException",
)

Expand Down Expand Up @@ -96,16 +95,6 @@ def __init__(self):
super().__init__()


class DeepMergeException(Exception):
"""
Exception to signal that there's a conflict when trying
to merge two dicts together
"""

def __init__(self, message):
super().__init__(message)


class NoWorkersException(Exception):
"""
Exception to signal that no workers were started
Expand Down
44 changes: 33 additions & 11 deletions merlin/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,12 @@
from copy import deepcopy
from datetime import datetime, timedelta
from types import SimpleNamespace
from typing import List, Optional, Union
from typing import Callable, List, Optional, Union

import numpy as np
import psutil
import yaml

from merlin.exceptions import DeepMergeException


try:
import cPickle as pickle
except ImportError:
Expand Down Expand Up @@ -559,33 +556,53 @@ def needs_merlin_expansion(
return False


def dict_deep_merge(dict_a, dict_b, path=None):
def dict_deep_merge(dict_a: dict, dict_b: dict, path: str = None, conflict_handler: Callable = None):
"""
This function recursively merges dict_b into dict_a. The built-in
merge of dictionaries in python (dict(dict_a) | dict(dict_b)) does not do a
deep merge so this function is necessary. This will only merge in new keys,
it will NOT update existing ones.
it will NOT update existing ones, unless you specify a conflict handler function.
Credit to this stack overflow post: https://stackoverflow.com/a/7205107.
:param `dict_a`: A dict that we'll merge dict_b into
:param `dict_b`: A dict that we want to merge into dict_a
:param `path`: The path down the dictionary tree that we're currently at
:param `conflict_handler`: An optional function to handle conflicts between values at the same key.
The function should return the value to be used in the merged dictionary.
The default behavior without this argument is to log a warning.
"""

# Check to make sure we have valid dict_a and dict_b input
dict_a_is_dict = isinstance(dict_a, dict)
dict_b_is_dict = isinstance(dict_b, dict)
if not dict_a_is_dict or not dict_b_is_dict:
if not dict_a_is_dict and not dict_b_is_dict:
problem_str = f"both dict_a '{dict_a}' and dict_b '{dict_b}' are not dictionaries"
elif not dict_a_is_dict:
problem_str = f"dict_a '{dict_a}' is not a dictionary"
elif not dict_b_is_dict:
problem_str = f"dict_b '{dict_b}' is not a dictionary"

LOG.warning(f"Problem with dict_deep_merge: {problem_str}. Ignoring this merge call.")
return

if path is None:
path = []
for key in dict_b:
if key in dict_a:
if isinstance(dict_a[key], dict) and isinstance(dict_b[key], dict):
dict_deep_merge(dict_a[key], dict_b[key], path + [str(key)])
elif key == "workers": # specifically for status merging
all_workers = [dict_a[key], dict_b[key]]
dict_a[key] = list(set().union(*all_workers))
dict_deep_merge(dict_a[key], dict_b[key], path=path+[str(key)], conflict_handler=conflict_handler)
elif isinstance(dict_a[key], list) and isinstance(dict_a[key], list):
dict_a[key] += dict_b[key]
elif dict_a[key] == dict_b[key]:
pass # same leaf value
else:
raise DeepMergeException(f"Conflict at {'.'.join(path + [str(key)])}")
if conflict_handler is not None:
merged_val = conflict_handler(dict_a_val=dict_a[key], dict_b_val=dict_b[key], key=key, path=path+[str(key)])
dict_a[key] = merged_val
else:
# Want to just output a warning instead of raising an exception so that the workflow doesn't crash
LOG.warning(f"Conflict at {'.'.join(path + [str(key)])}. Ignoring the update to key '{key}'.")
else:
dict_a[key] = dict_b[key]

Expand Down Expand Up @@ -619,6 +636,11 @@ def convert_to_timedelta(timestr: Union[str, int]) -> timedelta:
"""
# make sure it's a string in case we get an int
timestr = str(timestr)

# remove time unit characters (if any exist)
time_unit_chars = r"[dhms]"
timestr = re.sub(time_unit_chars, "", timestr)

nfields = len(timestr.split(":"))
if nfields > 4:
raise ValueError(f"Cannot convert {timestr} to a timedelta. Valid format: days:hours:minutes:seconds.")
Expand Down
255 changes: 255 additions & 0 deletions tests/unit/utils/test_dict_deep_merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
"""
Tests for the `dict_deep_merge` function defined in the `utils.py` module.
"""
import logging
import pytest
from io import StringIO
from typing import Any, Dict, List

from merlin.utils import dict_deep_merge

def run_invalid_check(dict_a: Any, dict_b: Any, expected_log: str):
"""
Helper function to run invalid input tests on the `dict_deep_merge` function.
:param dict_a: The value of dict_a that we're testing against
:param dict_b: The value of dict_b that we're testing against
:param expected_log: The log that we're expecting `dict_deep_merge` to write
"""

# Create a capture stream to capture logs
capture_stream = StringIO()
handler = logging.StreamHandler(capture_stream)
logger = logging.getLogger()
logger.addHandler(handler)

# Store initial value of dict_a
if isinstance(dict_a, list):
dict_a_initial = dict_a.copy()
else:
dict_a_initial = dict_a

# Check that dict_deep_merge returns None and that dict_a wasn't modified
assert dict_deep_merge(dict_a, dict_b) is None
assert dict_a_initial == dict_a

# Check that dict_deep_merge logs a warning
logger.removeHandler(handler)
assert expected_log in capture_stream.getvalue(), "Missing expected log message"


@pytest.mark.parametrize(
"dict_a, dict_b",
[
(None, None),
(None, ["no lists allowed!"]),
(["no lists allowed!"], None),
(["no lists allowed!"], ["no lists allowed!"]),
("no strings allowed!", None),
(None, "no strings allowed!"),
("no strings allowed!", "no strings allowed!"),
(10, None),
(None, 10),
(10, 10),
(10.5, None),
(None, 10.5),
(10.5, 10.5),
(("no","tuples"), None),
(None, ("no","tuples")),
(("no","tuples"), ("no","tuples")),
(True, None),
(None, True),
(True, True),
],
)
def test_dict_deep_merge_both_dicts_invalid(dict_a: Any, dict_b: Any):
"""
Test the `dict_deep_merge` function with both `dict_a` and `dict_b`
parameters being an invalid type. This should log a message and do
nothing.
:param dict_a: The value of dict_a that we're testing against
:param dict_b: The value of dict_b that we're testing against
"""

# The expected log that's output by dict_deep_merge
expected_log = f"Problem with dict_deep_merge: both dict_a '{dict_a}' " \
f"and dict_b '{dict_b}' are not dictionaries. Ignoring this merge call."

# Run the actual test
run_invalid_check(dict_a, dict_b, expected_log)


@pytest.mark.parametrize(
"dict_a, dict_b",
[
(None, {"test_key": "test_val"}),
(["no lists allowed!"], {"test_key": "test_val"}),
("no strings allowed!", {"test_key": "test_val"}),
(10, {"test_key": "test_val"}),
(10.5, {"test_key": "test_val"}),
(("no","tuples"), {"test_key": "test_val"}),
(True, {"test_key": "test_val"}),
],
)
def test_dict_deep_merge_dict_a_invalid(dict_a: Any, dict_b: Dict[str, str]):
"""
Test the `dict_deep_merge` function with the `dict_a` parameter
being an invalid type. This should log a message and do nothing.
:param dict_a: The value of dict_a that we're testing against
:param dict_b: The value of dict_b that we're testing against
"""

# The expected log that's output by dict_deep_merge
expected_log = f"Problem with dict_deep_merge: dict_a '{dict_a}' is not a dictionary. Ignoring this merge call."

# Run the actual test
run_invalid_check(dict_a, dict_b, expected_log)


@pytest.mark.parametrize(
"dict_a, dict_b",
[
({"test_key": "test_val"}, None),
({"test_key": "test_val"}, ["no lists allowed!"]),
({"test_key": "test_val"}, "no strings allowed!"),
({"test_key": "test_val"}, 10),
({"test_key": "test_val"}, 10.5),
({"test_key": "test_val"}, ("no","tuples")),
({"test_key": "test_val"}, True),
],
)
def test_dict_deep_merge_dict_b_invalid(dict_a: Dict[str, str], dict_b: Any):
"""
Test the `dict_deep_merge` function with the `dict_b` parameter
being an invalid type. This should log a message and do nothing.
:param dict_a: The value of dict_a that we're testing against
:param dict_b: The value of dict_b that we're testing against
"""

# The expected log that's output by dict_deep_merge
expected_log = f"Problem with dict_deep_merge: dict_b '{dict_b}' is not a dictionary. Ignoring this merge call."

# Run the actual test
run_invalid_check(dict_a, dict_b, expected_log)

@pytest.mark.parametrize(
"dict_a, dict_b, expected",
[
({"test_key": {}}, {"test_key": {}}, {}), # Testing merge of two empty dicts
({"test_key": {}}, {"test_key": {"new_key": "new_val"}}, {"new_key": "new_val"}), # Testing dict_a empty dict merge
({"test_key": {"existing_key": "existing_val"}}, {"test_key": {}}, {"existing_key": "existing_val"}), # Testing dict_b empty dict merge
({"test_key": {"existing_key": "existing_val"}}, {"test_key": {"new_key": "new_val"}}, {"existing_key": "existing_val", "new_key": "new_val"}), # Testing merge of dicts with content
],
)
def test_dict_deep_merge_dict_merge(dict_a: Dict[str, Dict[Any, Any]], dict_b: Dict[str, Dict[Any, Any]], expected: Dict[Any, Any]):
"""
Test the `dict_deep_merge` function with dicts that need to be merged.
NOTE we're keeping the test values of this function simple since the other tests
related to `dict_deep_merge` should be hitting the other possible scenarios.
:param dict_a: The value of dict_a that we're testing against
:param dict_b: The value of dict_b that we're testing against
:param expected: The dict that we're expecting to now be in dict_a at 'test_key'
"""
dict_deep_merge(dict_a, dict_b)
assert dict_a["test_key"] == expected

@pytest.mark.parametrize(
"dict_a, dict_b, expected",
[
({"test_key": []}, {"test_key": []}, []), # Testing merge of two empty lists
({"test_key": []}, {"test_key": ["new_val"]}, ["new_val"]), # Testing dict_a empty list merge
({"test_key": ["existing_val"]}, {"test_key": []}, ["existing_val"]), # Testing dict_b empty list merge
({"test_key": ["existing_val"]}, {"test_key": ["new_val"]}, ["existing_val", "new_val"]), # Testing merge of list of strings
({"test_key": [None]}, {"test_key": [None]}, [None, None]), # Testing merge of list of None
({"test_key": [0]}, {"test_key": [1]}, [0, 1]), # Testing merge of list of integers
({"test_key": [True]}, {"test_key": [False]}, [True, False]), # Testing merge of list of bools
({"test_key": [0.0]}, {"test_key": [1.0]}, [0.0, 1.0]), # Testing merge of list of floats
({"test_key": [(True, False)]}, {"test_key": [(False, True)]}, [(True, False), (False, True)]), # Testing merge of list of tuples
({"test_key": [{"existing_key": "existing_val"}]}, {"test_key": [{"new_key": "new_val"}]}, [{"existing_key": "existing_val"}, {"new_key": "new_val"}]), # Testing merge of list of dicts
({"test_key": ["existing_val", 0]}, {"test_key": [True, 1.0, None]}, ["existing_val", 0, True, 1.0, None]), # Testing merge of list of multiple types
],
)
def test_dict_deep_merge_list_merge(dict_a: Dict[str, List[Any]], dict_b: Dict[str, List[Any]], expected: List[Any]):
"""
Test the `dict_deep_merge` function with lists that need to be merged.
:param dict_a: The value of dict_a that we're testing against
:param dict_b: The value of dict_b that we're testing against
:param expected: The list that we're expecting to now be in dict_a at 'test_key'
"""
dict_deep_merge(dict_a, dict_b)
assert dict_a["test_key"] == expected


@pytest.mark.parametrize(
"dict_a, dict_b, expected",
[
({"test_key": None}, {"test_key": None}, None), # Testing merge of None
({"test_key": "test_val"}, {"test_key": "test_val"}, "test_val"), # Testing merge of string
({"test_key": 1}, {"test_key": 1}, 1), # Testing merge of int
({"test_key": 1.0}, {"test_key": 1.0}, 1.0), # Testing merge of float
({"test_key": False}, {"test_key": False}, False), # Testing merge of bool
],
)
def test_dict_deep_merge_same_leaf(dict_a: Dict[str, Any], dict_b: Dict[str, Any], expected: Any):
"""
Test the `dict_deep_merge` function with equivalent values in dict_a and dict_b.
Nothing should happen here so dict_a["test_key"] should be the exact same.
:param dict_a: The value of dict_a that we're testing against
:param dict_b: The value of dict_b that we're testing against
:param expected: The value that we're expecting to now be in dict_a at 'test_key'
"""
dict_deep_merge(dict_a, dict_b)
assert dict_a["test_key"] == expected

def test_dict_deep_merge_conflict_no_conflict_handler():
"""
Test the `dict_deep_merge` function with a conflicting value in dict_b
and no conflict handler. Since there's no conflict handler this should
log a warning and ignore any merge for the key that has the conflict.
"""
dict_a = {"test_key": "existing_value"}
dict_b = {"test_key": "new_value"}

# Create a capture stream to capture logs
capture_stream = StringIO()
handler = logging.StreamHandler(capture_stream)
logger = logging.getLogger()
logger.addHandler(handler)

# Call deep merge and make sure "test_key" in dict_a wasn't updated
dict_deep_merge(dict_a, dict_b)
assert dict_a["test_key"] == "existing_value"

# Check that dict_deep_merge logs a warning
logger.removeHandler(handler)
assert f"Conflict at test_key. Ignoring the update to key 'test_key'." in capture_stream.getvalue(), "Missing expected log message"


def test_dict_deep_merge_conflict_with_conflict_handler():
"""
Test the `dict_deep_merge` function with a conflicting value in dict_b
and a conflict handler. Our conflict handler will just concatenate the
conflicting strings.
"""
dict_a = {"test_key": "existing_value"}
dict_b = {"test_key": "new_value"}

def conflict_handler(*args, **kwargs):
"""
The conflict handler that we'll be passing in to `dict_deep_merge`.
This will concatenate the conflicting strings.
"""
dict_a_val = kwargs.get("dict_a_val", None)
dict_b_val = kwargs.get("dict_b_val", None)
return ", ".join([dict_a_val, dict_b_val])

# Call deep merge and make sure "test_key" in dict_a wasn't updated
dict_deep_merge(dict_a, dict_b, conflict_handler=conflict_handler)
assert dict_a["test_key"] == "existing_value, new_value"

0 comments on commit 9050ce2

Please sign in to comment.