Skip to content

Commit

Permalink
bugfix/deep-merge-existing-keys (LLNL#476)
Browse files Browse the repository at this point in the history
* remove a merge conflict statement that was missed

* add a 'pip freeze' call in github workflow to view reqs versions

* remove DeepMergeException and add conflict_handler to dict_deep_merge

* add conflict handler to dict_deep_merge

* fix broken tests for detailed-status

* use caplog fixture rather than IO stream

* add ability to define module-specific fixtures

* add tests for read/write status files and conlict handling

* add caplog explanation to docstrings

* update CHANGELOG

* run fix-style

* add pytest-mock as dependency for test suite

* clean up input check in dict_deep_merge
  • Loading branch information
bgunnar5 authored May 10, 2024
1 parent eace86f commit f476a98
Show file tree
Hide file tree
Showing 13 changed files with 1,026 additions and 66 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]
### Added
- Conflict handler option to the `dict_deep_merge` function in `utils.py`
- Ability to add module-specific pytest fixtures
- Added fixtures specifically for testing status functionality
- Added tests for reading and writing status files, and status conflict handling
- Added tests for the `dict_deep_merge` function
- Pytest-mock as a dependency for the test suite (necessary for using mocks and fixtures in the same test)

### Changed

Expand Down
6 changes: 3 additions & 3 deletions merlin/common/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from merlin.exceptions import HardFailException, InvalidChainException, RestartException, RetryException
from merlin.router import stop_workers
from merlin.spec.expansion import parameter_substitutions_for_cmd, parameter_substitutions_for_sample
from merlin.study.status import read_status
from merlin.study.status import read_status, status_conflict_handler
from merlin.utils import dict_deep_merge


Expand Down Expand Up @@ -484,7 +484,7 @@ def gather_statuses(
# Make sure the status for this sample workspace is in a finished state (not initialized or running)
if status[step_name][f"{condensed_workspace}/{path}"]["status"] not in ("INITIALIZED", "RUNNING"):
# Add the status data to the statuses we'll write to the condensed file and remove this status file
dict_deep_merge(condensed_statuses, status)
dict_deep_merge(condensed_statuses, status, conflict_handler=status_conflict_handler)
files_to_remove.append(status_filepath)
files_to_remove.append(lock_filepath) # Remove the lock files as well as the status files
except KeyError:
Expand Down Expand Up @@ -556,7 +556,7 @@ def condense_status_files(self, *args: Any, **kwargs: Any) -> ReturnCode: # pyl
existing_condensed_statuses = json.load(condensed_status_file)
# Merging the statuses we're condensing into the already existing statuses
# because it's faster at scale than vice versa
dict_deep_merge(existing_condensed_statuses, condensed_statuses)
dict_deep_merge(existing_condensed_statuses, condensed_statuses, conflict_handler=status_conflict_handler)
condensed_statuses = existing_condensed_statuses

# Write the condensed statuses to the condensed status file
Expand Down
11 changes: 0 additions & 11 deletions merlin/exceptions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
"HardFailException",
"InvalidChainException",
"RestartException",
"DeepMergeException",
"NoWorkersException",
)

Expand Down Expand Up @@ -96,16 +95,6 @@ def __init__(self):
super().__init__()


class DeepMergeException(Exception):
"""
Exception to signal that there's a conflict when trying
to merge two dicts together
"""

def __init__(self, message):
super().__init__(message)


class NoWorkersException(Exception):
"""
Exception to signal that no workers were started
Expand Down
125 changes: 110 additions & 15 deletions merlin/study/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,11 @@
from datetime import datetime
from glob import glob
from traceback import print_exception
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
from filelock import FileLock, Timeout
from maestrowf.utils import get_duration
from tabulate import tabulate

from merlin.common.dumper import dump_handler
Expand Down Expand Up @@ -363,7 +364,7 @@ def get_step_statuses(self, step_workspace: str, started_step_name: str) -> Dict
statuses_read = read_status(status_filepath, f"{root}/status.lock")

# Merge the statuses we read with the dict tracking all statuses for this step
dict_deep_merge(step_statuses, statuses_read)
dict_deep_merge(step_statuses, statuses_read, conflict_handler=status_conflict_handler)

# Add full step name to the tracker and count number of statuses we just read in
for full_step_name, status_info in statuses_read.items():
Expand Down Expand Up @@ -391,7 +392,7 @@ def load_requested_statuses(self):
for sstep in self.step_tracker["started_steps"]:
step_workspace = f"{self.workspace}/{sstep}"
step_statuses = self.get_step_statuses(step_workspace, sstep)
dict_deep_merge(self.requested_statuses, step_statuses)
dict_deep_merge(self.requested_statuses, step_statuses, conflict_handler=status_conflict_handler)

# Calculate run time average and standard deviation for this step
self.get_runtime_avg_std_dev(step_statuses, sstep)
Expand Down Expand Up @@ -531,8 +532,9 @@ def format_status_for_csv(self) -> Dict:

# Loop through information for each step
for step_info_key, step_info_value in overall_step_info.items():
# Skip the workers entry at the top level; this will be added in the else statement below on a task-by-task basis
if step_info_key == "workers":
# Skip the workers entry at the top level;
# this will be added in the else statement below on a task-by-task basis
if step_info_key in ("workers", "worker_name"):
continue
# Format task queue entry
if step_info_key == "task_queue":
Expand Down Expand Up @@ -833,13 +835,15 @@ def apply_filters(self):
filtered_statuses = {}
for step_name, overall_step_info in self.requested_statuses.items():
filtered_statuses[step_name] = {}
# Add the non-workspace keys to the filtered_status dict so we don't accidentally miss any of this information while filtering
# Add the non-workspace keys to the filtered_status dict so we
# don't accidentally miss any of this information while filtering
for non_ws_key in NON_WORKSPACE_KEYS:
try:
filtered_statuses[step_name][non_ws_key] = overall_step_info[non_ws_key]
except KeyError:
LOG.debug(
f"Tried to add {non_ws_key} to filtered_statuses dict but it was not found in requested_statuses[{step_name}]"
f"Tried to add {non_ws_key} to filtered_statuses dict "
f"but it was not found in requested_statuses[{step_name}]"
)

# Go through the actual statuses and filter them as necessary
Expand Down Expand Up @@ -916,7 +920,7 @@ def apply_max_tasks_limit(self):
self.args.max_tasks -= len(sub_step_workspaces)

# Merge in the task statuses that we're allowing
dict_deep_merge(new_status_dict[step_name], overall_step_info)
dict_deep_merge(new_status_dict[step_name], overall_step_info, conflict_handler=status_conflict_handler)

LOG.info(f"Limited the number of tasks to display to {max_tasks} tasks.")

Expand Down Expand Up @@ -1099,6 +1103,95 @@ def display(self, test_mode: Optional[bool] = False):
LOG.warning("No statuses to display.")


# Pylint complains that args is unused but we can ignore that
def status_conflict_handler(*args, **kwargs) -> Any: # pylint: disable=W0613
"""
The conflict handler function to apply to any status entries that have conflicting
values while merging two status files together.
kwargs should include:
- dict_a_val: The conflicting value from the dictionary that we're merging into
- dict_b_val: The conflicting value from the dictionary that we're pulling from
- key: The key into each dictionary that has a conflict
- path: The path down the dictionary tree that `dict_deep_merge` is currently at
When we're reading in status files, we're merging all of the statuses into one dictionary.
This function defines the merge rules in case there is a merge conflict. We ignore the list
and dictionary entries since `dict_deep_merge` from `utils.py` handles these scenarios already.
There are currently 4 rules:
- string-concatenate: take the two conflicting values and concatenate them in a string
- use-initial-and-log-warning: use the value from dict_a and log a warning message
- use-longest-time: use the longest time between the two conflicting values
- use-max: use the larger integer between the two conflicting values
:returns: The value to merge into dict_a at `key`
"""
# Grab the arguments passed into this function
dict_a_val = kwargs.get("dict_a_val", None)
dict_b_val = kwargs.get("dict_b_val", None)
key = kwargs.get("key", None)
path = kwargs.get("path", None)

merge_rules = {
"task_queue": "string-concatenate",
"worker_name": "string-concatenate",
"status": "use-initial-and-log-warning",
"return_code": "use-initial-and-log-warning",
"elapsed_time": "use-longest-time",
"run_time": "use-longest-time",
"restarts": "use-max",
}

# TODO
# - make status tracking more modular (see https://lc.llnl.gov/gitlab/weave/merlin/-/issues/58)
# - once it's more modular, move the below code and the above merge_rules dict to a property in
# one of the new status classes (the one that has condensing maybe? or upstream from that?)

# params = self.spec.get_parameters()
# for token in params.parameters:
# merge_rules[token] = "use-initial-and-log-warning"

# Set parameter token key rules (commented for loop would be better but it's
# only possible if this conflict handler is contained within Status object; however,
# since this function needs to be imported outside of this file we can't do that)
if path is not None and "parameters" in path:
merge_rules[key] = "use-initial-and-log-warning"

try:
merge_rule = merge_rules[key]
except KeyError:
LOG.warning(f"The key '{key}' does not have a merge rule defined. Setting this merge to None.")
return None

merge_val = None

if merge_rule == "string-concatenate":
merge_val = f"{dict_a_val}, {dict_b_val}"
elif merge_rule == "use-initial-and-log-warning":
LOG.warning(
f"Conflict at key '{key}' while merging status files. Defaulting to initial value. "
"This could lead to incorrect status information, you may want to re-run in debug mode and "
"check the files in the output directory for this task."
)
merge_val = dict_a_val
elif merge_rule == "use-longest-time":
if dict_a_val == "--:--:--":
merge_val = dict_b_val
elif dict_b_val == "--:--:--":
merge_val = dict_a_val
else:
dict_a_time = convert_to_timedelta(dict_a_val)
dict_b_time = convert_to_timedelta(dict_b_val)
merge_val = get_duration(max(dict_a_time, dict_b_time))
elif merge_rule == "use-max":
merge_val = max(dict_a_val, dict_b_val)
else:
LOG.warning(f"The merge_rule '{merge_rule}' was provided but it has no implementation.")

return merge_val


def read_status(
status_filepath: str, lock_file: str, display_fnf_message: bool = True, raise_errors: bool = False, timeout: int = 10
) -> Dict:
Expand All @@ -1112,6 +1205,8 @@ def read_status(
:param timeout: An integer representing how long to hold a lock for before timing out.
:returns: A dict of the contents in the status file
"""
statuses_read = {}

# Pylint complains that we're instantiating an abstract class but this is correct usage
lock = FileLock(lock_file) # pylint: disable=abstract-class-instantiated
try:
Expand All @@ -1122,25 +1217,24 @@ def read_status(
# Handle timeouts
except Timeout as to_exc:
LOG.warning(f"Timed out when trying to read status from '{status_filepath}'")
statuses_read = {}
if raise_errors:
raise Timeout from to_exc
raise to_exc
# Handle FNF errors
except FileNotFoundError as fnf_exc:
if display_fnf_message:
LOG.warning(f"Could not find '{status_filepath}'")
statuses_read = {}
if raise_errors:
raise FileNotFoundError from fnf_exc
raise fnf_exc
# Handle JSONDecode errors (this is likely due to an empty status file)
except json.decoder.JSONDecodeError as json_exc:
LOG.warning(f"JSONDecodeError raised when trying to read status from '{status_filepath}'")
if raise_errors:
raise json.decoder.JSONDecodeError from json_exc
raise json_exc
# Catch all exceptions so that we don't crash the workers
except Exception as exc: # pylint: disable=broad-except
LOG.warning(
f"An exception was raised while trying to read status from '{status_filepath}'!\n{print_exception(type(exc), exc, exc.__traceback__)}"
f"An exception was raised while trying to read status from '{status_filepath}'!\n"
f"{print_exception(type(exc), exc, exc.__traceback__)}"
)
if raise_errors:
raise exc
Expand All @@ -1167,5 +1261,6 @@ def write_status(status_to_write: Dict, status_filepath: str, lock_file: str, ti
# Catch all exceptions so that we don't crash the workers
except Exception as exc: # pylint: disable=broad-except
LOG.warning(
f"An exception was raised while trying to write status to '{status_filepath}'!\n{print_exception(type(exc), exc, exc.__traceback__)}"
f"An exception was raised while trying to write status to '{status_filepath}'!\n"
f"{print_exception(type(exc), exc, exc.__traceback__)}"
)
41 changes: 31 additions & 10 deletions merlin/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,12 @@
from copy import deepcopy
from datetime import datetime, timedelta
from types import SimpleNamespace
from typing import List, Optional, Union
from typing import Callable, List, Optional, Union

import numpy as np
import psutil
import yaml

from merlin.exceptions import DeepMergeException


try:
import cPickle as pickle
Expand Down Expand Up @@ -559,33 +557,51 @@ def needs_merlin_expansion(
return False


def dict_deep_merge(dict_a, dict_b, path=None):
def dict_deep_merge(dict_a: dict, dict_b: dict, path: str = None, conflict_handler: Callable = None):
"""
This function recursively merges dict_b into dict_a. The built-in
merge of dictionaries in python (dict(dict_a) | dict(dict_b)) does not do a
deep merge so this function is necessary. This will only merge in new keys,
it will NOT update existing ones.
it will NOT update existing ones, unless you specify a conflict handler function.
Credit to this stack overflow post: https://stackoverflow.com/a/7205107.
:param `dict_a`: A dict that we'll merge dict_b into
:param `dict_b`: A dict that we want to merge into dict_a
:param `path`: The path down the dictionary tree that we're currently at
:param `conflict_handler`: An optional function to handle conflicts between values at the same key.
The function should return the value to be used in the merged dictionary.
The default behavior without this argument is to log a warning.
"""

# Check to make sure we have valid dict_a and dict_b input
msgs = [
f"{name} '{actual_dict}' is not a dict"
for name, actual_dict in [("dict_a", dict_a), ("dict_b", dict_b)]
if not isinstance(actual_dict, dict)
]
if len(msgs) > 0:
LOG.warning(f"Problem with dict_deep_merge: {', '.join(msgs)}. Ignoring this merge call.")
return

if path is None:
path = []
for key in dict_b:
if key in dict_a:
if isinstance(dict_a[key], dict) and isinstance(dict_b[key], dict):
dict_deep_merge(dict_a[key], dict_b[key], path + [str(key)])
elif key == "workers": # specifically for status merging
all_workers = [dict_a[key], dict_b[key]]
dict_a[key] = list(set().union(*all_workers))
dict_deep_merge(dict_a[key], dict_b[key], path=path + [str(key)], conflict_handler=conflict_handler)
elif isinstance(dict_a[key], list) and isinstance(dict_a[key], list):
dict_a[key] += dict_b[key]
elif dict_a[key] == dict_b[key]:
pass # same leaf value
else:
raise DeepMergeException(f"Conflict at {'.'.join(path + [str(key)])}")
if conflict_handler is not None:
merged_val = conflict_handler(
dict_a_val=dict_a[key], dict_b_val=dict_b[key], key=key, path=path + [str(key)]
)
dict_a[key] = merged_val
else:
# Want to just output a warning instead of raising an exception so that the workflow doesn't crash
LOG.warning(f"Conflict at {'.'.join(path + [str(key)])}. Ignoring the update to key '{key}'.")
else:
dict_a[key] = dict_b[key]

Expand Down Expand Up @@ -619,6 +635,11 @@ def convert_to_timedelta(timestr: Union[str, int]) -> timedelta:
"""
# make sure it's a string in case we get an int
timestr = str(timestr)

# remove time unit characters (if any exist)
time_unit_chars = r"[dhms]"
timestr = re.sub(time_unit_chars, "", timestr)

nfields = len(timestr.split(":"))
if nfields > 4:
raise ValueError(f"Cannot convert {timestr} to a timedelta. Valid format: days:hours:minutes:seconds.")
Expand Down
1 change: 1 addition & 0 deletions requirements/dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ alabaster
johnnydep
deepdiff
pytest-order
pytest-mock
Loading

0 comments on commit f476a98

Please sign in to comment.