From 29573d43b17f97f90d93ce44773770754d7359de Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Mon, 20 Nov 2023 14:42:08 -0800 Subject: [PATCH 01/44] remove a merge conflict statement that was missed --- merlin/__init__.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/merlin/__init__.py b/merlin/__init__.py index dda10809c..c1ad21b22 100644 --- a/merlin/__init__.py +++ b/merlin/__init__.py @@ -38,11 +38,7 @@ import sys -<<<<<<< HEAD -__version__ = "1.10.2" -======= __version__ = "1.11.1" ->>>>>>> 38651f2650e8aba97552c4575e97d66be3205545 VERSION = __version__ PATH_TO_PROJ = os.path.join(os.path.dirname(__file__), "") From f10c896d9f67397fc3e8e6111742a2ba5e3257d1 Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Mon, 11 Dec 2023 15:30:05 -0800 Subject: [PATCH 02/44] add pytest coverage library and add sample_index coverage --- .gitignore | 3 +- merlin/common/sample_index.py | 2 +- requirements/dev.txt | 1 + tests/unit/common/test_sample_index.py | 672 ++++++++++++++++++------- 4 files changed, 508 insertions(+), 170 deletions(-) diff --git a/.gitignore b/.gitignore index c22521934..cec577a85 100644 --- a/.gitignore +++ b/.gitignore @@ -39,8 +39,9 @@ flux.out slurm*.out docs/build/ -# Tox files +# Test files .tox/* +.coverage # Jupyter jupyter/.ipynb_checkpoints diff --git a/merlin/common/sample_index.py b/merlin/common/sample_index.py index 4e3ac3a52..00295dab0 100644 --- a/merlin/common/sample_index.py +++ b/merlin/common/sample_index.py @@ -225,8 +225,8 @@ def __setitem__(self, full_address, sub_tree): # Replace if we already have something at this address. if delete_me is not None: - self.children.__delitem__(full_address) SampleIndex.check_valid_addresses_for_insertion(full_address, sub_tree) + self.children.__delitem__(full_address) self.children[full_address] = sub_tree return raise KeyError diff --git a/requirements/dev.txt b/requirements/dev.txt index 895a89249..ccf00e112 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -5,6 +5,7 @@ dep-license flake8 isort pytest +pytest-cov pylint twine sphinx>=2.0.0 diff --git a/tests/unit/common/test_sample_index.py b/tests/unit/common/test_sample_index.py index c693827f0..1237c52a1 100644 --- a/tests/unit/common/test_sample_index.py +++ b/tests/unit/common/test_sample_index.py @@ -1,178 +1,514 @@ import os +import pytest import shutil from contextlib import suppress +from merlin.common.sample_index import SampleIndex, uniform_directories, new_dir from merlin.common.sample_index_factory import create_hierarchy, read_hierarchy -TEST_DIR = "UNIT_TEST_SPACE" - - -def clear_test_tree(): - with suppress(FileNotFoundError): - shutil.rmtree(TEST_DIR) - - -def clear(func): - def wrapper(): - clear_test_tree() - func() - clear_test_tree() - - return wrapper - - -@clear -def test_index_file_writing(): - indx = create_hierarchy(1000000000, 10000, [100000000, 10000000, 1000000], root=TEST_DIR) - indx.write_directories() - indx.write_multiple_sample_index_files() - indx2 = read_hierarchy(TEST_DIR) - assert indx2.get_path_to_sample(123000123) == indx.get_path_to_sample(123000123) - - -def test_bundle_retrieval(): - indx = create_hierarchy(1000000000, 10000, [100000000, 10000000, 1000000], root=TEST_DIR) - expected = f"{TEST_DIR}/0/0/0/samples0-10000.ext" - result = indx.get_path_to_sample(123) - assert expected == result - - expected = f"{TEST_DIR}/0/0/0/samples10000-20000.ext" - result = indx.get_path_to_sample(10000) - assert expected == result - - expected = f"{TEST_DIR}/1/2/3/samples123000000-123010000.ext" - result = indx.get_path_to_sample(123000123) - assert expected == result - - -def test_start_sample_id(): - expected = """: DIRECTORY MIN 203 MAX 303 NUM_BUNDLES 10 - 0: BUNDLE 0 MIN 203 MAX 213 - 1: BUNDLE 1 MIN 213 MAX 223 - 2: BUNDLE 2 MIN 223 MAX 233 - 3: BUNDLE 3 MIN 233 MAX 243 - 4: BUNDLE 4 MIN 243 MAX 253 - 5: BUNDLE 5 MIN 253 MAX 263 - 6: BUNDLE 6 MIN 263 MAX 273 - 7: BUNDLE 7 MIN 273 MAX 283 - 8: BUNDLE 8 MIN 283 MAX 293 - 9: BUNDLE 9 MIN 293 MAX 303 -""" - idx203 = create_hierarchy(100, 10, start_sample_id=203) - assert expected == str(idx203) - - -@clear -def test_directory_writing(): - path = os.path.join(TEST_DIR) - indx = create_hierarchy(2, 1, [1], root=path) - expected = """: DIRECTORY MIN 0 MAX 2 NUM_BUNDLES 2 - 0: DIRECTORY MIN 0 MAX 1 NUM_BUNDLES 1 - 0.0: BUNDLE 0 MIN 0 MAX 1 - 1: DIRECTORY MIN 1 MAX 2 NUM_BUNDLES 1 - 1.0: BUNDLE 1 MIN 1 MAX 2 -""" - assert expected == str(indx) - indx.write_directories() - assert os.path.isdir(f"{TEST_DIR}/0") - assert os.path.isdir(f"{TEST_DIR}/1") - indx.write_multiple_sample_index_files() - - clear_test_tree() - - path = os.path.join(TEST_DIR) - indx = create_hierarchy(1000000000, 10000, [100000000, 10000000], root=path) - indx.write_directories() - path = indx.get_path_to_sample(123000123) - assert os.path.exists(os.path.dirname(path)) - assert path != TEST_DIR - path = indx.get_path_to_sample(10000000000) - assert path == TEST_DIR - - clear_test_tree() - - path = os.path.join(TEST_DIR) - indx = create_hierarchy(1000000000, 10000, [100000000, 10000000, 1000000], root=path) - indx.write_directories() - - -def test_directory_path(): - indx = create_hierarchy(20, 1, [20, 5, 1], root="") - leaves = indx.make_directory_string() - expected_leaves = "0/0/0 0/0/1 0/0/2 0/0/3 0/0/4 0/1/0 0/1/1 0/1/2 0/1/3 0/1/4 0/2/0 0/2/1 0/2/2 0/2/3 0/2/4 0/3/0 0/3/1 0/3/2 0/3/3 0/3/4" - assert leaves == expected_leaves - all_dirs = indx.make_directory_string(just_leaf_directories=False) - expected_all_dirs = " 0 0/0 0/0/0 0/0/1 0/0/2 0/0/3 0/0/4 0/1 0/1/0 0/1/1 0/1/2 0/1/3 0/1/4 0/2 0/2/0 0/2/1 0/2/2 0/2/3 0/2/4 0/3 0/3/0 0/3/1 0/3/2 0/3/3 0/3/4" - assert all_dirs == expected_all_dirs - - -@clear -def test_subhierarchy_insertion(): - indx = create_hierarchy(2, 1, [1], root=TEST_DIR) - print("Writing directories") - indx.write_directories() - indx.write_multiple_sample_index_files() - print("reading heirarchy") - top = read_hierarchy(os.path.abspath(TEST_DIR)) - expected = """: DIRECTORY MIN 0 MAX 2 NUM_BUNDLES 2 - 0: DIRECTORY MIN 0 MAX 1 NUM_BUNDLES 1 - 0.0: BUNDLE -1 MIN 0 MAX 1 - 1: DIRECTORY MIN 1 MAX 2 NUM_BUNDLES 1 - 1.0: BUNDLE -1 MIN 1 MAX 2 -""" - assert str(top) == expected - print("creating sub_heirarchy") - sub_h = create_hierarchy(100, 10, address="1.0") - print("inserting sub_heirarchy") - top["1.0"] = sub_h - print(str(indx)) - print("after insertion") - print(str(top)) - expected = """: DIRECTORY MIN 0 MAX 2 NUM_BUNDLES 2 - 0: DIRECTORY MIN 0 MAX 1 NUM_BUNDLES 1 - 0.0: BUNDLE -1 MIN 0 MAX 1 - 1: DIRECTORY MIN 1 MAX 2 NUM_BUNDLES 1 - 1.0: DIRECTORY MIN 0 MAX 100 NUM_BUNDLES 10 - 1.0.0: BUNDLE 0 MIN 0 MAX 10 - 1.0.1: BUNDLE 1 MIN 10 MAX 20 - 1.0.2: BUNDLE 2 MIN 20 MAX 30 - 1.0.3: BUNDLE 3 MIN 30 MAX 40 - 1.0.4: BUNDLE 4 MIN 40 MAX 50 - 1.0.5: BUNDLE 5 MIN 50 MAX 60 - 1.0.6: BUNDLE 6 MIN 60 MAX 70 - 1.0.7: BUNDLE 7 MIN 70 MAX 80 - 1.0.8: BUNDLE 8 MIN 80 MAX 90 - 1.0.9: BUNDLE 9 MIN 90 MAX 100 -""" - assert str(top) == expected - - -def test_sample_index(): - """Run through some basic testing of the SampleIndex class.""" +def test_uniform_directories(): + """ + Test the `uniform_directories` function with different inputs. + """ + # Create the tests and the expected outputs tests = [ - (10, 1, []), - (10, 3, []), - (11, 2, [5]), - (10, 3, [3]), - (10, 3, [1]), - (10, 1, [3]), - (10, 3, [1, 3]), - (10, 1, [2]), - (1000, 100, [500]), - (1000, 50, [500, 100]), - (1000000000, 100000132, []), + # SMALL SAMPLE SIZE + (10, 1, 100), # Bundle size of 1 and max dir level of 100 is default + (10, 1, 2), + (10, 2, 100), + (10, 2, 2), + # MEDIUM SAMPLE SIZE + (10000, 1, 100), # Bundle size of 1 and max dir level of 100 is default + (10000, 1, 5), + (10000, 5, 100), + (10000, 5, 10), + # LARGE SAMPLE SIZE + (1000000000, 1, 100), # Bundle size of 1 and max dir level of 100 is default + (1000000000, 1, 5), + (1000000000, 5, 100), + (1000000000, 5, 10), ] + expected_outputs = [ + # SMALL SAMPLE SIZE + [1], + [8, 4, 2, 1], + [2], + [8, 4, 2], + # MEDIUM SAMPLE SIZE + [100, 1], + [3125, 625, 125, 25, 5, 1], + [500, 5], + [5000, 500, 50, 5], + # LARGE SAMPLE SIZE + [100000000, 1000000, 10000, 100, 1], + [244140625, 48828125, 9765625, 1953125, 390625, 78125, 15625, 3125, 625, 125, 25, 5, 1], + [500000000, 5000000, 50000, 500, 5], + [500000000, 50000000, 5000000, 500000, 50000, 5000, 500, 50, 5], + ] + + # Run the tests and compare outputs + for i, test in enumerate(tests): + actual = uniform_directories(num_samples=test[0], bundle_size=test[1], level_max_dirs=test[2]) + assert actual == expected_outputs[i] + + +def test_new_dir(temp_output_dir: str): + """ + Test the `new_dir` function. This will test a valid path and also raising an OSError during + creation. + + :param temp_output_dir: A pytest fixture defined in conftest.py that creates a + temporary output path for our tests + """ + # Test basic functionality + test_path = f"{os.getcwd()}/test_new_dir" + new_dir(test_path) + assert os.path.exists(test_path) + + # Test OSError functionality + new_dir(test_path) + + + +class TestSampleIndex: + """ + These tests focus on testing the SampleIndex class used for creating the + sample hierarchy. + + NOTE to see output of creating any hierarchy, change `write_all_hierarchies` to True. + The results of each hierarchy will be written to: + /tmp/`whoami`/pytest/pytest-of-`whoami`/pytest-current/integration_outfiles_current/test_sample_index/ + """ + + write_all_hierarchies = False + + def get_working_dir(self, test_workspace: str): + """ + This method is called for every test to get a unique workspace in the temporary + directory for the test output. + + :param test_workspace: The unique name for this workspace + (all tests use their unique test name for this value usually) + """ + return f"{os.getcwd()}/test_sample_index/{test_workspace}" + + def write_hierarchy_for_debug(self, indx: SampleIndex): + """ + This method is for debugging purposes. It will cause all tests that don't write + hierarchies to write them so the output can be investigated. + + :param indx: The `SampleIndex` object to write the hierarchy for + """ + if self.write_all_hierarchies: + indx.write_directories() + indx.write_multiple_sample_index_files() + + def test_invalid_children(self): + """ + This will test that an invalid type for the `children` argument will raise + an error. + """ + tests = [ + ["a", "b", "c"], + True, + "a b c", + ] + for test in tests: + with pytest.raises(TypeError): + SampleIndex(0, 10, test, "name") + + def test_is_parent_of_leaf(self, temp_output_dir: str): + """ + Test the `is_parent_of_leaf` property. + + :param temp_output_dir: A pytest fixture defined in conftest.py that creates a + temporary output path for our tests + """ + # Create a hierarchy to test + working_dir = self.get_working_dir("test_is_parent_of_leaf") + indx = create_hierarchy(10, 1, [2], root=working_dir) + self.write_hierarchy_for_debug(indx) + + # Test to see if parent of leaf is recognized + assert indx.is_parent_of_leaf is False + assert indx.children["0"].is_parent_of_leaf is True + + # Test to see if leaf is recognized + leaf_node = indx.children["0"].children["0.0"] + assert leaf_node.is_parent_of_leaf is False + + def test_is_grandparent_of_leaf(self, temp_output_dir: str): + """ + Test the `is_grandparent_of_leaf` property. + + :param temp_output_dir: A pytest fixture defined in conftest.py that creates a + temporary output path for our tests + """ + # Create a hierarchy to test + working_dir = self.get_working_dir("test_is_grandparent_of_leaf") + indx = create_hierarchy(10, 1, [2], root=working_dir) + self.write_hierarchy_for_debug(indx) + + # Test to see if grandparent of leaf is recognized + assert indx.is_grandparent_of_leaf is True + assert indx.children["0"].is_grandparent_of_leaf is False + + # Test to see if leaf is recognized + leaf_node = indx.children["0"].children["0.0"] + assert leaf_node.is_grandparent_of_leaf is False + + def test_is_great_grandparent_of_leaf(self, temp_output_dir: str): + """ + Test the `is_great_grandparent_of_leaf` property. + + :param temp_output_dir: A pytest fixture defined in conftest.py that creates a + temporary output path for our tests + """ + # Create a hierarchy to test + working_dir = self.get_working_dir("test_is_great_grandparent_of_leaf") + indx = create_hierarchy(10, 1, [5, 1], root=working_dir) + self.write_hierarchy_for_debug(indx) + + # Test to see if great grandparent of leaf is recognized + assert indx.is_great_grandparent_of_leaf is True + assert indx.children["0"].is_great_grandparent_of_leaf is False + assert indx.children["0"].children["0.0"].is_great_grandparent_of_leaf is False + + # Test to see if leaf is recognized + leaf_node = indx.children["0"].children["0.0"].children["0.0.0"] + assert leaf_node.is_great_grandparent_of_leaf is False + + def test_traverse_bundle(self, temp_output_dir: str): + """ + Test the `traverse_bundle` method to make sure it's just returning leaves. + + :param temp_output_dir: A pytest fixture defined in conftest.py that creates a + temporary output path for our tests + """ + # Create a hierarchy to test + working_dir = self.get_working_dir("test_is_grandparent_of_leaf") + indx = create_hierarchy(10, 1, [2], root=working_dir) + self.write_hierarchy_for_debug(indx) + + # Ensure all nodes in the traversal are leaves + for _, node in indx.traverse_bundles(): + assert node.is_leaf + + def test_getitem(self, temp_output_dir: str): + """ + Test the `__getitem__` magic method. + + :param temp_output_dir: A pytest fixture defined in conftest.py that creates a + temporary output path for our tests + """ + # Create a hierarchy to test + working_dir = self.get_working_dir("test_is_grandparent_of_leaf") + indx = create_hierarchy(10, 1, [2], root=working_dir) + self.write_hierarchy_for_debug(indx) + + # Test getting that requesting the root returns itself + assert indx[""] == indx + + # Test a valid address + assert indx["0"] == indx.children["0"] + + # Test an invalid address + with pytest.raises(KeyError): + indx["10"] + + def test_setitem(self, temp_output_dir: str): + """ + Test the `__setitem__` magic method. + + :param temp_output_dir: A pytest fixture defined in conftest.py that creates a + temporary output path for our tests + """ + # Create a hierarchy to test + working_dir = self.get_working_dir("test_is_grandparent_of_leaf") + indx = create_hierarchy(10, 1, [2], root=working_dir) + self.write_hierarchy_for_debug(indx) + + invalid_indx = SampleIndex(1, 3, {}, "invalid_indx") + + # Ensure that trying to change the root raises an error + with pytest.raises(KeyError): + indx[""] = invalid_indx + + # Ensure we can't just add a new subtree to a level + with pytest.raises(KeyError): + indx["10"] = invalid_indx + + # Test that invalid subtrees are caught + with pytest.raises(TypeError): + indx["0"] = invalid_indx + + # Test a valid set operation + dummy_indx = SampleIndex(0, 1, {}, "dummy_indx", leafid=0, address="0.0") + indx["0"]["0.0"] = dummy_indx + + + def test_index_file_writing(self, temp_output_dir: str): + """ + Test the functionality of writing multiple index files. + + :param temp_output_dir: A pytest fixture defined in conftest.py that creates a + temporary output path for our tests + """ + working_dir = self.get_working_dir("test_index_file_writing") + indx = create_hierarchy(1000000000, 10000, [100000000, 10000000, 1000000], root=working_dir) + indx.write_directories() + indx.write_multiple_sample_index_files() + indx2 = read_hierarchy(working_dir) + assert indx2.get_path_to_sample(123000123) == indx.get_path_to_sample(123000123) + + def test_directory_writing_small(self, temp_output_dir: str): + """ + Test that writing a small directory functions properly. + + :param temp_output_dir: A pytest fixture defined in conftest.py that creates a + temporary output path for our tests + """ + # Create the directory and ensure it has the correct format + working_dir = self.get_working_dir("test_directory_writing_small/") + indx = create_hierarchy(2, 1, [1], root=working_dir) + expected = ": DIRECTORY MIN 0 MAX 2 NUM_BUNDLES 2\n" \ + " 0: DIRECTORY MIN 0 MAX 1 NUM_BUNDLES 1\n" \ + " 0.0: BUNDLE 0 MIN 0 MAX 1\n" \ + " 1: DIRECTORY MIN 1 MAX 2 NUM_BUNDLES 1\n" \ + " 1.0: BUNDLE 1 MIN 1 MAX 2\n" \ + + assert expected == str(indx) + + # Write the directories and ensure the paths are actually written + indx.write_directories() + assert os.path.isdir(f"{working_dir}/0") + assert os.path.isdir(f"{working_dir}/1") + indx.write_multiple_sample_index_files() + + def test_directory_writing_large(self, temp_output_dir: str): + """ + Test that writing a large directory functions properly. + + :param temp_output_dir: A pytest fixture defined in conftest.py that creates a + temporary output path for our tests + """ + working_dir = self.get_working_dir("test_directory_writing_large") + indx = create_hierarchy(1000000000, 10000, [100000000, 10000000, 1000000], root=working_dir) + indx.write_directories() + path = indx.get_path_to_sample(123000123) + assert os.path.exists(os.path.dirname(path)) + assert path != working_dir + path = indx.get_path_to_sample(10000000000) + assert path == working_dir + + def test_bundle_retrieval(self, temp_output_dir: str): + """ + Test the functionality to get a bundle of samples when providing a sample id to find. + This will test a large sample hierarchy to ensure this scales properly. + + :param temp_output_dir: A pytest fixture defined in conftest.py that creates a + temporary output path for our tests + """ + # Create the hierarchy + working_dir = self.get_working_dir("test_bundle_retrieval") + indx = create_hierarchy(1000000000, 10000, [100000000, 10000000, 1000000], root=working_dir) + self.write_hierarchy_for_debug(indx) + + # Test for a small sample id + expected = f"{working_dir}/0/0/0/samples0-10000.ext" + result = indx.get_path_to_sample(123) + assert expected == result + + # Test for a mid size sample id + expected = f"{working_dir}/0/0/0/samples10000-20000.ext" + result = indx.get_path_to_sample(10000) + assert expected == result + + # Test for a large sample id + expected = f"{working_dir}/1/2/3/samples123000000-123010000.ext" + result = indx.get_path_to_sample(123000123) + assert expected == result + + def test_start_sample_id(self, temp_output_dir: str): + """ + Test creating a hierarchy using a starting sample id. + + :param temp_output_dir: A pytest fixture defined in conftest.py that creates a + temporary output path for our tests + """ + working_dir = self.get_working_dir("test_start_sample_id") + expected = ": DIRECTORY MIN 203 MAX 303 NUM_BUNDLES 10\n" \ + " 0: BUNDLE 0 MIN 203 MAX 213\n" \ + " 1: BUNDLE 1 MIN 213 MAX 223\n" \ + " 2: BUNDLE 2 MIN 223 MAX 233\n" \ + " 3: BUNDLE 3 MIN 233 MAX 243\n" \ + " 4: BUNDLE 4 MIN 243 MAX 253\n" \ + " 5: BUNDLE 5 MIN 253 MAX 263\n" \ + " 6: BUNDLE 6 MIN 263 MAX 273\n" \ + " 7: BUNDLE 7 MIN 273 MAX 283\n" \ + " 8: BUNDLE 8 MIN 283 MAX 293\n" \ + " 9: BUNDLE 9 MIN 293 MAX 303\n" \ + + idx203 = create_hierarchy(100, 10, start_sample_id=203, root=working_dir) + self.write_hierarchy_for_debug(idx203) + + assert expected == str(idx203) + + def test_make_directory_string(self, temp_output_dir: str): + """ + Test the `make_directory_string` method of `SampleIndex`. This will check + both the normal functionality where we just request paths to the leaves and + also the inverse functionality where we request all paths that are not leaves. + + :param temp_output_dir: A pytest fixture defined in conftest.py that creates a + temporary output path for our tests + """ + # Creating the hierarchy + working_dir = self.get_working_dir("test_make_directory_string") + indx = create_hierarchy(20, 1, [20, 5, 1], root=working_dir) + self.write_hierarchy_for_debug(indx) + + # Testing normal functionality (just leaf directories) + leaves = indx.make_directory_string() + expected_leaves_list = [ + f"{working_dir}/0/0/0", + f"{working_dir}/0/0/1", + f"{working_dir}/0/0/2", + f"{working_dir}/0/0/3", + f"{working_dir}/0/0/4", + f"{working_dir}/0/1/0", + f"{working_dir}/0/1/1", + f"{working_dir}/0/1/2", + f"{working_dir}/0/1/3", + f"{working_dir}/0/1/4", + f"{working_dir}/0/2/0", + f"{working_dir}/0/2/1", + f"{working_dir}/0/2/2", + f"{working_dir}/0/2/3", + f"{working_dir}/0/2/4", + f"{working_dir}/0/3/0", + f"{working_dir}/0/3/1", + f"{working_dir}/0/3/2", + f"{working_dir}/0/3/3", + f"{working_dir}/0/3/4", + ] + expected_leaves = " ".join(expected_leaves_list) + assert leaves == expected_leaves + + # Testing no leaf functionality + all_dirs = indx.make_directory_string(just_leaf_directories=False) + expected_all_dirs_list = [ + working_dir, + f"{working_dir}/0", + f"{working_dir}/0/0", + f"{working_dir}/0/0/0", + f"{working_dir}/0/0/1", + f"{working_dir}/0/0/2", + f"{working_dir}/0/0/3", + f"{working_dir}/0/0/4", + f"{working_dir}/0/1", + f"{working_dir}/0/1/0", + f"{working_dir}/0/1/1", + f"{working_dir}/0/1/2", + f"{working_dir}/0/1/3", + f"{working_dir}/0/1/4", + f"{working_dir}/0/2", + f"{working_dir}/0/2/0", + f"{working_dir}/0/2/1", + f"{working_dir}/0/2/2", + f"{working_dir}/0/2/3", + f"{working_dir}/0/2/4", + f"{working_dir}/0/3", + f"{working_dir}/0/3/0", + f"{working_dir}/0/3/1", + f"{working_dir}/0/3/2", + f"{working_dir}/0/3/3", + f"{working_dir}/0/3/4" + ] + expected_all_dirs = " ".join(expected_all_dirs_list) + assert all_dirs == expected_all_dirs + + def test_subhierarchy_insertion(self, temp_output_dir: str): + """ + Test that a subhierarchy can be inserted into our `SampleIndex` properly. + + :param temp_output_dir: A pytest fixture defined in conftest.py that creates a + temporary output path for our tests + """ + # Create the hierarchy and read it + working_dir = self.get_working_dir("test_subhierarchy_insertion") + indx = create_hierarchy(2, 1, [1], root=working_dir) + indx.write_directories() + indx.write_multiple_sample_index_files() + top = read_hierarchy(os.path.abspath(working_dir)) + + # Compare results + expected = ": DIRECTORY MIN 0 MAX 2 NUM_BUNDLES 2\n" \ + " 0: DIRECTORY MIN 0 MAX 1 NUM_BUNDLES 1\n" \ + " 0.0: BUNDLE -1 MIN 0 MAX 1\n" \ + " 1: DIRECTORY MIN 1 MAX 2 NUM_BUNDLES 1\n" \ + " 1.0: BUNDLE -1 MIN 1 MAX 2\n" \ + + assert str(top) == expected + + # Create and insert the sub hierarchy + sub_h = create_hierarchy(100, 10, address="1.0") + top["1.0"] = sub_h + + # Compare results + expected = ": DIRECTORY MIN 0 MAX 2 NUM_BUNDLES 2\n" \ + " 0: DIRECTORY MIN 0 MAX 1 NUM_BUNDLES 1\n" \ + " 0.0: BUNDLE -1 MIN 0 MAX 1\n" \ + " 1: DIRECTORY MIN 1 MAX 2 NUM_BUNDLES 1\n" \ + " 1.0: DIRECTORY MIN 0 MAX 100 NUM_BUNDLES 10\n" \ + " 1.0.0: BUNDLE 0 MIN 0 MAX 10\n" \ + " 1.0.1: BUNDLE 1 MIN 10 MAX 20\n" \ + " 1.0.2: BUNDLE 2 MIN 20 MAX 30\n" \ + " 1.0.3: BUNDLE 3 MIN 30 MAX 40\n" \ + " 1.0.4: BUNDLE 4 MIN 40 MAX 50\n" \ + " 1.0.5: BUNDLE 5 MIN 50 MAX 60\n" \ + " 1.0.6: BUNDLE 6 MIN 60 MAX 70\n" \ + " 1.0.7: BUNDLE 7 MIN 70 MAX 80\n" \ + " 1.0.8: BUNDLE 8 MIN 80 MAX 90\n" \ + " 1.0.9: BUNDLE 9 MIN 90 MAX 100\n" \ + + assert str(top) == expected + + def test_sample_index_creation_and_insertion(self, temp_output_dir: str): + """ + Run through some basic testing of the SampleIndex class. This will try + creating hierarchies of different sizes and inserting subhierarchies of + different sizes as well. + + :param temp_output_dir: A pytest fixture defined in conftest.py that creates a + temporary output path for our tests + """ + # Define the tests for hierarchies of varying sizes + tests = [ + (10, 1, []), + (10, 3, []), + (11, 2, [5]), + (10, 3, [3]), + (10, 3, [1]), + (10, 1, [3]), + (10, 3, [1, 3]), + (10, 1, [2]), + (1000, 100, [500]), + (1000, 50, [500, 100]), + (1000000000, 100000132, []), + ] + + # Run all the tests we defined above + for i, args in enumerate(tests): + working_dir = self.get_working_dir(f"test_sample_index_creation_and_insertion/{i}") + + # Put at root address of "0" to guarantee insertion at "0.1" later is valid + idx = create_hierarchy(args[0], args[1], args[2], address="0", root=working_dir) + self.write_hierarchy_for_debug(idx) - for args in tests: - print(f"############ TEST {args[0]} {args[1]} {args[2]} ###########") - # put at root address of "0" to guarantee insertion at "0.1" later is valid - idx = create_hierarchy(args[0], args[1], args[2], address="0") - print(str(idx)) - try: - idx["0.1"] = create_hierarchy(args[0], args[1], args[2], address="0.1") - print("successful set") - print(str(idx)) - except KeyError as error: - print(error) - assert False + # Inserting hierarchy at 0.1 + try: + idx["0.1"] = create_hierarchy(args[0], args[1], args[2], address="0.1") + except KeyError as error: + assert False From 362478ef084c44876d54571616e0e78e346b7bc5 Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Tue, 12 Dec 2023 09:31:22 -0800 Subject: [PATCH 03/44] run fix style and add module header --- tests/unit/common/test_sample_index.py | 100 +++++++++++++------------ 1 file changed, 52 insertions(+), 48 deletions(-) diff --git a/tests/unit/common/test_sample_index.py b/tests/unit/common/test_sample_index.py index 1237c52a1..296783273 100644 --- a/tests/unit/common/test_sample_index.py +++ b/tests/unit/common/test_sample_index.py @@ -1,9 +1,11 @@ +""" +Tests for the `sample_index.py` and `sample_index_factory.py` files. +""" import os + import pytest -import shutil -from contextlib import suppress -from merlin.common.sample_index import SampleIndex, uniform_directories, new_dir +from merlin.common.sample_index import SampleIndex, new_dir, uniform_directories from merlin.common.sample_index_factory import create_hierarchy, read_hierarchy @@ -70,7 +72,6 @@ def test_new_dir(temp_output_dir: str): new_dir(test_path) - class TestSampleIndex: """ These tests focus on testing the SampleIndex class used for creating the @@ -228,7 +229,7 @@ def test_setitem(self, temp_output_dir: str): working_dir = self.get_working_dir("test_is_grandparent_of_leaf") indx = create_hierarchy(10, 1, [2], root=working_dir) self.write_hierarchy_for_debug(indx) - + invalid_indx = SampleIndex(1, 3, {}, "invalid_indx") # Ensure that trying to change the root raises an error @@ -247,7 +248,6 @@ def test_setitem(self, temp_output_dir: str): dummy_indx = SampleIndex(0, 1, {}, "dummy_indx", leafid=0, address="0.0") indx["0"]["0.0"] = dummy_indx - def test_index_file_writing(self, temp_output_dir: str): """ Test the functionality of writing multiple index files. @@ -272,12 +272,13 @@ def test_directory_writing_small(self, temp_output_dir: str): # Create the directory and ensure it has the correct format working_dir = self.get_working_dir("test_directory_writing_small/") indx = create_hierarchy(2, 1, [1], root=working_dir) - expected = ": DIRECTORY MIN 0 MAX 2 NUM_BUNDLES 2\n" \ - " 0: DIRECTORY MIN 0 MAX 1 NUM_BUNDLES 1\n" \ - " 0.0: BUNDLE 0 MIN 0 MAX 1\n" \ - " 1: DIRECTORY MIN 1 MAX 2 NUM_BUNDLES 1\n" \ - " 1.0: BUNDLE 1 MIN 1 MAX 2\n" \ - + expected = ( + ": DIRECTORY MIN 0 MAX 2 NUM_BUNDLES 2\n" + " 0: DIRECTORY MIN 0 MAX 1 NUM_BUNDLES 1\n" + " 0.0: BUNDLE 0 MIN 0 MAX 1\n" + " 1: DIRECTORY MIN 1 MAX 2 NUM_BUNDLES 1\n" + " 1.0: BUNDLE 1 MIN 1 MAX 2\n" + ) assert expected == str(indx) # Write the directories and ensure the paths are actually written @@ -338,18 +339,19 @@ def test_start_sample_id(self, temp_output_dir: str): temporary output path for our tests """ working_dir = self.get_working_dir("test_start_sample_id") - expected = ": DIRECTORY MIN 203 MAX 303 NUM_BUNDLES 10\n" \ - " 0: BUNDLE 0 MIN 203 MAX 213\n" \ - " 1: BUNDLE 1 MIN 213 MAX 223\n" \ - " 2: BUNDLE 2 MIN 223 MAX 233\n" \ - " 3: BUNDLE 3 MIN 233 MAX 243\n" \ - " 4: BUNDLE 4 MIN 243 MAX 253\n" \ - " 5: BUNDLE 5 MIN 253 MAX 263\n" \ - " 6: BUNDLE 6 MIN 263 MAX 273\n" \ - " 7: BUNDLE 7 MIN 273 MAX 283\n" \ - " 8: BUNDLE 8 MIN 283 MAX 293\n" \ - " 9: BUNDLE 9 MIN 293 MAX 303\n" \ - + expected = ( + ": DIRECTORY MIN 203 MAX 303 NUM_BUNDLES 10\n" + " 0: BUNDLE 0 MIN 203 MAX 213\n" + " 1: BUNDLE 1 MIN 213 MAX 223\n" + " 2: BUNDLE 2 MIN 223 MAX 233\n" + " 3: BUNDLE 3 MIN 233 MAX 243\n" + " 4: BUNDLE 4 MIN 243 MAX 253\n" + " 5: BUNDLE 5 MIN 253 MAX 263\n" + " 6: BUNDLE 6 MIN 263 MAX 273\n" + " 7: BUNDLE 7 MIN 273 MAX 283\n" + " 8: BUNDLE 8 MIN 283 MAX 293\n" + " 9: BUNDLE 9 MIN 293 MAX 303\n" + ) idx203 = create_hierarchy(100, 10, start_sample_id=203, root=working_dir) self.write_hierarchy_for_debug(idx203) @@ -424,7 +426,7 @@ def test_make_directory_string(self, temp_output_dir: str): f"{working_dir}/0/3/1", f"{working_dir}/0/3/2", f"{working_dir}/0/3/3", - f"{working_dir}/0/3/4" + f"{working_dir}/0/3/4", ] expected_all_dirs = " ".join(expected_all_dirs_list) assert all_dirs == expected_all_dirs @@ -444,12 +446,13 @@ def test_subhierarchy_insertion(self, temp_output_dir: str): top = read_hierarchy(os.path.abspath(working_dir)) # Compare results - expected = ": DIRECTORY MIN 0 MAX 2 NUM_BUNDLES 2\n" \ - " 0: DIRECTORY MIN 0 MAX 1 NUM_BUNDLES 1\n" \ - " 0.0: BUNDLE -1 MIN 0 MAX 1\n" \ - " 1: DIRECTORY MIN 1 MAX 2 NUM_BUNDLES 1\n" \ - " 1.0: BUNDLE -1 MIN 1 MAX 2\n" \ - + expected = ( + ": DIRECTORY MIN 0 MAX 2 NUM_BUNDLES 2\n" + " 0: DIRECTORY MIN 0 MAX 1 NUM_BUNDLES 1\n" + " 0.0: BUNDLE -1 MIN 0 MAX 1\n" + " 1: DIRECTORY MIN 1 MAX 2 NUM_BUNDLES 1\n" + " 1.0: BUNDLE -1 MIN 1 MAX 2\n" + ) assert str(top) == expected # Create and insert the sub hierarchy @@ -457,22 +460,23 @@ def test_subhierarchy_insertion(self, temp_output_dir: str): top["1.0"] = sub_h # Compare results - expected = ": DIRECTORY MIN 0 MAX 2 NUM_BUNDLES 2\n" \ - " 0: DIRECTORY MIN 0 MAX 1 NUM_BUNDLES 1\n" \ - " 0.0: BUNDLE -1 MIN 0 MAX 1\n" \ - " 1: DIRECTORY MIN 1 MAX 2 NUM_BUNDLES 1\n" \ - " 1.0: DIRECTORY MIN 0 MAX 100 NUM_BUNDLES 10\n" \ - " 1.0.0: BUNDLE 0 MIN 0 MAX 10\n" \ - " 1.0.1: BUNDLE 1 MIN 10 MAX 20\n" \ - " 1.0.2: BUNDLE 2 MIN 20 MAX 30\n" \ - " 1.0.3: BUNDLE 3 MIN 30 MAX 40\n" \ - " 1.0.4: BUNDLE 4 MIN 40 MAX 50\n" \ - " 1.0.5: BUNDLE 5 MIN 50 MAX 60\n" \ - " 1.0.6: BUNDLE 6 MIN 60 MAX 70\n" \ - " 1.0.7: BUNDLE 7 MIN 70 MAX 80\n" \ - " 1.0.8: BUNDLE 8 MIN 80 MAX 90\n" \ - " 1.0.9: BUNDLE 9 MIN 90 MAX 100\n" \ - + expected = ( + ": DIRECTORY MIN 0 MAX 2 NUM_BUNDLES 2\n" + " 0: DIRECTORY MIN 0 MAX 1 NUM_BUNDLES 1\n" + " 0.0: BUNDLE -1 MIN 0 MAX 1\n" + " 1: DIRECTORY MIN 1 MAX 2 NUM_BUNDLES 1\n" + " 1.0: DIRECTORY MIN 0 MAX 100 NUM_BUNDLES 10\n" + " 1.0.0: BUNDLE 0 MIN 0 MAX 10\n" + " 1.0.1: BUNDLE 1 MIN 10 MAX 20\n" + " 1.0.2: BUNDLE 2 MIN 20 MAX 30\n" + " 1.0.3: BUNDLE 3 MIN 30 MAX 40\n" + " 1.0.4: BUNDLE 4 MIN 40 MAX 50\n" + " 1.0.5: BUNDLE 5 MIN 50 MAX 60\n" + " 1.0.6: BUNDLE 6 MIN 60 MAX 70\n" + " 1.0.7: BUNDLE 7 MIN 70 MAX 80\n" + " 1.0.8: BUNDLE 8 MIN 80 MAX 90\n" + " 1.0.9: BUNDLE 9 MIN 90 MAX 100\n" + ) assert str(top) == expected def test_sample_index_creation_and_insertion(self, temp_output_dir: str): @@ -510,5 +514,5 @@ def test_sample_index_creation_and_insertion(self, temp_output_dir: str): # Inserting hierarchy at 0.1 try: idx["0.1"] = create_hierarchy(args[0], args[1], args[2], address="0.1") - except KeyError as error: + except KeyError: assert False From 9339b6bd439ba08ff6094eda7ae207e88d69c589 Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Tue, 12 Dec 2023 09:31:43 -0800 Subject: [PATCH 04/44] add tests for encryption modules --- merlin/common/security/encrypt.py | 9 +- tests/conftest.py | 37 ++++++++ tests/encryption_manager.py | 49 ++++++++++ tests/unit/common/test_encryption.py | 129 +++++++++++++++++++++++++++ 4 files changed, 219 insertions(+), 5 deletions(-) create mode 100644 tests/encryption_manager.py create mode 100644 tests/unit/common/test_encryption.py diff --git a/merlin/common/security/encrypt.py b/merlin/common/security/encrypt.py index 806d42e0c..a9f4a7107 100644 --- a/merlin/common/security/encrypt.py +++ b/merlin/common/security/encrypt.py @@ -52,11 +52,10 @@ def _get_key_path(): except AttributeError: key_filepath = "~/.merlin/encrypt_data_key" - try: - key_filepath = os.path.abspath(os.path.expanduser(key_filepath)) - except KeyError as e: - raise ValueError("Error! No password provided for RabbitMQ") from e - return key_filepath + if key_filepath is None: + raise ValueError("Error! No password provided for RabbitMQ") + + return os.path.abspath(os.path.expanduser(key_filepath)) def _gen_key(key_path): diff --git a/tests/conftest.py b/tests/conftest.py index 88932c5db..ca4237319 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -42,6 +42,8 @@ from _pytest.tmpdir import TempPathFactory from celery import Celery +from tests.encryption_manager import EncryptionManager + class RedisServerError(Exception): """ @@ -300,3 +302,38 @@ def launch_workers(celery_app: Celery, worker_queue_map: Dict[str, str]): # pyl # Shut down the workers and terminate the processes celery_app.control.broadcast("shutdown", destination=list(worker_queue_map.keys())) shutdown_processes(worker_processes, echo_processes) + + +@pytest.fixture(scope="session") +def encryption_output_dir(temp_output_dir: str) -> str: # pylint: disable=redefined-outer-name + """ + Get a temporary output directory for our encryption tests. + + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + """ + encryption_dir = f"{temp_output_dir}/encryption_tests" + os.mkdir(encryption_dir) + return encryption_dir + + +@pytest.fixture(scope="session") +def test_encryption_key() -> bytes: + """An encryption key to be used for tests that need it""" + return b"Q3vLp07Ljm60ahfU9HwOOnfgGY91lSrUmqcTiP0v9i0=" + + +@pytest.fixture(scope="class") +def use_fake_encrypt_data_key(encryption_output_dir: str, test_encryption_key: bytes): # pylint: disable=redefined-outer-name + """ + Create a fake encrypt data key to use for these tests. This will save the + current data key so we can set it back to what it was prior to running + the tests. + + :param encryption_output_dir: The path to the temporary output directory we'll be using for this test run + """ + # Use a context manager to ensure cleanup runs even if an error occurs + with EncryptionManager(encryption_output_dir, test_encryption_key) as encrypt_manager: + # Set the fake encryption key + encrypt_manager.set_fake_key() + # Yield control to the tests + yield diff --git a/tests/encryption_manager.py b/tests/encryption_manager.py new file mode 100644 index 000000000..883b1a184 --- /dev/null +++ b/tests/encryption_manager.py @@ -0,0 +1,49 @@ +""" +Module to define functionality for managing encryption settings +while running our test suite. +""" +import os +from types import TracebackType +from typing import Type + +from merlin.config.configfile import CONFIG + + +class EncryptionManager: + """ + A class to handle safe setup and teardown of encryption tests. + """ + + def __init__(self, temp_output_dir: str, test_encryption_key: bytes): + self.temp_output_dir = temp_output_dir + self.key_path = os.path.abspath(os.path.expanduser(f"{self.temp_output_dir}/encrypt_data_key")) + self.test_encryption_key = test_encryption_key + self.orig_results_backend = CONFIG.results_backend + + def __enter__(self): + """This magic method is necessary for allowing this class to be sued as a context manager""" + return self + + def __exit__(self, exc_type: Type[Exception], exc_value: Exception, traceback: TracebackType): + """ + This will always run at the end of a context with statement, even if an error is raised. + It's a safe way to ensure all of our encryption settings at the start of the tests are reset. + """ + self.reset_encryption_settings() + + def set_fake_key(self): + """ + Create a fake encrypt data key to use for tests. This will save the fake encryption key to + our temporary output directory located at: + /tmp/`whoami`/pytest-of-`whoami`/pytest-current/integration_outfiles_current/encryption_tests/ + """ + with open(self.key_path, "w") as key_file: + key_file.write(self.test_encryption_key.decode("utf-8")) + + CONFIG.results_backend.encryption_key = self.key_path + + def reset_encryption_settings(self): + """ + Reset the encryption settings to what they were prior to running our encryption tests. + """ + CONFIG.results_backend = self.orig_results_backend diff --git a/tests/unit/common/test_encryption.py b/tests/unit/common/test_encryption.py new file mode 100644 index 000000000..6daa53817 --- /dev/null +++ b/tests/unit/common/test_encryption.py @@ -0,0 +1,129 @@ +""" +Tests for the `encrypt.py` and `encrypt_backend_traffic.py` files. +""" +import os + +import celery +import pytest + +from merlin.common.security.encrypt import _gen_key, _get_key, _get_key_path, decrypt, encrypt +from merlin.common.security.encrypt_backend_traffic import _decrypt_decode, _encrypt_encode, set_backend_funcs +from merlin.config.configfile import CONFIG + + +class TestEncryption: + """ + This class will house all tests necessary for our encryption modules. + """ + + def test_encrypt(self, use_fake_encrypt_data_key: "fixture"): # noqa: F821 + """ + Test that our encryption function is encrypting the bytes that we're + passing to it. + + :param use_fake_encrypt_data_key: A fixture to set up a fake encryption key for testing + """ + str_to_encrypt = b"super secret string shhh" + encrypted_str = encrypt(str_to_encrypt) + for word in str_to_encrypt.decode("utf-8").split(" "): + assert word not in encrypted_str.decode("utf-8") + + def test_decrypt(self, use_fake_encrypt_data_key: "fixture"): # noqa: F821 + """ + Test that our decryption function is decrypting the bytes that we're + passing to it. + + :param use_fake_encrypt_data_key: A fixture to set up a fake encryption key for testing + """ + # This is the output of the bytes from the encrypt test + str_to_decrypt = b"gAAAAABld6k-jEncgCW5AePgrwn-C30dhr7dzGVhqzcqskPqFyA2Hdg3VWmo0qQnLklccaUYzAGlB4PMxyp4T-1gAYlAOf_7sC_bJOEcYOIkhZFoH6cX4Uw=" + decrypted_str = decrypt(str_to_decrypt) + assert decrypted_str == b"super secret string shhh" + + def test_get_key_path(self, use_fake_encrypt_data_key: "fixture"): # noqa F821 + """ + Test the `_get_key_path` function. + + :param use_fake_encrypt_data_key: A fixture to set up a fake encryption key for testing + """ + # Test the default behavior (`_get_key_path` will pull from CONFIG.results_backend which + # will be set to the temporary output path for our tests in the `use_fake_encrypt_data_key` fixture) + user = os.getlogin() + actual_default = _get_key_path() + assert actual_default.startswith(f"/tmp/{user}/") and actual_default.endswith("/encryption_tests/encrypt_data_key") + + # Test with having the encryption key set to None + temp = CONFIG.results_backend.encryption_key + CONFIG.results_backend.encryption_key = None + with pytest.raises(ValueError) as excinfo: + _get_key_path() + assert "Error! No password provided for RabbitMQ" in str(excinfo.value) + CONFIG.results_backend.encryption_key = temp + + # Test with having the entire results_backend wiped from CONFIG + orig_results_backend = CONFIG.results_backend + CONFIG.results_backend = None + actual_no_results_backend = _get_key_path() + assert actual_no_results_backend == os.path.abspath(os.path.expanduser("~/.merlin/encrypt_data_key")) + CONFIG.results_backend = orig_results_backend + + def test_gen_key(self, encryption_output_dir: str): + """ + Test the `_gen_key` function. + + :param encryption_output_dir: A fixture to create a temporary output directory for our encryption tests + """ + # Create the file but don't put anything in it + key_gen_test_file = f"{encryption_output_dir}/key_gen_test" + with open(key_gen_test_file, "w"): + pass + + # Ensure nothing is in the file + with open(key_gen_test_file, "r") as key_gen_file: + key_gen_contents = key_gen_file.read() + assert key_gen_contents == "" + + # Run the test and then check to make sure the file is now populated + _gen_key(key_gen_test_file) + with open(key_gen_test_file, "r") as key_gen_file: + key_gen_contents = key_gen_file.read() + assert key_gen_contents != "" + + def test_get_key(self, use_fake_encrypt_data_key: str, encryption_output_dir: str, test_encryption_key: bytes): + """ + Test the `_get_key` function. + + :param use_fake_encrypt_data_key: A fixture to set up a fake encryption key for testing + :param encryption_output_dir: A fixture to create a temporary output directory for our encryption tests + :param test_encryption_key: A fixture to establish a fixed encryption key for testing + """ + # Test the default functionality + actual_default = _get_key() + assert actual_default == test_encryption_key + + # Modify the permission of the key file so that it can't be read by anyone + # (we're purposefully trying to raise an IOError) + key_path = f"{encryption_output_dir}/encrypt_data_key" + orig_file_permissions = os.stat(key_path).st_mode + os.chmod(key_path, 0o222) + with pytest.raises(IOError): + _get_key() + os.chmod(key_path, orig_file_permissions) + + # Reset the key value to our test value since the IOError test will rewrite the key + with open(key_path, "w") as key_file: + key_file.write(test_encryption_key.decode("utf-8")) + + def test_set_backend_funcs(self): + """ + Test the `set_backend_funcs` function. + """ + # Make sure these values haven't been set yet + assert celery.backends.base.Backend.encode != _encrypt_encode + assert celery.backends.base.Backend.decode != _decrypt_decode + + set_backend_funcs() + + # Ensure the new functions have been set + assert celery.backends.base.Backend.encode == _encrypt_encode + assert celery.backends.base.Backend.decode == _decrypt_decode From 54a31bce2bf5a5bad5dd9ff25b7bfd32476e0aa8 Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Tue, 12 Dec 2023 10:52:34 -0800 Subject: [PATCH 05/44] add unit tests for util_sampling --- merlin/common/util_sampling.py | 1 + tests/unit/common/test_util_sampling.py | 44 +++++++++++++++++++++++++ 2 files changed, 45 insertions(+) create mode 100644 tests/unit/common/test_util_sampling.py diff --git a/merlin/common/util_sampling.py b/merlin/common/util_sampling.py index 134d0b66c..1309448ef 100644 --- a/merlin/common/util_sampling.py +++ b/merlin/common/util_sampling.py @@ -35,6 +35,7 @@ import numpy as np +# TODO should we move this to merlin-spellbook? def scale_samples(samples_norm, limits, limits_norm=(0, 1), do_log=False): """Scale samples to new limits, either log10 or linearly. diff --git a/tests/unit/common/test_util_sampling.py b/tests/unit/common/test_util_sampling.py new file mode 100644 index 000000000..0fd77739f --- /dev/null +++ b/tests/unit/common/test_util_sampling.py @@ -0,0 +1,44 @@ +""" +Tests for the `util_sampling.py` file. +""" +import numpy as np +import pytest + +from merlin.common.util_sampling import scale_samples + + +class TestUtilSampling: + """ + This class will hold all of the tests for the `util_sampling.py` file. + """ + + def test_scale_samples_basic(self): + """Test basic functionality without logging""" + samples_norm = np.array([[0.2, 0.4], [0.6, 0.8]]) + limits = [(-1, 1), (2, 6)] + result = scale_samples(samples_norm, limits) + expected_result = np.array([[-0.6, 3.6], [0.2, 5.2]]) + np.testing.assert_array_almost_equal(result, expected_result) + + def test_scale_samples_logarithmic(self): + """Test functionality with log enabled""" + samples_norm = np.array([[0.2, 0.4], [0.6, 0.8]]) + limits = [(1, 5), (1, 100)] + result = scale_samples(samples_norm, limits, do_log=[False, True]) + expected_result = np.array([[1.8, 6.309573], [3.4, 39.810717]]) + np.testing.assert_array_almost_equal(result, expected_result) + + def test_scale_samples_invalid_input(self): + """Test that function raises ValueError for invalid input""" + with pytest.raises(ValueError): + # Invalid input: samples_norm should be a 2D array + scale_samples([0.2, 0.4, 0.6], [(1, 5), (2, 6)]) + + def test_scale_samples_with_custom_limits_norm(self): + """Test functionality with custom limits_norm""" + samples_norm = np.array([[0.2, 0.4], [0.6, 0.8]]) + limits = [(1, 5), (2, 6)] + limits_norm = (-1, 1) + result = scale_samples(samples_norm, limits, limits_norm=limits_norm) + expected_result = np.array([[3.4, 4.8], [4.2, 5.6]]) + np.testing.assert_array_almost_equal(result, expected_result) \ No newline at end of file From be02611b4a436cc1973e8b83d52c30451254453d Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Tue, 12 Dec 2023 10:54:33 -0800 Subject: [PATCH 06/44] run fix-style and fix typo --- tests/unit/common/test_util_sampling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/common/test_util_sampling.py b/tests/unit/common/test_util_sampling.py index 0fd77739f..c957ac105 100644 --- a/tests/unit/common/test_util_sampling.py +++ b/tests/unit/common/test_util_sampling.py @@ -13,7 +13,7 @@ class TestUtilSampling: """ def test_scale_samples_basic(self): - """Test basic functionality without logging""" + """Test basic functionality""" samples_norm = np.array([[0.2, 0.4], [0.6, 0.8]]) limits = [(-1, 1), (2, 6)] result = scale_samples(samples_norm, limits) @@ -41,4 +41,4 @@ def test_scale_samples_with_custom_limits_norm(self): limits_norm = (-1, 1) result = scale_samples(samples_norm, limits, limits_norm=limits_norm) expected_result = np.array([[3.4, 4.8], [4.2, 5.6]]) - np.testing.assert_array_almost_equal(result, expected_result) \ No newline at end of file + np.testing.assert_array_almost_equal(result, expected_result) From 63d22f063a755a5ac99095a160550bc32c51008c Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Tue, 12 Dec 2023 11:52:48 -0800 Subject: [PATCH 07/44] create directory for context managers and fix issue with an encryption test --- tests/conftest.py | 7 +++---- tests/context_managers/__init__.py | 0 .../celery_workers_manager.py} | 5 +++-- tests/{ => context_managers}/encryption_manager.py | 0 tests/unit/common/test_encryption.py | 6 ++++++ 5 files changed, 12 insertions(+), 6 deletions(-) create mode 100644 tests/context_managers/__init__.py rename tests/{celery_test_workers.py => context_managers/celery_workers_manager.py} (98%) rename tests/{ => context_managers}/encryption_manager.py (100%) diff --git a/tests/conftest.py b/tests/conftest.py index 5aec494e8..33fea4ce1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -42,9 +42,8 @@ from celery import Celery from celery.canvas import Signature -from tests.celery_test_workers import CeleryTestWorkersManager - -from tests.encryption_manager import EncryptionManager +from tests.context_managers.celery_workers_manager import CeleryWorkersManager +from tests.context_managers.encryption_manager import EncryptionManager class RedisServerError(Exception): @@ -206,7 +205,7 @@ def launch_workers(celery_app: Celery, worker_queue_map: Dict[str, str]): # pyl # (basically just add in concurrency value to worker_queue_map) worker_info = {worker_name: {"concurrency": 1, "queues": [queue]} for worker_name, queue in worker_queue_map.items()} - with CeleryTestWorkersManager(celery_app) as workers_manager: + with CeleryWorkersManager(celery_app) as workers_manager: workers_manager.launch_workers(worker_info) yield diff --git a/tests/context_managers/__init__.py b/tests/context_managers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/celery_test_workers.py b/tests/context_managers/celery_workers_manager.py similarity index 98% rename from tests/celery_test_workers.py rename to tests/context_managers/celery_workers_manager.py index 39eb2a39b..70a01c8a0 100644 --- a/tests/celery_test_workers.py +++ b/tests/context_managers/celery_workers_manager.py @@ -40,9 +40,9 @@ from typing import Dict, List, Type from celery import Celery +from merlin.config.configfile import CONFIG - -class CeleryTestWorkersManager: +class CeleryWorkersManager: """ A class to handle the setup and teardown of celery workers. This should be treated as a context and used with python's @@ -198,6 +198,7 @@ def launch_workers(self, worker_info: Dict[str, Dict]): :param worker_info: A dict of worker info with the form {"worker_name": {"concurrency": , "queues": }} """ + # CONFIG.results_backend.encryption_key = "~/.merlin/encrypt_data_key" for worker_name, worker_settings in worker_info.items(): self.launch_worker(worker_name, worker_settings["queues"], worker_settings["concurrency"]) diff --git a/tests/encryption_manager.py b/tests/context_managers/encryption_manager.py similarity index 100% rename from tests/encryption_manager.py rename to tests/context_managers/encryption_manager.py diff --git a/tests/unit/common/test_encryption.py b/tests/unit/common/test_encryption.py index 6daa53817..6f978ddfe 100644 --- a/tests/unit/common/test_encryption.py +++ b/tests/unit/common/test_encryption.py @@ -118,6 +118,9 @@ def test_set_backend_funcs(self): """ Test the `set_backend_funcs` function. """ + orig_encode = celery.backends.base.Backend.encode + orig_decode = celery.backends.base.Backend.decode + # Make sure these values haven't been set yet assert celery.backends.base.Backend.encode != _encrypt_encode assert celery.backends.base.Backend.decode != _decrypt_decode @@ -127,3 +130,6 @@ def test_set_backend_funcs(self): # Ensure the new functions have been set assert celery.backends.base.Backend.encode == _encrypt_encode assert celery.backends.base.Backend.decode == _decrypt_decode + + celery.backends.base.Backend.encode = orig_encode + celery.backends.base.Backend.decode = orig_decode From b2a997628066665b72166722b5b4ce7af40b8f0f Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Tue, 12 Dec 2023 17:31:41 -0800 Subject: [PATCH 08/44] add a context manager for spinning up/down the redis server --- tests/conftest.py | 85 ++------------ .../celery_workers_manager.py | 3 +- tests/context_managers/encryption_manager.py | 2 +- tests/context_managers/server_manager.py | 105 ++++++++++++++++++ 4 files changed, 117 insertions(+), 78 deletions(-) create mode 100644 tests/context_managers/server_manager.py diff --git a/tests/conftest.py b/tests/conftest.py index 33fea4ce1..037d5868f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -32,30 +32,17 @@ integration test suite. """ import os -import subprocess from time import sleep from typing import Dict import pytest -import redis from _pytest.tmpdir import TempPathFactory from celery import Celery from celery.canvas import Signature from tests.context_managers.celery_workers_manager import CeleryWorkersManager from tests.context_managers.encryption_manager import EncryptionManager - - -class RedisServerError(Exception): - """ - Exception to signal that the server wasn't pinged properly. - """ - - -class ServerInitError(Exception): - """ - Exception to signal that there was an error initializing the server. - """ +from tests.context_managers.server_manager import RedisServerManager @pytest.fixture(scope="session") @@ -80,73 +67,20 @@ def temp_output_dir(tmp_path_factory: TempPathFactory) -> str: @pytest.fixture(scope="session") -def redis_pass() -> str: - """ - This fixture represents the password to the merlin test server. - - :returns: The redis password for our test server - """ - return "merlin-test-server" - - -@pytest.fixture(scope="session") -def merlin_server_dir(temp_output_dir: str, redis_pass: str) -> str: # pylint: disable=redefined-outer-name - """ - This fixture will initialize the merlin server (i.e. create all the files we'll - need to start up a local redis server). It will return the path to the directory - containing the files needed for the server to start up. - - :param temp_output_dir: The path to the temporary output directory we'll be using for this test run - :param redis_pass: The password to the test redis server that we'll create here - :returns: The path to the merlin_server directory with the server configurations - """ - # Initialize the setup for the local redis server - # We'll also set the password to 'merlin-test-server' so it'll be easy to shutdown if there's an issue - subprocess.run(f"merlin server init; merlin server config -pwd {redis_pass}", shell=True, capture_output=True, text=True) - - # Check that the merlin server was initialized properly - server_dir = f"{temp_output_dir}/merlin_server" - if not os.path.exists(server_dir): - raise ServerInitError("The merlin server was not initialized properly.") - - return server_dir - - -@pytest.fixture(scope="session") -def redis_server(merlin_server_dir: str, redis_pass: str) -> str: # pylint: disable=redefined-outer-name,unused-argument +def redis_server(temp_output_dir: str) -> str: # pylint: disable=redefined-outer-name """ Start a redis server instance that runs on localhost:6379. This will yield the redis server uri that can be used to create a connection with celery. - :param merlin_server_dir: The directory to the merlin test server configuration. - This will not be used here but we need the server configurations before we can - start the server. - :param redis_pass: The raw redis password stored in the redis.pass file + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run :yields: The local redis server uri """ - # Start the local redis server - try: - # Need to set LC_ALL='C' before starting the server or else redis causes a failure - subprocess.run("export LC_ALL='C'; merlin server start", shell=True, timeout=5) - except subprocess.TimeoutExpired: - pass - - # Ensure the server started properly - host = "localhost" - port = 6379 - database = 0 - username = "default" - redis_client = redis.Redis(host=host, port=port, db=database, password=redis_pass, username=username) - if not redis_client.ping(): - raise RedisServerError("The redis server could not be pinged. Check that the server is running with 'ps ux'.") - - # Hand over the redis server url to any other fixtures/tests that need it - redis_server_uri = f"redis://{username}:{redis_pass}@{host}:{port}/{database}" - yield redis_server_uri - - # Kill the server; don't run this until all tests are done (accomplished with 'yield' above) - kill_process = subprocess.run("merlin server stop", shell=True, capture_output=True, text=True) - assert "Merlin server terminated." in kill_process.stderr + with RedisServerManager(temp_output_dir) as redis_server_manager: + redis_server_manager.initialize_server() + redis_server_manager.start_server() + # Yield the redis_server uri to any fixtures/tests that may need it + yield redis_server_manager.redis_server_uri + # The server will be stopped once this context reaches the end of it's execution here @pytest.fixture(scope="session") @@ -242,3 +176,4 @@ def use_fake_encrypt_data_key(encryption_output_dir: str, test_encryption_key: b # Set the fake encryption key encrypt_manager.set_fake_key() # Yield control to the tests + yield diff --git a/tests/context_managers/celery_workers_manager.py b/tests/context_managers/celery_workers_manager.py index 70a01c8a0..38526bc1b 100644 --- a/tests/context_managers/celery_workers_manager.py +++ b/tests/context_managers/celery_workers_manager.py @@ -40,7 +40,7 @@ from typing import Dict, List, Type from celery import Celery -from merlin.config.configfile import CONFIG + class CeleryWorkersManager: """ @@ -198,7 +198,6 @@ def launch_workers(self, worker_info: Dict[str, Dict]): :param worker_info: A dict of worker info with the form {"worker_name": {"concurrency": , "queues": }} """ - # CONFIG.results_backend.encryption_key = "~/.merlin/encrypt_data_key" for worker_name, worker_settings in worker_info.items(): self.launch_worker(worker_name, worker_settings["queues"], worker_settings["concurrency"]) diff --git a/tests/context_managers/encryption_manager.py b/tests/context_managers/encryption_manager.py index 883b1a184..84b2e4a1e 100644 --- a/tests/context_managers/encryption_manager.py +++ b/tests/context_managers/encryption_manager.py @@ -21,7 +21,7 @@ def __init__(self, temp_output_dir: str, test_encryption_key: bytes): self.orig_results_backend = CONFIG.results_backend def __enter__(self): - """This magic method is necessary for allowing this class to be sued as a context manager""" + """This magic method is necessary for allowing this class to be used as a context manager""" return self def __exit__(self, exc_type: Type[Exception], exc_value: Exception, traceback: TracebackType): diff --git a/tests/context_managers/server_manager.py b/tests/context_managers/server_manager.py new file mode 100644 index 000000000..d373c1f1c --- /dev/null +++ b/tests/context_managers/server_manager.py @@ -0,0 +1,105 @@ +""" +Module to define functionality for managing the containerized +server used for testing. +""" +import os +import signal +import subprocess +from types import TracebackType +from typing import Type + +import redis +import yaml + + +class RedisServerError(Exception): + """ + Exception to signal that the server wasn't pinged properly. + """ + + +class ServerInitError(Exception): + """ + Exception to signal that there was an error initializing the server. + """ + + +class RedisServerManager: + """ + A class to handle the setup and teardown of a containerized redis server. + This should be treated as a context and used with python's built-in 'with' + statement. If you use it without this statement, beware that the processes + spun up here may never be stopped. + """ + + def __init__(self, temp_output_dir: str): + self._redis_pass = "merlin-test-server" + self.server_dir = f"{temp_output_dir}/merlin_server" + self.host = "localhost" + self.port = 6379 + self.database = 0 + self.username = "default" + self.redis_server_uri = f"redis://{self.username}:{self._redis_pass}@{self.host}:{self.port}/{self.database}" + + def __enter__(self): + """This magic method is necessary for allowing this class to be used as a context manager""" + return self + + def __exit__(self, exc_type: Type[Exception], exc_value: Exception, traceback: TracebackType): + """ + This will always run at the end of a context with statement, even if an error is raised. + It's a safe way to ensure all of our server gets stopped no matter what. + """ + self.stop_server() + + def initialize_server(self): + """ + Initialize the setup for the local redis server. We'll write the folder to: + /tmp/`whoami`/pytest-of-`whoami`/pytest-current/integration_outfiles_current/ + We'll set the password to be 'merlin-test-server' so it'll be easy to shutdown if necessary + """ + subprocess.run( + f"merlin server init; merlin server config -pwd {self._redis_pass}", shell=True, capture_output=True, text=True + ) + + # Check that the merlin server was initialized properly + if not os.path.exists(self.server_dir): + raise ServerInitError("The merlin server was not initialized properly.") + + def start_server(self): + """Attempt to start the local redis server.""" + try: + # Need to set LC_ALL='C' before starting the server or else redis causes a failure + subprocess.run("export LC_ALL='C'; merlin server start", shell=True, timeout=5) + except subprocess.TimeoutExpired: + pass + + # Ensure the server started properly + redis_client = redis.Redis( + host=self.host, port=self.port, db=self.database, password=self._redis_pass, username=self.username + ) + if not redis_client.ping(): + raise RedisServerError("The redis server could not be pinged. Check that the server is running with 'ps ux'.") + + def stop_server(self): + """Stop the server.""" + # Attempt to stop the server gracefully with `merlin server` + kill_process = subprocess.run("merlin server stop", shell=True, capture_output=True, text=True) + + # Check that the server was terminated + if "Merlin server terminated." not in kill_process.stderr: + # If it wasn't, try to kill the process by using the pid stored in a file created by `merlin server` + try: + with open(f"{self.server_dir}/merlin_server.pf", "r") as process_file: + server_process_info = yaml.load(process_file, yaml.Loader) + os.kill(int(server_process_info["image_pid"]), signal.SIGKILL) + # If the file can't be found then let's make sure there's even a redis-server process running + except FileNotFoundError as exc: + process_query = subprocess.run("ps ux", shell=True, text=True, capture_output=True) + # If there is a file running we didn't start it in this test run so we can't kill it + if "redis-server" in process_query.stdout: + raise RedisServerError( + "Found an active redis server but cannot stop it since there is no process file (merlin_server.pf). " + "Did you start this server before running tests?" + ) from exc + # No else here. If there's no redis-server process found then there's nothing to stop From 35c614a4b8efc63e944db4ce54eeb257af22cf52 Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Wed, 13 Dec 2023 14:42:41 -0800 Subject: [PATCH 09/44] fix issue with path in one test --- tests/unit/common/test_sample_index.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/common/test_sample_index.py b/tests/unit/common/test_sample_index.py index 296783273..cdb5b2f4f 100644 --- a/tests/unit/common/test_sample_index.py +++ b/tests/unit/common/test_sample_index.py @@ -64,7 +64,7 @@ def test_new_dir(temp_output_dir: str): temporary output path for our tests """ # Test basic functionality - test_path = f"{os.getcwd()}/test_new_dir" + test_path = f"{os.getcwd()}/test_sample_index/test_new_dir" new_dir(test_path) assert os.path.exists(test_path) From 638a27e47ef1d7ace2dc2045d3564465d450a6d1 Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Wed, 13 Dec 2023 14:44:12 -0800 Subject: [PATCH 10/44] rework CONFIG functionality for testing --- merlin/config/__init__.py | 30 +++++++ tests/conftest.py | 91 +++++++++++++------- tests/context_managers/encryption_manager.py | 49 ----------- tests/context_managers/server_manager.py | 29 ++++++- tests/unit/common/test_encryption.py | 31 ++++--- 5 files changed, 133 insertions(+), 97 deletions(-) delete mode 100644 tests/context_managers/encryption_manager.py diff --git a/merlin/config/__init__.py b/merlin/config/__init__.py index b58e3b2a9..c2dd4d12b 100644 --- a/merlin/config/__init__.py +++ b/merlin/config/__init__.py @@ -31,6 +31,7 @@ """ Used to store the application configuration. """ +from copy import copy from types import SimpleNamespace from typing import Dict, List, Optional @@ -56,6 +57,35 @@ def __init__(self, app_dict): self.results_backend: Optional[SimpleNamespace] self.load_app_into_namespaces(app_dict) + def __copy__(self): + """ + A magic method to allow this class to be copied with copy(instance_of_Config). + """ + cls = self.__class__ + result = cls.__new__(cls) + copied_attrs = { + "celery": copy(self.__dict__["celery"]), + "broker": copy(self.__dict__["broker"]), + "results_backend": copy(self.__dict__["results_backend"]), + } + result.__dict__.update(copied_attrs) + return result + + def __str__(self): + """ + A magic method so we can print the CONFIG class. + """ + formatted_str = "config:" + attrs = {"celery": self.celery, "broker": self.broker, "results_backend": self.results_backend} + for name, attr in attrs.items(): + if attr is not None: + items = (f" {k}: {v!r}" for k, v in attr.__dict__.items()) + joined_items = "\n".join(items) + formatted_str += f"\n {name}: \n{joined_items}" + else: + formatted_str += f"\n {name}: \n None" + return formatted_str + def load_app_into_namespaces(self, app_dict: Dict) -> None: """ Makes the application dictionary into a namespace, sets the attributes of the Config from the namespace values. diff --git a/tests/conftest.py b/tests/conftest.py index 037d5868f..992b5203b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,22 +28,25 @@ # SOFTWARE. ############################################################################### """ -This module contains pytest fixtures to be used throughout the entire -integration test suite. +This module contains pytest fixtures to be used throughout the entire test suite. """ import os +import yaml +from copy import copy from time import sleep -from typing import Dict +from typing import Any, Dict import pytest from _pytest.tmpdir import TempPathFactory from celery import Celery from celery.canvas import Signature +from merlin.config.configfile import CONFIG from tests.context_managers.celery_workers_manager import CeleryWorkersManager -from tests.context_managers.encryption_manager import EncryptionManager from tests.context_managers.server_manager import RedisServerManager +REDIS_PASS = "merlin-test-server" + @pytest.fixture(scope="session") def temp_output_dir(tmp_path_factory: TempPathFactory) -> str: @@ -67,15 +70,27 @@ def temp_output_dir(tmp_path_factory: TempPathFactory) -> str: @pytest.fixture(scope="session") -def redis_server(temp_output_dir: str) -> str: # pylint: disable=redefined-outer-name +def merlin_server_dir(temp_output_dir: str) -> str: + """ + The path to the merlin_server directory that will be created by the `redis_server` fixture. + + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + :returns: The path to the merlin_server directory that will be created by the `redis_server` fixture + """ + return f"{temp_output_dir}/merlin_server" + + +@pytest.fixture(scope="session") +def redis_server(merlin_server_dir: str, test_encryption_key: bytes) -> str: # pylint: disable=redefined-outer-name """ Start a redis server instance that runs on localhost:6379. This will yield the redis server uri that can be used to create a connection with celery. - :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + :param merlin_server_dir: The directory to the merlin test server configuration + :param test_encryption_key: An encryption key to be used for testing :yields: The local redis server uri """ - with RedisServerManager(temp_output_dir) as redis_server_manager: + with RedisServerManager(merlin_server_dir, REDIS_PASS, test_encryption_key) as redis_server_manager: redis_server_manager.initialize_server() redis_server_manager.start_server() # Yield the redis_server uri to any fixtures/tests that may need it @@ -145,35 +160,53 @@ def launch_workers(celery_app: Celery, worker_queue_map: Dict[str, str]): # pyl @pytest.fixture(scope="session") -def encryption_output_dir(temp_output_dir: str) -> str: # pylint: disable=redefined-outer-name +def test_encryption_key() -> bytes: """ - Get a temporary output directory for our encryption tests. - - :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + An encryption key to be used for tests that need it. + + :returns: The test encryption key """ - encryption_dir = f"{temp_output_dir}/encryption_tests" - os.mkdir(encryption_dir) - return encryption_dir + return b"Q3vLp07Ljm60ahfU9HwOOnfgGY91lSrUmqcTiP0v9i0=" @pytest.fixture(scope="session") -def test_encryption_key() -> bytes: - """An encryption key to be used for tests that need it""" - return b"Q3vLp07Ljm60ahfU9HwOOnfgGY91lSrUmqcTiP0v9i0=" +def app_yaml(merlin_server_dir: str, redis_server: str) -> Dict[str, Any]: # pylint: disable=redefined-outer-name + """ + Load in the app.yaml file generated by starting the redis server. + :param merlin_server_dir: The directory to the merlin test server configuration + :param redis_server: The fixture that starts up the redis server + :returns: The contents of the app.yaml file created by starting the redis server + """ + with open(f"{merlin_server_dir}/app.yaml", "r") as app_yaml_file: + app_yaml = yaml.load(app_yaml_file, yaml.Loader) + return app_yaml -@pytest.fixture(scope="class") -def use_fake_encrypt_data_key(encryption_output_dir: str, test_encryption_key: bytes): # pylint: disable=redefined-outer-name + +@pytest.fixture(scope="function") +def config(app_yaml: str): # pylint: disable=redefined-outer-name """ - Create a fake encrypt data key to use for these tests. This will save the - current data key so we can set it back to what it was prior to running - the tests. + This fixture is intended to be used for testing any functionality in the codebase + that uses the CONFIG object. This will modify the CONFIG object to use static test values + that shouldn't change. - :param encryption_output_dir: The path to the temporary output directory we'll be using for this test run + :param app_yaml: The contents of the app.yaml created by starting the containerized redis server """ - # Use a context manager to ensure cleanup runs even if an error occurs - with EncryptionManager(encryption_output_dir, test_encryption_key) as encrypt_manager: - # Set the fake encryption key - encrypt_manager.set_fake_key() - # Yield control to the tests - yield + global CONFIG + orig_config = copy(CONFIG) + + CONFIG.broker.password = app_yaml["broker"]["password"] + CONFIG.broker.port = app_yaml["broker"]["port"] + CONFIG.broker.server = app_yaml["broker"]["server"] + CONFIG.broker.username = app_yaml["broker"]["username"] + CONFIG.broker.vhost = app_yaml["broker"]["vhost"] + + CONFIG.results_backend.password = app_yaml["results_backend"]["password"] + CONFIG.results_backend.port = app_yaml["results_backend"]["port"] + CONFIG.results_backend.server = app_yaml["results_backend"]["server"] + CONFIG.results_backend.username = app_yaml["results_backend"]["username"] + CONFIG.results_backend.encryption_key = app_yaml["results_backend"]["encryption_key"] + + yield + + CONFIG = orig_config diff --git a/tests/context_managers/encryption_manager.py b/tests/context_managers/encryption_manager.py deleted file mode 100644 index 84b2e4a1e..000000000 --- a/tests/context_managers/encryption_manager.py +++ /dev/null @@ -1,49 +0,0 @@ -""" -Module to define functionality for managing encryption settings -while running our test suite. -""" -import os -from types import TracebackType -from typing import Type - -from merlin.config.configfile import CONFIG - - -class EncryptionManager: - """ - A class to handle safe setup and teardown of encryption tests. - """ - - def __init__(self, temp_output_dir: str, test_encryption_key: bytes): - self.temp_output_dir = temp_output_dir - self.key_path = os.path.abspath(os.path.expanduser(f"{self.temp_output_dir}/encrypt_data_key")) - self.test_encryption_key = test_encryption_key - self.orig_results_backend = CONFIG.results_backend - - def __enter__(self): - """This magic method is necessary for allowing this class to be used as a context manager""" - return self - - def __exit__(self, exc_type: Type[Exception], exc_value: Exception, traceback: TracebackType): - """ - This will always run at the end of a context with statement, even if an error is raised. - It's a safe way to ensure all of our encryption settings at the start of the tests are reset. - """ - self.reset_encryption_settings() - - def set_fake_key(self): - """ - Create a fake encrypt data key to use for tests. This will save the fake encryption key to - our temporary output directory located at: - /tmp/`whoami`/pytest-of-`whoami`/pytest-current/integration_outfiles_current/encryption_tests/ - """ - with open(self.key_path, "w") as key_file: - key_file.write(self.test_encryption_key.decode("utf-8")) - - CONFIG.results_backend.encryption_key = self.key_path - - def reset_encryption_settings(self): - """ - Reset the encryption settings to what they were prior to running our encryption tests. - """ - CONFIG.results_backend = self.orig_results_backend diff --git a/tests/context_managers/server_manager.py b/tests/context_managers/server_manager.py index d373c1f1c..9a10e0cbf 100644 --- a/tests/context_managers/server_manager.py +++ b/tests/context_managers/server_manager.py @@ -32,9 +32,10 @@ class RedisServerManager: spun up here may never be stopped. """ - def __init__(self, temp_output_dir: str): - self._redis_pass = "merlin-test-server" - self.server_dir = f"{temp_output_dir}/merlin_server" + def __init__(self, server_dir: str, redis_pass: str, test_encryption_key: bytes): + self._redis_pass = redis_pass + self._test_encryption_key = test_encryption_key + self.server_dir = server_dir self.host = "localhost" self.port = 6379 self.database = 0 @@ -66,6 +67,26 @@ def initialize_server(self): if not os.path.exists(self.server_dir): raise ServerInitError("The merlin server was not initialized properly.") + def _create_fake_encryption_key(self): + """ + For testing we'll use a specific encryption key. We'll create a file for that and + save it to the app.yaml created for testing. + """ + # Create a fake encryption key file for testing purposes + encryption_file = f"{self.server_dir}/encrypt_data_key" + with open(encryption_file, "w") as key_file: + key_file.write(self._test_encryption_key.decode("utf-8")) + + # Load up the app.yaml that was created by starting the server + server_app_yaml = f"{self.server_dir}/app.yaml" + with open(server_app_yaml, "r") as app_yaml_file: + app_yaml = yaml.load(app_yaml_file, yaml.Loader) + + # Modify the path to the encryption key and then save it + app_yaml["results_backend"]["encryption_key"] = encryption_file + with open(server_app_yaml, "w") as app_yaml_file: + yaml.dump(app_yaml, app_yaml_file) + def start_server(self): """Attempt to start the local redis server.""" try: @@ -81,6 +102,8 @@ def start_server(self): if not redis_client.ping(): raise RedisServerError("The redis server could not be pinged. Check that the server is running with 'ps ux'.") + self._create_fake_encryption_key() + def stop_server(self): """Stop the server.""" # Attempt to stop the server gracefully with `merlin server` diff --git a/tests/unit/common/test_encryption.py b/tests/unit/common/test_encryption.py index 6f978ddfe..012c5c540 100644 --- a/tests/unit/common/test_encryption.py +++ b/tests/unit/common/test_encryption.py @@ -16,41 +16,40 @@ class TestEncryption: This class will house all tests necessary for our encryption modules. """ - def test_encrypt(self, use_fake_encrypt_data_key: "fixture"): # noqa: F821 + def test_encrypt(self, config: "fixture"): # noqa: F821 """ Test that our encryption function is encrypting the bytes that we're passing to it. - :param use_fake_encrypt_data_key: A fixture to set up a fake encryption key for testing + :param config: A fixture to set the CONFIG object to a test configuration that we'll use here """ str_to_encrypt = b"super secret string shhh" encrypted_str = encrypt(str_to_encrypt) for word in str_to_encrypt.decode("utf-8").split(" "): assert word not in encrypted_str.decode("utf-8") - def test_decrypt(self, use_fake_encrypt_data_key: "fixture"): # noqa: F821 + def test_decrypt(self, config: "fixture"): # noqa: F821 """ - Test that our decryption function is decrypting the bytes that we're - passing to it. + Test that our decryption function is decrypting the bytes that we're passing to it. - :param use_fake_encrypt_data_key: A fixture to set up a fake encryption key for testing + :param config: A fixture to set the CONFIG object to a test configuration that we'll use here """ # This is the output of the bytes from the encrypt test str_to_decrypt = b"gAAAAABld6k-jEncgCW5AePgrwn-C30dhr7dzGVhqzcqskPqFyA2Hdg3VWmo0qQnLklccaUYzAGlB4PMxyp4T-1gAYlAOf_7sC_bJOEcYOIkhZFoH6cX4Uw=" decrypted_str = decrypt(str_to_decrypt) assert decrypted_str == b"super secret string shhh" - def test_get_key_path(self, use_fake_encrypt_data_key: "fixture"): # noqa F821 + def test_get_key_path(self, config: "fixture"): # noqa: F821 """ Test the `_get_key_path` function. - :param use_fake_encrypt_data_key: A fixture to set up a fake encryption key for testing + :param config: A fixture to set the CONFIG object to a test configuration that we'll use here """ # Test the default behavior (`_get_key_path` will pull from CONFIG.results_backend which # will be set to the temporary output path for our tests in the `use_fake_encrypt_data_key` fixture) user = os.getlogin() actual_default = _get_key_path() - assert actual_default.startswith(f"/tmp/{user}/") and actual_default.endswith("/encryption_tests/encrypt_data_key") + assert actual_default.startswith(f"/tmp/{user}/") and actual_default.endswith("/encrypt_data_key") # Test with having the encryption key set to None temp = CONFIG.results_backend.encryption_key @@ -67,14 +66,14 @@ def test_get_key_path(self, use_fake_encrypt_data_key: "fixture"): # noqa F821 assert actual_no_results_backend == os.path.abspath(os.path.expanduser("~/.merlin/encrypt_data_key")) CONFIG.results_backend = orig_results_backend - def test_gen_key(self, encryption_output_dir: str): + def test_gen_key(self, temp_output_dir: str): """ Test the `_gen_key` function. - :param encryption_output_dir: A fixture to create a temporary output directory for our encryption tests + :param temp_output_dir: The path to the temporary output directory for this test run """ # Create the file but don't put anything in it - key_gen_test_file = f"{encryption_output_dir}/key_gen_test" + key_gen_test_file = f"{temp_output_dir}/key_gen_test" with open(key_gen_test_file, "w"): pass @@ -89,13 +88,13 @@ def test_gen_key(self, encryption_output_dir: str): key_gen_contents = key_gen_file.read() assert key_gen_contents != "" - def test_get_key(self, use_fake_encrypt_data_key: str, encryption_output_dir: str, test_encryption_key: bytes): + def test_get_key(self, merlin_server_dir: str, test_encryption_key: bytes, config: "fixture"): # noqa: F821 """ Test the `_get_key` function. - :param use_fake_encrypt_data_key: A fixture to set up a fake encryption key for testing - :param encryption_output_dir: A fixture to create a temporary output directory for our encryption tests + :param merlin_server_dir: The directory to the merlin test server configuration :param test_encryption_key: A fixture to establish a fixed encryption key for testing + :param config: A fixture to set the CONFIG object to a test configuration that we'll use here """ # Test the default functionality actual_default = _get_key() @@ -103,7 +102,7 @@ def test_get_key(self, use_fake_encrypt_data_key: str, encryption_output_dir: st # Modify the permission of the key file so that it can't be read by anyone # (we're purposefully trying to raise an IOError) - key_path = f"{encryption_output_dir}/encrypt_data_key" + key_path = f"{merlin_server_dir}/encrypt_data_key" orig_file_permissions = os.stat(key_path).st_mode os.chmod(key_path, 0o222) with pytest.raises(IOError): From 9b342ab0325a724d39a101e066740b84dbdcb407 Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Thu, 14 Dec 2023 11:52:08 -0800 Subject: [PATCH 11/44] refactor config fixture so it doesn't depend on redis server to be started --- tests/conftest.py | 159 ++++++++++++++++++----- tests/context_managers/server_manager.py | 25 +--- tests/unit/common/test_encryption.py | 19 +-- 3 files changed, 135 insertions(+), 68 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 992b5203b..4c970992c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -45,7 +45,81 @@ from tests.context_managers.celery_workers_manager import CeleryWorkersManager from tests.context_managers.server_manager import RedisServerManager -REDIS_PASS = "merlin-test-server" +SERVER_PASS = "merlin-test-server" + + +####################################### +#### Helper Functions for Fixtures #### +####################################### + + +def create_pass_file(pass_filepath: str): + """ + Check if a password file already exists (it will if the redis server has been started) + and if it hasn't then create one and write the password to the file. + + :param pass_filepath: The path to the password file that we need to check for/create + """ + if not os.path.exists(pass_filepath): + with open(pass_filepath, "w") as pass_file: + pass_file.write(SERVER_PASS) + + +def create_encryption_file(key_filepath: str, encryption_key: bytes, app_yaml_filepath: str = None): + """ + Check if an encryption file already exists (it will if the redis server has been started) + and if it hasn't then create one and write the encryption key to the file. If an app.yaml + filepath has been passed to this function then we'll need to update it so that the encryption + key points to the `key_filepath`. + + :param key_filepath: The path to the file that will store our encryption key + :param encryption_key: An encryption key to be used for testing + :param app_yaml_filepath: A path to the app.yaml file that needs to be updated + """ + if not os.path.exists(key_filepath): + with open(key_filepath, "w") as key_file: + key_file.write(encryption_key.decode("utf-8")) + + if app_yaml_filepath is not None: + # Load up the app.yaml that was created by starting the server + with open(app_yaml_filepath, "r") as app_yaml_file: + app_yaml = yaml.load(app_yaml_file, yaml.Loader) + + # Modify the path to the encryption key and then save it + app_yaml["results_backend"]["encryption_key"] = key_filepath + with open(app_yaml_filepath, "w") as app_yaml_file: + yaml.dump(app_yaml, app_yaml_file) + + +def set_config(broker: Dict[str, str], results_backend: Dict[str, str]): + """ + Given configuration options for the broker and results_backend, update + the CONFIG object. + + :param broker: A dict of the configuration settings for the broker + :param results_backend: A dict of configuration settings for the results_backend + """ + global CONFIG + + # Set the broker configuration for testing + CONFIG.broker.password = broker["password"] + CONFIG.broker.port = broker["port"] + CONFIG.broker.server = broker["server"] + CONFIG.broker.username = broker["username"] + CONFIG.broker.vhost = broker["vhost"] + CONFIG.broker.name = broker["name"] + + # Set the results_backend configuration for testing + CONFIG.results_backend.password = results_backend["password"] + CONFIG.results_backend.port = results_backend["port"] + CONFIG.results_backend.server = results_backend["server"] + CONFIG.results_backend.username = results_backend["username"] + CONFIG.results_backend.encryption_key = results_backend["encryption_key"] + + +####################################### +######### Fixture Definitions ######### +####################################### @pytest.fixture(scope="session") @@ -77,7 +151,10 @@ def merlin_server_dir(temp_output_dir: str) -> str: :param temp_output_dir: The path to the temporary output directory we'll be using for this test run :returns: The path to the merlin_server directory that will be created by the `redis_server` fixture """ - return f"{temp_output_dir}/merlin_server" + server_dir = f"{temp_output_dir}/merlin_server" + if not os.path.exists(server_dir): + os.mkdir(server_dir) + return server_dir @pytest.fixture(scope="session") @@ -90,9 +167,10 @@ def redis_server(merlin_server_dir: str, test_encryption_key: bytes) -> str: # :param test_encryption_key: An encryption key to be used for testing :yields: The local redis server uri """ - with RedisServerManager(merlin_server_dir, REDIS_PASS, test_encryption_key) as redis_server_manager: + with RedisServerManager(merlin_server_dir, SERVER_PASS) as redis_server_manager: redis_server_manager.initialize_server() redis_server_manager.start_server() + create_encryption_file(f"{merlin_server_dir}/encrypt_data_key", test_encryption_key, app_yaml_filepath=f"{merlin_server_dir}/app.yaml") # Yield the redis_server uri to any fixtures/tests that may need it yield redis_server_manager.redis_server_uri # The server will be stopped once this context reaches the end of it's execution here @@ -163,50 +241,61 @@ def launch_workers(celery_app: Celery, worker_queue_map: Dict[str, str]): # pyl def test_encryption_key() -> bytes: """ An encryption key to be used for tests that need it. - + :returns: The test encryption key """ return b"Q3vLp07Ljm60ahfU9HwOOnfgGY91lSrUmqcTiP0v9i0=" -@pytest.fixture(scope="session") -def app_yaml(merlin_server_dir: str, redis_server: str) -> Dict[str, Any]: # pylint: disable=redefined-outer-name - """ - Load in the app.yaml file generated by starting the redis server. - - :param merlin_server_dir: The directory to the merlin test server configuration - :param redis_server: The fixture that starts up the redis server - :returns: The contents of the app.yaml file created by starting the redis server - """ - with open(f"{merlin_server_dir}/app.yaml", "r") as app_yaml_file: - app_yaml = yaml.load(app_yaml_file, yaml.Loader) - return app_yaml - - @pytest.fixture(scope="function") -def config(app_yaml: str): # pylint: disable=redefined-outer-name +def redis_config(merlin_server_dir: str, test_encryption_key: bytes): # pylint: disable=redefined-outer-name """ This fixture is intended to be used for testing any functionality in the codebase - that uses the CONFIG object. This will modify the CONFIG object to use static test values - that shouldn't change. + that uses the CONFIG object with a Redis broker and results_backend. - :param app_yaml: The contents of the app.yaml created by starting the containerized redis server + :param merlin_server_dir: The directory to the merlin test server configuration + :param test_encryption_key: An encryption key to be used for testing """ global CONFIG - orig_config = copy(CONFIG) - CONFIG.broker.password = app_yaml["broker"]["password"] - CONFIG.broker.port = app_yaml["broker"]["port"] - CONFIG.broker.server = app_yaml["broker"]["server"] - CONFIG.broker.username = app_yaml["broker"]["username"] - CONFIG.broker.vhost = app_yaml["broker"]["vhost"] - - CONFIG.results_backend.password = app_yaml["results_backend"]["password"] - CONFIG.results_backend.port = app_yaml["results_backend"]["port"] - CONFIG.results_backend.server = app_yaml["results_backend"]["server"] - CONFIG.results_backend.username = app_yaml["results_backend"]["username"] - CONFIG.results_backend.encryption_key = app_yaml["results_backend"]["encryption_key"] + # Create a copy of the CONFIG option so we can reset it after the test + orig_config = copy(CONFIG) + # Create a password file and encryption key file (if they don't already exist) + pass_file = f"{merlin_server_dir}/redis.pass" + key_file = f"{merlin_server_dir}/encrypt_data_key" + create_pass_file(pass_file) + create_encryption_file(key_file, test_encryption_key) + + # Create the broker and results_backend configuration to use + broker = { + "cert_reqs": "none", + "password": pass_file, + "port": 6379, + "server": "127.0.0.1", + "username": "default", + "vhost": "host4testing", + "name": "redis", + } + + results_backend = { + "cert_reqs": "none", + "db_num": 0, + "encryption_key": key_file, + "password": pass_file, + "port": 6379, + "server": "127.0.0.1", + "username": "default", + "name": "redis", + } + + # Set the configuration + set_config(broker, results_backend) + + # Go run the tests yield - CONFIG = orig_config + # Reset the configuration + CONFIG.celery = orig_config.celery + CONFIG.broker = orig_config.broker + CONFIG.results_backend = orig_config.results_backend diff --git a/tests/context_managers/server_manager.py b/tests/context_managers/server_manager.py index 9a10e0cbf..ea6a731ff 100644 --- a/tests/context_managers/server_manager.py +++ b/tests/context_managers/server_manager.py @@ -32,9 +32,8 @@ class RedisServerManager: spun up here may never be stopped. """ - def __init__(self, server_dir: str, redis_pass: str, test_encryption_key: bytes): + def __init__(self, server_dir: str, redis_pass: str): self._redis_pass = redis_pass - self._test_encryption_key = test_encryption_key self.server_dir = server_dir self.host = "localhost" self.port = 6379 @@ -67,26 +66,6 @@ def initialize_server(self): if not os.path.exists(self.server_dir): raise ServerInitError("The merlin server was not initialized properly.") - def _create_fake_encryption_key(self): - """ - For testing we'll use a specific encryption key. We'll create a file for that and - save it to the app.yaml created for testing. - """ - # Create a fake encryption key file for testing purposes - encryption_file = f"{self.server_dir}/encrypt_data_key" - with open(encryption_file, "w") as key_file: - key_file.write(self._test_encryption_key.decode("utf-8")) - - # Load up the app.yaml that was created by starting the server - server_app_yaml = f"{self.server_dir}/app.yaml" - with open(server_app_yaml, "r") as app_yaml_file: - app_yaml = yaml.load(app_yaml_file, yaml.Loader) - - # Modify the path to the encryption key and then save it - app_yaml["results_backend"]["encryption_key"] = encryption_file - with open(server_app_yaml, "w") as app_yaml_file: - yaml.dump(app_yaml, app_yaml_file) - def start_server(self): """Attempt to start the local redis server.""" try: @@ -102,8 +81,6 @@ def start_server(self): if not redis_client.ping(): raise RedisServerError("The redis server could not be pinged. Check that the server is running with 'ps ux'.") - self._create_fake_encryption_key() - def stop_server(self): """Stop the server.""" # Attempt to stop the server gracefully with `merlin server` diff --git a/tests/unit/common/test_encryption.py b/tests/unit/common/test_encryption.py index 012c5c540..d65d201f2 100644 --- a/tests/unit/common/test_encryption.py +++ b/tests/unit/common/test_encryption.py @@ -1,6 +1,7 @@ """ Tests for the `encrypt.py` and `encrypt_backend_traffic.py` files. """ +import getpass import os import celery @@ -16,38 +17,38 @@ class TestEncryption: This class will house all tests necessary for our encryption modules. """ - def test_encrypt(self, config: "fixture"): # noqa: F821 + def test_encrypt(self, redis_config: "fixture"): # noqa: F821 """ Test that our encryption function is encrypting the bytes that we're passing to it. - :param config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here """ str_to_encrypt = b"super secret string shhh" encrypted_str = encrypt(str_to_encrypt) for word in str_to_encrypt.decode("utf-8").split(" "): assert word not in encrypted_str.decode("utf-8") - def test_decrypt(self, config: "fixture"): # noqa: F821 + def test_decrypt(self, redis_config: "fixture"): # noqa: F821 """ Test that our decryption function is decrypting the bytes that we're passing to it. - :param config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here """ # This is the output of the bytes from the encrypt test str_to_decrypt = b"gAAAAABld6k-jEncgCW5AePgrwn-C30dhr7dzGVhqzcqskPqFyA2Hdg3VWmo0qQnLklccaUYzAGlB4PMxyp4T-1gAYlAOf_7sC_bJOEcYOIkhZFoH6cX4Uw=" decrypted_str = decrypt(str_to_decrypt) assert decrypted_str == b"super secret string shhh" - def test_get_key_path(self, config: "fixture"): # noqa: F821 + def test_get_key_path(self, redis_config: "fixture"): # noqa: F821 """ Test the `_get_key_path` function. - :param config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here """ # Test the default behavior (`_get_key_path` will pull from CONFIG.results_backend which # will be set to the temporary output path for our tests in the `use_fake_encrypt_data_key` fixture) - user = os.getlogin() + user = getpass.getuser() actual_default = _get_key_path() assert actual_default.startswith(f"/tmp/{user}/") and actual_default.endswith("/encrypt_data_key") @@ -88,13 +89,13 @@ def test_gen_key(self, temp_output_dir: str): key_gen_contents = key_gen_file.read() assert key_gen_contents != "" - def test_get_key(self, merlin_server_dir: str, test_encryption_key: bytes, config: "fixture"): # noqa: F821 + def test_get_key(self, merlin_server_dir: str, test_encryption_key: bytes, redis_config: "fixture"): # noqa: F821 """ Test the `_get_key` function. :param merlin_server_dir: The directory to the merlin test server configuration :param test_encryption_key: A fixture to establish a fixed encryption key for testing - :param config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here """ # Test the default functionality actual_default = _get_key() From 661ab71daa9ec365eb1f11b35389a86b0457395b Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Thu, 14 Dec 2023 14:24:52 -0800 Subject: [PATCH 12/44] split CONFIG fixtures into rabbit and redis configs, run fix-style --- merlin/config/__init__.py | 1 - tests/conftest.py | 114 ++++++++++++++++++--------- tests/unit/common/test_encryption.py | 4 +- 3 files changed, 77 insertions(+), 42 deletions(-) diff --git a/merlin/config/__init__.py b/merlin/config/__init__.py index c2dd4d12b..d0a0bf9c5 100644 --- a/merlin/config/__init__.py +++ b/merlin/config/__init__.py @@ -32,7 +32,6 @@ Used to store the application configuration. """ from copy import copy - from types import SimpleNamespace from typing import Dict, List, Optional diff --git a/tests/conftest.py b/tests/conftest.py index 4c970992c..3415385f9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -31,12 +31,12 @@ This module contains pytest fixtures to be used throughout the entire test suite. """ import os -import yaml from copy import copy from time import sleep -from typing import Any, Dict +from typing import Dict import pytest +import yaml from _pytest.tmpdir import TempPathFactory from celery import Celery from celery.canvas import Signature @@ -45,6 +45,7 @@ from tests.context_managers.celery_workers_manager import CeleryWorkersManager from tests.context_managers.server_manager import RedisServerManager + SERVER_PASS = "merlin-test-server" @@ -84,7 +85,7 @@ def create_encryption_file(key_filepath: str, encryption_key: bytes, app_yaml_fi # Load up the app.yaml that was created by starting the server with open(app_yaml_filepath, "r") as app_yaml_file: app_yaml = yaml.load(app_yaml_file, yaml.Loader) - + # Modify the path to the encryption key and then save it app_yaml["results_backend"]["encryption_key"] = key_filepath with open(app_yaml_filepath, "w") as app_yaml_file: @@ -99,8 +100,6 @@ def set_config(broker: Dict[str, str], results_backend: Dict[str, str]): :param broker: A dict of the configuration settings for the broker :param results_backend: A dict of configuration settings for the results_backend """ - global CONFIG - # Set the broker configuration for testing CONFIG.broker.password = broker["password"] CONFIG.broker.port = broker["port"] @@ -144,7 +143,7 @@ def temp_output_dir(tmp_path_factory: TempPathFactory) -> str: @pytest.fixture(scope="session") -def merlin_server_dir(temp_output_dir: str) -> str: +def merlin_server_dir(temp_output_dir: str) -> str: # pylint: disable=redefined-outer-name """ The path to the merlin_server directory that will be created by the `redis_server` fixture. @@ -170,7 +169,9 @@ def redis_server(merlin_server_dir: str, test_encryption_key: bytes) -> str: # with RedisServerManager(merlin_server_dir, SERVER_PASS) as redis_server_manager: redis_server_manager.initialize_server() redis_server_manager.start_server() - create_encryption_file(f"{merlin_server_dir}/encrypt_data_key", test_encryption_key, app_yaml_filepath=f"{merlin_server_dir}/app.yaml") + create_encryption_file( + f"{merlin_server_dir}/encrypt_data_key", test_encryption_key, app_yaml_filepath=f"{merlin_server_dir}/app.yaml" + ) # Yield the redis_server uri to any fixtures/tests that may need it yield redis_server_manager.redis_server_uri # The server will be stopped once this context reaches the end of it's execution here @@ -248,49 +249,42 @@ def test_encryption_key() -> bytes: @pytest.fixture(scope="function") -def redis_config(merlin_server_dir: str, test_encryption_key: bytes): # pylint: disable=redefined-outer-name +def config(merlin_server_dir: str, test_encryption_key: bytes): # pylint: disable=redefined-outer-name """ - This fixture is intended to be used for testing any functionality in the codebase - that uses the CONFIG object with a Redis broker and results_backend. + DO NOT USE THIS FIXTURE IN A TEST, USE `redis_config` OR `rabbit_config` INSTEAD. + This fixture is intended to be used strictly by the `redis_config` and `rabbit_config` + fixtures. It sets up the CONFIG object but leaves certain broker settings unset. :param merlin_server_dir: The directory to the merlin test server configuration :param test_encryption_key: An encryption key to be used for testing """ - global CONFIG + # global CONFIG # Create a copy of the CONFIG option so we can reset it after the test orig_config = copy(CONFIG) - # Create a password file and encryption key file (if they don't already exist) - pass_file = f"{merlin_server_dir}/redis.pass" + # Create an encryption key file (if it doesn't already exist) key_file = f"{merlin_server_dir}/encrypt_data_key" - create_pass_file(pass_file) create_encryption_file(key_file, test_encryption_key) - # Create the broker and results_backend configuration to use - broker = { - "cert_reqs": "none", - "password": pass_file, - "port": 6379, - "server": "127.0.0.1", - "username": "default", - "vhost": "host4testing", - "name": "redis", - } - - results_backend = { - "cert_reqs": "none", - "db_num": 0, - "encryption_key": key_file, - "password": pass_file, - "port": 6379, - "server": "127.0.0.1", - "username": "default", - "name": "redis", - } - - # Set the configuration - set_config(broker, results_backend) + # Set the broker configuration for testing + CONFIG.broker.password = "password path not yet set" # This will be updated in `redis_config` or `rabbit_config` + CONFIG.broker.port = "port not yet set" # This will be updated in `redis_config` or `rabbit_config` + CONFIG.broker.name = "name not yet set" # This will be updated in `redis_config` or `rabbit_config` + CONFIG.broker.server = "127.0.0.1" + CONFIG.broker.username = "default" + CONFIG.broker.vhost = "host4testing" + CONFIG.broker.cert_reqs = "none" + + # Set the results_backend configuration for testing + CONFIG.results_backend.password = f"{merlin_server_dir}/redis.pass" + CONFIG.results_backend.port = 6379 + CONFIG.results_backend.server = "127.0.0.1" + CONFIG.results_backend.username = "default" + CONFIG.results_backend.cert_reqs = "none" + CONFIG.results_backend.encryption_key = key_file + CONFIG.results_backend.db_num = 0 + CONFIG.results_backend.name = "redis" # Go run the tests yield @@ -299,3 +293,47 @@ def redis_config(merlin_server_dir: str, test_encryption_key: bytes): # pylint: CONFIG.celery = orig_config.celery CONFIG.broker = orig_config.broker CONFIG.results_backend = orig_config.results_backend + + +@pytest.fixture(scope="function") +def redis_config(merlin_server_dir: str, config: "fixture"): # noqa: F821 pylint: disable=redefined-outer-name,unused-argument + """ + This fixture is intended to be used for testing any functionality in the codebase + that uses the CONFIG object with a Redis broker and results_backend. + + :param merlin_server_dir: The directory to the merlin test server configuration + :param config: The fixture that sets up most of the CONFIG object for testing + """ + # global CONFIG + + pass_file = f"{merlin_server_dir}/redis.pass" + create_pass_file(pass_file) + + CONFIG.broker.password = pass_file + CONFIG.broker.port = 6379 + CONFIG.broker.name = "redis" + + yield + + +@pytest.fixture(scope="function") +def rabbit_config( + merlin_server_dir: str, config: "fixture" +): # noqa: F821 pylint: disable=redefined-outer-name,unused-argument + """ + This fixture is intended to be used for testing any functionality in the codebase + that uses the CONFIG object with a RabbitMQ broker and Redis results_backend. + + :param merlin_server_dir: The directory to the merlin test server configuration + :param config: The fixture that sets up most of the CONFIG object for testing + """ + # global CONFIG + + pass_file = f"{merlin_server_dir}/rabbit.pass" + create_pass_file(pass_file) + + CONFIG.broker.password = pass_file + CONFIG.broker.port = 5671 + CONFIG.broker.name = "rabbitmq" + + yield diff --git a/tests/unit/common/test_encryption.py b/tests/unit/common/test_encryption.py index d65d201f2..6392cf8da 100644 --- a/tests/unit/common/test_encryption.py +++ b/tests/unit/common/test_encryption.py @@ -1,7 +1,6 @@ """ Tests for the `encrypt.py` and `encrypt_backend_traffic.py` files. """ -import getpass import os import celery @@ -48,9 +47,8 @@ def test_get_key_path(self, redis_config: "fixture"): # noqa: F821 """ # Test the default behavior (`_get_key_path` will pull from CONFIG.results_backend which # will be set to the temporary output path for our tests in the `use_fake_encrypt_data_key` fixture) - user = getpass.getuser() actual_default = _get_key_path() - assert actual_default.startswith(f"/tmp/{user}/") and actual_default.endswith("/encrypt_data_key") + assert actual_default.startswith("/tmp/") and actual_default.endswith("/encrypt_data_key") # Test with having the encryption key set to None temp = CONFIG.results_backend.encryption_key From db1f20a073ce0cd580948f96974f37ef7db9d80f Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Thu, 14 Dec 2023 14:25:15 -0800 Subject: [PATCH 13/44] add unit tests for broker.py --- merlin/config/broker.py | 14 +- tests/unit/config/test_broker.py | 549 +++++++++++++++++++++++++++++++ 2 files changed, 553 insertions(+), 10 deletions(-) create mode 100644 tests/unit/config/test_broker.py diff --git a/merlin/config/broker.py b/merlin/config/broker.py index 385b8c1df..152c6a9b8 100644 --- a/merlin/config/broker.py +++ b/merlin/config/broker.py @@ -85,13 +85,13 @@ def get_rabbit_connection(include_password, conn="amqps"): password_filepath = CONFIG.broker.password LOG.debug(f"Broker: password filepath = {password_filepath}") password_filepath = os.path.abspath(expanduser(password_filepath)) - except KeyError as e: # pylint: disable=C0103 - raise ValueError("Broker: No password provided for RabbitMQ") from e + except (AttributeError, KeyError) as exc: + raise ValueError("Broker: No password provided for RabbitMQ") from exc try: password = read_file(password_filepath) - except IOError as e: # pylint: disable=C0103 - raise ValueError(f"Broker: RabbitMQ password file {password_filepath} does not exist") from e + except IOError as exc: + raise ValueError(f"Broker: RabbitMQ password file {password_filepath} does not exist") from exc try: port = CONFIG.broker.port @@ -205,12 +205,6 @@ def get_connection_string(include_password=True): except AttributeError: broker = "" - try: - config_path = CONFIG.celery.certs - config_path = os.path.abspath(os.path.expanduser(config_path)) - except AttributeError: - config_path = None - if broker not in BROKERS: raise ValueError(f"Error: {broker} is not a supported broker.") return _sort_valid_broker(broker, include_password) diff --git a/tests/unit/config/test_broker.py b/tests/unit/config/test_broker.py new file mode 100644 index 000000000..9d4760f3e --- /dev/null +++ b/tests/unit/config/test_broker.py @@ -0,0 +1,549 @@ +""" +Tests for the `broker.py` file. +""" +import os +from ssl import CERT_NONE +from typing import Any, Dict + +import pytest + +from merlin.config.broker import ( + RABBITMQ_CONNECTION, + REDISSOCK_CONNECTION, + get_connection_string, + get_rabbit_connection, + get_redis_connection, + get_redissock_connection, + get_ssl_config, + read_file, +) +from merlin.config.configfile import CONFIG +from tests.conftest import SERVER_PASS, create_pass_file + + +def test_read_file(merlin_server_dir: str): + """ + Test the `read_file` function. We'll start up our containerized redis server + so that we have a password file to read here. + + :param merlin_server_dir: The directory to the merlin test server configuration + """ + pass_file = f"{merlin_server_dir}/redis.pass" + create_pass_file(pass_file) + actual = read_file(pass_file) + assert actual == SERVER_PASS + + +def test_get_connection_string_invalid_broker(redis_config: "fixture"): # noqa: F821 + """ + Test the `get_connection_string` function with an invalid broker (a broker that isn't one of: + ["rabbitmq", "redis", "rediss", "redis+socket", "amqps", "amqp"]). + + :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + CONFIG.broker.name = "invalid_broker" + with pytest.raises(ValueError): + get_connection_string() + + +def test_get_connection_string_no_broker(redis_config: "fixture"): # noqa: F821 + """ + Test the `get_connection_string` function without a broker name value in the CONFIG object. This + should raise a ValueError just like the `test_get_connection_string_invalid_broker` does. + + :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + del CONFIG.broker.name + with pytest.raises(ValueError): + get_connection_string() + + +def test_get_connection_string_simple(redis_config: "fixture"): # noqa: F821 + """ + Test the `get_connection_string` function in the simplest way that we can. This function + will automatically check for a broker url and if it finds one in the CONFIG object it will just + return the value it finds. + + :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + test_url = "test_url" + CONFIG.broker.url = test_url + actual = get_connection_string() + assert actual == test_url + + +def test_get_ssl_config_no_broker(redis_config: "fixture"): # noqa: F821 + """ + Test the `get_ssl_config` function without a broker. This should return False. + + :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + del CONFIG.broker.name + assert not get_ssl_config() + + +class TestRabbitBroker: + """ + This class will house all tests necessary for our broker module when using a + rabbit broker. + """ + + def run_get_rabbit_connection(self, expected_vals: Dict[str, Any], include_password: bool, conn: str): + """ + Helper method to run the tests for the `get_rabbit_connection`. + + :param expected_vals: A dict of expected values for this test. Format: + {"conn": "", + "vhost": "host4testing", + "username": "default", + "password": "", + "server": "127.0.0.1", + "port": } + :param include_password: If True, include the password in the output. Otherwise don't. + :param conn: The connection type to pass in (either amqp or amqps) + """ + expected = RABBITMQ_CONNECTION.format(**expected_vals) + actual = get_rabbit_connection(include_password=include_password, conn=conn) + assert actual == expected + + def test_get_rabbit_connection(self, rabbit_config: "fixture"): # noqa: F821 + """ + Test the `get_rabbit_connection` function. + + :param rabbit_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + conn = "amqps" + expected_vals = { + "conn": conn, + "vhost": "host4testing", + "username": "default", + "password": SERVER_PASS, + "server": "127.0.0.1", + "port": 5671, + } + self.run_get_rabbit_connection(expected_vals=expected_vals, include_password=True, conn=conn) + + def test_get_rabbit_connection_dont_include_password(self, rabbit_config: "fixture"): # noqa: F821 + """ + Test the `get_rabbit_connection` function but set include_password to False. This should * out the + password + + :param rabbit_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + conn = "amqps" + expected_vals = { + "conn": conn, + "vhost": "host4testing", + "username": "default", + "password": "******", + "server": "127.0.0.1", + "port": 5671, + } + self.run_get_rabbit_connection(expected_vals=expected_vals, include_password=False, conn=conn) + + def test_get_rabbit_connection_no_port_amqp(self, rabbit_config: "fixture"): # noqa: F821 + """ + Test the `get_rabbit_connection` function with no port in the CONFIG object. This should use + 5672 as the port since we're using amqp as the connection. + + :param rabbit_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + del CONFIG.broker.port + CONFIG.broker.name = "amqp" + conn = "amqp" + expected_vals = { + "conn": conn, + "vhost": "host4testing", + "username": "default", + "password": SERVER_PASS, + "server": "127.0.0.1", + "port": 5672, + } + self.run_get_rabbit_connection(expected_vals=expected_vals, include_password=True, conn=conn) + + def test_get_rabbit_connection_no_port_amqps(self, rabbit_config: "fixture"): # noqa: F821 + """ + Test the `get_rabbit_connection` function with no port in the CONFIG object. This should use + 5671 as the port since we're using amqps as the connection. + + :param rabbit_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + del CONFIG.broker.port + conn = "amqps" + expected_vals = { + "conn": conn, + "vhost": "host4testing", + "username": "default", + "password": SERVER_PASS, + "server": "127.0.0.1", + "port": 5671, + } + self.run_get_rabbit_connection(expected_vals=expected_vals, include_password=True, conn=conn) + + def test_get_rabbit_connection_no_password(self, rabbit_config: "fixture"): # noqa: F821 + """ + Test the `get_rabbit_connection` function with no password file set. This should raise a ValueError. + + :param rabbit_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + del CONFIG.broker.password + with pytest.raises(ValueError) as excinfo: + get_rabbit_connection(True) + assert "Broker: No password provided for RabbitMQ" in str(excinfo.value) + + def test_get_rabbit_connection_invalid_pass_filepath(self, rabbit_config: "fixture"): # noqa: F821 + """ + Test the `get_rabbit_connection` function with an invalid password filepath. + This should raise a ValueError. + + :param rabbit_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + CONFIG.broker.password = "invalid_filepath" + expanded_filepath = os.path.abspath(os.path.expanduser(CONFIG.broker.password)) + with pytest.raises(ValueError) as excinfo: + get_rabbit_connection(True) + assert f"Broker: RabbitMQ password file {expanded_filepath} does not exist" in str(excinfo.value) + + def run_get_connection_string(self, expected_vals: Dict[str, Any]): + """ + Helper method to run the tests for the `get_connection_string`. + + :param expected_vals: A dict of expected values for this test. Format: + {"conn": "", + "vhost": "host4testing", + "username": "default", + "password": "", + "server": "127.0.0.1", + "port": } + """ + expected = RABBITMQ_CONNECTION.format(**expected_vals) + actual = get_connection_string() + assert actual == expected + + def test_get_connection_string_rabbitmq(self, rabbit_config: "fixture"): # noqa: F821 + """ + Test the `get_connection_string` function with rabbitmq as the broker. + + :param rabbit_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + expected_vals = { + "conn": "amqps", + "vhost": "host4testing", + "username": "default", + "password": SERVER_PASS, + "server": "127.0.0.1", + "port": 5671, + } + self.run_get_connection_string(expected_vals) + + def test_get_connection_string_amqp(self, rabbit_config: "fixture"): # noqa: F821 + """ + Test the `get_connection_string` function with amqp as the broker. + + :param rabbit_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + del CONFIG.broker.port + CONFIG.broker.name = "amqp" + expected_vals = { + "conn": "amqp", + "vhost": "host4testing", + "username": "default", + "password": SERVER_PASS, + "server": "127.0.0.1", + "port": 5672, + } + self.run_get_connection_string(expected_vals) + + +class TestRedisBroker: + """ + This class will house all tests necessary for our broker module when using a + redis broker. + """ + + def run_get_redissock_connection(self, expected_vals: Dict[str, str]): + """ + Helper method to run the tests for the `get_redissock_connection`. + + :param expected_vals: A dict of expected values for this test. Format: + {"db_num": "", "path": ""} + """ + expected = REDISSOCK_CONNECTION.format(**expected_vals) + actual = get_redissock_connection() + assert actual == expected + + def test_get_redissock_connection(self, redis_config: "fixture"): # noqa: F821 + """ + Test the `get_redissock_connection` function with both a db_num and a broker path set. + + :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + # Create and store a fake path and db_num for testing + test_path = "/fake/path/to/broker" + test_db_num = "45" + CONFIG.broker.path = test_path + CONFIG.broker.db_num = test_db_num + + # Set up our expected vals and compare against the actual result + expected_vals = {"db_num": test_db_num, "path": test_path} + self.run_get_redissock_connection(expected_vals) + + def test_get_redissock_connection_no_db(self, redis_config: "fixture"): # noqa: F821 + """ + Test the `get_redissock_connection` function with a broker path set but no db num. + This should default the db_num to 0. + + :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + # Create and store a fake path for testing + test_path = "/fake/path/to/broker" + CONFIG.broker.path = test_path + + # Set up our expected vals and compare against the actual result + expected_vals = {"db_num": 0, "path": test_path} + self.run_get_redissock_connection(expected_vals) + + def test_get_redissock_connection_no_path(self, redis_config: "fixture"): # noqa: F821 + """ + Test the `get_redissock_connection` function with a db num set but no broker path. + This should raise an AttributeError since there will be no path value to read from + in `CONFIG.broker`. + + :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + CONFIG.broker.db_num = "45" + with pytest.raises(AttributeError): + get_redissock_connection() + + def test_get_redissock_connection_no_path_nor_db(self, redis_config: "fixture"): # noqa: F821 + """ + Test the `get_redissock_connection` function with neither a broker path nor a db num set. + This should raise an AttributeError since there will be no path value to read from + in `CONFIG.broker`. + + :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + with pytest.raises(AttributeError): + get_redissock_connection() + + def run_get_redis_connection(self, expected_vals: Dict[str, Any], include_password: bool, use_ssl: bool): + """ + Helper method to run the tests for the `get_redis_connection`. + + :param expected_vals: A dict of expected values for this test. Format: + {"urlbase": "", "spass": "", "server": "127.0.0.1", "port": , "db_num": } + :param include_password: If True, include the password in the output. Otherwise don't. + :param use_ssl: If True, use ssl for the connection. Otherwise don't. + """ + expected = "{urlbase}://{spass}{server}:{port}/{db_num}".format(**expected_vals) + actual = get_redis_connection(include_password=include_password, use_ssl=use_ssl) + assert expected == actual + + def test_get_redis_connection(self, redis_config: "fixture"): # noqa: F821 + """ + Test the `get_redis_connection` function with default functionality (including password and not using ssl). + + :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + expected_vals = { + "urlbase": "redis", + "spass": "default:merlin-test-server@", + "server": "127.0.0.1", + "port": 6379, + "db_num": 0, + } + self.run_get_redis_connection(expected_vals=expected_vals, include_password=True, use_ssl=False) + + def test_get_redis_connection_no_port(self, redis_config: "fixture"): # noqa: F821 + """ + Test the `get_redis_connection` function with default functionality (including password and not using ssl). + We'll run this after deleting the port setting from the CONFIG object. This should still run and give us + port = 6379. + + :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + del CONFIG.broker.port + expected_vals = { + "urlbase": "redis", + "spass": "default:merlin-test-server@", + "server": "127.0.0.1", + "port": 6379, + "db_num": 0, + } + self.run_get_redis_connection(expected_vals=expected_vals, include_password=True, use_ssl=False) + + def test_get_redis_connection_with_db(self, redis_config: "fixture"): # noqa: F821 + """ + Test the `get_redis_connection` function with default functionality (including password and not using ssl). + We'll run this after adding the db_num setting to the CONFIG object. + + :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + test_db_num = "45" + CONFIG.broker.db_num = test_db_num + expected_vals = { + "urlbase": "redis", + "spass": "default:merlin-test-server@", + "server": "127.0.0.1", + "port": 6379, + "db_num": test_db_num, + } + self.run_get_redis_connection(expected_vals=expected_vals, include_password=True, use_ssl=False) + + def test_get_redis_connection_no_username(self, redis_config: "fixture"): # noqa: F821 + """ + Test the `get_redis_connection` function with default functionality (including password and not using ssl). + We'll run this after deleting the username setting from the CONFIG object. This should still run and give us + username = ''. + + :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + del CONFIG.broker.username + expected_vals = {"urlbase": "redis", "spass": ":merlin-test-server@", "server": "127.0.0.1", "port": 6379, "db_num": 0} + self.run_get_redis_connection(expected_vals=expected_vals, include_password=True, use_ssl=False) + + def test_get_redis_connection_invalid_pass_file(self, redis_config: "fixture"): # noqa: F821 + """ + Test the `get_redis_connection` function with default functionality (including password and not using ssl). + We'll run this after changing the permissions of the password file so it can't be opened. This should still + run and give us password = CONFIG.broker.password. + + :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + # Capture the initial permissions of the password file so we can reset them + orig_file_permissions = os.stat(CONFIG.broker.password).st_mode + + # Change the permissions of the password file so it can't be read + os.chmod(CONFIG.broker.password, 0o222) + + try: + # Run the test + expected_vals = { + "urlbase": "redis", + "spass": f"default:{CONFIG.broker.password}@", + "server": "127.0.0.1", + "port": 6379, + "db_num": 0, + } + self.run_get_redis_connection(expected_vals=expected_vals, include_password=True, use_ssl=False) + except AssertionError as exc: + # If this test failed, make sure to reset the permissions in case other tests need to read this file + os.chmod(CONFIG.broker.password, orig_file_permissions) + raise AssertionError from exc + + os.chmod(CONFIG.broker.password, orig_file_permissions) + + def test_get_redis_connection_dont_include_password(self, redis_config: "fixture"): # noqa: F821 + """ + Test the `get_redis_connection` function without including the password. This should place 6 *s + where the password would normally be placed in spass. + + :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + expected_vals = {"urlbase": "redis", "spass": "default:******@", "server": "127.0.0.1", "port": 6379, "db_num": 0} + self.run_get_redis_connection(expected_vals=expected_vals, include_password=False, use_ssl=False) + + def test_get_redis_connection_use_ssl(self, redis_config: "fixture"): # noqa: F821 + """ + Test the `get_redis_connection` function with using ssl. This should change the urlbase to rediss (with two 's'). + + :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + expected_vals = { + "urlbase": "rediss", + "spass": "default:merlin-test-server@", + "server": "127.0.0.1", + "port": 6379, + "db_num": 0, + } + self.run_get_redis_connection(expected_vals=expected_vals, include_password=True, use_ssl=True) + + def test_get_redis_connection_no_password(self, redis_config: "fixture"): # noqa: F821 + """ + Test the `get_redis_connection` function with default functionality (including password and not using ssl). + We'll run this after deleting the password setting from the CONFIG object. This should still run and give us + spass = ''. + + :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + del CONFIG.broker.password + expected_vals = {"urlbase": "redis", "spass": "", "server": "127.0.0.1", "port": 6379, "db_num": 0} + self.run_get_redis_connection(expected_vals=expected_vals, include_password=True, use_ssl=False) + + def test_get_connection_string_redis(self, redis_config: "fixture"): # noqa: F821 + """ + Test the `get_connection_string` function with redis as the broker (this is what our CONFIG + is set to by default with the redis_config fixture). + + :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + expected_vals = { + "urlbase": "redis", + "spass": "default:merlin-test-server@", + "server": "127.0.0.1", + "port": 6379, + "db_num": 0, + } + expected = "{urlbase}://{spass}{server}:{port}/{db_num}".format(**expected_vals) + actual = get_connection_string() + assert expected == actual + + def test_get_connection_string_rediss(self, redis_config: "fixture"): # noqa: F821 + """ + Test the `get_connection_string` function with rediss (with two 's') as the broker. + + :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + CONFIG.broker.name = "rediss" + expected_vals = { + "urlbase": "rediss", + "spass": "default:merlin-test-server@", + "server": "127.0.0.1", + "port": 6379, + "db_num": 0, + } + expected = "{urlbase}://{spass}{server}:{port}/{db_num}".format(**expected_vals) + actual = get_connection_string() + assert expected == actual + + def test_get_connection_string_redis_socket(self, redis_config: "fixture"): # noqa: F821 + """ + Test the `get_connection_string` function with redis+socket as the broker. + + :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + # Change our broker + CONFIG.broker.name = "redis+socket" + + # Create and store a fake path and db_num for testing + test_path = "/fake/path/to/broker" + test_db_num = "45" + CONFIG.broker.path = test_path + CONFIG.broker.db_num = test_db_num + + # Set up our expected vals and compare against the actual result + expected_vals = {"db_num": test_db_num, "path": test_path} + expected = REDISSOCK_CONNECTION.format(**expected_vals) + actual = get_connection_string() + assert actual == expected + + def test_get_ssl_config_redis(self, redis_config: "fixture"): # noqa: F821 + """ + Test the `get_ssl_config` function with redis as the broker (this is the default in our tests). + This should return False. + + :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + assert not get_ssl_config() + + def test_get_ssl_config_rediss(self, redis_config: "fixture"): # noqa: F821 + """ + Test the `get_ssl_config` function with rediss (with two 's') as the broker. + This should return a dict of cert reqs with ssl.CERT_NONE as the value. + + :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + CONFIG.broker.name = "rediss" + expected = {"ssl_cert_reqs": CERT_NONE} + actual = get_ssl_config() + assert actual == expected From 896898e89c432221d708f2d7ca83986ac59a9b04 Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Thu, 14 Dec 2023 15:31:57 -0800 Subject: [PATCH 14/44] add unit tests for the Config object --- merlin/config/__init__.py | 4 +- setup.cfg | 5 + tests/conftest.py | 4 +- tests/unit/config/test_config_object.py | 149 ++++++++++++++++++++++++ 4 files changed, 158 insertions(+), 4 deletions(-) create mode 100644 tests/unit/config/test_config_object.py diff --git a/merlin/config/__init__.py b/merlin/config/__init__.py index d0a0bf9c5..af1562ae4 100644 --- a/merlin/config/__init__.py +++ b/merlin/config/__init__.py @@ -80,9 +80,9 @@ def __str__(self): if attr is not None: items = (f" {k}: {v!r}" for k, v in attr.__dict__.items()) joined_items = "\n".join(items) - formatted_str += f"\n {name}: \n{joined_items}" + formatted_str += f"\n {name}:\n{joined_items}" else: - formatted_str += f"\n {name}: \n None" + formatted_str += f"\n {name}:\n None" return formatted_str def load_app_into_namespaces(self, app_dict: Dict) -> None: diff --git a/setup.cfg b/setup.cfg index a000df59a..0eaa116ea 100644 --- a/setup.cfg +++ b/setup.cfg @@ -26,3 +26,8 @@ max-line-length = 127 files=best_practices,test ignore_missing_imports=true + +[coverage:run] +omit = + merlin/ascii.py + merlin/config/celeryconfig.py diff --git a/tests/conftest.py b/tests/conftest.py index 3415385f9..79ad427bb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -318,8 +318,8 @@ def redis_config(merlin_server_dir: str, config: "fixture"): # noqa: F821 pylin @pytest.fixture(scope="function") def rabbit_config( - merlin_server_dir: str, config: "fixture" -): # noqa: F821 pylint: disable=redefined-outer-name,unused-argument + merlin_server_dir: str, config: "fixture" # noqa: F821 pylint: disable=redefined-outer-name,unused-argument +): """ This fixture is intended to be used for testing any functionality in the codebase that uses the CONFIG object with a RabbitMQ broker and Redis results_backend. diff --git a/tests/unit/config/test_config_object.py b/tests/unit/config/test_config_object.py new file mode 100644 index 000000000..bd658bc66 --- /dev/null +++ b/tests/unit/config/test_config_object.py @@ -0,0 +1,149 @@ +""" +Test the functionality of the Config object. +""" +from copy import copy, deepcopy +from types import SimpleNamespace + +from merlin.config import Config + + +class TestConfig: + """ + Class for testing the Config object. We'll store a valid `app_dict` + as an attribute here so that each test doesn't have to redefine it + each time. + """ + + app_dict = { + "celery": {"override": {"visibility_timeout": 86400}}, + "broker": { + "cert_reqs": "none", + "name": "rabbitmq", + "password": "/path/to/pass_file", + "port": 5671, + "server": "127.0.0.1", + "username": "default", + "vhost": "host4testing", + }, + "results_backend": { + "cert_reqs": "none", + "db_num": 0, + "name": "rediss", + "password": "/path/to/pass_file", + "port": 6379, + "server": "127.0.0.1", + "username": "default", + "vhost": "host4testing", + "encryption_key": "/path/to/encryption_key", + }, + } + + def test_config_creation(self): + """ + Test the creation of the Config object. This should create nested namespaces + for each key in the `app_dict` variable and save them to their respective + attributes in the object. + """ + config = Config(self.app_dict) + + # Create the nested namespace for celery and compare result + override_namespace = SimpleNamespace(**self.app_dict["celery"]["override"]) + updated_celery_dict = deepcopy(self.app_dict) + updated_celery_dict["celery"]["override"] = override_namespace + celery_namespace = SimpleNamespace(**updated_celery_dict["celery"]) + assert config.celery == celery_namespace + + # Broker and Results Backend are easier since there's no nested namespace here + assert config.broker == SimpleNamespace(**self.app_dict["broker"]) + assert config.results_backend == SimpleNamespace(**self.app_dict["results_backend"]) + + def test_config_creation_no_celery(self): + """ + Test the creation of the Config object without the celery key. This should still + work and just not set anything for the celery attribute. + """ + + # Copy the celery section so we can restore it later and then delete it + celery_section = copy(self.app_dict["celery"]) + del self.app_dict["celery"] + config = Config(self.app_dict) + + # Broker and Results Backend are the only things loaded here + assert config.broker == SimpleNamespace(**self.app_dict["broker"]) + assert config.results_backend == SimpleNamespace(**self.app_dict["results_backend"]) + + # Ensure the celery attribute is not loaded + assert "celery" not in dir(config) + + # Reset celery section in case other tests use it after this + self.app_dict["celery"] = celery_section + + def test_config_copy(self): + """ + Test the `__copy__` magic method of the Config object. Here we'll make sure + each attribute was copied properly but the ids should be different. + """ + orig_config = Config(self.app_dict) + copied_config = copy(orig_config) + + assert orig_config.celery == copied_config.celery + assert orig_config.broker == copied_config.broker + assert orig_config.results_backend == copied_config.results_backend + + assert id(orig_config) != id(copied_config) + + def test_config_str(self): + """ + Test the `__str__` magic method of the Config object. This should just give us + a formatted string of the attributes in the object. + """ + config = Config(self.app_dict) + + # Test normal printing + actual = config.__str__() + expected = ( + "config:\n" + " celery:\n" + " override: namespace(visibility_timeout=86400)\n" + " broker:\n" + " cert_reqs: 'none'\n" + " name: 'rabbitmq'\n" + " password: '/path/to/pass_file'\n" + " port: 5671\n" + " server: '127.0.0.1'\n" + " username: 'default'\n" + " vhost: 'host4testing'\n" + " results_backend:\n" + " cert_reqs: 'none'\n" + " db_num: 0\n" + " name: 'rediss'\n" + " password: '/path/to/pass_file'\n" + " port: 6379\n" + " server: '127.0.0.1'\n" + " username: 'default'\n" + " vhost: 'host4testing'\n" + " encryption_key: '/path/to/encryption_key'" + ) + + assert actual == expected + + # Test printing with one section set to None + config.results_backend = None + actual_with_none = config.__str__() + expected_with_none = ( + "config:\n" + " celery:\n" + " override: namespace(visibility_timeout=86400)\n" + " broker:\n" + " cert_reqs: 'none'\n" + " name: 'rabbitmq'\n" + " password: '/path/to/pass_file'\n" + " port: 5671\n" + " server: '127.0.0.1'\n" + " username: 'default'\n" + " vhost: 'host4testing'\n" + " results_backend:\n" + " None" + ) + + assert actual_with_none == expected_with_none From 2a5148c9b233a95bfe45762ce27dd0072be61e9a Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Thu, 14 Dec 2023 15:32:08 -0800 Subject: [PATCH 15/44] update CHANGELOG --- CHANGELOG.md | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bf0074e9d..3d0bea05d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,10 +6,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added -- Pytest fixtures in the `conftest.py` file of the integration test suite +- Pytest fixtures in the `conftest.py` file of the test suite - NOTE: an export command `export LC_ALL='C'` had to be added to fix a bug in the WEAVE CI. This can be removed when we resolve this issue for the `merlin server` command -- Tests for the `celeryadapter.py` module -- New CeleryTestWorkersManager context to help with starting/stopping workers for tests +- Coverage to the test suite. This includes adding tests for: + - `merlin/common/` + - `merlin/config/` + - `celeryadapter.py` +- Context managers for the `conftest.py` file to ensure safe spin up and shutdown of fixtures + - RedisServerManager: context to help with starting/stopping a redis server for tests + - CeleryWorkersManager: context to help with starting/stopping workers for tests +- Ability to copy and print the `Config` object from `merlin/config/__init__.py` ### Fixed - The `merlin status` command so that it's consistent in its output whether using redis or rabbitmq as the broker From 3d7228e1b981962d873520bbbd65ea78b9b76dd0 Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Mon, 18 Dec 2023 08:54:15 -0800 Subject: [PATCH 16/44] make CONFIG fixtures more flexible for tests --- tests/conftest.py | 79 ++++++++++++++---- tests/unit/common/test_encryption.py | 16 ++-- tests/unit/config/test_broker.py | 118 +++++++++++++-------------- 3 files changed, 132 insertions(+), 81 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 79ad427bb..b681ade35 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -248,6 +248,21 @@ def test_encryption_key() -> bytes: return b"Q3vLp07Ljm60ahfU9HwOOnfgGY91lSrUmqcTiP0v9i0=" +####################################### +########### CONFIG Fixtures ########### +####################################### +# These are intended to be used # +# either by themselves or together # +# For example, you can use a rabbit # +# broker config and a redis results # +# backend config together # +####################################### +############ !!!WARNING!!! ############ +# DO NOT USE THE `config` FIXTURE # +# IN A TEST; IT HAS UNSET VALUES # +####################################### + + @pytest.fixture(scope="function") def config(merlin_server_dir: str, test_encryption_key: bytes): # pylint: disable=redefined-outer-name """ @@ -258,7 +273,6 @@ def config(merlin_server_dir: str, test_encryption_key: bytes): # pylint: disab :param merlin_server_dir: The directory to the merlin test server configuration :param test_encryption_key: An encryption key to be used for testing """ - # global CONFIG # Create a copy of the CONFIG option so we can reset it after the test orig_config = copy(CONFIG) @@ -268,23 +282,24 @@ def config(merlin_server_dir: str, test_encryption_key: bytes): # pylint: disab create_encryption_file(key_file, test_encryption_key) # Set the broker configuration for testing - CONFIG.broker.password = "password path not yet set" # This will be updated in `redis_config` or `rabbit_config` - CONFIG.broker.port = "port not yet set" # This will be updated in `redis_config` or `rabbit_config` - CONFIG.broker.name = "name not yet set" # This will be updated in `redis_config` or `rabbit_config` + CONFIG.broker.password = None # This will be updated in `redis_broker_config` or `rabbit_broker_config` + CONFIG.broker.port = None # This will be updated in `redis_broker_config` or `rabbit_broker_config` + CONFIG.broker.name = None # This will be updated in `redis_broker_config` or `rabbit_broker_config` CONFIG.broker.server = "127.0.0.1" CONFIG.broker.username = "default" CONFIG.broker.vhost = "host4testing" CONFIG.broker.cert_reqs = "none" # Set the results_backend configuration for testing - CONFIG.results_backend.password = f"{merlin_server_dir}/redis.pass" - CONFIG.results_backend.port = 6379 + CONFIG.results_backend.password = None # This will be updated in `redis_results_backend_config` or `mysql_results_backend_config` + CONFIG.results_backend.port = None # This will be updated in `redis_results_backend_config` + CONFIG.results_backend.name = None # This will be updated in `redis_results_backend_config` or `mysql_results_backend_config` + CONFIG.results_backend.dbname = None # This will be updated in `mysql_results_backend_config` CONFIG.results_backend.server = "127.0.0.1" CONFIG.results_backend.username = "default" CONFIG.results_backend.cert_reqs = "none" CONFIG.results_backend.encryption_key = key_file CONFIG.results_backend.db_num = 0 - CONFIG.results_backend.name = "redis" # Go run the tests yield @@ -296,7 +311,7 @@ def config(merlin_server_dir: str, test_encryption_key: bytes): # pylint: disab @pytest.fixture(scope="function") -def redis_config(merlin_server_dir: str, config: "fixture"): # noqa: F821 pylint: disable=redefined-outer-name,unused-argument +def redis_broker_config(merlin_server_dir: str, config: "fixture"): # noqa: F821 pylint: disable=redefined-outer-name,unused-argument """ This fixture is intended to be used for testing any functionality in the codebase that uses the CONFIG object with a Redis broker and results_backend. @@ -304,8 +319,6 @@ def redis_config(merlin_server_dir: str, config: "fixture"): # noqa: F821 pylin :param merlin_server_dir: The directory to the merlin test server configuration :param config: The fixture that sets up most of the CONFIG object for testing """ - # global CONFIG - pass_file = f"{merlin_server_dir}/redis.pass" create_pass_file(pass_file) @@ -317,18 +330,35 @@ def redis_config(merlin_server_dir: str, config: "fixture"): # noqa: F821 pylin @pytest.fixture(scope="function") -def rabbit_config( +def redis_results_backend_config(merlin_server_dir: str, config: "fixture"): # noqa: F821 pylint: disable=redefined-outer-name,unused-argument + """ + This fixture is intended to be used for testing any functionality in the codebase + that uses the CONFIG object with a Redis results_backend. + + :param merlin_server_dir: The directory to the merlin test server configuration + :param config: The fixture that sets up most of the CONFIG object for testing + """ + pass_file = f"{merlin_server_dir}/redis.pass" + create_pass_file(pass_file) + + CONFIG.results_backend.password = pass_file + CONFIG.results_backend.port = 6379 + CONFIG.results_backend.name = "redis" + + yield + + +@pytest.fixture(scope="function") +def rabbit_broker_config( merlin_server_dir: str, config: "fixture" # noqa: F821 pylint: disable=redefined-outer-name,unused-argument ): """ This fixture is intended to be used for testing any functionality in the codebase - that uses the CONFIG object with a RabbitMQ broker and Redis results_backend. + that uses the CONFIG object with a RabbitMQ broker. :param merlin_server_dir: The directory to the merlin test server configuration :param config: The fixture that sets up most of the CONFIG object for testing """ - # global CONFIG - pass_file = f"{merlin_server_dir}/rabbit.pass" create_pass_file(pass_file) @@ -337,3 +367,24 @@ def rabbit_config( CONFIG.broker.name = "rabbitmq" yield + + +@pytest.fixture(scope="function") +def mysql_results_backend_config( + merlin_server_dir: str, config: "fixture" # noqa: F821 pylint: disable=redefined-outer-name,unused-argument +): + """ + This fixture is intended to be used for testing any functionality in the codebase + that uses the CONFIG object with a MySQL results_backend. + + :param merlin_server_dir: The directory to the merlin test server configuration + :param config: The fixture that sets up most of the CONFIG object for testing + """ + pass_file = f"{merlin_server_dir}/mysql.pass" + create_pass_file(pass_file) + + CONFIG.results_backend.password = pass_file + CONFIG.results_backend.name = "mysql" + CONFIG.results_backend.dbname = "test_mysql_db" + + yield diff --git a/tests/unit/common/test_encryption.py b/tests/unit/common/test_encryption.py index 6392cf8da..d0069f09e 100644 --- a/tests/unit/common/test_encryption.py +++ b/tests/unit/common/test_encryption.py @@ -16,34 +16,34 @@ class TestEncryption: This class will house all tests necessary for our encryption modules. """ - def test_encrypt(self, redis_config: "fixture"): # noqa: F821 + def test_encrypt(self, redis_results_backend_config: "fixture"): # noqa: F821 """ Test that our encryption function is encrypting the bytes that we're passing to it. - :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param redis_results_backend_config: A fixture to set the CONFIG object to a test configuration that we'll use here """ str_to_encrypt = b"super secret string shhh" encrypted_str = encrypt(str_to_encrypt) for word in str_to_encrypt.decode("utf-8").split(" "): assert word not in encrypted_str.decode("utf-8") - def test_decrypt(self, redis_config: "fixture"): # noqa: F821 + def test_decrypt(self, redis_results_backend_config: "fixture"): # noqa: F821 """ Test that our decryption function is decrypting the bytes that we're passing to it. - :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param redis_results_backend_config: A fixture to set the CONFIG object to a test configuration that we'll use here """ # This is the output of the bytes from the encrypt test str_to_decrypt = b"gAAAAABld6k-jEncgCW5AePgrwn-C30dhr7dzGVhqzcqskPqFyA2Hdg3VWmo0qQnLklccaUYzAGlB4PMxyp4T-1gAYlAOf_7sC_bJOEcYOIkhZFoH6cX4Uw=" decrypted_str = decrypt(str_to_decrypt) assert decrypted_str == b"super secret string shhh" - def test_get_key_path(self, redis_config: "fixture"): # noqa: F821 + def test_get_key_path(self, redis_results_backend_config: "fixture"): # noqa: F821 """ Test the `_get_key_path` function. - :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param redis_results_backend_config: A fixture to set the CONFIG object to a test configuration that we'll use here """ # Test the default behavior (`_get_key_path` will pull from CONFIG.results_backend which # will be set to the temporary output path for our tests in the `use_fake_encrypt_data_key` fixture) @@ -87,13 +87,13 @@ def test_gen_key(self, temp_output_dir: str): key_gen_contents = key_gen_file.read() assert key_gen_contents != "" - def test_get_key(self, merlin_server_dir: str, test_encryption_key: bytes, redis_config: "fixture"): # noqa: F821 + def test_get_key(self, merlin_server_dir: str, test_encryption_key: bytes, redis_results_backend_config: "fixture"): # noqa: F821 """ Test the `_get_key` function. :param merlin_server_dir: The directory to the merlin test server configuration :param test_encryption_key: A fixture to establish a fixed encryption key for testing - :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param redis_results_backend_config: A fixture to set the CONFIG object to a test configuration that we'll use here """ # Test the default functionality actual_default = _get_key() diff --git a/tests/unit/config/test_broker.py b/tests/unit/config/test_broker.py index 9d4760f3e..490b47649 100644 --- a/tests/unit/config/test_broker.py +++ b/tests/unit/config/test_broker.py @@ -34,37 +34,37 @@ def test_read_file(merlin_server_dir: str): assert actual == SERVER_PASS -def test_get_connection_string_invalid_broker(redis_config: "fixture"): # noqa: F821 +def test_get_connection_string_invalid_broker(redis_broker_config: "fixture"): # noqa: F821 """ Test the `get_connection_string` function with an invalid broker (a broker that isn't one of: ["rabbitmq", "redis", "rediss", "redis+socket", "amqps", "amqp"]). - :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param redis_broker_config: A fixture to set the CONFIG object to a test configuration that we'll use here """ CONFIG.broker.name = "invalid_broker" with pytest.raises(ValueError): get_connection_string() -def test_get_connection_string_no_broker(redis_config: "fixture"): # noqa: F821 +def test_get_connection_string_no_broker(redis_broker_config: "fixture"): # noqa: F821 """ Test the `get_connection_string` function without a broker name value in the CONFIG object. This should raise a ValueError just like the `test_get_connection_string_invalid_broker` does. - :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param redis_broker_config: A fixture to set the CONFIG object to a test configuration that we'll use here """ del CONFIG.broker.name with pytest.raises(ValueError): get_connection_string() -def test_get_connection_string_simple(redis_config: "fixture"): # noqa: F821 +def test_get_connection_string_simple(redis_broker_config: "fixture"): # noqa: F821 """ Test the `get_connection_string` function in the simplest way that we can. This function will automatically check for a broker url and if it finds one in the CONFIG object it will just return the value it finds. - :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param redis_broker_config: A fixture to set the CONFIG object to a test configuration that we'll use here """ test_url = "test_url" CONFIG.broker.url = test_url @@ -72,11 +72,11 @@ def test_get_connection_string_simple(redis_config: "fixture"): # noqa: F821 assert actual == test_url -def test_get_ssl_config_no_broker(redis_config: "fixture"): # noqa: F821 +def test_get_ssl_config_no_broker(redis_broker_config: "fixture"): # noqa: F821 """ Test the `get_ssl_config` function without a broker. This should return False. - :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param redis_broker_config: A fixture to set the CONFIG object to a test configuration that we'll use here """ del CONFIG.broker.name assert not get_ssl_config() @@ -106,11 +106,11 @@ def run_get_rabbit_connection(self, expected_vals: Dict[str, Any], include_passw actual = get_rabbit_connection(include_password=include_password, conn=conn) assert actual == expected - def test_get_rabbit_connection(self, rabbit_config: "fixture"): # noqa: F821 + def test_get_rabbit_connection(self, rabbit_broker_config: "fixture"): # noqa: F821 """ Test the `get_rabbit_connection` function. - :param rabbit_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param rabbit_broker_config: A fixture to set the CONFIG object to a test configuration that we'll use here """ conn = "amqps" expected_vals = { @@ -123,12 +123,12 @@ def test_get_rabbit_connection(self, rabbit_config: "fixture"): # noqa: F821 } self.run_get_rabbit_connection(expected_vals=expected_vals, include_password=True, conn=conn) - def test_get_rabbit_connection_dont_include_password(self, rabbit_config: "fixture"): # noqa: F821 + def test_get_rabbit_connection_dont_include_password(self, rabbit_broker_config: "fixture"): # noqa: F821 """ Test the `get_rabbit_connection` function but set include_password to False. This should * out the password - :param rabbit_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param rabbit_broker_config: A fixture to set the CONFIG object to a test configuration that we'll use here """ conn = "amqps" expected_vals = { @@ -141,12 +141,12 @@ def test_get_rabbit_connection_dont_include_password(self, rabbit_config: "fixtu } self.run_get_rabbit_connection(expected_vals=expected_vals, include_password=False, conn=conn) - def test_get_rabbit_connection_no_port_amqp(self, rabbit_config: "fixture"): # noqa: F821 + def test_get_rabbit_connection_no_port_amqp(self, rabbit_broker_config: "fixture"): # noqa: F821 """ Test the `get_rabbit_connection` function with no port in the CONFIG object. This should use 5672 as the port since we're using amqp as the connection. - :param rabbit_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param rabbit_broker_config: A fixture to set the CONFIG object to a test configuration that we'll use here """ del CONFIG.broker.port CONFIG.broker.name = "amqp" @@ -161,12 +161,12 @@ def test_get_rabbit_connection_no_port_amqp(self, rabbit_config: "fixture"): # } self.run_get_rabbit_connection(expected_vals=expected_vals, include_password=True, conn=conn) - def test_get_rabbit_connection_no_port_amqps(self, rabbit_config: "fixture"): # noqa: F821 + def test_get_rabbit_connection_no_port_amqps(self, rabbit_broker_config: "fixture"): # noqa: F821 """ Test the `get_rabbit_connection` function with no port in the CONFIG object. This should use 5671 as the port since we're using amqps as the connection. - :param rabbit_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param rabbit_broker_config: A fixture to set the CONFIG object to a test configuration that we'll use here """ del CONFIG.broker.port conn = "amqps" @@ -180,23 +180,23 @@ def test_get_rabbit_connection_no_port_amqps(self, rabbit_config: "fixture"): # } self.run_get_rabbit_connection(expected_vals=expected_vals, include_password=True, conn=conn) - def test_get_rabbit_connection_no_password(self, rabbit_config: "fixture"): # noqa: F821 + def test_get_rabbit_connection_no_password(self, rabbit_broker_config: "fixture"): # noqa: F821 """ Test the `get_rabbit_connection` function with no password file set. This should raise a ValueError. - :param rabbit_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param rabbit_broker_config: A fixture to set the CONFIG object to a test configuration that we'll use here """ del CONFIG.broker.password with pytest.raises(ValueError) as excinfo: get_rabbit_connection(True) assert "Broker: No password provided for RabbitMQ" in str(excinfo.value) - def test_get_rabbit_connection_invalid_pass_filepath(self, rabbit_config: "fixture"): # noqa: F821 + def test_get_rabbit_connection_invalid_pass_filepath(self, rabbit_broker_config: "fixture"): # noqa: F821 """ Test the `get_rabbit_connection` function with an invalid password filepath. This should raise a ValueError. - :param rabbit_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param rabbit_broker_config: A fixture to set the CONFIG object to a test configuration that we'll use here """ CONFIG.broker.password = "invalid_filepath" expanded_filepath = os.path.abspath(os.path.expanduser(CONFIG.broker.password)) @@ -220,11 +220,11 @@ def run_get_connection_string(self, expected_vals: Dict[str, Any]): actual = get_connection_string() assert actual == expected - def test_get_connection_string_rabbitmq(self, rabbit_config: "fixture"): # noqa: F821 + def test_get_connection_string_rabbitmq(self, rabbit_broker_config: "fixture"): # noqa: F821 """ Test the `get_connection_string` function with rabbitmq as the broker. - :param rabbit_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param rabbit_broker_config: A fixture to set the CONFIG object to a test configuration that we'll use here """ expected_vals = { "conn": "amqps", @@ -236,11 +236,11 @@ def test_get_connection_string_rabbitmq(self, rabbit_config: "fixture"): # noqa } self.run_get_connection_string(expected_vals) - def test_get_connection_string_amqp(self, rabbit_config: "fixture"): # noqa: F821 + def test_get_connection_string_amqp(self, rabbit_broker_config: "fixture"): # noqa: F821 """ Test the `get_connection_string` function with amqp as the broker. - :param rabbit_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param rabbit_broker_config: A fixture to set the CONFIG object to a test configuration that we'll use here """ del CONFIG.broker.port CONFIG.broker.name = "amqp" @@ -272,11 +272,11 @@ def run_get_redissock_connection(self, expected_vals: Dict[str, str]): actual = get_redissock_connection() assert actual == expected - def test_get_redissock_connection(self, redis_config: "fixture"): # noqa: F821 + def test_get_redissock_connection(self, redis_broker_config: "fixture"): # noqa: F821 """ Test the `get_redissock_connection` function with both a db_num and a broker path set. - :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param redis_broker_config: A fixture to set the CONFIG object to a test configuration that we'll use here """ # Create and store a fake path and db_num for testing test_path = "/fake/path/to/broker" @@ -288,12 +288,12 @@ def test_get_redissock_connection(self, redis_config: "fixture"): # noqa: F821 expected_vals = {"db_num": test_db_num, "path": test_path} self.run_get_redissock_connection(expected_vals) - def test_get_redissock_connection_no_db(self, redis_config: "fixture"): # noqa: F821 + def test_get_redissock_connection_no_db(self, redis_broker_config: "fixture"): # noqa: F821 """ Test the `get_redissock_connection` function with a broker path set but no db num. This should default the db_num to 0. - :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param redis_broker_config: A fixture to set the CONFIG object to a test configuration that we'll use here """ # Create and store a fake path for testing test_path = "/fake/path/to/broker" @@ -303,25 +303,25 @@ def test_get_redissock_connection_no_db(self, redis_config: "fixture"): # noqa: expected_vals = {"db_num": 0, "path": test_path} self.run_get_redissock_connection(expected_vals) - def test_get_redissock_connection_no_path(self, redis_config: "fixture"): # noqa: F821 + def test_get_redissock_connection_no_path(self, redis_broker_config: "fixture"): # noqa: F821 """ Test the `get_redissock_connection` function with a db num set but no broker path. This should raise an AttributeError since there will be no path value to read from in `CONFIG.broker`. - :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param redis_broker_config: A fixture to set the CONFIG object to a test configuration that we'll use here """ CONFIG.broker.db_num = "45" with pytest.raises(AttributeError): get_redissock_connection() - def test_get_redissock_connection_no_path_nor_db(self, redis_config: "fixture"): # noqa: F821 + def test_get_redissock_connection_no_path_nor_db(self, redis_broker_config: "fixture"): # noqa: F821 """ Test the `get_redissock_connection` function with neither a broker path nor a db num set. This should raise an AttributeError since there will be no path value to read from in `CONFIG.broker`. - :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param redis_broker_config: A fixture to set the CONFIG object to a test configuration that we'll use here """ with pytest.raises(AttributeError): get_redissock_connection() @@ -339,11 +339,11 @@ def run_get_redis_connection(self, expected_vals: Dict[str, Any], include_passwo actual = get_redis_connection(include_password=include_password, use_ssl=use_ssl) assert expected == actual - def test_get_redis_connection(self, redis_config: "fixture"): # noqa: F821 + def test_get_redis_connection(self, redis_broker_config: "fixture"): # noqa: F821 """ Test the `get_redis_connection` function with default functionality (including password and not using ssl). - :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param redis_broker_config: A fixture to set the CONFIG object to a test configuration that we'll use here """ expected_vals = { "urlbase": "redis", @@ -354,13 +354,13 @@ def test_get_redis_connection(self, redis_config: "fixture"): # noqa: F821 } self.run_get_redis_connection(expected_vals=expected_vals, include_password=True, use_ssl=False) - def test_get_redis_connection_no_port(self, redis_config: "fixture"): # noqa: F821 + def test_get_redis_connection_no_port(self, redis_broker_config: "fixture"): # noqa: F821 """ Test the `get_redis_connection` function with default functionality (including password and not using ssl). We'll run this after deleting the port setting from the CONFIG object. This should still run and give us port = 6379. - :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param redis_broker_config: A fixture to set the CONFIG object to a test configuration that we'll use here """ del CONFIG.broker.port expected_vals = { @@ -372,12 +372,12 @@ def test_get_redis_connection_no_port(self, redis_config: "fixture"): # noqa: F } self.run_get_redis_connection(expected_vals=expected_vals, include_password=True, use_ssl=False) - def test_get_redis_connection_with_db(self, redis_config: "fixture"): # noqa: F821 + def test_get_redis_connection_with_db(self, redis_broker_config: "fixture"): # noqa: F821 """ Test the `get_redis_connection` function with default functionality (including password and not using ssl). We'll run this after adding the db_num setting to the CONFIG object. - :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param redis_broker_config: A fixture to set the CONFIG object to a test configuration that we'll use here """ test_db_num = "45" CONFIG.broker.db_num = test_db_num @@ -390,25 +390,25 @@ def test_get_redis_connection_with_db(self, redis_config: "fixture"): # noqa: F } self.run_get_redis_connection(expected_vals=expected_vals, include_password=True, use_ssl=False) - def test_get_redis_connection_no_username(self, redis_config: "fixture"): # noqa: F821 + def test_get_redis_connection_no_username(self, redis_broker_config: "fixture"): # noqa: F821 """ Test the `get_redis_connection` function with default functionality (including password and not using ssl). We'll run this after deleting the username setting from the CONFIG object. This should still run and give us username = ''. - :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param redis_broker_config: A fixture to set the CONFIG object to a test configuration that we'll use here """ del CONFIG.broker.username expected_vals = {"urlbase": "redis", "spass": ":merlin-test-server@", "server": "127.0.0.1", "port": 6379, "db_num": 0} self.run_get_redis_connection(expected_vals=expected_vals, include_password=True, use_ssl=False) - def test_get_redis_connection_invalid_pass_file(self, redis_config: "fixture"): # noqa: F821 + def test_get_redis_connection_invalid_pass_file(self, redis_broker_config: "fixture"): # noqa: F821 """ Test the `get_redis_connection` function with default functionality (including password and not using ssl). We'll run this after changing the permissions of the password file so it can't be opened. This should still run and give us password = CONFIG.broker.password. - :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param redis_broker_config: A fixture to set the CONFIG object to a test configuration that we'll use here """ # Capture the initial permissions of the password file so we can reset them orig_file_permissions = os.stat(CONFIG.broker.password).st_mode @@ -433,21 +433,21 @@ def test_get_redis_connection_invalid_pass_file(self, redis_config: "fixture"): os.chmod(CONFIG.broker.password, orig_file_permissions) - def test_get_redis_connection_dont_include_password(self, redis_config: "fixture"): # noqa: F821 + def test_get_redis_connection_dont_include_password(self, redis_broker_config: "fixture"): # noqa: F821 """ Test the `get_redis_connection` function without including the password. This should place 6 *s where the password would normally be placed in spass. - :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param redis_broker_config: A fixture to set the CONFIG object to a test configuration that we'll use here """ expected_vals = {"urlbase": "redis", "spass": "default:******@", "server": "127.0.0.1", "port": 6379, "db_num": 0} self.run_get_redis_connection(expected_vals=expected_vals, include_password=False, use_ssl=False) - def test_get_redis_connection_use_ssl(self, redis_config: "fixture"): # noqa: F821 + def test_get_redis_connection_use_ssl(self, redis_broker_config: "fixture"): # noqa: F821 """ Test the `get_redis_connection` function with using ssl. This should change the urlbase to rediss (with two 's'). - :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param redis_broker_config: A fixture to set the CONFIG object to a test configuration that we'll use here """ expected_vals = { "urlbase": "rediss", @@ -458,24 +458,24 @@ def test_get_redis_connection_use_ssl(self, redis_config: "fixture"): # noqa: F } self.run_get_redis_connection(expected_vals=expected_vals, include_password=True, use_ssl=True) - def test_get_redis_connection_no_password(self, redis_config: "fixture"): # noqa: F821 + def test_get_redis_connection_no_password(self, redis_broker_config: "fixture"): # noqa: F821 """ Test the `get_redis_connection` function with default functionality (including password and not using ssl). We'll run this after deleting the password setting from the CONFIG object. This should still run and give us spass = ''. - :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param redis_broker_config: A fixture to set the CONFIG object to a test configuration that we'll use here """ del CONFIG.broker.password expected_vals = {"urlbase": "redis", "spass": "", "server": "127.0.0.1", "port": 6379, "db_num": 0} self.run_get_redis_connection(expected_vals=expected_vals, include_password=True, use_ssl=False) - def test_get_connection_string_redis(self, redis_config: "fixture"): # noqa: F821 + def test_get_connection_string_redis(self, redis_broker_config: "fixture"): # noqa: F821 """ Test the `get_connection_string` function with redis as the broker (this is what our CONFIG - is set to by default with the redis_config fixture). + is set to by default with the redis_broker_config fixture). - :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param redis_broker_config: A fixture to set the CONFIG object to a test configuration that we'll use here """ expected_vals = { "urlbase": "redis", @@ -488,11 +488,11 @@ def test_get_connection_string_redis(self, redis_config: "fixture"): # noqa: F8 actual = get_connection_string() assert expected == actual - def test_get_connection_string_rediss(self, redis_config: "fixture"): # noqa: F821 + def test_get_connection_string_rediss(self, redis_broker_config: "fixture"): # noqa: F821 """ Test the `get_connection_string` function with rediss (with two 's') as the broker. - :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param redis_broker_config: A fixture to set the CONFIG object to a test configuration that we'll use here """ CONFIG.broker.name = "rediss" expected_vals = { @@ -506,11 +506,11 @@ def test_get_connection_string_rediss(self, redis_config: "fixture"): # noqa: F actual = get_connection_string() assert expected == actual - def test_get_connection_string_redis_socket(self, redis_config: "fixture"): # noqa: F821 + def test_get_connection_string_redis_socket(self, redis_broker_config: "fixture"): # noqa: F821 """ Test the `get_connection_string` function with redis+socket as the broker. - :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param redis_broker_config: A fixture to set the CONFIG object to a test configuration that we'll use here """ # Change our broker CONFIG.broker.name = "redis+socket" @@ -527,21 +527,21 @@ def test_get_connection_string_redis_socket(self, redis_config: "fixture"): # n actual = get_connection_string() assert actual == expected - def test_get_ssl_config_redis(self, redis_config: "fixture"): # noqa: F821 + def test_get_ssl_config_redis(self, redis_broker_config: "fixture"): # noqa: F821 """ Test the `get_ssl_config` function with redis as the broker (this is the default in our tests). This should return False. - :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param redis_broker_config: A fixture to set the CONFIG object to a test configuration that we'll use here """ assert not get_ssl_config() - def test_get_ssl_config_rediss(self, redis_config: "fixture"): # noqa: F821 + def test_get_ssl_config_rediss(self, redis_broker_config: "fixture"): # noqa: F821 """ Test the `get_ssl_config` function with rediss (with two 's') as the broker. This should return a dict of cert reqs with ssl.CERT_NONE as the value. - :param redis_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param redis_broker_config: A fixture to set the CONFIG object to a test configuration that we'll use here """ CONFIG.broker.name = "rediss" expected = {"ssl_cert_reqs": CERT_NONE} From 525e4032baf54f1cca811d0ffb5c96af688371d3 Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Mon, 18 Dec 2023 12:55:05 -0800 Subject: [PATCH 17/44] add tests for results_backend.py --- merlin/config/results_backend.py | 6 + tests/conftest.py | 21 + tests/unit/config/test_results_backend.py | 570 ++++++++++++++++++++++ 3 files changed, 597 insertions(+) create mode 100644 tests/unit/config/test_results_backend.py diff --git a/merlin/config/results_backend.py b/merlin/config/results_backend.py index b88655399..6775cb854 100644 --- a/merlin/config/results_backend.py +++ b/merlin/config/results_backend.py @@ -236,6 +236,12 @@ def get_mysql(certs_path=None, mysql_certs=None, include_password=True): mysql_config["password"] = "******" mysql_config["server"] = server + # Ensure the ssl_key, ssl_ca, and ssl_cert keys are all set + if mysql_certs == MYSQL_CONFIG_FILENAMES: + for key, cert_file in mysql_certs.items(): + if key not in mysql_config: + mysql_config[key] = os.path.join(certs_path, cert_file) + return MYSQL_CONNECTION_STRING.format(**mysql_config) diff --git a/tests/conftest.py b/tests/conftest.py index b681ade35..88eeaddb0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -47,6 +47,11 @@ SERVER_PASS = "merlin-test-server" +CERT_FILES = { + "ssl_cert": "test-rabbit-client-cert.pem", + "ssl_ca": "test-mysql-ca-cert.pem", + "ssl_key": "test-rabbit-client-key.pem", +} ####################################### @@ -92,6 +97,20 @@ def create_encryption_file(key_filepath: str, encryption_key: bytes, app_yaml_fi yaml.dump(app_yaml, app_yaml_file) +def create_cert_files(cert_filepath: str, cert_files: Dict[str, str]): + """ + Check if cert files already exist and if they don't then create them. + + :param cert_filepath: The path to the cert files + :param cert_files: A dict of certification files to create + """ + for cert_file in cert_files.values(): + full_cert_filepath = f"{cert_filepath}/{cert_file}" + if not os.path.exists(full_cert_filepath): + with open(full_cert_filepath, "w"): + pass + + def set_config(broker: Dict[str, str], results_backend: Dict[str, str]): """ Given configuration options for the broker and results_backend, update @@ -383,6 +402,8 @@ def mysql_results_backend_config( pass_file = f"{merlin_server_dir}/mysql.pass" create_pass_file(pass_file) + create_cert_files(merlin_server_dir, CERT_FILES) + CONFIG.results_backend.password = pass_file CONFIG.results_backend.name = "mysql" CONFIG.results_backend.dbname = "test_mysql_db" diff --git a/tests/unit/config/test_results_backend.py b/tests/unit/config/test_results_backend.py new file mode 100644 index 000000000..3531a83a2 --- /dev/null +++ b/tests/unit/config/test_results_backend.py @@ -0,0 +1,570 @@ +""" +Tests for the `results_backend.py` file. +""" +import os +import pytest +from ssl import CERT_NONE +from typing import Any, Dict + +from merlin.config.configfile import CONFIG +from merlin.config.results_backend import ( + MYSQL_CONFIG_FILENAMES, + MYSQL_CONNECTION_STRING, + SQLITE_CONNECTION_STRING, + get_backend_password, + get_connection_string, + get_mysql, + get_mysql_config, + get_redis, + get_ssl_config +) +from tests.conftest import CERT_FILES, SERVER_PASS, create_cert_files, create_pass_file + +RESULTS_BACKEND_DIR = "{temp_output_dir}/test_results_backend" + + +def test_get_backend_password_pass_file_in_merlin(): + """ + Test the `get_backend_password` function with the password file in the ~/.merlin/ + directory. We'll create a dummy file in this directory and delete it once the test + is done. + """ + + # Check if the .merlin directory exists and create it if it doesn't + remove_merlin_dir_after_test = False + path_to_merlin_dir = os.path.expanduser("~/.merlin") + if not os.path.exists(path_to_merlin_dir): + remove_merlin_dir_after_test = True + os.mkdir(path_to_merlin_dir) + + # Create the test password file + pass_filename = "test.pass" + full_pass_filepath = f"{path_to_merlin_dir}/{pass_filename}" + create_pass_file(full_pass_filepath) + + try: + # Run the test + assert get_backend_password(pass_filename) == SERVER_PASS + # Cleanup + os.remove(full_pass_filepath) + if remove_merlin_dir_after_test: + os.rmdir(path_to_merlin_dir) + except AssertionError as exc: + # If the test fails, make sure we clean up the files/dirs created + os.remove(full_pass_filepath) + if remove_merlin_dir_after_test: + os.rmdir(path_to_merlin_dir) + raise AssertionError from exc + + +def test_get_backend_password_pass_file_not_in_merlin(temp_output_dir: str): + """ + Test the `get_backend_password` function with the password file not in the ~/.merlin/ + directory. By using the `temp_output_dir` fixture, our cwd will be the temporary directory. + We'll create a password file in the this directory for this test and have `get_backend_password` + read from that. + + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + """ + pass_file = "test.pass" + create_pass_file(pass_file) + + assert get_backend_password(pass_file) == SERVER_PASS + + +def test_get_backend_password_directly_pass_password(): + """ + Test the `get_backend_password` function by passing the password directly to this + function instead of a password file. + """ + assert get_backend_password(SERVER_PASS) == SERVER_PASS + + +def test_get_backend_password_using_certs_path(temp_output_dir: str): + """ + Test the `get_backend_password` function with certs_path set to our temporary testing path. + We'll create a password file in the temporary directory for this test and have `get_backend_password` + read from that. + + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + """ + pass_filename = "test_certs.pass" + test_dir = RESULTS_BACKEND_DIR.format(temp_output_dir=temp_output_dir) + if not os.path.exists(test_dir): + os.mkdir(test_dir) + full_pass_filepath = f"{test_dir}/{pass_filename}" + create_pass_file(full_pass_filepath) + + assert get_backend_password(pass_filename, certs_path=test_dir) == SERVER_PASS + + +def test_get_ssl_config_no_results_backend(config: "fixture"): # noqa: F821 + """ + Test the `get_ssl_config` function with no results_backend set. This should return False. + NOTE: we're using the config fixture here to make sure values are reset after this test finishes. + We won't actually use anything from the config fixture. + + :param config: A fixture to set up the CONFIG object for us + """ + del CONFIG.results_backend.name + assert get_ssl_config() is False + + +def test_get_connection_string_no_results_backend(config: "fixture"): # noqa: F821 + """ + Test the `get_connection_string` function with no results_backend set. + This should raise a ValueError. + NOTE: we're using the config fixture here to make sure values are reset after this test finishes. + We won't actually use anything from the config fixture. + + :param config: A fixture to set up the CONFIG object for us + """ + del CONFIG.results_backend.name + with pytest.raises(ValueError) as excinfo: + get_connection_string() + + assert "'' is not a supported results backend" in str(excinfo.value) + + +class TestRedisResultsBackend: + """ + This class will house all tests necessary for our results_backend module when using a + redis results_backend. + """ + + def run_get_redis( + self, + expected_vals: Dict[str, Any], + certs_path: str = None, + include_password: bool = True, + ssl: bool = False, + ): + """ + Helper method for running tests for the `get_redis` function. + + :param expected_vals: A dict of expected values for this test. Format: + {"urlbase": "redis", + "spass": "", + "server": "127.0.0.1", + "port": 6379, + "db_num": 0} + :param certs_path: A string denoting the path to the certification files + :param include_password: If True, include the password in the output. Otherwise don't. + :param ssl: If True, use ssl. Otherwise, don't. + """ + expected = "{urlbase}://{spass}{server}:{port}/{db_num}".format(**expected_vals) + actual = get_redis(certs_path=certs_path, include_password=include_password, ssl=ssl) + assert actual == expected + + def test_get_redis(self, redis_results_backend_config: "fixture"): # noqa: F821 + """ + Test the `get_redis` function with default functionality. + + :param redis_results_backend_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + expected_vals = { + "urlbase": "redis", + "spass": f"default:{SERVER_PASS}@", + "server": "127.0.0.1", + "port": 6379, + "db_num": 0, + } + self.run_get_redis(expected_vals=expected_vals, certs_path=None, include_password=True, ssl=False) + + def test_get_redis_dont_include_password(self, redis_results_backend_config: "fixture"): # noqa: F821 + """ + Test the `get_redis` function with the password hidden. This should * out the password. + + :param redis_results_backend_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + expected_vals = { + "urlbase": "redis", + "spass": f"default:******@", + "server": "127.0.0.1", + "port": 6379, + "db_num": 0, + } + self.run_get_redis(expected_vals=expected_vals, certs_path=None, include_password=False, ssl=False) + + def test_get_redis_using_ssl(self, redis_results_backend_config: "fixture"): # noqa: F821 + """ + Test the `get_redis` function with ssl enabled. + + :param redis_results_backend_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + expected_vals = { + "urlbase": "rediss", + "spass": f"default:{SERVER_PASS}@", + "server": "127.0.0.1", + "port": 6379, + "db_num": 0, + } + self.run_get_redis(expected_vals=expected_vals, certs_path=None, include_password=True, ssl=True) + + def test_get_redis_no_port(self, redis_results_backend_config: "fixture"): # noqa: F821 + """ + Test the `get_redis` function with no port in our CONFIG object. This should default to port=6379. + + :param redis_results_backend_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + del CONFIG.results_backend.port + expected_vals = { + "urlbase": "redis", + "spass": f"default:{SERVER_PASS}@", + "server": "127.0.0.1", + "port": 6379, + "db_num": 0, + } + self.run_get_redis(expected_vals=expected_vals, certs_path=None, include_password=True, ssl=False) + + def test_get_redis_no_db_num(self, redis_results_backend_config: "fixture"): # noqa: F821 + """ + Test the `get_redis` function with no db_num in our CONFIG object. This should default to db_num=0. + + :param redis_results_backend_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + del CONFIG.results_backend.db_num + expected_vals = { + "urlbase": "redis", + "spass": f"default:{SERVER_PASS}@", + "server": "127.0.0.1", + "port": 6379, + "db_num": 0, + } + self.run_get_redis(expected_vals=expected_vals, certs_path=None, include_password=True, ssl=False) + + def test_get_redis_no_username(self, redis_results_backend_config: "fixture"): # noqa: F821 + """ + Test the `get_redis` function with no username in our CONFIG object. This should default to username=''. + + :param redis_results_backend_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + del CONFIG.results_backend.username + expected_vals = { + "urlbase": "redis", + "spass": f":{SERVER_PASS}@", + "server": "127.0.0.1", + "port": 6379, + "db_num": 0, + } + self.run_get_redis(expected_vals=expected_vals, certs_path=None, include_password=True, ssl=False) + + def test_get_redis_no_password_file(self, redis_results_backend_config: "fixture"): # noqa: F821 + """ + Test the `get_redis` function with no password filepath in our CONFIG object. This should default to spass=''. + + :param redis_results_backend_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + del CONFIG.results_backend.password + expected_vals = { + "urlbase": "redis", + "spass": "", + "server": "127.0.0.1", + "port": 6379, + "db_num": 0, + } + self.run_get_redis(expected_vals=expected_vals, certs_path=None, include_password=True, ssl=False) + + def test_get_redis_invalid_pass_file(self, redis_results_backend_config: "fixture"): # noqa: F821 + """ + Test the `get_redis` function. We'll run this after changing the permissions of the password file so it + can't be opened. This should still run and give us password=CONFIG.results_backend.password. + + :param redis_results_backend_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + + # Capture the initial permissions of the password file so we can reset them + orig_file_permissions = os.stat(CONFIG.results_backend.password).st_mode + + # Change the permissions of the password file so it can't be read + os.chmod(CONFIG.results_backend.password, 0o222) + + try: + # Run the test + expected_vals = { + "urlbase": "redis", + "spass": f"default:{CONFIG.results_backend.password}@", + "server": "127.0.0.1", + "port": 6379, + "db_num": 0, + } + self.run_get_redis(expected_vals=expected_vals, certs_path=None, include_password=True, ssl=False) + os.chmod(CONFIG.results_backend.password, orig_file_permissions) + except AssertionError as exc: + # If this test failed, make sure to reset the permissions in case other tests need to read this file + os.chmod(CONFIG.results_backend.password, orig_file_permissions) + raise AssertionError from exc + + def test_get_ssl_config_redis(self, redis_results_backend_config: "fixture"): # noqa: F821 + """ + Test the `get_ssl_config` function with redis as the results_backend. This should return False since + ssl requires using rediss (with two 's'). + + :param redis_results_backend_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + assert get_ssl_config() is False + + def test_get_ssl_config_rediss(self, redis_results_backend_config: "fixture"): # noqa: F821 + """ + Test the `get_ssl_config` function with rediss as the results_backend. + This should return a dict of cert reqs with ssl.CERT_NONE as the value. + + :param redis_results_backend_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + CONFIG.results_backend.name = "rediss" + assert get_ssl_config() == {"ssl_cert_reqs": CERT_NONE} + + def test_get_ssl_config_rediss_no_cert_reqs(self, redis_results_backend_config: "fixture"): # noqa: F821 + """ + Test the `get_ssl_config` function with rediss as the results_backend and no cert_reqs set. + This should return True. + + :param redis_results_backend_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + del CONFIG.results_backend.cert_reqs + CONFIG.results_backend.name = "rediss" + assert get_ssl_config() is True + + def test_get_connection_string_redis(self, redis_results_backend_config: "fixture"): # noqa: F821 + """ + Test the `get_connection_string` function with redis as the results_backend. + + :param redis_results_backend_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + expected_vals = { + "urlbase": "redis", + "spass": f"default:{SERVER_PASS}@", + "server": "127.0.0.1", + "port": 6379, + "db_num": 0, + } + expected = "{urlbase}://{spass}{server}:{port}/{db_num}".format(**expected_vals) + actual = get_connection_string() + assert actual == expected + + def test_get_connection_string_rediss(self, redis_results_backend_config: "fixture"): # noqa: F821 + """ + Test the `get_connection_string` function with rediss as the results_backend. + + :param redis_results_backend_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + CONFIG.results_backend.name = "rediss" + expected_vals = { + "urlbase": "rediss", + "spass": f"default:{SERVER_PASS}@", + "server": "127.0.0.1", + "port": 6379, + "db_num": 0, + } + expected = "{urlbase}://{spass}{server}:{port}/{db_num}".format(**expected_vals) + actual = get_connection_string() + assert actual == expected + + +class TestMySQLResultsBackend: + """ + This class will house all tests necessary for our results_backend module when using a + MySQL results_backend. + NOTE: You'll notice a lot of these tests are setting CONFIG.results_backend.name to be + "invalid". This is so that we can get by the first if statement in the `get_mysql_config` + function. + """ + + def test_get_mysql_config_certs_set(self, mysql_results_backend_config: "fixture", merlin_server_dir: str): # noqa: F821 + """ + Test the `get_mysql_config` function with the certs dict getting set and returned. + + :param mysql_results_backend_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param merlin_server_dir: The directory that has the test certification files + """ + CONFIG.results_backend.name = "invalid" + expected = {} + for key, cert_file in CERT_FILES.items(): + expected[key] = f"{merlin_server_dir}/{cert_file}" + actual = get_mysql_config(merlin_server_dir, CERT_FILES) + assert actual == expected + + def test_get_mysql_config_ssl_exists(self, mysql_results_backend_config: "fixture"): # noqa: F821 + """ + Test the `get_mysql_config` function with mysql_ssl being found. This should just return the ssl value that's found. + + :param mysql_results_backend_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + assert get_mysql_config(None, None) == {"cert_reqs": CERT_NONE} + + def test_get_mysql_config_no_mysql_certs(self, mysql_results_backend_config: "fixture", merlin_server_dir: str): # noqa: F821 + """ + Test the `get_mysql_config` function with no mysql certs dict. + + :param mysql_results_backend_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param merlin_server_dir: The directory that has the test certification files + """ + CONFIG.results_backend.name = "invalid" + assert get_mysql_config(merlin_server_dir, {}) == {} + + def test_get_mysql_config_invalid_certs_path(self, mysql_results_backend_config: "fixture"): # noqa: F821 + """ + Test the `get_mysql_config` function with an invalid certs path. This should return False. + + :param mysql_results_backend_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + CONFIG.results_backend.name = "invalid" + assert get_mysql_config("invalid/path", CERT_FILES) is False + + def run_get_mysql(self, expected_vals: Dict[str, Any], certs_path: str, mysql_certs: Dict[str, str], include_password: bool): + """ + Helper method for running tests for the `get_mysql` function. + + :param expected_vals: A dict of expected values for this test. Format: + {"cert_reqs": cert reqs dict, + "user": "default", + "password": "", + "server": "127.0.0.1", + "ssl_cert": "test-rabbit-client-cert.pem", + "ssl_ca": "test-mysql-ca-cert.pem", + "ssl_key": "test-rabbit-client-key.pem"} + :param certs_path: A string denoting the path to the certification files + :param mysql_certs: A dict of cert files + :param include_password: If True, include the password in the output. Otherwise don't. + """ + expected = MYSQL_CONNECTION_STRING.format(**expected_vals) + actual = get_mysql(certs_path=certs_path, mysql_certs=mysql_certs, include_password=include_password) + assert actual == expected + + def test_get_mysql(self, mysql_results_backend_config: "fixture", merlin_server_dir: str): # noqa: F821 + """ + Test the `get_mysql` function with default behavior. + + :param mysql_results_backend_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param merlin_server_dir: The directory that has the test certification files + """ + CONFIG.results_backend.name = "invalid" + expected_vals = { + "cert_reqs": CERT_NONE, + "user": "default", + "password": SERVER_PASS, + "server": "127.0.0.1", + } + for key, cert_file in CERT_FILES.items(): + expected_vals[key] = f"{merlin_server_dir}/{cert_file}" + self.run_get_mysql(expected_vals=expected_vals, certs_path=merlin_server_dir, mysql_certs=CERT_FILES, include_password=True) + + def test_get_mysql_dont_include_password(self, mysql_results_backend_config: "fixture", merlin_server_dir: str): # noqa: F821 + """ + Test the `get_mysql` function but set include_password to False. This should * out the password. + + :param mysql_results_backend_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param merlin_server_dir: The directory that has the test certification files + """ + CONFIG.results_backend.name = "invalid" + expected_vals = { + "cert_reqs": CERT_NONE, + "user": "default", + "password": "******", + "server": "127.0.0.1", + } + for key, cert_file in CERT_FILES.items(): + expected_vals[key] = f"{merlin_server_dir}/{cert_file}" + self.run_get_mysql(expected_vals=expected_vals, certs_path=merlin_server_dir, mysql_certs=CERT_FILES, include_password=False) + + def test_get_mysql_no_mysql_certs(self, mysql_results_backend_config: "fixture", merlin_server_dir: str): # noqa: F821 + """ + Test the `get_mysql` function with no mysql_certs passed in. This should use default config filenames so we'll + have to create these default files. + + :param mysql_results_backend_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param merlin_server_dir: The directory that has the test certification files + """ + CONFIG.results_backend.name = "invalid" + expected_vals = { + "cert_reqs": CERT_NONE, + "user": "default", + "password": SERVER_PASS, + "server": "127.0.0.1", + } + + create_cert_files(merlin_server_dir, MYSQL_CONFIG_FILENAMES) + + for key, cert_file in MYSQL_CONFIG_FILENAMES.items(): + # Password file is already is already set in expected_vals dict + if key == "password": + continue + expected_vals[key] = f"{merlin_server_dir}/{cert_file}" + + self.run_get_mysql(expected_vals=expected_vals, certs_path=merlin_server_dir, mysql_certs=None, include_password=True) + + def test_get_mysql_no_server(self, mysql_results_backend_config: "fixture"): # noqa: F821 + """ + Test the `get_mysql` function with no server set. This should raise a TypeError. + + :param mysql_results_backend_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + CONFIG.results_backend.server = False + with pytest.raises(TypeError) as excinfo: + get_mysql() + assert f"Results backend: server False does not have a configuration" in str(excinfo.value) + + def test_get_mysql_invalid_certs_path(self, mysql_results_backend_config: "fixture"): # noqa: F821 + """ + Test the `get_mysql` function with an invalid certs_path. This should raise a TypeError. + + :param mysql_results_backend_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + CONFIG.results_backend.name = "invalid" + with pytest.raises(TypeError) as excinfo: + get_mysql(certs_path="invalid_path", mysql_certs=CERT_FILES) + err_msg = f"""The connection information for MySQL could not be set, cannot find:\n + {CERT_FILES}\ncheck the celery/certs path or set the ssl information in the app.yaml file.""" + assert err_msg in str(excinfo.value) + + def test_get_ssl_config_mysql(self, mysql_results_backend_config: "fixture"): # noqa: F821 + """ + Test the `get_ssl_config` function with mysql as the results_backend. + This should return a dict of cert reqs with ssl.CERT_NONE as the value. + + :param mysql_results_backend_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + assert get_ssl_config() == {"cert_reqs": CERT_NONE} + + def test_get_ssl_config_mysql_celery_check(self, mysql_results_backend_config: "fixture"): # noqa: F821 + """ + Test the `get_ssl_config` function with mysql as the results_backend and celery_check set. + This should return False. + + :param mysql_results_backend_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + assert get_ssl_config(celery_check=True) is False + + def test_get_connection_string_mysql(self, mysql_results_backend_config: "fixture", merlin_server_dir: str): # noqa: F821 + """ + Test the `get_connection_string` function with MySQL as the results_backend. + + :param mysql_results_backend_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param merlin_server_dir: The directory that has the test certification files + """ + CONFIG.celery.certs = merlin_server_dir + + create_cert_files(merlin_server_dir, MYSQL_CONFIG_FILENAMES) + + expected_vals = { + "cert_reqs": CERT_NONE, + "user": "default", + "password": SERVER_PASS, + "server": "127.0.0.1", + } + for key, cert_file in MYSQL_CONFIG_FILENAMES.items(): + # Password file is already is already set in expected_vals dict + if key == "password": + continue + expected_vals[key] = f"{merlin_server_dir}/{cert_file}" + + assert MYSQL_CONNECTION_STRING.format(**expected_vals) == get_connection_string(include_password=True) + + def test_get_connection_string_sqlite(self, mysql_results_backend_config: "fixture"): # noqa: F821 + """ + Test the `get_connection_string` function with sqlite as the results_backend. + + :param mysql_results_backend_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + CONFIG.results_backend.name = "sqlite" + assert get_connection_string() == SQLITE_CONNECTION_STRING From 47a0b4e5f834d120cfe54ca095264f2ac622240b Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Mon, 18 Dec 2023 13:09:17 -0800 Subject: [PATCH 18/44] fix lint issues for most recent changes --- tests/conftest.py | 18 ++++++++---- tests/unit/common/test_encryption.py | 4 ++- tests/unit/config/test_results_backend.py | 36 +++++++++++++++-------- 3 files changed, 40 insertions(+), 18 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 88eeaddb0..56ba762ff 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -101,7 +101,7 @@ def create_cert_files(cert_filepath: str, cert_files: Dict[str, str]): """ Check if cert files already exist and if they don't then create them. - :param cert_filepath: The path to the cert files + :param cert_filepath: The path to the cert files :param cert_files: A dict of certification files to create """ for cert_file in cert_files.values(): @@ -310,9 +310,13 @@ def config(merlin_server_dir: str, test_encryption_key: bytes): # pylint: disab CONFIG.broker.cert_reqs = "none" # Set the results_backend configuration for testing - CONFIG.results_backend.password = None # This will be updated in `redis_results_backend_config` or `mysql_results_backend_config` + CONFIG.results_backend.password = ( + None # This will be updated in `redis_results_backend_config` or `mysql_results_backend_config` + ) CONFIG.results_backend.port = None # This will be updated in `redis_results_backend_config` - CONFIG.results_backend.name = None # This will be updated in `redis_results_backend_config` or `mysql_results_backend_config` + CONFIG.results_backend.name = ( + None # This will be updated in `redis_results_backend_config` or `mysql_results_backend_config` + ) CONFIG.results_backend.dbname = None # This will be updated in `mysql_results_backend_config` CONFIG.results_backend.server = "127.0.0.1" CONFIG.results_backend.username = "default" @@ -330,7 +334,9 @@ def config(merlin_server_dir: str, test_encryption_key: bytes): # pylint: disab @pytest.fixture(scope="function") -def redis_broker_config(merlin_server_dir: str, config: "fixture"): # noqa: F821 pylint: disable=redefined-outer-name,unused-argument +def redis_broker_config( + merlin_server_dir: str, config: "fixture" # noqa: F821 pylint: disable=redefined-outer-name,unused-argument +): """ This fixture is intended to be used for testing any functionality in the codebase that uses the CONFIG object with a Redis broker and results_backend. @@ -349,7 +355,9 @@ def redis_broker_config(merlin_server_dir: str, config: "fixture"): # noqa: F82 @pytest.fixture(scope="function") -def redis_results_backend_config(merlin_server_dir: str, config: "fixture"): # noqa: F821 pylint: disable=redefined-outer-name,unused-argument +def redis_results_backend_config( + merlin_server_dir: str, config: "fixture" # noqa: F821 pylint: disable=redefined-outer-name,unused-argument +): """ This fixture is intended to be used for testing any functionality in the codebase that uses the CONFIG object with a Redis results_backend. diff --git a/tests/unit/common/test_encryption.py b/tests/unit/common/test_encryption.py index d0069f09e..d797f68c0 100644 --- a/tests/unit/common/test_encryption.py +++ b/tests/unit/common/test_encryption.py @@ -87,7 +87,9 @@ def test_gen_key(self, temp_output_dir: str): key_gen_contents = key_gen_file.read() assert key_gen_contents != "" - def test_get_key(self, merlin_server_dir: str, test_encryption_key: bytes, redis_results_backend_config: "fixture"): # noqa: F821 + def test_get_key( + self, merlin_server_dir: str, test_encryption_key: bytes, redis_results_backend_config: "fixture" # noqa: F821 + ): """ Test the `_get_key` function. diff --git a/tests/unit/config/test_results_backend.py b/tests/unit/config/test_results_backend.py index 3531a83a2..80ce05657 100644 --- a/tests/unit/config/test_results_backend.py +++ b/tests/unit/config/test_results_backend.py @@ -2,10 +2,11 @@ Tests for the `results_backend.py` file. """ import os -import pytest from ssl import CERT_NONE from typing import Any, Dict +import pytest + from merlin.config.configfile import CONFIG from merlin.config.results_backend import ( MYSQL_CONFIG_FILENAMES, @@ -16,10 +17,11 @@ get_mysql, get_mysql_config, get_redis, - get_ssl_config + get_ssl_config, ) from tests.conftest import CERT_FILES, SERVER_PASS, create_cert_files, create_pass_file + RESULTS_BACKEND_DIR = "{temp_output_dir}/test_results_backend" @@ -36,7 +38,7 @@ def test_get_backend_password_pass_file_in_merlin(): if not os.path.exists(path_to_merlin_dir): remove_merlin_dir_after_test = True os.mkdir(path_to_merlin_dir) - + # Create the test password file pass_filename = "test.pass" full_pass_filepath = f"{path_to_merlin_dir}/{pass_filename}" @@ -179,7 +181,7 @@ def test_get_redis_dont_include_password(self, redis_results_backend_config: "fi """ expected_vals = { "urlbase": "redis", - "spass": f"default:******@", + "spass": "default:******@", "server": "127.0.0.1", "port": 6379, "db_num": 0, @@ -372,7 +374,7 @@ class TestMySQLResultsBackend: def test_get_mysql_config_certs_set(self, mysql_results_backend_config: "fixture", merlin_server_dir: str): # noqa: F821 """ - Test the `get_mysql_config` function with the certs dict getting set and returned. + Test the `get_mysql_config` function with the certs dict getting set and returned. :param mysql_results_backend_config: A fixture to set the CONFIG object to a test configuration that we'll use here :param merlin_server_dir: The directory that has the test certification files @@ -392,9 +394,11 @@ def test_get_mysql_config_ssl_exists(self, mysql_results_backend_config: "fixtur """ assert get_mysql_config(None, None) == {"cert_reqs": CERT_NONE} - def test_get_mysql_config_no_mysql_certs(self, mysql_results_backend_config: "fixture", merlin_server_dir: str): # noqa: F821 + def test_get_mysql_config_no_mysql_certs( + self, mysql_results_backend_config: "fixture", merlin_server_dir: str # noqa: F821 + ): """ - Test the `get_mysql_config` function with no mysql certs dict. + Test the `get_mysql_config` function with no mysql certs dict. :param mysql_results_backend_config: A fixture to set the CONFIG object to a test configuration that we'll use here :param merlin_server_dir: The directory that has the test certification files @@ -411,7 +415,9 @@ def test_get_mysql_config_invalid_certs_path(self, mysql_results_backend_config: CONFIG.results_backend.name = "invalid" assert get_mysql_config("invalid/path", CERT_FILES) is False - def run_get_mysql(self, expected_vals: Dict[str, Any], certs_path: str, mysql_certs: Dict[str, str], include_password: bool): + def run_get_mysql( + self, expected_vals: Dict[str, Any], certs_path: str, mysql_certs: Dict[str, str], include_password: bool + ): """ Helper method for running tests for the `get_mysql` function. @@ -447,9 +453,13 @@ def test_get_mysql(self, mysql_results_backend_config: "fixture", merlin_server_ } for key, cert_file in CERT_FILES.items(): expected_vals[key] = f"{merlin_server_dir}/{cert_file}" - self.run_get_mysql(expected_vals=expected_vals, certs_path=merlin_server_dir, mysql_certs=CERT_FILES, include_password=True) + self.run_get_mysql( + expected_vals=expected_vals, certs_path=merlin_server_dir, mysql_certs=CERT_FILES, include_password=True + ) - def test_get_mysql_dont_include_password(self, mysql_results_backend_config: "fixture", merlin_server_dir: str): # noqa: F821 + def test_get_mysql_dont_include_password( + self, mysql_results_backend_config: "fixture", merlin_server_dir: str # noqa: F821 + ): """ Test the `get_mysql` function but set include_password to False. This should * out the password. @@ -465,7 +475,9 @@ def test_get_mysql_dont_include_password(self, mysql_results_backend_config: "fi } for key, cert_file in CERT_FILES.items(): expected_vals[key] = f"{merlin_server_dir}/{cert_file}" - self.run_get_mysql(expected_vals=expected_vals, certs_path=merlin_server_dir, mysql_certs=CERT_FILES, include_password=False) + self.run_get_mysql( + expected_vals=expected_vals, certs_path=merlin_server_dir, mysql_certs=CERT_FILES, include_password=False + ) def test_get_mysql_no_mysql_certs(self, mysql_results_backend_config: "fixture", merlin_server_dir: str): # noqa: F821 """ @@ -502,7 +514,7 @@ def test_get_mysql_no_server(self, mysql_results_backend_config: "fixture"): # CONFIG.results_backend.server = False with pytest.raises(TypeError) as excinfo: get_mysql() - assert f"Results backend: server False does not have a configuration" in str(excinfo.value) + assert "Results backend: server False does not have a configuration" in str(excinfo.value) def test_get_mysql_invalid_certs_path(self, mysql_results_backend_config: "fixture"): # noqa: F821 """ From 91a3f2f2cd3babaa42003c396293b03322a25d05 Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Mon, 18 Dec 2023 16:31:19 -0800 Subject: [PATCH 19/44] fix filename issue in setup.cfg and move celeryadapter tests to integration suite --- setup.cfg | 2 +- tests/{unit/study => integration}/test_celeryadapter.py | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename tests/{unit/study => integration}/test_celeryadapter.py (100%) diff --git a/setup.cfg b/setup.cfg index 0eaa116ea..6b4278799 100644 --- a/setup.cfg +++ b/setup.cfg @@ -29,5 +29,5 @@ ignore_missing_imports=true [coverage:run] omit = - merlin/ascii.py + merlin/ascii_art.py merlin/config/celeryconfig.py diff --git a/tests/unit/study/test_celeryadapter.py b/tests/integration/test_celeryadapter.py similarity index 100% rename from tests/unit/study/test_celeryadapter.py rename to tests/integration/test_celeryadapter.py From 78b019b67a8ad426176168dab1a7479e71bb4090 Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Mon, 18 Dec 2023 16:32:53 -0800 Subject: [PATCH 20/44] add ssl filepaths to mysql config object --- tests/conftest.py | 15 +++++++++++++++ tests/unit/config/test_results_backend.py | 17 +++++++++++++---- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 56ba762ff..e9d6027fd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -31,6 +31,7 @@ This module contains pytest fixtures to be used throughout the entire test suite. """ import os +import shutil from copy import copy from time import sleep from typing import Dict @@ -111,6 +112,17 @@ def create_cert_files(cert_filepath: str, cert_files: Dict[str, str]): pass +def create_app_yaml(app_yaml_filepath: str): + """ + Create a dummy app.yaml file at `app_yaml_filepath`. + + :param app_yaml_filepath: The location to create an app.yaml file at + """ + full_app_yaml_filepath = f"{app_yaml_filepath}/app.yaml" + if not os.path.exists(full_app_yaml_filepath): + shutil.copy(f"{os.path.dirname(__file__)}/dummy_app.yaml", full_app_yaml_filepath) + + def set_config(broker: Dict[str, str], results_backend: Dict[str, str]): """ Given configuration options for the broker and results_backend, update @@ -415,5 +427,8 @@ def mysql_results_backend_config( CONFIG.results_backend.password = pass_file CONFIG.results_backend.name = "mysql" CONFIG.results_backend.dbname = "test_mysql_db" + CONFIG.results_backend.keyfile = CERT_FILES["ssl_key"] + CONFIG.results_backend.certfile = CERT_FILES["ssl_cert"] + CONFIG.results_backend.ca_certs = CERT_FILES["ssl_ca"] yield diff --git a/tests/unit/config/test_results_backend.py b/tests/unit/config/test_results_backend.py index 80ce05657..59e53a5ae 100644 --- a/tests/unit/config/test_results_backend.py +++ b/tests/unit/config/test_results_backend.py @@ -386,13 +386,16 @@ def test_get_mysql_config_certs_set(self, mysql_results_backend_config: "fixture actual = get_mysql_config(merlin_server_dir, CERT_FILES) assert actual == expected - def test_get_mysql_config_ssl_exists(self, mysql_results_backend_config: "fixture"): # noqa: F821 + def test_get_mysql_config_ssl_exists(self, mysql_results_backend_config: "fixture", temp_output_dir: str): # noqa: F821 """ Test the `get_mysql_config` function with mysql_ssl being found. This should just return the ssl value that's found. :param mysql_results_backend_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run """ - assert get_mysql_config(None, None) == {"cert_reqs": CERT_NONE} + expected = {key: f"{temp_output_dir}/{cert_file}" for key, cert_file in CERT_FILES.items()} + expected["cert_reqs"] = CERT_NONE + assert get_mysql_config(None, None) == expected def test_get_mysql_config_no_mysql_certs( self, mysql_results_backend_config: "fixture", merlin_server_dir: str # noqa: F821 @@ -529,14 +532,17 @@ def test_get_mysql_invalid_certs_path(self, mysql_results_backend_config: "fixtu {CERT_FILES}\ncheck the celery/certs path or set the ssl information in the app.yaml file.""" assert err_msg in str(excinfo.value) - def test_get_ssl_config_mysql(self, mysql_results_backend_config: "fixture"): # noqa: F821 + def test_get_ssl_config_mysql(self, mysql_results_backend_config: "fixture", temp_output_dir: str): # noqa: F821 """ Test the `get_ssl_config` function with mysql as the results_backend. This should return a dict of cert reqs with ssl.CERT_NONE as the value. :param mysql_results_backend_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run """ - assert get_ssl_config() == {"cert_reqs": CERT_NONE} + expected = {key: f"{temp_output_dir}/{cert_file}" for key, cert_file in CERT_FILES.items()} + expected["cert_reqs"] = CERT_NONE + assert get_ssl_config() == expected def test_get_ssl_config_mysql_celery_check(self, mysql_results_backend_config: "fixture"): # noqa: F821 """ @@ -557,6 +563,9 @@ def test_get_connection_string_mysql(self, mysql_results_backend_config: "fixtur CONFIG.celery.certs = merlin_server_dir create_cert_files(merlin_server_dir, MYSQL_CONFIG_FILENAMES) + CONFIG.results_backend.keyfile = MYSQL_CONFIG_FILENAMES["ssl_key"] + CONFIG.results_backend.certfile = MYSQL_CONFIG_FILENAMES["ssl_cert"] + CONFIG.results_backend.ca_certs = MYSQL_CONFIG_FILENAMES["ssl_ca"] expected_vals = { "cert_reqs": CERT_NONE, From 275fbd474cc260ca755785b102969a7b5615515e Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Tue, 19 Dec 2023 09:24:17 -0800 Subject: [PATCH 21/44] add unit tests for configfile.py --- tests/conftest.py | 12 - tests/unit/config/dummy_app.yaml | 33 + tests/unit/config/old_test_configfile.py | 96 --- tests/unit/config/old_test_results_backend.py | 66 -- tests/unit/config/test_configfile.py | 705 ++++++++++++++++++ 5 files changed, 738 insertions(+), 174 deletions(-) create mode 100644 tests/unit/config/dummy_app.yaml delete mode 100644 tests/unit/config/old_test_configfile.py delete mode 100644 tests/unit/config/old_test_results_backend.py create mode 100644 tests/unit/config/test_configfile.py diff --git a/tests/conftest.py b/tests/conftest.py index e9d6027fd..446e3118b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -31,7 +31,6 @@ This module contains pytest fixtures to be used throughout the entire test suite. """ import os -import shutil from copy import copy from time import sleep from typing import Dict @@ -112,17 +111,6 @@ def create_cert_files(cert_filepath: str, cert_files: Dict[str, str]): pass -def create_app_yaml(app_yaml_filepath: str): - """ - Create a dummy app.yaml file at `app_yaml_filepath`. - - :param app_yaml_filepath: The location to create an app.yaml file at - """ - full_app_yaml_filepath = f"{app_yaml_filepath}/app.yaml" - if not os.path.exists(full_app_yaml_filepath): - shutil.copy(f"{os.path.dirname(__file__)}/dummy_app.yaml", full_app_yaml_filepath) - - def set_config(broker: Dict[str, str], results_backend: Dict[str, str]): """ Given configuration options for the broker and results_backend, update diff --git a/tests/unit/config/dummy_app.yaml b/tests/unit/config/dummy_app.yaml new file mode 100644 index 000000000..966156566 --- /dev/null +++ b/tests/unit/config/dummy_app.yaml @@ -0,0 +1,33 @@ +broker: + cert_reqs: none + name: redis + password: redis.pass + port: '6379' + server: 127.0.0.1 + username: default + vhost: host4gunny +celery: + override: + visibility_timeout: 86400 +container: + config: redis.conf + config_dir: ./merlin_server/ + format: singularity + image: redis_latest.sif + image_type: redis + pass_file: redis.pass + pfile: merlin_server.pf + url: docker://redis + user_file: redis.users +process: + kill: kill {pid} + status: pgrep -P {pid} +results_backend: + cert_reqs: none + db_num: 0 + encryption_key: encrypt_data_key + name: redis + password: redis.pass + port: '6379' + server: 127.0.0.1 + username: default \ No newline at end of file diff --git a/tests/unit/config/old_test_configfile.py b/tests/unit/config/old_test_configfile.py deleted file mode 100644 index 1ee970531..000000000 --- a/tests/unit/config/old_test_configfile.py +++ /dev/null @@ -1,96 +0,0 @@ -"""Tests for the configfile module.""" -import os -import shutil -import tempfile -import unittest -from getpass import getuser - -from merlin.config import configfile - -from .utils import mkfile - - -CONFIG_FILE_CONTENTS = """ -celery: - certs: path/to/celery/config/files - -broker: - name: rabbitmq - username: testuser - password: rabbit.password # The filename that contains the password. - server: jackalope.llnl.gov - -results_backend: - name: mysql - dbname: testuser - username: mlsi - password: mysql.password # The filename that contains the password. - server: rabbit.llnl.gov - -""" - - -class TestFindConfigFile(unittest.TestCase): - def setUp(self): - self.tmpdir = tempfile.mkdtemp() - self.appfile = mkfile(self.tmpdir, "app.yaml") - - def tearDown(self): - shutil.rmtree(self.tmpdir, ignore_errors=True) - - def test_tempdir(self): - self.assertTrue(os.path.isdir(self.tmpdir)) - - def test_find_config_file(self): - """ - Given the path to a vaild config file, find and return the full - filepath. - """ - path = configfile.find_config_file(path=self.tmpdir) - expected = os.path.join(self.tmpdir, self.appfile) - self.assertEqual(path, expected) - - def test_find_config_file_error(self): - """Given an invalid path, return None.""" - invalid = "invalid/path" - expected = None - - path = configfile.find_config_file(path=invalid) - self.assertEqual(path, expected) - - -class TestConfigFile(unittest.TestCase): - """Unit tests for loading the config file.""" - - def setUp(self): - self.tmpdir = tempfile.mkdtemp() - self.configfile = mkfile(self.tmpdir, "app.yaml", content=CONFIG_FILE_CONTENTS) - - def tearDown(self): - shutil.rmtree(self.tmpdir, ignore_errors=True) - - def test_get_config(self): - """ - Given the directory path to a valid merlin config file, then - `get_config` should find the merlin config file and load the YAML - contents to a dictionary. - """ - expected = { - "broker": { - "name": "rabbitmq", - "password": "rabbit.password", - "server": "jackalope.llnl.gov", - "username": "testuser", - "vhost": getuser(), - }, - "celery": {"certs": "path/to/celery/config/files"}, - "results_backend": { - "dbname": "testuser", - "name": "mysql", - "password": "mysql.password", - "server": "rabbit.llnl.gov", - "username": "mlsi", - }, - } - - self.assertDictEqual(configfile.get_config(self.tmpdir), expected) diff --git a/tests/unit/config/old_test_results_backend.py b/tests/unit/config/old_test_results_backend.py deleted file mode 100644 index d6e0c2f22..000000000 --- a/tests/unit/config/old_test_results_backend.py +++ /dev/null @@ -1,66 +0,0 @@ -"""Tests for the results_backend module.""" -import os -import shutil -import tempfile -import unittest - -from merlin.config import results_backend - -from .utils import mkfile - - -class TestResultsBackend(unittest.TestCase): - def setUp(self): - self.tmpdir = tempfile.mkdtemp() - - # Create test files. - self.tmpfile1 = mkfile(self.tmpdir, "mysql_test1.txt") - self.tmpfile2 = mkfile(self.tmpdir, "mysql_test2.txt") - - def tearDown(self): - shutil.rmtree(self.tmpdir, ignore_errors=True) - - def test_mysql_config(self): - """ - Given the path to a directory containing the MySQL cert files and a - dictionary of files to look for, then find and return the full path to - all the certs. - """ - certs = {"test1": "mysql_test1.txt", "test2": "mysql_test2.txt"} - - # This will just be the above dictionary with the full file paths. - expected = { - "test1": os.path.join(self.tmpdir, certs["test1"]), - "test2": os.path.join(self.tmpdir, certs["test2"]), - } - results = results_backend.get_mysql_config(self.tmpdir, certs) - self.assertDictEqual(results, expected) - - def test_mysql_config_no_files(self): - """ - Given the path to a directory containing the MySQL cert files and - an empty dictionary, then `get_mysql_config` should return an empty - dictionary. - """ - files = {} - result = results_backend.get_mysql_config(self.tmpdir, files) - self.assertEqual(result, {}) - - -class TestConfingMysqlErrorPath(unittest.TestCase): - """ - Test `get_mysql_config` against cases were the given path does not exist. - """ - - def test_mysql_config_false(self): - """ - Given a path that does not exist, then `get_mysql_config` should return - False. - """ - path = "invalid/path" - - # We don't need the dictionary populated for this test. The function - # should return False before trying to process the dictionary. - certs = {} - result = results_backend.get_mysql_config(path, certs) - self.assertFalse(result) diff --git a/tests/unit/config/test_configfile.py b/tests/unit/config/test_configfile.py new file mode 100644 index 000000000..5d635e79b --- /dev/null +++ b/tests/unit/config/test_configfile.py @@ -0,0 +1,705 @@ +""" +Tests for the configfile.py module. +""" +import getpass +import os +import shutil +import ssl +from copy import copy, deepcopy + +import pytest +import yaml + +from merlin.config.configfile import ( + CONFIG, + default_config_info, + find_config_file, + get_cert_file, + get_config, + get_ssl_entries, + is_debug, + load_config, + load_default_celery, + load_default_user_names, + load_defaults, + merge_sslmap, + process_ssl_map, +) +from tests.conftest import CERT_FILES + + +CONFIGFILE_DIR = "{temp_output_dir}/test_configfile" +COPIED_APP_FILENAME = "app_copy.yaml" +DUMMY_APP_FILEPATH = f"{os.path.dirname(__file__)}/dummy_app.yaml" + + +def create_configfile_dir(temp_output_dir: str): + """ + Create the configfile dir if it doesn't exist yet. + + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + """ + full_configfile_dirpath = CONFIGFILE_DIR.format(temp_output_dir=temp_output_dir) + if not os.path.exists(full_configfile_dirpath): + os.mkdir(full_configfile_dirpath) + + +def create_app_yaml(app_yaml_filepath: str): + """ + Create a dummy app.yaml file at `app_yaml_filepath`. + + :param app_yaml_filepath: The location to create an app.yaml file at + """ + full_app_yaml_filepath = f"{app_yaml_filepath}/app.yaml" + if not os.path.exists(full_app_yaml_filepath): + shutil.copy(DUMMY_APP_FILEPATH, full_app_yaml_filepath) + + +def test_load_config(temp_output_dir: str): + """ + Test the `load_config` function. + + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + """ + create_configfile_dir(temp_output_dir) + configfile_dir = CONFIGFILE_DIR.format(temp_output_dir=temp_output_dir) + create_app_yaml(configfile_dir) + + with open(DUMMY_APP_FILEPATH, "r") as dummy_app_file: + expected = yaml.load(dummy_app_file, yaml.Loader) + + actual = load_config(f"{configfile_dir}/app.yaml") + assert actual == expected + + +def test_load_config_invalid_file(): + """ + Test the `load_config` function with an invalid filepath. + """ + assert load_config("invalid/filepath") is None + + +def test_find_config_file_valid_path(temp_output_dir: str): + """ + Test the `find_config_file` function with passing a valid path in. + + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + """ + create_configfile_dir(temp_output_dir) + configfile_dir = CONFIGFILE_DIR.format(temp_output_dir=temp_output_dir) + create_app_yaml(configfile_dir) + + assert find_config_file(configfile_dir) == f"{configfile_dir}/app.yaml" + + +def test_find_config_file_invalid_path(): + """ + Test the `find_config_file` function with passing an invalid path in. + """ + assert find_config_file("invalid/path") is None + + +def test_find_config_file_local_path(temp_output_dir: str): + """ + Test the `find_config_file` function by having it find a local (in our cwd) app.yaml file. + We'll use the `temp_output_dir` fixture so that our current working directory is in a temp + location. + + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + """ + + # Create the configfile directory and put an app.yaml file there + create_configfile_dir(temp_output_dir) + configfile_dir = CONFIGFILE_DIR.format(temp_output_dir=temp_output_dir) + create_app_yaml(configfile_dir) + + # Move into the configfile directory and run the test + os.chdir(configfile_dir) + try: + assert find_config_file() == f"{os.getcwd()}/app.yaml" + except AssertionError as exc: + # Move back to the temp output directory even if the test fails + os.chdir(temp_output_dir) + raise AssertionError from exc + + # Move back to the temp output directory + os.chdir(temp_output_dir) + + +def test_find_config_file_merlin_home_path(temp_output_dir: str): + """ + Test the `find_config_file` function by having it find an app.yaml file in our merlin directory. + We'll use the `temp_output_dir` fixture so that our current working directory is in a temp + location. + + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + """ + merlin_home = os.path.expanduser("~/.merlin") + if not os.path.exists(merlin_home): + os.mkdir(merlin_home) + create_app_yaml(merlin_home) + assert find_config_file() == f"{merlin_home}/app.yaml" + + +def check_for_and_move_app_yaml(dir_to_check: str) -> bool: + """ + Check for any app.yaml files in `dir_to_check`. If one is found, rename it. + Return True if an app.yaml was found, false otherwise. + + :param dir_to_check: The directory to search for an app.yaml in + :returns: True if an app.yaml was found. False otherwise. + """ + for filename in os.listdir(dir_to_check): + full_path = os.path.join(dir_to_check, filename) + if os.path.isfile(full_path) and filename == "app.yaml": + os.rename(full_path, f"{dir_to_check}/{COPIED_APP_FILENAME}") + return True + return False + + +def test_find_config_file_no_path(temp_output_dir: str): + """ + Test the `find_config_file` function by making it unable to find any app.yaml path. + We'll use the `temp_output_dir` fixture so that our current working directory is in a temp + location. + + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + """ + + # Rename any app.yaml in the cwd + cwd_path = os.getcwd() + cwd_had_app_yaml = check_for_and_move_app_yaml(cwd_path) + + # Rename any app.yaml in the merlin home directory + merlin_home_dir = os.path.expanduser("~/.merlin") + merlin_home_had_app_yaml = check_for_and_move_app_yaml(merlin_home_dir) + + try: + assert find_config_file() is None + except AssertionError as exc: + # Reset the cwd app.yaml even if the test fails + if cwd_had_app_yaml: + os.rename(f"{cwd_path}/{COPIED_APP_FILENAME}", f"{cwd_path}/app.yaml") + + # Reset the merlin home app.yaml even if the test fails + if merlin_home_had_app_yaml: + os.rename(f"{merlin_home_dir}/{COPIED_APP_FILENAME}", f"{merlin_home_dir}/app.yaml") + + raise AssertionError from exc + + # Reset the cwd app.yaml + if cwd_had_app_yaml: + os.rename(f"{cwd_path}/{COPIED_APP_FILENAME}", f"{cwd_path}/app.yaml") + + # Reset the merlin home app.yaml + if merlin_home_had_app_yaml: + os.rename(f"{merlin_home_dir}/{COPIED_APP_FILENAME}", f"{merlin_home_dir}/app.yaml") + + +def test_load_default_user_names_nothing_to_load(): + """ + Test the `load_default_user_names` function with nothing to load. In other words, in this + test the config dict will have a username and vhost already set for the broker. We'll + create the dict then make a copy of it to test against after calling the function. + """ + actual_config = {"broker": {"username": "default", "vhost": "host4testing"}} + expected_config = deepcopy(actual_config) + assert actual_config is not expected_config + + load_default_user_names(actual_config) + + # Ensure that nothing was modified after our call to load_default_user_names + assert actual_config == expected_config + + +def test_load_default_user_names_no_username(): + """ + Test the `load_default_user_names` function with no username. In other words, in this + test the config dict will have vhost already set for the broker but not a username. + """ + expected_config = {"broker": {"username": getpass.getuser(), "vhost": "host4testing"}} + actual_config = {"broker": {"vhost": "host4testing"}} + load_default_user_names(actual_config) + + # Ensure that the username was set in the call to load_default_user_names + assert actual_config == expected_config + + +def test_load_default_user_names_no_vhost(): + """ + Test the `load_default_user_names` function with no vhost. In other words, in this + test the config dict will have username already set for the broker but not a vhost. + """ + expected_config = {"broker": {"username": "default", "vhost": getpass.getuser()}} + actual_config = {"broker": {"username": "default"}} + load_default_user_names(actual_config) + + # Ensure that the vhost was set in the call to load_default_user_names + assert actual_config == expected_config + + +def test_load_default_celery_nothing_to_load(): + """ + Test the `load_default_celery` function with nothing to load. In other words, in this + test the config dict will have a celery entry containing omit_queue_tag, queue_tag, and + override. We'll create the dict then make a copy of it to test against after calling + the function. + """ + actual_config = {"celery": {"omit_queue_tag": False, "queue_tag": "[merlin]_", "override": None}} + expected_config = deepcopy(actual_config) + assert actual_config is not expected_config + + load_default_celery(actual_config) + + # Ensure that nothing was modified after our call to load_default_celery + assert actual_config == expected_config + + +def test_load_default_celery_no_omit_queue_tag(): + """ + Test the `load_default_celery` function with no omit_queue_tag. The function should + create a default entry of False for this. + """ + actual_config = {"celery": {"queue_tag": "[merlin]_", "override": None}} + expected_config = {"celery": {"omit_queue_tag": False, "queue_tag": "[merlin]_", "override": None}} + load_default_celery(actual_config) + + # Ensure that the omit_queue_tag was set in the call to load_default_celery + assert actual_config == expected_config + + +def test_load_default_celery_no_queue_tag(): + """ + Test the `load_default_celery` function with no queue_tag. The function should + create a default entry of '[merlin]_' for this. + """ + actual_config = {"celery": {"omit_queue_tag": False, "override": None}} + expected_config = {"celery": {"omit_queue_tag": False, "queue_tag": "[merlin]_", "override": None}} + load_default_celery(actual_config) + + # Ensure that the queue_tag was set in the call to load_default_celery + assert actual_config == expected_config + + +def test_load_default_celery_no_override(): + """ + Test the `load_default_celery` function with no override. The function should + create a default entry of None for this. + """ + actual_config = {"celery": {"omit_queue_tag": False, "queue_tag": "[merlin]_"}} + expected_config = {"celery": {"omit_queue_tag": False, "queue_tag": "[merlin]_", "override": None}} + load_default_celery(actual_config) + + # Ensure that the override was set in the call to load_default_celery + assert actual_config == expected_config + + +def test_load_default_celery_no_celery_block(): + """ + Test the `load_default_celery` function with no celery block. The function should + create a default entry of + {"celery": {"omit_queue_tag": False, "queue_tag": "[merlin]_", "override": None}} for this. + """ + actual_config = {} + expected_config = {"celery": {"omit_queue_tag": False, "queue_tag": "[merlin]_", "override": None}} + load_default_celery(actual_config) + + # Ensure that the celery block was set in the call to load_default_celery + assert actual_config == expected_config + + +def test_load_defaults(): + """ + Test that the `load_defaults` function loads the user names and the celery block properly. + """ + actual_config = {"broker": {}} + expected_config = { + "broker": {"username": getpass.getuser(), "vhost": getpass.getuser()}, + "celery": {"omit_queue_tag": False, "queue_tag": "[merlin]_", "override": None}, + } + load_defaults(actual_config) + + assert actual_config == expected_config + + +def test_get_config(temp_output_dir: str): + """ + Test the `get_config` function. + + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + """ + + # Create the configfile directory and put an app.yaml file there + create_configfile_dir(temp_output_dir) + configfile_dir = CONFIGFILE_DIR.format(temp_output_dir=temp_output_dir) + create_app_yaml(configfile_dir) + + # Load up the contents of the dummy app.yaml file that we copied + with open(DUMMY_APP_FILEPATH, "r") as dummy_app_file: + expected = yaml.load(dummy_app_file, yaml.Loader) + + # Add in default settings that should be added + expected["celery"]["omit_queue_tag"] = False + expected["celery"]["queue_tag"] = "[merlin]_" + + actual = get_config(configfile_dir) + + assert actual == expected + + +def test_get_config_invalid_path(): + """ + Test the `get_config` function with an invalid path. This should raise a ValueError. + """ + with pytest.raises(ValueError) as excinfo: + get_config("invalid/path") + + assert "Cannot find a merlin config file!" in str(excinfo.value) + + +def test_is_debug_no_merlin_debug(): + """ + Test the `is_debug` function without having MERLIN_DEBUG in the environment. + This should return False. + """ + + # Delete the current val of MERLIN_DEBUG and store it (if there is one) + reset_merlin_debug = False + debug_val = None + if "MERLIN_DEBUG" in os.environ: + debug_val = copy(os.environ["MERLIN_DEBUG"]) + del os.environ["MERLIN_DEBUG"] + reset_merlin_debug = True + + # Run the test + try: + assert is_debug() is False + except AssertionError as exc: + # Make sure to reset the value of MERLIN_DEBUG even if the test fails + if reset_merlin_debug: + os.environ["MERLIN_DEBUG"] = debug_val + raise AssertionError from exc + + # Reset the value of MERLIN_DEBUG + if reset_merlin_debug: + os.environ["MERLIN_DEBUG"] = debug_val + + +def test_is_debug_with_merlin_debug(): + """ + Test the `is_debug` function with having MERLIN_DEBUG in the environment. + This should return True. + """ + + # Grab the current value of MERLIN_DEBUG if there is one + reset_merlin_debug = False + debug_val = None + if "MERLIN_DEBUG" in os.environ and int(os.environ["MERLIN_DEBUG"]) != 1: + debug_val = copy(os.environ["MERLIN_DEBUG"]) + reset_merlin_debug = True + + # Set the MERLIN_DEBUG value to be 1 + os.environ["MERLIN_DEBUG"] = "1" + + try: + assert is_debug() is True + except AssertionError as exc: + # Make sure to reset the value of MERLIN_DEBUG even if the test fails + if reset_merlin_debug: + os.environ["MERLIN_DEBUG"] = debug_val + raise AssertionError from exc + + # Reset the value of MERLIN_DEBUG + if reset_merlin_debug: + os.environ["MERLIN_DEBUG"] = debug_val + + +def test_default_config_info(temp_output_dir: str): + """ + Test the `default_config_info` function. + + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + """ + + # Create the configfile directory and put an app.yaml file there + create_configfile_dir(temp_output_dir) + configfile_dir = CONFIGFILE_DIR.format(temp_output_dir=temp_output_dir) + create_app_yaml(configfile_dir) + cwd = os.getcwd() + os.chdir(configfile_dir) + + # Delete the current val of MERLIN_DEBUG and store it (if there is one) + reset_merlin_debug = False + debug_val = None + if "MERLIN_DEBUG" in os.environ: + debug_val = copy(os.environ["MERLIN_DEBUG"]) + del os.environ["MERLIN_DEBUG"] + reset_merlin_debug = True + + # Create the merlin home directory if it doesn't already exist + merlin_home = f"{os.path.expanduser('~')}/.merlin" + remove_merlin_home = False + if not os.path.exists(merlin_home): + os.mkdir(merlin_home) + remove_merlin_home = True + + # Run the test + try: + expected = { + "config_file": f"{configfile_dir}/app.yaml", + "is_debug": False, + "merlin_home": merlin_home, + "merlin_home_exists": True, + } + actual = default_config_info() + assert actual == expected + except AssertionError as exc: + # Make sure to reset values even if the test fails + if reset_merlin_debug: + os.environ["MERLIN_DEBUG"] = debug_val + if remove_merlin_home: + os.rmdir(merlin_home) + raise AssertionError from exc + + # Reset values if necessary + if reset_merlin_debug: + os.environ["MERLIN_DEBUG"] = debug_val + if remove_merlin_home: + os.rmdir(merlin_home) + + os.chdir(cwd) + + +def test_get_cert_file_all_valid_args(mysql_results_backend_config: "fixture", merlin_server_dir: str): # noqa: F821 + """ + Test the `get_cert_file` function with all valid arguments. + + :param mysql_results_backend_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param merlin_server_dir: The path to the temporary merlin server directory that's housing our cert files + """ + expected = f"{merlin_server_dir}/{CERT_FILES['ssl_key']}" + actual = get_cert_file( + server_type="Results Backend", config=CONFIG.results_backend, cert_name="keyfile", cert_path=merlin_server_dir + ) + assert actual == expected + + +def test_get_cert_file_invalid_cert_name(mysql_results_backend_config: "fixture", merlin_server_dir: str): # noqa: F821 + """ + Test the `get_cert_file` function with an invalid cert_name argument. This should just return None. + + :param mysql_results_backend_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param merlin_server_dir: The path to the temporary merlin server directory that's housing our cert files + """ + actual = get_cert_file( + server_type="Results Backend", config=CONFIG.results_backend, cert_name="invalid", cert_path=merlin_server_dir + ) + assert actual is None + + +def test_get_cert_file_nonexistent_cert_path( + mysql_results_backend_config: "fixture", temp_output_dir: str, merlin_server_dir: str # noqa: F821 +): + """ + Test the `get_cert_file` function with cert_path argument that doesn't exist. + This should still return the nonexistent path at the root of our temporary directory for testing. + + :param mysql_results_backend_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + :param merlin_server_dir: The path to the temporary merlin server directory that's housing our cert files + """ + CONFIG.results_backend.certfile = "new_certfile.pem" + expected = f"{temp_output_dir}/new_certfile.pem" + actual = get_cert_file( + server_type="Results Backend", config=CONFIG.results_backend, cert_name="certfile", cert_path=merlin_server_dir + ) + assert actual == expected + + +def test_get_ssl_entries_required_certs(mysql_results_backend_config: "fixture", temp_output_dir: str): # noqa: F821 + """ + Test the `get_ssl_entries` function with mysql as the results_backend. For this test we'll make + cert reqs be required. + + :param mysql_results_backend_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + """ + CONFIG.results_backend.cert_reqs = "required" + + expected = { + "ssl_key": f"{temp_output_dir}/{CERT_FILES['ssl_key']}", + "ssl_cert": f"{temp_output_dir}/{CERT_FILES['ssl_cert']}", + "ssl_ca": f"{temp_output_dir}/{CERT_FILES['ssl_ca']}", + "cert_reqs": ssl.CERT_REQUIRED, + } + actual = get_ssl_entries( + server_type="Results Backend", server_name="mysql", server_config=CONFIG.results_backend, cert_path=temp_output_dir + ) + assert expected == actual + + +def test_get_ssl_entries_optional_certs(mysql_results_backend_config: "fixture", temp_output_dir: str): # noqa: F821 + """ + Test the `get_ssl_entries` function with mysql as the results_backend. For this test we'll make + cert reqs be optional. + + :param mysql_results_backend_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + """ + CONFIG.results_backend.cert_reqs = "optional" + + expected = { + "ssl_key": f"{temp_output_dir}/{CERT_FILES['ssl_key']}", + "ssl_cert": f"{temp_output_dir}/{CERT_FILES['ssl_cert']}", + "ssl_ca": f"{temp_output_dir}/{CERT_FILES['ssl_ca']}", + "cert_reqs": ssl.CERT_OPTIONAL, + } + actual = get_ssl_entries( + server_type="Results Backend", server_name="mysql", server_config=CONFIG.results_backend, cert_path=temp_output_dir + ) + assert expected == actual + + +def test_get_ssl_entries_none_certs(mysql_results_backend_config: "fixture", temp_output_dir: str): # noqa: F821 + """ + Test the `get_ssl_entries` function with mysql as the results_backend. For this test we won't require + any cert reqs. + + :param mysql_results_backend_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + """ + CONFIG.results_backend.cert_reqs = "none" + + expected = { + "ssl_key": f"{temp_output_dir}/{CERT_FILES['ssl_key']}", + "ssl_cert": f"{temp_output_dir}/{CERT_FILES['ssl_cert']}", + "ssl_ca": f"{temp_output_dir}/{CERT_FILES['ssl_ca']}", + "cert_reqs": ssl.CERT_NONE, + } + actual = get_ssl_entries( + server_type="Results Backend", server_name="mysql", server_config=CONFIG.results_backend, cert_path=temp_output_dir + ) + assert expected == actual + + +def test_get_ssl_entries_omit_certs(mysql_results_backend_config: "fixture", temp_output_dir: str): # noqa: F821 + """ + Test the `get_ssl_entries` function with mysql as the results_backend. For this test we'll completely + omit the cert_reqs option + + :param mysql_results_backend_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + """ + del CONFIG.results_backend.cert_reqs + + expected = { + "ssl_key": f"{temp_output_dir}/{CERT_FILES['ssl_key']}", + "ssl_cert": f"{temp_output_dir}/{CERT_FILES['ssl_cert']}", + "ssl_ca": f"{temp_output_dir}/{CERT_FILES['ssl_ca']}", + "cert_reqs": ssl.CERT_REQUIRED, + } + actual = get_ssl_entries( + server_type="Results Backend", server_name="mysql", server_config=CONFIG.results_backend, cert_path=temp_output_dir + ) + assert expected == actual + + +def test_get_ssl_entries_with_ssl_protocol(mysql_results_backend_config: "fixture", temp_output_dir: str): # noqa: F821 + """ + Test the `get_ssl_entries` function with mysql as the results_backend. For this test we'll add in a + dummy ssl_protocol value that should get added to the dict that's output. + + :param mysql_results_backend_config: A fixture to set the CONFIG object to a test configuration that we'll use here + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + """ + protocol = "test_protocol" + CONFIG.results_backend.ssl_protocol = protocol + + expected = { + "ssl_key": f"{temp_output_dir}/{CERT_FILES['ssl_key']}", + "ssl_cert": f"{temp_output_dir}/{CERT_FILES['ssl_cert']}", + "ssl_ca": f"{temp_output_dir}/{CERT_FILES['ssl_ca']}", + "cert_reqs": ssl.CERT_NONE, + "ssl_protocol": protocol, + } + actual = get_ssl_entries( + server_type="Results Backend", server_name="mysql", server_config=CONFIG.results_backend, cert_path=temp_output_dir + ) + assert expected == actual + + +def test_process_ssl_map_mysql(): + """Test the `process_ssl_map` function with mysql as the server name.""" + expected = {"keyfile": "ssl_key", "certfile": "ssl_cert", "ca_certs": "ssl_ca"} + actual = process_ssl_map("mysql") + assert actual == expected + + +def test_process_ssl_map_rediss(): + """Test the `process_ssl_map` function with rediss as the server name.""" + expected = { + "keyfile": "ssl_keyfile", + "certfile": "ssl_certfile", + "ca_certs": "ssl_ca_certs", + "cert_reqs": "ssl_cert_reqs", + } + actual = process_ssl_map("rediss") + assert actual == expected + + +def test_merge_sslmap_all_keys_present(): + """ + Test the `merge_sslmap` function with all keys from server_ssl in ssl_map. + We'll assume we're using a rediss server for this. + """ + expected = { + "ssl_keyfile": "/path/to/keyfile", + "ssl_certfile": "/path/to/certfile", + "ssl_ca_certs": "/path/to/ca_file", + "ssl_cert_reqs": ssl.CERT_NONE, + } + test_server_ssl = { + "keyfile": "/path/to/keyfile", + "certfile": "/path/to/certfile", + "ca_certs": "/path/to/ca_file", + "cert_reqs": ssl.CERT_NONE, + } + test_ssl_map = { + "keyfile": "ssl_keyfile", + "certfile": "ssl_certfile", + "ca_certs": "ssl_ca_certs", + "cert_reqs": "ssl_cert_reqs", + } + actual = merge_sslmap(test_server_ssl, test_ssl_map) + assert actual == expected + + +def test_merge_sslmap_some_keys_present(): + """ + Test the `merge_sslmap` function with some keys from server_ssl in ssl_map and others not. + We'll assume we're using a rediss server for this. + """ + expected = { + "ssl_keyfile": "/path/to/keyfile", + "ssl_certfile": "/path/to/certfile", + "ssl_ca_certs": "/path/to/ca_file", + "ssl_cert_reqs": ssl.CERT_NONE, + "new_key": "new_val", + "second_new_key": "second_new_val", + } + test_server_ssl = { + "keyfile": "/path/to/keyfile", + "certfile": "/path/to/certfile", + "ca_certs": "/path/to/ca_file", + "cert_reqs": ssl.CERT_NONE, + "new_key": "new_val", + "second_new_key": "second_new_val", + } + test_ssl_map = { + "keyfile": "ssl_keyfile", + "certfile": "ssl_certfile", + "ca_certs": "ssl_ca_certs", + "cert_reqs": "ssl_cert_reqs", + } + actual = merge_sslmap(test_server_ssl, test_ssl_map) + assert actual == expected From 9669bd0cc559be15c689512016a14fe8cfd3c544 Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Tue, 19 Dec 2023 10:40:13 -0800 Subject: [PATCH 22/44] add tests for the utils.py file in config/ --- tests/unit/config/test_utils.py | 83 +++++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 tests/unit/config/test_utils.py diff --git a/tests/unit/config/test_utils.py b/tests/unit/config/test_utils.py new file mode 100644 index 000000000..a02bc1ff1 --- /dev/null +++ b/tests/unit/config/test_utils.py @@ -0,0 +1,83 @@ +""" +Tests for the merlin/config/utils.py module. +""" + +import pytest + +from merlin.config.configfile import CONFIG +from merlin.config.utils import Priority, get_priority, is_rabbit_broker, is_redis_broker + + +def test_is_rabbit_broker(): + """Test the `is_rabbit_broker` by passing in rabbit as the broker""" + assert is_rabbit_broker("rabbitmq") is True + assert is_rabbit_broker("amqp") is True + assert is_rabbit_broker("amqps") is True + + +def test_is_rabbit_broker_invalid(): + """Test the `is_rabbit_broker` by passing in an invalid broker""" + assert is_rabbit_broker("redis") is False + assert is_rabbit_broker("") is False + + +def test_is_redis_broker(): + """Test the `is_redis_broker` by passing in redis as the broker""" + assert is_redis_broker("redis") is True + assert is_redis_broker("rediss") is True + assert is_redis_broker("redis+socket") is True + + +def test_is_redis_broker_invalid(): + """Test the `is_redis_broker` by passing in an invalid broker""" + assert is_redis_broker("rabbitmq") is False + assert is_redis_broker("") is False + + +def test_get_priority_rabbit_broker(rabbit_broker_config: "fixture"): # noqa: F821 + """ + Test the `get_priority` function with rabbit as the broker. + Low priority for rabbit is 1 and high is 10. + + :param rabbit_broker_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + assert get_priority(Priority.LOW) == 1 + assert get_priority(Priority.MID) == 5 + assert get_priority(Priority.HIGH) == 10 + + +def test_get_priority_redis_broker(redis_broker_config: "fixture"): # noqa: F821 + """ + Test the `get_priority` function with redis as the broker. + Low priority for redis is 10 and high is 1. + + :param redis_broker_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + assert get_priority(Priority.LOW) == 10 + assert get_priority(Priority.MID) == 5 + assert get_priority(Priority.HIGH) == 1 + + +def test_get_priority_invalid_broker(redis_broker_config: "fixture"): # noqa: F821 + """ + Test the `get_priority` function with an invalid broker. + This should raise a ValueError. + + :param redis_broker_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + CONFIG.broker.name = "invalid" + with pytest.raises(ValueError) as excinfo: + get_priority(Priority.LOW) + assert "Function get_priority has reached unknown state! Maybe unsupported broker invalid?" in str(excinfo.value) + + +def test_get_priority_invalid_priority(redis_broker_config: "fixture"): # noqa: F821 + """ + Test the `get_priority` function with an invalid priority. + This should raise a TypeError. + + :param redis_broker_config: A fixture to set the CONFIG object to a test configuration that we'll use here + """ + with pytest.raises(TypeError) as excinfo: + get_priority("invalid_priority") + assert "Unrecognized priority 'invalid_priority'!" in str(excinfo.value) From 2a209e58056f5c2091cc0f066869ae1dff4dcb70 Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Tue, 19 Dec 2023 13:10:02 -0800 Subject: [PATCH 23/44] create utilities file and constants file --- tests/conftest.py | 60 +---------------------- tests/constants.py | 10 ++++ tests/unit/config/test_broker.py | 3 +- tests/unit/config/test_results_backend.py | 3 +- tests/unit/config/utils.py | 23 --------- tests/utils.py | 35 +++++++++++++ 6 files changed, 51 insertions(+), 83 deletions(-) create mode 100644 tests/constants.py delete mode 100644 tests/unit/config/utils.py create mode 100644 tests/utils.py diff --git a/tests/conftest.py b/tests/conftest.py index 446e3118b..20749d4cd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -42,16 +42,10 @@ from celery.canvas import Signature from merlin.config.configfile import CONFIG +from tests.constants import SERVER_PASS, CERT_FILES from tests.context_managers.celery_workers_manager import CeleryWorkersManager from tests.context_managers.server_manager import RedisServerManager - - -SERVER_PASS = "merlin-test-server" -CERT_FILES = { - "ssl_cert": "test-rabbit-client-cert.pem", - "ssl_ca": "test-mysql-ca-cert.pem", - "ssl_key": "test-rabbit-client-key.pem", -} +from tests.utils import create_cert_files, create_pass_file ####################################### @@ -59,18 +53,6 @@ ####################################### -def create_pass_file(pass_filepath: str): - """ - Check if a password file already exists (it will if the redis server has been started) - and if it hasn't then create one and write the password to the file. - - :param pass_filepath: The path to the password file that we need to check for/create - """ - if not os.path.exists(pass_filepath): - with open(pass_filepath, "w") as pass_file: - pass_file.write(SERVER_PASS) - - def create_encryption_file(key_filepath: str, encryption_key: bytes, app_yaml_filepath: str = None): """ Check if an encryption file already exists (it will if the redis server has been started) @@ -97,44 +79,6 @@ def create_encryption_file(key_filepath: str, encryption_key: bytes, app_yaml_fi yaml.dump(app_yaml, app_yaml_file) -def create_cert_files(cert_filepath: str, cert_files: Dict[str, str]): - """ - Check if cert files already exist and if they don't then create them. - - :param cert_filepath: The path to the cert files - :param cert_files: A dict of certification files to create - """ - for cert_file in cert_files.values(): - full_cert_filepath = f"{cert_filepath}/{cert_file}" - if not os.path.exists(full_cert_filepath): - with open(full_cert_filepath, "w"): - pass - - -def set_config(broker: Dict[str, str], results_backend: Dict[str, str]): - """ - Given configuration options for the broker and results_backend, update - the CONFIG object. - - :param broker: A dict of the configuration settings for the broker - :param results_backend: A dict of configuration settings for the results_backend - """ - # Set the broker configuration for testing - CONFIG.broker.password = broker["password"] - CONFIG.broker.port = broker["port"] - CONFIG.broker.server = broker["server"] - CONFIG.broker.username = broker["username"] - CONFIG.broker.vhost = broker["vhost"] - CONFIG.broker.name = broker["name"] - - # Set the results_backend configuration for testing - CONFIG.results_backend.password = results_backend["password"] - CONFIG.results_backend.port = results_backend["port"] - CONFIG.results_backend.server = results_backend["server"] - CONFIG.results_backend.username = results_backend["username"] - CONFIG.results_backend.encryption_key = results_backend["encryption_key"] - - ####################################### ######### Fixture Definitions ######### ####################################### diff --git a/tests/constants.py b/tests/constants.py new file mode 100644 index 000000000..a2b354146 --- /dev/null +++ b/tests/constants.py @@ -0,0 +1,10 @@ +""" +This module will store constants that will be used throughout our test suite. +""" + +SERVER_PASS = "merlin-test-server" +CERT_FILES = { + "ssl_cert": "test-rabbit-client-cert.pem", + "ssl_ca": "test-mysql-ca-cert.pem", + "ssl_key": "test-rabbit-client-key.pem", +} \ No newline at end of file diff --git a/tests/unit/config/test_broker.py b/tests/unit/config/test_broker.py index 490b47649..8af1dda75 100644 --- a/tests/unit/config/test_broker.py +++ b/tests/unit/config/test_broker.py @@ -18,7 +18,8 @@ read_file, ) from merlin.config.configfile import CONFIG -from tests.conftest import SERVER_PASS, create_pass_file +from tests.constants import SERVER_PASS +from tests.utils import create_pass_file def test_read_file(merlin_server_dir: str): diff --git a/tests/unit/config/test_results_backend.py b/tests/unit/config/test_results_backend.py index 59e53a5ae..314df6ce7 100644 --- a/tests/unit/config/test_results_backend.py +++ b/tests/unit/config/test_results_backend.py @@ -19,7 +19,8 @@ get_redis, get_ssl_config, ) -from tests.conftest import CERT_FILES, SERVER_PASS, create_cert_files, create_pass_file +from tests.constants import CERT_FILES, SERVER_PASS +from tests.utils import create_cert_files, create_pass_file RESULTS_BACKEND_DIR = "{temp_output_dir}/test_results_backend" diff --git a/tests/unit/config/utils.py b/tests/unit/config/utils.py deleted file mode 100644 index 11510c5fd..000000000 --- a/tests/unit/config/utils.py +++ /dev/null @@ -1,23 +0,0 @@ -""" -Utils module for common test functionality. -""" -import os - - -def mkfile(tmpdir, filename, content=""): - """ - A simple function for creating a file and returning the path. This is to - abstract out file creation logic in the tests. - - :param tmpdir: (str) The path to the temp directory. - :param filename: (str) The name of the file. - :param contents: (str) Optional contents to write to the file. Defaults to - an empty string. - :returns: (str) The appended path of the given tempdir and filename. - """ - filepath = os.path.join(tmpdir, filename) - - with open(filepath, "w") as f: - f.write(content) - - return filepath diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 000000000..51fbd56cc --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,35 @@ +""" +Utility functions for our test suite. +""" +import os +from typing import Dict + +from tests.constants import SERVER_PASS + + +def create_pass_file(pass_filepath: str): + """ + Check if a password file already exists (it will if the redis server has been started) + and if it hasn't then create one and write the password to the file. + + :param pass_filepath: The path to the password file that we need to check for/create + """ + if not os.path.exists(pass_filepath): + with open(pass_filepath, "w") as pass_file: + pass_file.write(SERVER_PASS) + + +def create_cert_files(cert_filepath: str, cert_files: Dict[str, str]): + """ + Check if cert files already exist and if they don't then create them. + + :param cert_filepath: The path to the cert files + :param cert_files: A dict of certification files to create + """ + for cert_file in cert_files.values(): + full_cert_filepath = f"{cert_filepath}/{cert_file}" + if not os.path.exists(full_cert_filepath): + with open(full_cert_filepath, "w"): + pass + + From ddb0588cd9170a077ded446a6fb9072a4e563add Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Tue, 19 Dec 2023 16:42:23 -0800 Subject: [PATCH 24/44] move create_dir function to utils.py --- tests/unit/config/test_configfile.py | 24 +++++++----------------- tests/utils.py | 9 +++++++++ 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/tests/unit/config/test_configfile.py b/tests/unit/config/test_configfile.py index 5d635e79b..aeb1da941 100644 --- a/tests/unit/config/test_configfile.py +++ b/tests/unit/config/test_configfile.py @@ -25,7 +25,8 @@ merge_sslmap, process_ssl_map, ) -from tests.conftest import CERT_FILES +from tests.constants import CERT_FILES +from tests.utils import create_dir CONFIGFILE_DIR = "{temp_output_dir}/test_configfile" @@ -33,17 +34,6 @@ DUMMY_APP_FILEPATH = f"{os.path.dirname(__file__)}/dummy_app.yaml" -def create_configfile_dir(temp_output_dir: str): - """ - Create the configfile dir if it doesn't exist yet. - - :param temp_output_dir: The path to the temporary output directory we'll be using for this test run - """ - full_configfile_dirpath = CONFIGFILE_DIR.format(temp_output_dir=temp_output_dir) - if not os.path.exists(full_configfile_dirpath): - os.mkdir(full_configfile_dirpath) - - def create_app_yaml(app_yaml_filepath: str): """ Create a dummy app.yaml file at `app_yaml_filepath`. @@ -61,8 +51,8 @@ def test_load_config(temp_output_dir: str): :param temp_output_dir: The path to the temporary output directory we'll be using for this test run """ - create_configfile_dir(temp_output_dir) configfile_dir = CONFIGFILE_DIR.format(temp_output_dir=temp_output_dir) + create_dir(configfile_dir) create_app_yaml(configfile_dir) with open(DUMMY_APP_FILEPATH, "r") as dummy_app_file: @@ -85,8 +75,8 @@ def test_find_config_file_valid_path(temp_output_dir: str): :param temp_output_dir: The path to the temporary output directory we'll be using for this test run """ - create_configfile_dir(temp_output_dir) configfile_dir = CONFIGFILE_DIR.format(temp_output_dir=temp_output_dir) + create_dir(configfile_dir) create_app_yaml(configfile_dir) assert find_config_file(configfile_dir) == f"{configfile_dir}/app.yaml" @@ -109,8 +99,8 @@ def test_find_config_file_local_path(temp_output_dir: str): """ # Create the configfile directory and put an app.yaml file there - create_configfile_dir(temp_output_dir) configfile_dir = CONFIGFILE_DIR.format(temp_output_dir=temp_output_dir) + create_dir(configfile_dir) create_app_yaml(configfile_dir) # Move into the configfile directory and run the test @@ -330,8 +320,8 @@ def test_get_config(temp_output_dir: str): """ # Create the configfile directory and put an app.yaml file there - create_configfile_dir(temp_output_dir) configfile_dir = CONFIGFILE_DIR.format(temp_output_dir=temp_output_dir) + create_dir(configfile_dir) create_app_yaml(configfile_dir) # Load up the contents of the dummy app.yaml file that we copied @@ -422,8 +412,8 @@ def test_default_config_info(temp_output_dir: str): """ # Create the configfile directory and put an app.yaml file there - create_configfile_dir(temp_output_dir) configfile_dir = CONFIGFILE_DIR.format(temp_output_dir=temp_output_dir) + create_dir(configfile_dir) create_app_yaml(configfile_dir) cwd = os.getcwd() os.chdir(configfile_dir) diff --git a/tests/utils.py b/tests/utils.py index 51fbd56cc..3a75622b8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -33,3 +33,12 @@ def create_cert_files(cert_filepath: str, cert_files: Dict[str, str]): pass +def create_dir(dirpath: str): + """ + Check if `dirpath` exists and if it doesn't then create it. + + :param dirpath: The directory to create + """ + if not os.path.exists(dirpath): + os.mkdir(dirpath) + From e5bc0fe239df705ce536ae67e5264c737e3f4176 Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Tue, 19 Dec 2023 16:42:52 -0800 Subject: [PATCH 25/44] add tests for merlin/examples/generator.py --- merlin/examples/generator.py | 8 + setup.cfg | 1 + tests/unit/test_examples_generator.py | 575 ++++++++++++++++++++++++++ 3 files changed, 584 insertions(+) create mode 100644 tests/unit/test_examples_generator.py diff --git a/merlin/examples/generator.py b/merlin/examples/generator.py index 2fa5e61ce..f1e58f430 100644 --- a/merlin/examples/generator.py +++ b/merlin/examples/generator.py @@ -48,6 +48,14 @@ EXAMPLES_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "workflows") +# TODO modify the example command to eliminate redundancy +# - e.g. running `merlin example flux_local` will produce the same output +# as running `merlin example flux_par` or `merlin example flux_par_restart`. +# This should just be `merlin example flux`. +# - restart and restart delay should be one example +# - feature demo and remote feature demo should be one example +# - all openfoam examples should just be under one openfoam label + def gather_example_dirs(): """Get all the example directories""" diff --git a/setup.cfg b/setup.cfg index 6b4278799..77ac2d84f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -31,3 +31,4 @@ ignore_missing_imports=true omit = merlin/ascii_art.py merlin/config/celeryconfig.py + merlin/examples/examples.py diff --git a/tests/unit/test_examples_generator.py b/tests/unit/test_examples_generator.py new file mode 100644 index 000000000..7d7ccc5bf --- /dev/null +++ b/tests/unit/test_examples_generator.py @@ -0,0 +1,575 @@ +""" +Tests for the `merlin/examples/generator.py` module. +""" +import os +import pathlib +from typing import List + +from tabulate import tabulate + +from merlin.examples.generator import ( + EXAMPLES_DIR, + gather_all_examples, + gather_example_dirs, + list_examples, + setup_example, + write_example +) +from tests.utils import create_dir + + +EXAMPLES_GENERATOR_DIR = "{temp_output_dir}/examples_generator" + + +def test_gather_example_dirs(): + """Test the `gather_example_dirs` function.""" + example_workflows = [ + "feature_demo", + "flux", + "hello", + "hpc_demo", + "iterative_demo", + "lsf", + "null_spec", + "openfoam_wf", + "openfoam_wf_no_docker", + "openfoam_wf_singularity", + "optimization", + "remote_feature_demo", + "restart", + "restart_delay", + "simple_chain", + "slurm" + ] + expected = {} + for wf_dir in example_workflows: + expected[wf_dir] = wf_dir + actual = gather_example_dirs() + assert actual == expected + + +def test_gather_all_examples(): + """Test the `gather_all_examples` function.""" + expected = [ + f"{EXAMPLES_DIR}/feature_demo/feature_demo.yaml", + f"{EXAMPLES_DIR}/flux/flux_local.yaml", + f"{EXAMPLES_DIR}/flux/flux_par_restart.yaml", + f"{EXAMPLES_DIR}/flux/flux_par.yaml", + f"{EXAMPLES_DIR}/flux/paper.yaml", + f"{EXAMPLES_DIR}/hello/hello_samples.yaml", + f"{EXAMPLES_DIR}/hello/hello.yaml", + f"{EXAMPLES_DIR}/hello/my_hello.yaml", + f"{EXAMPLES_DIR}/hpc_demo/hpc_demo.yaml", + f"{EXAMPLES_DIR}/iterative_demo/iterative_demo.yaml", + f"{EXAMPLES_DIR}/lsf/lsf_par_srun.yaml", + f"{EXAMPLES_DIR}/lsf/lsf_par.yaml", + f"{EXAMPLES_DIR}/null_spec/null_chain.yaml", + f"{EXAMPLES_DIR}/null_spec/null_spec.yaml", + f"{EXAMPLES_DIR}/openfoam_wf/openfoam_wf_template.yaml", + f"{EXAMPLES_DIR}/openfoam_wf/openfoam_wf.yaml", + f"{EXAMPLES_DIR}/openfoam_wf_no_docker/openfoam_wf_no_docker_template.yaml", + f"{EXAMPLES_DIR}/openfoam_wf_no_docker/openfoam_wf_no_docker.yaml", + f"{EXAMPLES_DIR}/openfoam_wf_singularity/openfoam_wf_singularity.yaml", + f"{EXAMPLES_DIR}/optimization/optimization_basic.yaml", + f"{EXAMPLES_DIR}/remote_feature_demo/remote_feature_demo.yaml", + f"{EXAMPLES_DIR}/restart/restart.yaml", + f"{EXAMPLES_DIR}/restart_delay/restart_delay.yaml", + f"{EXAMPLES_DIR}/simple_chain/simple_chain.yaml", + f"{EXAMPLES_DIR}/slurm/slurm_par_restart.yaml", + f"{EXAMPLES_DIR}/slurm/slurm_par.yaml" + ] + actual = gather_all_examples() + assert sorted(actual) == sorted(expected) + + +def test_write_example_dir(temp_output_dir: str): + """ + Test the `write_example` function with the src_path as a directory. + + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + """ + generator_dir = EXAMPLES_GENERATOR_DIR.format(temp_output_dir=temp_output_dir) + dir_to_copy = f"{EXAMPLES_DIR}/feature_demo/" + + write_example(dir_to_copy, generator_dir) + assert sorted(os.listdir(dir_to_copy)) == sorted(os.listdir(generator_dir)) + + +def test_write_example_file(temp_output_dir: str): + """ + Test the `write_example` function with the src_path as a file. + + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + """ + generator_dir = EXAMPLES_GENERATOR_DIR.format(temp_output_dir=temp_output_dir) + create_dir(generator_dir) + + dst_path = f"{generator_dir}/flux_par.yaml" + file_to_copy = f"{EXAMPLES_DIR}/flux/flux_par.yaml" + + write_example(file_to_copy, generator_dir) + assert os.path.exists(dst_path) + + +def test_list_examples(): + """Test the `list_examples` function to see if it gives us all of the examples that we want.""" + expected_headers = ["name", "description"] + expected_rows = [ + ["openfoam_wf_no_docker", "A parameter study that includes initializing, running,\n" \ + "post-processing, collecting, learning and vizualizing OpenFOAM runs\n" \ + "without using docker."], + ["optimization_basic", "Design Optimization Template\n" \ + "To use,\n" \ + "1. Specify the first three variables here (N_DIMS, TEST_FUNCTION, DEBUG)\n" \ + "2. Run the template_config file in current directory using `python template_config.py`\n" \ + "3. Merlin run as usual (merlin run optimization.yaml)\n" \ + "* MAX_ITER and the N_SAMPLES options use default values unless using DEBUG mode\n" \ + "* BOUNDS_X and UNCERTS_X are configured using the template_config.py scripts"], + ["feature_demo", "Run 10 hello worlds."], + ["flux_local", "Run a scan through Merlin/Maestro"], + ["flux_par", "A simple ensemble of parallel MPI jobs run by flux."], + ["flux_par_restart", "A simple ensemble of parallel MPI jobs run by flux."], + ["paper_flux", "Use flux to run single core MPI jobs and record timings."], + ["lsf_par", "A simple ensemble of parallel MPI jobs run by lsf (jsrun)."], + ["lsf_par_srun", "A simple ensemble of parallel MPI jobs run by lsf using the srun wrapper (srun)."], + ["restart", "A simple ensemble of with restarts."], + ["restart_delay", "A simple ensemble of with restart delay times."], + ["simple_chain", "test to see that chains are not run in parallel"], + ["slurm_par", "A simple ensemble of parallel MPI jobs run by slurm (srun)."], + ["slurm_par_restart", "A simple ensemble of parallel MPI jobs run by slurm (srun)."], + ["remote_feature_demo", "Run 10 hello worlds."], + ["hello", "a very simple merlin workflow"], + ["hello_samples", "a very simple merlin workflow, with samples"], + ["hpc_demo", "Demo running a workflow on HPC machines"], + ["openfoam_wf", "A parameter study that includes initializing, running,\n" \ + "post-processing, collecting, learning and visualizing OpenFOAM runs\n" \ + "using docker."], + ["openfoam_wf_singularity", "A parameter study that includes initializing, running,\n" \ + "post-processing, collecting, learning and visualizing OpenFOAM runs\n" \ + "using singularity."], + ["null_chain", "Run N_SAMPLES steps of TIME seconds each at CONC concurrency.\n" \ + "May be used to measure overhead in merlin.\n" \ + "Iterates thru a chain of workflows."], + ["null_spec", "run N_SAMPLES null steps at CONC concurrency for TIME seconds each. May be used to measure overhead in merlin."], + ["iterative_demo", "Demo of a workflow with self driven iteration/looping"], + ] + expected = "\n" + tabulate(expected_rows, expected_headers) + "\n" + actual = list_examples() + assert actual == expected + + +def test_setup_example_invalid_name(): + """ + Test the `setup_example` function with an invalid example name. + This should just return None. + """ + assert setup_example("invalid_example_name", None) is None + + +def test_setup_example_no_outdir(temp_output_dir: str): + """ + Test the `setup_example` function with an invalid example name. + This should create a directory with the example name (in this case hello) + and copy all of the example contents to this folder. + We'll create a directory specifically for this test and move into it so that + the `setup_example` function creates the hello/ subdirectory in a directory with + the name of this test (setup_no_outdir). + + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + """ + cwd = os.getcwd() + + # Create the temp path to store this setup and move into that directory + generator_dir = EXAMPLES_GENERATOR_DIR.format(temp_output_dir=temp_output_dir) + create_dir(generator_dir) + setup_example_dir = os.path.join(generator_dir, "setup_no_outdir") + create_dir(setup_example_dir) + os.chdir(setup_example_dir) + + # This should still work and return to us the name of the example + try: + assert setup_example("hello", None) == "hello" + except AssertionError as exc: + os.chdir(cwd) + raise AssertionError from exc + + # All files from this example should be written to a directory with the example name + full_output_path = os.path.join(setup_example_dir, "hello") + expected_files = [ + os.path.join(full_output_path, "hello_samples.yaml"), + os.path.join(full_output_path, "hello.yaml"), + os.path.join(full_output_path, "my_hello.yaml"), + os.path.join(full_output_path, "requirements.txt"), + os.path.join(full_output_path, "make_samples.py"), + ] + try: + for file in expected_files: + assert os.path.exists(file) + except AssertionError as exc: + os.chdir(cwd) + raise AssertionError from exc + + +def test_setup_example_outdir_exists(temp_output_dir: str): + """ + Test the `setup_example` function with an output directory that already exists. + This should just return None. + + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + """ + generator_dir = EXAMPLES_GENERATOR_DIR.format(temp_output_dir=temp_output_dir) + create_dir(generator_dir) + + assert setup_example("hello", generator_dir) is None + + +##################################### +# Tests for setting up each example # +##################################### + + +def run_setup_example(temp_output_dir: str, example_name: str, example_files: List[str], expected_return: str): + """ + Helper function to run tests for the `setup_example` function. + + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + :param example_name: The name of the example to setup + :param example_files: A list of filenames that should be copied by setup_example + :param expected_return: The expected return value from `setup_example` + """ + # Create the temp path to store this setup + generator_dir = EXAMPLES_GENERATOR_DIR.format(temp_output_dir=temp_output_dir) + create_dir(generator_dir) + setup_example_dir = os.path.join(generator_dir, f"setup_{example_name}") + + # Ensure that the example name is returned + actual = setup_example(example_name, setup_example_dir) + assert actual == expected_return + + # Ensure all of the files that should've been copied were copied + expected_files = [os.path.join(setup_example_dir, expected_file) for expected_file in example_files] + for file in expected_files: + assert os.path.exists(file) + + +def test_setup_example_feature_demo(temp_output_dir: str): + """ + Test the `setup_example` function for the feature_demo example. + + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + """ + example_name = "feature_demo" + example_files = [ + ".gitignore", + "feature_demo.yaml", + "requirements.txt", + "scripts/features.json", + "scripts/hello_world.py", + "scripts/pgen.py", + ] + + run_setup_example(temp_output_dir, example_name, example_files, example_name) + + +def test_setup_example_flux(temp_output_dir: str): + """ + Test the `setup_example` function for the flux example. + + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + """ + example_files = [ + "flux_local.yaml", + "flux_par_restart.yaml", + "flux_par.yaml", + "paper.yaml", + "requirements.txt", + "scripts/flux_info.py", + "scripts/hello_sleep.c", + "scripts/hello.c", + "scripts/make_samples.py", + "scripts/paper_workers.sbatch", + "scripts/test_workers.sbatch", + "scripts/workers.sbatch", + "scripts/workers.bsub", + ] + + run_setup_example(temp_output_dir, "flux_local", example_files, "flux") + + +def test_setup_example_lsf(temp_output_dir: str): + """ + Test the `setup_example` function for the lsf example. + + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + """ + + # TODO should there be a workers.bsub for this example? + example_files = [ + "lsf_par_srun.yaml", + "lsf_par.yaml", + "scripts/hello.c", + "scripts/make_samples.py", + ] + + run_setup_example(temp_output_dir, "lsf_par", example_files, "lsf") + + +def test_setup_example_slurm(temp_output_dir: str): + """ + Test the `setup_example` function for the slurm example. + + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + """ + example_files = [ + "slurm_par.yaml", + "slurm_par_restart.yaml", + "requirements.txt", + "scripts/hello.c", + "scripts/make_samples.py", + "scripts/test_workers.sbatch", + "scripts/workers.sbatch", + ] + + run_setup_example(temp_output_dir, "slurm_par", example_files, "slurm") + + +def test_setup_example_hello(temp_output_dir: str): + """ + Test the `setup_example` function for the hello example. + + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + """ + example_name = "hello" + example_files = [ + "hello_samples.yaml", + "hello.yaml", + "my_hello.yaml", + "requirements.txt", + "make_samples.py", + ] + + run_setup_example(temp_output_dir, example_name, example_files, example_name) + + +def test_setup_example_hpc(temp_output_dir: str): + """ + Test the `setup_example` function for the hpc_demo example. + + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + """ + example_name = "hpc_demo" + example_files = [ + "hpc_demo.yaml", + "cumulative_sample_processor.py", + "faker_sample.py", + "sample_collector.py", + "sample_processor.py", + "requirements.txt", + ] + + run_setup_example(temp_output_dir, example_name, example_files, example_name) + + +def test_setup_example_iterative(temp_output_dir: str): + """ + Test the `setup_example` function for the iterative_demo example. + + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + """ + example_name = "iterative_demo" + example_files = [ + "iterative_demo.yaml", + "cumulative_sample_processor.py", + "faker_sample.py", + "sample_collector.py", + "sample_processor.py", + "requirements.txt", + ] + + run_setup_example(temp_output_dir, example_name, example_files, example_name) + + +def test_setup_example_null(temp_output_dir: str): + """ + Test the `setup_example` function for the null_spec example. + + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + """ + example_name = "null_spec" + example_files = [ + "null_spec.yaml", + "null_chain.yaml", + ".gitignore", + "Makefile", + "requirements.txt", + "scripts/aggregate_chain_output.sh", + "scripts/aggregate_output.sh", + "scripts/check_completion.sh", + "scripts/kill_all.sh", + "scripts/launch_chain_job.py", + "scripts/launch_jobs.py", + "scripts/make_samples.py", + "scripts/read_output_chain.py", + "scripts/read_output.py", + "scripts/search.sh", + "scripts/submit_chain.sbatch", + "scripts/submit.sbatch", + ] + + run_setup_example(temp_output_dir, example_name, example_files, example_name) + + +def test_setup_example_openfoam(temp_output_dir: str): + """ + Test the `setup_example` function for the openfoam_wf example. + + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + """ + example_name = "openfoam_wf" + example_files = [ + "openfoam_wf.yaml", + "openfoam_wf_template.yaml", + "README.md", + "requirements.txt", + "scripts/make_samples.py", + "scripts/blockMesh_template.txt", + "scripts/cavity_setup.sh", + "scripts/combine_outputs.py", + "scripts/learn.py", + "scripts/mesh_param_script.py", + "scripts/run_openfoam", + ] + + run_setup_example(temp_output_dir, example_name, example_files, example_name) + + +def test_setup_example_openfoam_no_docker(temp_output_dir: str): + """ + Test the `setup_example` function for the openfoam_wf_no_docker example. + + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + """ + example_name = "openfoam_wf_no_docker" + example_files = [ + "openfoam_wf_no_docker.yaml", + "openfoam_wf_no_docker_template.yaml", + "requirements.txt", + "scripts/make_samples.py", + "scripts/blockMesh_template.txt", + "scripts/cavity_setup.sh", + "scripts/combine_outputs.py", + "scripts/learn.py", + "scripts/mesh_param_script.py", + "scripts/run_openfoam", + ] + + run_setup_example(temp_output_dir, example_name, example_files, example_name) + + +def test_setup_example_openfoam_singularity(temp_output_dir: str): + """ + Test the `setup_example` function for the openfoam_wf_singularity example. + + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + """ + example_name = "openfoam_wf_singularity" + example_files = [ + "openfoam_wf_singularity.yaml", + "requirements.txt", + "scripts/make_samples.py", + "scripts/blockMesh_template.txt", + "scripts/cavity_setup.sh", + "scripts/combine_outputs.py", + "scripts/learn.py", + "scripts/mesh_param_script.py", + "scripts/run_openfoam", + ] + + run_setup_example(temp_output_dir, example_name, example_files, example_name) + + +def test_setup_example_optimization(temp_output_dir: str): + """ + Test the `setup_example` function for the optimization example. + + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + """ + example_files = [ + "optimization_basic.yaml", + "requirements.txt", + "template_config.py", + "template_optimization.temp", + "scripts/collector.py", + "scripts/optimizer.py", + "scripts/test_functions.py", + "scripts/visualizer.py", + ] + + run_setup_example(temp_output_dir, "optimization_basic", example_files, "optimization") + + +def test_setup_example_remote_feature_demo(temp_output_dir: str): + """ + Test the `setup_example` function for the remote_feature_demo example. + + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + """ + example_name = "remote_feature_demo" + example_files = [ + ".gitignore", + "remote_feature_demo.yaml", + "requirements.txt", + "scripts/features.json", + "scripts/hello_world.py", + "scripts/pgen.py", + ] + + run_setup_example(temp_output_dir, example_name, example_files, example_name) + + +def test_setup_example_restart(temp_output_dir: str): + """ + Test the `setup_example` function for the restart example. + + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + """ + example_name = "restart" + example_files = [ + "restart.yaml", + "scripts/make_samples.py" + ] + + run_setup_example(temp_output_dir, example_name, example_files, example_name) + + +def test_setup_example_restart_delay(temp_output_dir: str): + """ + Test the `setup_example` function for the restart_delay example. + + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + """ + example_name = "restart_delay" + example_files = [ + "restart_delay.yaml", + "scripts/make_samples.py" + ] + + run_setup_example(temp_output_dir, example_name, example_files, example_name) + + +def test_setup_example_simple_chain(temp_output_dir: str): + """ + Test the `setup_example` function for the simple_chain example. + + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + """ + + # Create the temp path to store this setup + generator_dir = EXAMPLES_GENERATOR_DIR.format(temp_output_dir=temp_output_dir) + create_dir(generator_dir) + output_file = os.path.join(generator_dir, "simple_chain.yaml") + + # Ensure that the example name is returned + actual = setup_example("simple_chain", output_file) + assert actual == "simple_chain" + assert os.path.exists(output_file) From 681bd717a06abe33d658fd7093787e65c28935e2 Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Tue, 19 Dec 2023 16:47:28 -0800 Subject: [PATCH 26/44] run fix-style and update changelog --- CHANGELOG.md | 5 +- tests/conftest.py | 2 +- tests/constants.py | 3 +- tests/unit/test_examples_generator.py | 79 +++++++++++++++------------ tests/utils.py | 1 - 5 files changed, 51 insertions(+), 39 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3d0bea05d..01cc3b35e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,10 +11,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Coverage to the test suite. This includes adding tests for: - `merlin/common/` - `merlin/config/` + - `merlin/examples/` - `celeryadapter.py` - Context managers for the `conftest.py` file to ensure safe spin up and shutdown of fixtures - - RedisServerManager: context to help with starting/stopping a redis server for tests - - CeleryWorkersManager: context to help with starting/stopping workers for tests + - `RedisServerManager`: context to help with starting/stopping a redis server for tests + - `CeleryWorkersManager`: context to help with starting/stopping workers for tests - Ability to copy and print the `Config` object from `merlin/config/__init__.py` ### Fixed diff --git a/tests/conftest.py b/tests/conftest.py index 20749d4cd..c444a2168 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -42,7 +42,7 @@ from celery.canvas import Signature from merlin.config.configfile import CONFIG -from tests.constants import SERVER_PASS, CERT_FILES +from tests.constants import CERT_FILES, SERVER_PASS from tests.context_managers.celery_workers_manager import CeleryWorkersManager from tests.context_managers.server_manager import RedisServerManager from tests.utils import create_cert_files, create_pass_file diff --git a/tests/constants.py b/tests/constants.py index a2b354146..26cfe4c0a 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -3,8 +3,9 @@ """ SERVER_PASS = "merlin-test-server" + CERT_FILES = { "ssl_cert": "test-rabbit-client-cert.pem", "ssl_ca": "test-mysql-ca-cert.pem", "ssl_key": "test-rabbit-client-key.pem", -} \ No newline at end of file +} diff --git a/tests/unit/test_examples_generator.py b/tests/unit/test_examples_generator.py index 7d7ccc5bf..97948feaf 100644 --- a/tests/unit/test_examples_generator.py +++ b/tests/unit/test_examples_generator.py @@ -2,7 +2,6 @@ Tests for the `merlin/examples/generator.py` module. """ import os -import pathlib from typing import List from tabulate import tabulate @@ -13,7 +12,7 @@ gather_example_dirs, list_examples, setup_example, - write_example + write_example, ) from tests.utils import create_dir @@ -39,7 +38,7 @@ def test_gather_example_dirs(): "restart", "restart_delay", "simple_chain", - "slurm" + "slurm", ] expected = {} for wf_dir in example_workflows: @@ -76,7 +75,7 @@ def test_gather_all_examples(): f"{EXAMPLES_DIR}/restart_delay/restart_delay.yaml", f"{EXAMPLES_DIR}/simple_chain/simple_chain.yaml", f"{EXAMPLES_DIR}/slurm/slurm_par_restart.yaml", - f"{EXAMPLES_DIR}/slurm/slurm_par.yaml" + f"{EXAMPLES_DIR}/slurm/slurm_par.yaml", ] actual = gather_all_examples() assert sorted(actual) == sorted(expected) @@ -85,7 +84,7 @@ def test_gather_all_examples(): def test_write_example_dir(temp_output_dir: str): """ Test the `write_example` function with the src_path as a directory. - + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run """ generator_dir = EXAMPLES_GENERATOR_DIR.format(temp_output_dir=temp_output_dir) @@ -98,7 +97,7 @@ def test_write_example_dir(temp_output_dir: str): def test_write_example_file(temp_output_dir: str): """ Test the `write_example` function with the src_path as a file. - + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run """ generator_dir = EXAMPLES_GENERATOR_DIR.format(temp_output_dir=temp_output_dir) @@ -115,16 +114,22 @@ def test_list_examples(): """Test the `list_examples` function to see if it gives us all of the examples that we want.""" expected_headers = ["name", "description"] expected_rows = [ - ["openfoam_wf_no_docker", "A parameter study that includes initializing, running,\n" \ - "post-processing, collecting, learning and vizualizing OpenFOAM runs\n" \ - "without using docker."], - ["optimization_basic", "Design Optimization Template\n" \ - "To use,\n" \ - "1. Specify the first three variables here (N_DIMS, TEST_FUNCTION, DEBUG)\n" \ - "2. Run the template_config file in current directory using `python template_config.py`\n" \ - "3. Merlin run as usual (merlin run optimization.yaml)\n" \ - "* MAX_ITER and the N_SAMPLES options use default values unless using DEBUG mode\n" \ - "* BOUNDS_X and UNCERTS_X are configured using the template_config.py scripts"], + [ + "openfoam_wf_no_docker", + "A parameter study that includes initializing, running,\n" + "post-processing, collecting, learning and vizualizing OpenFOAM runs\n" + "without using docker.", + ], + [ + "optimization_basic", + "Design Optimization Template\n" + "To use,\n" + "1. Specify the first three variables here (N_DIMS, TEST_FUNCTION, DEBUG)\n" + "2. Run the template_config file in current directory using `python template_config.py`\n" + "3. Merlin run as usual (merlin run optimization.yaml)\n" + "* MAX_ITER and the N_SAMPLES options use default values unless using DEBUG mode\n" + "* BOUNDS_X and UNCERTS_X are configured using the template_config.py scripts", + ], ["feature_demo", "Run 10 hello worlds."], ["flux_local", "Run a scan through Merlin/Maestro"], ["flux_par", "A simple ensemble of parallel MPI jobs run by flux."], @@ -141,16 +146,28 @@ def test_list_examples(): ["hello", "a very simple merlin workflow"], ["hello_samples", "a very simple merlin workflow, with samples"], ["hpc_demo", "Demo running a workflow on HPC machines"], - ["openfoam_wf", "A parameter study that includes initializing, running,\n" \ - "post-processing, collecting, learning and visualizing OpenFOAM runs\n" \ - "using docker."], - ["openfoam_wf_singularity", "A parameter study that includes initializing, running,\n" \ - "post-processing, collecting, learning and visualizing OpenFOAM runs\n" \ - "using singularity."], - ["null_chain", "Run N_SAMPLES steps of TIME seconds each at CONC concurrency.\n" \ - "May be used to measure overhead in merlin.\n" \ - "Iterates thru a chain of workflows."], - ["null_spec", "run N_SAMPLES null steps at CONC concurrency for TIME seconds each. May be used to measure overhead in merlin."], + [ + "openfoam_wf", + "A parameter study that includes initializing, running,\n" + "post-processing, collecting, learning and visualizing OpenFOAM runs\n" + "using docker.", + ], + [ + "openfoam_wf_singularity", + "A parameter study that includes initializing, running,\n" + "post-processing, collecting, learning and visualizing OpenFOAM runs\n" + "using singularity.", + ], + [ + "null_chain", + "Run N_SAMPLES steps of TIME seconds each at CONC concurrency.\n" + "May be used to measure overhead in merlin.\n" + "Iterates thru a chain of workflows.", + ], + [ + "null_spec", + "run N_SAMPLES null steps at CONC concurrency for TIME seconds each. May be used to measure overhead in merlin.", + ], ["iterative_demo", "Demo of a workflow with self driven iteration/looping"], ] expected = "\n" + tabulate(expected_rows, expected_headers) + "\n" @@ -534,10 +551,7 @@ def test_setup_example_restart(temp_output_dir: str): :param temp_output_dir: The path to the temporary output directory we'll be using for this test run """ example_name = "restart" - example_files = [ - "restart.yaml", - "scripts/make_samples.py" - ] + example_files = ["restart.yaml", "scripts/make_samples.py"] run_setup_example(temp_output_dir, example_name, example_files, example_name) @@ -549,10 +563,7 @@ def test_setup_example_restart_delay(temp_output_dir: str): :param temp_output_dir: The path to the temporary output directory we'll be using for this test run """ example_name = "restart_delay" - example_files = [ - "restart_delay.yaml", - "scripts/make_samples.py" - ] + example_files = ["restart_delay.yaml", "scripts/make_samples.py"] run_setup_example(temp_output_dir, example_name, example_files, example_name) diff --git a/tests/utils.py b/tests/utils.py index 3a75622b8..d883b83cd 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -41,4 +41,3 @@ def create_dir(dirpath: str): """ if not os.path.exists(dirpath): os.mkdir(dirpath) - From 4b8fab51f16898a0340aea74749bf9c48536fdb7 Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Wed, 14 Feb 2024 13:18:17 -0800 Subject: [PATCH 27/44] add a 'pip freeze' call in github workflow to view reqs versions --- .github/workflows/push-pr_workflow.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/push-pr_workflow.yml b/.github/workflows/push-pr_workflow.yml index eecbf3eeb..4b5de2373 100644 --- a/.github/workflows/push-pr_workflow.yml +++ b/.github/workflows/push-pr_workflow.yml @@ -95,6 +95,7 @@ jobs: python3 -m pip install --upgrade pip if [ -f requirements.txt ]; then pip install -r requirements.txt; fi pip3 install -r requirements/dev.txt + pip freeze - name: Install singularity run: | From 3099d4c0e877c0bdbd45e687ba4de2d27a49f38e Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Thu, 25 Apr 2024 12:42:09 -0700 Subject: [PATCH 28/44] re-delete the old config test files --- tests/unit/config/old_test_configfile.py | 97 ------------------- tests/unit/config/old_test_results_backend.py | 67 ------------- tests/unit/config/utils.py | 24 ----- 3 files changed, 188 deletions(-) delete mode 100644 tests/unit/config/old_test_configfile.py delete mode 100644 tests/unit/config/old_test_results_backend.py delete mode 100644 tests/unit/config/utils.py diff --git a/tests/unit/config/old_test_configfile.py b/tests/unit/config/old_test_configfile.py deleted file mode 100644 index 39139ec11..000000000 --- a/tests/unit/config/old_test_configfile.py +++ /dev/null @@ -1,97 +0,0 @@ -"""Tests for the configfile module.""" - -import os -import shutil -import tempfile -import unittest -from getpass import getuser - -from merlin.config import configfile - -from .utils import mkfile - - -CONFIG_FILE_CONTENTS = """ -celery: - certs: path/to/celery/config/files - -broker: - name: rabbitmq - username: testuser - password: rabbit.password # The filename that contains the password. - server: jackalope.llnl.gov - -results_backend: - name: mysql - dbname: testuser - username: mlsi - password: mysql.password # The filename that contains the password. - server: rabbit.llnl.gov - -""" - - -class TestFindConfigFile(unittest.TestCase): - def setUp(self): - self.tmpdir = tempfile.mkdtemp() - self.appfile = mkfile(self.tmpdir, "app.yaml") - - def tearDown(self): - shutil.rmtree(self.tmpdir, ignore_errors=True) - - def test_tempdir(self): - self.assertTrue(os.path.isdir(self.tmpdir)) - - def test_find_config_file(self): - """ - Given the path to a vaild config file, find and return the full - filepath. - """ - path = configfile.find_config_file(path=self.tmpdir) - expected = os.path.join(self.tmpdir, self.appfile) - self.assertEqual(path, expected) - - def test_find_config_file_error(self): - """Given an invalid path, return None.""" - invalid = "invalid/path" - expected = None - - path = configfile.find_config_file(path=invalid) - self.assertEqual(path, expected) - - -class TestConfigFile(unittest.TestCase): - """Unit tests for loading the config file.""" - - def setUp(self): - self.tmpdir = tempfile.mkdtemp() - self.configfile = mkfile(self.tmpdir, "app.yaml", content=CONFIG_FILE_CONTENTS) - - def tearDown(self): - shutil.rmtree(self.tmpdir, ignore_errors=True) - - def test_get_config(self): - """ - Given the directory path to a valid merlin config file, then - `get_config` should find the merlin config file and load the YAML - contents to a dictionary. - """ - expected = { - "broker": { - "name": "rabbitmq", - "password": "rabbit.password", - "server": "jackalope.llnl.gov", - "username": "testuser", - "vhost": getuser(), - }, - "celery": {"certs": "path/to/celery/config/files"}, - "results_backend": { - "dbname": "testuser", - "name": "mysql", - "password": "mysql.password", - "server": "rabbit.llnl.gov", - "username": "mlsi", - }, - } - - self.assertDictEqual(configfile.get_config(self.tmpdir), expected) diff --git a/tests/unit/config/old_test_results_backend.py b/tests/unit/config/old_test_results_backend.py deleted file mode 100644 index 638f13eb8..000000000 --- a/tests/unit/config/old_test_results_backend.py +++ /dev/null @@ -1,67 +0,0 @@ -"""Tests for the results_backend module.""" - -import os -import shutil -import tempfile -import unittest - -from merlin.config import results_backend - -from .utils import mkfile - - -class TestResultsBackend(unittest.TestCase): - def setUp(self): - self.tmpdir = tempfile.mkdtemp() - - # Create test files. - self.tmpfile1 = mkfile(self.tmpdir, "mysql_test1.txt") - self.tmpfile2 = mkfile(self.tmpdir, "mysql_test2.txt") - - def tearDown(self): - shutil.rmtree(self.tmpdir, ignore_errors=True) - - def test_mysql_config(self): - """ - Given the path to a directory containing the MySQL cert files and a - dictionary of files to look for, then find and return the full path to - all the certs. - """ - certs = {"test1": "mysql_test1.txt", "test2": "mysql_test2.txt"} - - # This will just be the above dictionary with the full file paths. - expected = { - "test1": os.path.join(self.tmpdir, certs["test1"]), - "test2": os.path.join(self.tmpdir, certs["test2"]), - } - results = results_backend.get_mysql_config(self.tmpdir, certs) - self.assertDictEqual(results, expected) - - def test_mysql_config_no_files(self): - """ - Given the path to a directory containing the MySQL cert files and - an empty dictionary, then `get_mysql_config` should return an empty - dictionary. - """ - files = {} - result = results_backend.get_mysql_config(self.tmpdir, files) - self.assertEqual(result, {}) - - -class TestConfingMysqlErrorPath(unittest.TestCase): - """ - Test `get_mysql_config` against cases were the given path does not exist. - """ - - def test_mysql_config_false(self): - """ - Given a path that does not exist, then `get_mysql_config` should return - False. - """ - path = "invalid/path" - - # We don't need the dictionary populated for this test. The function - # should return False before trying to process the dictionary. - certs = {} - result = results_backend.get_mysql_config(path, certs) - self.assertFalse(result) diff --git a/tests/unit/config/utils.py b/tests/unit/config/utils.py deleted file mode 100644 index 1765e8478..000000000 --- a/tests/unit/config/utils.py +++ /dev/null @@ -1,24 +0,0 @@ -""" -Utils module for common test functionality. -""" - -import os - - -def mkfile(tmpdir, filename, content=""): - """ - A simple function for creating a file and returning the path. This is to - abstract out file creation logic in the tests. - - :param tmpdir: (str) The path to the temp directory. - :param filename: (str) The name of the file. - :param contents: (str) Optional contents to write to the file. Defaults to - an empty string. - :returns: (str) The appended path of the given tempdir and filename. - """ - filepath = os.path.join(tmpdir, filename) - - with open(filepath, "w") as f: - f.write(content) - - return filepath From 37839f63db701001fb646ce2aa0e0928cfb93681 Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Thu, 25 Apr 2024 14:04:43 -0700 Subject: [PATCH 29/44] fix tests/bugs introduced by merging in develop --- merlin/config/utils.py | 10 +++- merlin/examples/generator.py | 1 + ...m_wf.yaml => openfoam_wf_singularity.yaml} | 0 tests/unit/config/test_utils.py | 50 ++++++++++++++++--- tests/unit/test_examples_generator.py | 6 ++- 5 files changed, 55 insertions(+), 12 deletions(-) rename merlin/examples/workflows/openfoam_wf_singularity/{openfoam_wf.yaml => openfoam_wf_singularity.yaml} (100%) diff --git a/merlin/config/utils.py b/merlin/config/utils.py index bb0dcd58b..6bb3186df 100644 --- a/merlin/config/utils.py +++ b/merlin/config/utils.py @@ -77,8 +77,14 @@ def get_priority(priority: Priority) -> int: :param priority: The priority value that we want :returns: The priority value as an integer """ - if priority not in Priority: - raise ValueError(f"Invalid priority: {priority}") + priority_err_msg = f"Invalid priority: {priority}" + try: + # In python 3.12+ if something is not in the enum it will just return False + if priority not in Priority: + raise ValueError(priority_err_msg) + # In python 3.11 and below, a TypeError is raised when looking for something in an enum that is not there + except TypeError: + raise ValueError(priority_err_msg) priority_map = determine_priority_map(CONFIG.broker.name.lower()) return priority_map.get(priority, priority_map[Priority.MID]) # Default to MID priority for unknown priorities diff --git a/merlin/examples/generator.py b/merlin/examples/generator.py index cb214fed4..120d2defd 100644 --- a/merlin/examples/generator.py +++ b/merlin/examples/generator.py @@ -146,4 +146,5 @@ def setup_example(name, outdir): LOG.info(f"Copying example '{name}' to {outdir}") write_example(src_path, outdir) + print(f'example: {example}') return example diff --git a/merlin/examples/workflows/openfoam_wf_singularity/openfoam_wf.yaml b/merlin/examples/workflows/openfoam_wf_singularity/openfoam_wf_singularity.yaml similarity index 100% rename from merlin/examples/workflows/openfoam_wf_singularity/openfoam_wf.yaml rename to merlin/examples/workflows/openfoam_wf_singularity/openfoam_wf_singularity.yaml diff --git a/tests/unit/config/test_utils.py b/tests/unit/config/test_utils.py index a02bc1ff1..9d64c10c7 100644 --- a/tests/unit/config/test_utils.py +++ b/tests/unit/config/test_utils.py @@ -5,7 +5,7 @@ import pytest from merlin.config.configfile import CONFIG -from merlin.config.utils import Priority, get_priority, is_rabbit_broker, is_redis_broker +from merlin.config.utils import Priority, determine_priority_map, get_priority, is_rabbit_broker, is_redis_broker def test_is_rabbit_broker(): @@ -37,25 +37,27 @@ def test_is_redis_broker_invalid(): def test_get_priority_rabbit_broker(rabbit_broker_config: "fixture"): # noqa: F821 """ Test the `get_priority` function with rabbit as the broker. - Low priority for rabbit is 1 and high is 10. + Low priority for rabbit is 1 and high is 9. :param rabbit_broker_config: A fixture to set the CONFIG object to a test configuration that we'll use here """ assert get_priority(Priority.LOW) == 1 assert get_priority(Priority.MID) == 5 - assert get_priority(Priority.HIGH) == 10 + assert get_priority(Priority.HIGH) == 9 + assert get_priority(Priority.RETRY) == 10 def test_get_priority_redis_broker(redis_broker_config: "fixture"): # noqa: F821 """ Test the `get_priority` function with redis as the broker. - Low priority for redis is 10 and high is 1. + Low priority for redis is 10 and high is 2. :param redis_broker_config: A fixture to set the CONFIG object to a test configuration that we'll use here """ assert get_priority(Priority.LOW) == 10 assert get_priority(Priority.MID) == 5 - assert get_priority(Priority.HIGH) == 1 + assert get_priority(Priority.HIGH) == 2 + assert get_priority(Priority.RETRY) == 1 def test_get_priority_invalid_broker(redis_broker_config: "fixture"): # noqa: F821 @@ -68,7 +70,7 @@ def test_get_priority_invalid_broker(redis_broker_config: "fixture"): # noqa: F CONFIG.broker.name = "invalid" with pytest.raises(ValueError) as excinfo: get_priority(Priority.LOW) - assert "Function get_priority has reached unknown state! Maybe unsupported broker invalid?" in str(excinfo.value) + assert "Unsupported broker name: invalid" in str(excinfo.value) def test_get_priority_invalid_priority(redis_broker_config: "fixture"): # noqa: F821 @@ -78,6 +80,38 @@ def test_get_priority_invalid_priority(redis_broker_config: "fixture"): # noqa: :param redis_broker_config: A fixture to set the CONFIG object to a test configuration that we'll use here """ - with pytest.raises(TypeError) as excinfo: + with pytest.raises(ValueError) as excinfo: get_priority("invalid_priority") - assert "Unrecognized priority 'invalid_priority'!" in str(excinfo.value) + assert "Invalid priority: invalid_priority" in str(excinfo.value) + + +def test_determine_priority_map_rabbit(): + """ + Test the `determine_priority_map` function with rabbit as the broker. + This should return the following map: + {Priority.LOW: 1, Priority.MID: 5, Priority.HIGH: 9, Priority.RETRY: 10} + """ + expected = {Priority.LOW: 1, Priority.MID: 5, Priority.HIGH: 9, Priority.RETRY: 10} + actual = determine_priority_map("rabbitmq") + assert actual == expected + + +def test_determine_priority_map_redis(): + """ + Test the `determine_priority_map` function with redis as the broker. + This should return the following map: + {Priority.LOW: 10, Priority.MID: 5, Priority.HIGH: 2, Priority.RETRY: 1} + """ + expected = {Priority.LOW: 10, Priority.MID: 5, Priority.HIGH: 2, Priority.RETRY: 1} + actual = determine_priority_map("redis") + assert actual == expected + + +def test_determine_priority_map_invalid(): + """ + Test the `determine_priority_map` function with an invalid broker. + This should raise a ValueError. + """ + with pytest.raises(ValueError) as excinfo: + determine_priority_map("invalid_broker") + assert "Unsupported broker name: invalid_broker" in str(excinfo.value) diff --git a/tests/unit/test_examples_generator.py b/tests/unit/test_examples_generator.py index 97948feaf..5a05e3599 100644 --- a/tests/unit/test_examples_generator.py +++ b/tests/unit/test_examples_generator.py @@ -64,11 +64,12 @@ def test_gather_all_examples(): f"{EXAMPLES_DIR}/lsf/lsf_par.yaml", f"{EXAMPLES_DIR}/null_spec/null_chain.yaml", f"{EXAMPLES_DIR}/null_spec/null_spec.yaml", - f"{EXAMPLES_DIR}/openfoam_wf/openfoam_wf_template.yaml", + f"{EXAMPLES_DIR}/openfoam_wf/openfoam_wf_docker_template.yaml", f"{EXAMPLES_DIR}/openfoam_wf/openfoam_wf.yaml", f"{EXAMPLES_DIR}/openfoam_wf_no_docker/openfoam_wf_no_docker_template.yaml", f"{EXAMPLES_DIR}/openfoam_wf_no_docker/openfoam_wf_no_docker.yaml", f"{EXAMPLES_DIR}/openfoam_wf_singularity/openfoam_wf_singularity.yaml", + f"{EXAMPLES_DIR}/openfoam_wf_singularity/openfoam_wf_singularity_template.yaml", f"{EXAMPLES_DIR}/optimization/optimization_basic.yaml", f"{EXAMPLES_DIR}/remote_feature_demo/remote_feature_demo.yaml", f"{EXAMPLES_DIR}/restart/restart.yaml", @@ -445,7 +446,7 @@ def test_setup_example_openfoam(temp_output_dir: str): example_name = "openfoam_wf" example_files = [ "openfoam_wf.yaml", - "openfoam_wf_template.yaml", + "openfoam_wf_docker_template.yaml", "README.md", "requirements.txt", "scripts/make_samples.py", @@ -492,6 +493,7 @@ def test_setup_example_openfoam_singularity(temp_output_dir: str): example_name = "openfoam_wf_singularity" example_files = [ "openfoam_wf_singularity.yaml", + "openfoam_wf_singularity_template.yaml", "requirements.txt", "scripts/make_samples.py", "scripts/blockMesh_template.txt", From b8185cc0fbabf730dc77adf44e1a78701fb52953 Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Thu, 25 Apr 2024 16:36:46 -0700 Subject: [PATCH 30/44] add a unit test file for the dumper module --- tests/unit/common/test_dumper.py | 156 +++++++++++++++++++++++++++++++ 1 file changed, 156 insertions(+) create mode 100644 tests/unit/common/test_dumper.py diff --git a/tests/unit/common/test_dumper.py b/tests/unit/common/test_dumper.py new file mode 100644 index 000000000..7c437fde9 --- /dev/null +++ b/tests/unit/common/test_dumper.py @@ -0,0 +1,156 @@ +""" +Tests for the `dumper.py` file. +""" +import csv +import json +import os +import pytest + +from datetime import datetime +from time import sleep + +from merlin.common.dumper import dump_handler + +NUM_ROWS = 5 +CSV_INFO_TO_DUMP = {"row_num": [i for i in range(1, NUM_ROWS+1)], "other_info": [f"test_info_{i}" for i in range(1, NUM_ROWS+1)]} +JSON_INFO_TO_DUMP = {str(i): {f"other_info_{i}": f"test_info_{i}"} for i in range(1, NUM_ROWS+1)} +DUMP_HANDLER_DIR = "{temp_output_dir}/dump_handler" + +def test_dump_handler_invalid_dump_file(): + """ + This is really testing the initialization of the Dumper class with an invalid file type. + This should raise a ValueError. + """ + with pytest.raises(ValueError) as excinfo: + dump_handler("bad_file.txt", CSV_INFO_TO_DUMP) + assert "Invalid file type for bad_file.txt. Supported file types are: ['csv', 'json']" in str(excinfo.value) + +def get_output_file(temp_dir: str, file_name: str): + """ + Helper function to get a full path to the temporary output file. + + :param temp_dir: The path to the temporary output directory that pytest gives us + :param file_name: The name of the file + """ + dump_dir = DUMP_HANDLER_DIR.format(temp_output_dir=temp_dir) + if not os.path.exists(dump_dir): + os.mkdir(dump_dir) + dump_file = f"{dump_dir}/{file_name}" + return dump_file + +def run_csv_dump_test(dump_file: str, fmode: str): + """ + Run the test for csv dump. + + :param dump_file: The file that the dump was written to + :param fmode: The type of write that we're testing ("w" for write, "a" for append) + """ + + # Check that the file exists and that read in the contents of the file + assert os.path.exists(dump_file) + with open(dump_file, "r") as df: + reader = csv.reader(df) + written_data = list(reader) + + expected_rows = NUM_ROWS*2 if fmode == "a" else NUM_ROWS + assert len(written_data) == expected_rows+1 # Adding one because of the header row + for i, row in enumerate(written_data): + assert len(row) == 2 # Check number of columns + if i == 0: # Checking the header row + assert row[0] == "row_num" + assert row[1] == "other_info" + else: # Checking the data rows + assert row[0] == str(CSV_INFO_TO_DUMP["row_num"][(i%NUM_ROWS)-1]) + assert row[1] == str(CSV_INFO_TO_DUMP["other_info"][(i%NUM_ROWS)-1]) + +def test_dump_handler_csv_write(temp_output_dir: str): + """ + This is really testing the write method of the Dumper class. + This should create a csv file and write to it. + + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + """ + + # Create the path to the file we'll write to + dump_file = get_output_file(temp_output_dir, "csv_write.csv") + + # Run the actual call to dump to the file + dump_handler(dump_file, CSV_INFO_TO_DUMP) + + # Assert that everything ran properly + run_csv_dump_test(dump_file, "w") + +def test_dump_handler_csv_append(temp_output_dir: str): + """ + This is really testing the write method of the Dumper class with the file write mode set to append. + We'll write to a csv file first and then run again to make sure we can append to it properly. + + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + """ + + # Create the path to the file we'll write to + dump_file = get_output_file(temp_output_dir, "csv_append.csv") + + # Run the first call to create the csv file + dump_handler(dump_file, CSV_INFO_TO_DUMP) + + # Run the second call to append to the csv file + dump_handler(dump_file, CSV_INFO_TO_DUMP) + + # Assert that everything ran properly + run_csv_dump_test(dump_file, "a") + +def test_dump_handler_json_write(temp_output_dir: str): + """ + This is really testing the write method of the Dumper class. + This should create a json file and write to it. + + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + """ + + # Create the path to the file we'll write to + dump_file = get_output_file(temp_output_dir, "json_write.json") + + # Run the actual call to dump to the file + dump_handler(dump_file, JSON_INFO_TO_DUMP) + + # Check that the file exists and that the contents are correct + assert os.path.exists(dump_file) + with open(dump_file, "r") as df: + contents = json.load(df) + assert contents == JSON_INFO_TO_DUMP + +def test_dump_handler_json_append(temp_output_dir: str): + """ + This is really testing the write method of the Dumper class with the file write mode set to append. + We'll write to a json file first and then run again to make sure we can append to it properly. + + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + """ + + # Create the path to the file we'll write to + dump_file = get_output_file(temp_output_dir, "json_append.json") + + # Run the first call to create the file + timestamp_1 = str(datetime.now()) + first_dump = {timestamp_1: JSON_INFO_TO_DUMP} + dump_handler(dump_file, first_dump) + + # Sleep so we don't accidentally get the same timestamp + sleep(.5) + + # Run the second call to append to the file + timestamp_2 = str(datetime.now()) + second_dump = {timestamp_2: JSON_INFO_TO_DUMP} + dump_handler(dump_file, second_dump) + + # Check that the file exists and that the contents are correct + assert os.path.exists(dump_file) + with open(dump_file, "r") as df: + contents = json.load(df) + keys = contents.keys() + assert len(keys) == 2 + assert timestamp_1 in keys + assert timestamp_2 in keys + assert contents[timestamp_1] == JSON_INFO_TO_DUMP + assert contents[timestamp_2] == JSON_INFO_TO_DUMP \ No newline at end of file From e48fe32514ab772466ad5a79ec3c288463eadead Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Tue, 7 May 2024 14:58:41 -0700 Subject: [PATCH 31/44] begin work on server tests and modular fixtures --- merlin/server/server_util.py | 39 +++- tests/fixtures/server.py | 84 ++++++++ tests/unit/server/__init__.py | 0 tests/unit/server/test_server_util.py | 295 ++++++++++++++++++++++++++ 4 files changed, 411 insertions(+), 7 deletions(-) create mode 100644 tests/fixtures/server.py create mode 100644 tests/unit/server/__init__.py create mode 100644 tests/unit/server/test_server_util.py diff --git a/merlin/server/server_util.py b/merlin/server/server_util.py index c10e0e1d9..bdfd3652e 100644 --- a/merlin/server/server_util.py +++ b/merlin/server/server_util.py @@ -60,7 +60,7 @@ def valid_ipv4(ip: str) -> bool: # pylint: disable=C0103 return False for i in arr: - if int(i) < 0 and int(i) > 255: + if int(i) < 0 or int(i) > 255: return False return True @@ -121,6 +121,15 @@ def __init__(self, data: dict) -> None: self.pass_file = data["pass_file"] if "pass_file" in data else self.PASSWORD_FILE self.user_file = data["user_file"] if "user_file" in data else self.USERS_FILE + def __eq__(self, other: "ContainerFormatConfig"): + """ + Equality magic method used for testing this class + + :param other: Another ContainerFormatConfig object to check if they're the same + """ + variables = ("format", "image_type", "image", "url", "config", "config_dir", "pfile", "pass_file", "user_file") + return all(getattr(self, attr) == getattr(other, attr) for attr in variables) + def get_format(self) -> str: """Getter method to get the container format""" return self.format @@ -208,6 +217,15 @@ def __init__(self, data: dict) -> None: self.stop_command = data["stop_command"] if "stop_command" in data else self.STOP_COMMAND self.pull_command = data["pull_command"] if "pull_command" in data else self.PULL_COMMAND + def __eq__(self, other: "ContainerFormatConfig"): + """ + Equality magic method used for testing this class + + :param other: Another ContainerFormatConfig object to check if they're the same + """ + variables = ("command", "run_command", "stop_command", "pull_command") + return all(getattr(self, attr) == getattr(other, attr) for attr in variables) + def get_command(self) -> str: """Getter method to get the container command""" return self.command @@ -242,6 +260,15 @@ def __init__(self, data: dict) -> None: self.status = data["status"] if "status" in data else self.STATUS_COMMAND self.kill = data["kill"] if "kill" in data else self.KILL_COMMAND + def __eq__(self, other: "ProcessConfig"): + """ + Equality magic method used for testing this class + + :param other: Another ProcessConfig object to check if they're the same + """ + variables = ("status", "kill") + return all(getattr(self, attr) == getattr(other, attr) for attr in variables) + def get_status_command(self) -> str: """Getter method to get the status command""" return self.status @@ -264,12 +291,10 @@ class ServerConfig: # pylint: disable=R0903 container_format: ContainerFormatConfig = None def __init__(self, data: dict) -> None: - if "container" in data: - self.container = ContainerConfig(data["container"]) - if "process" in data: - self.process = ProcessConfig(data["process"]) - if self.container.get_format() in data: - self.container_format = ContainerFormatConfig(data[self.container.get_format()]) + self.container = ContainerConfig(data["container"]) if "container" in data else None + self.process = ProcessConfig(data["process"]) if "process" in data else None + container_format_data = data.get(self.container.get_format() if self.container else None) + self.container_format = ContainerFormatConfig(container_format_data) if container_format_data else None class RedisConfig: diff --git a/tests/fixtures/server.py b/tests/fixtures/server.py new file mode 100644 index 000000000..35efdcd65 --- /dev/null +++ b/tests/fixtures/server.py @@ -0,0 +1,84 @@ +""" +Fixtures specifically for help testing the modules in the server/ directory. +""" +import pytest +import shutil +from typing import Dict + +@pytest.fixture(scope="class") +def server_container_config_data(temp_output_dir: str): + """ + Fixture to provide sample data for ContainerConfig tests + + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + """ + return { + "format": "docker", + "image_type": "postgres", + "image": "postgres:latest", + "url": "postgres://localhost", + "config": "postgres.conf", + "config_dir": "/path/to/config", + "pfile": "merlin_server_postgres.pf", + "pass_file": f"{temp_output_dir}/postgres.pass", + "user_file": "postgres.users", + } + +@pytest.fixture(scope="class") +def server_container_format_config_data(): + """ + Fixture to provide sample data for ContainerFormatConfig tests + """ + return { + "command": "docker", + "run_command": "{command} run --name {name} -d {image}", + "stop_command": "{command} stop {name}", + "pull_command": "{command} pull {url}", + } + +@pytest.fixture(scope="class") +def server_process_config_data(): + """ + Fixture to provide sample data for ProcessConfig tests + """ + return { + "status": "status {pid}", + "kill": "terminate {pid}", + } + +@pytest.fixture(scope="class") +def server_server_config( + server_container_config_data: Dict[str, str], + server_process_config_data: Dict[str, str], + server_container_format_config_data: Dict[str, str], +): + """ + Fixture to provide sample data for ServerConfig tests + + :param server_container_config_data: A pytest fixture of test data to pass to the ContainerConfig class + :param server_process_config_data: A pytest fixture of test data to pass to the ProcessConfig class + :param server_container_format_config_data: A pytest fixture of test data to pass to the ContainerFormatConfig class + """ + return { + "container": server_container_config_data, + "process": server_process_config_data, + "docker": server_container_format_config_data, + } + + +@pytest.fixture(scope="class") +def server_redis_conf_file(temp_output_dir: str): + """ + Fixture to copy the redis.conf file from the merlin/server/ directory to the + temporary output directory and provide the path to the copied file + + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + """ + # TODO + # - will probably have to do more than just copying over the conf file + # - likely want to create our own test conf file with the settings that + # can be modified by RedisConf instead + path_to_redis_conf = f"{os.path.dirname(os.path.abspath(__file__))}/../../merlin/server/redis.conf" + path_to_copied_redis = f"{temp_output_dir}/redis.conf" + shutil.copy(path_to_redis_conf, path_to_copied_redis) + return path_to_copied_redis \ No newline at end of file diff --git a/tests/unit/server/__init__.py b/tests/unit/server/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/server/test_server_util.py b/tests/unit/server/test_server_util.py new file mode 100644 index 000000000..384e1ea37 --- /dev/null +++ b/tests/unit/server/test_server_util.py @@ -0,0 +1,295 @@ +""" +Tests for the `server_util.py` module. +""" +import os +import pytest +from typing import Callable, Dict, Union + +from merlin.server.server_util import ( + AppYaml, + ContainerConfig, + ContainerFormatConfig, + ProcessConfig, + RedisConfig, + RedisUsers, + ServerConfig, + valid_ipv4, + valid_port +) + +@pytest.mark.parametrize("valid_ip", [ + "0.0.0.0", + "127.0.0.1", + "14.105.200.58", + "255.255.255.255", +]) +def test_valid_ipv4_valid_ip(valid_ip: str): + """ + Test the `valid_ipv4` function with valid IPs. + This should return True. + + :param valid_ip: A valid port to test. + These are pulled from the parametrized list defined above this test. + """ + assert valid_ipv4(valid_ip) + +@pytest.mark.parametrize("invalid_ip", [ + "256.0.0.1", + "-1.0.0.1", + None, + "127.0.01", +]) +def test_valid_ipv4_invalid_ip(invalid_ip: Union[str, None]): + """ + Test the `valid_ipv4` function with invalid IPs. + An IP is valid if every integer separated by the '.' delimiter are between 0 and 255. + This should return False for both IPs tested here. + + :param invalid_ip: An invalid port to test. + These are pulled from the parametrized list defined above this test. + """ + assert not valid_ipv4(invalid_ip) + +@pytest.mark.parametrize("valid_input", [ + 1, + 433, + 65535, +]) +def test_valid_port_valid_input(valid_input: int): + """ + Test the `valid_port` function with valid port numbers. + Valid ports are ports between 1 and 65535. + This should return True. + + :param valid_input: A valid input value to test. + These are pulled from the parametrized list defined above this test. + """ + assert valid_port(valid_input) + +@pytest.mark.parametrize("invalid_input", [ + -1, + 0, + 65536, +]) +def test_valid_port_invalid_input(invalid_input: int): + """ + Test the `valid_port` function with invalid inputs. + Valid ports are ports between 1 and 65535. + This should return False for each invalid input tested. + + :param invalid_input: An invalid input value to test. + These are pulled from the parametrized list defined above this test. + """ + assert not valid_port(invalid_input) + + +class TestContainerConfig: + """Tests for the ContainerConfig class.""" + + def test_init_with_complete_data(self, server_container_config_data: Dict[str, str]): + """ + Tests that __init__ populates attributes correctly with complete data + + :param server_container_config_data: A pytest fixture of test data to pass to the ContainerConfig class + """ + config = ContainerConfig(server_container_config_data) + assert config.format == server_container_config_data["format"] + assert config.image_type == server_container_config_data["image_type"] + assert config.image == server_container_config_data["image"] + assert config.url == server_container_config_data["url"] + assert config.config == server_container_config_data["config"] + assert config.config_dir == server_container_config_data["config_dir"] + assert config.pfile == server_container_config_data["pfile"] + assert config.pass_file == server_container_config_data["pass_file"] + assert config.user_file == server_container_config_data["user_file"] + + def test_init_with_missing_data(self): + """ + Tests that __init__ uses defaults for missing data + """ + incomplete_data = {"format": "docker"} + config = ContainerConfig(incomplete_data) + assert config.format == incomplete_data["format"] + assert config.image_type == ContainerConfig.IMAGE_TYPE + assert config.image == ContainerConfig.IMAGE_NAME + assert config.url == ContainerConfig.REDIS_URL + assert config.config == ContainerConfig.CONFIG_FILE + assert config.config_dir == ContainerConfig.CONFIG_DIR + assert config.pfile == ContainerConfig.PROCESS_FILE + assert config.pass_file == ContainerConfig.PASSWORD_FILE + assert config.user_file == ContainerConfig.USERS_FILE + + @pytest.mark.parametrize("attr_name", [ + "image", + "config", + "pfile", + "pass_file", + "user_file", + ]) + def test_get_path_methods(self, server_container_config_data: Dict[str, str], attr_name: str): + """ + Tests that get_*_path methods construct the correct path + + :param server_container_config_data: A pytest fixture of test data to pass to the ContainerConfig class + :param attr_name: Name of the attribute to be tested. These are pulled from the parametrized list defined above this test. + """ + config = ContainerConfig(server_container_config_data) + get_path_method = getattr(config, f"get_{attr_name}_path") # Dynamically get the method based on attr_name + expected_path = os.path.join(server_container_config_data["config_dir"], server_container_config_data[attr_name]) + assert get_path_method() == expected_path + + @pytest.mark.parametrize("getter_name, expected_attr", [ + ("get_format", "format"), + ("get_image_type", "image_type"), + ("get_image_name", "image"), + ("get_image_url", "url"), + ("get_config_name", "config"), + ("get_config_dir", "config_dir"), + ("get_pfile_name", "pfile"), + ("get_pass_file_name", "pass_file"), + ("get_user_file_name", "user_file"), + ]) + def test_getter_methods(self, server_container_config_data: Dict[str, str], getter_name: str, expected_attr: str): + """ + Tests that all getter methods return the correct attribute values + + :param server_container_config_data: A pytest fixture of test data to pass to the ContainerConfig class + :param getter_name: Name of the getter method to test. This is pulled from the parametrized list defined above this test. + :param expected_attr: Name of the corresponding attribute. This is pulled from the parametrized list defined above this test. + """ + config = ContainerConfig(server_container_config_data) + getter = getattr(config, getter_name) + assert getter() == server_container_config_data[expected_attr] + + def test_get_container_password(self, server_container_config_data: Dict[str, str]): + """ + Test that the get_container_password is reading the password file properly + + :param server_container_config_data: A pytest fixture of test data to pass to the ContainerConfig class + """ + # Write a fake password to the password file + test_password = "super-secret-password" + with open(server_container_config_data["pass_file"], "w") as pass_file: + pass_file.write(test_password) + + # Run the test + config = ContainerConfig(server_container_config_data) + assert config.get_container_password() == test_password + + +class TestContainerFormatConfig: + """Tests for the ContainerFormatConfig class.""" + + def test_init_with_complete_data(self, server_container_format_config_data: Dict[str, str]): + """ + Tests that __init__ populates attributes correctly with complete data + + :param server_container_format_config_data: A pytest fixture of test data to pass to the ContainerFormatConfig class + """ + config = ContainerFormatConfig(server_container_format_config_data) + assert config.command == server_container_format_config_data["command"] + assert config.run_command == server_container_format_config_data["run_command"] + assert config.stop_command == server_container_format_config_data["stop_command"] + assert config.pull_command == server_container_format_config_data["pull_command"] + + def test_init_with_missing_data(self): + """ + Tests that __init__ uses defaults for missing data + """ + incomplete_data = {"command": "docker"} + config = ContainerFormatConfig(incomplete_data) + assert config.command == incomplete_data["command"] + assert config.run_command == config.RUN_COMMAND + assert config.stop_command == config.STOP_COMMAND + assert config.pull_command == config.PULL_COMMAND + + @pytest.mark.parametrize("getter_name, expected_attr", [ + ("get_command", "command"), + ("get_run_command", "run_command"), + ("get_stop_command", "stop_command"), + ("get_pull_command", "pull_command"), + ]) + def test_getter_methods(self, server_container_format_config_data: Dict[str, str], getter_name: str, expected_attr: str): + """ + Tests that all getter methods return the correct attribute values + + :param server_container_format_config_data: A pytest fixture of test data to pass to the ContainerFormatConfig class + :param getter_name: Name of the getter method to test. This is pulled from the parametrized list defined above this test. + :param expected_attr: Name of the corresponding attribute. This is pulled from the parametrized list defined above this test. + """ + config = ContainerFormatConfig(server_container_format_config_data) + getter = getattr(config, getter_name) + assert getter() == server_container_format_config_data[expected_attr] + + +class TestProcessConfig: + """Tests for the ProcessConfig class.""" + + def test_init_with_complete_data(self, server_process_config_data: Dict[str, str]): + """ + Tests that __init__ populates attributes correctly with complete data + + :param server_process_config_data: A pytest fixture of test data to pass to the ProcessConfig class + """ + config = ProcessConfig(server_process_config_data) + assert config.status == server_process_config_data["status"] + assert config.kill == server_process_config_data["kill"] + + def test_init_with_missing_data(self): + """ + Tests that __init__ uses defaults for missing data + """ + incomplete_data = {"status": "status {pid}"} + config = ProcessConfig(incomplete_data) + assert config.status == incomplete_data["status"] + assert config.kill == config.KILL_COMMAND + + @pytest.mark.parametrize("getter_name, expected_attr", [ + ("get_status_command", "status"), + ("get_kill_command", "kill"), + ]) + def test_getter_methods(self, server_process_config_data: Dict[str, str], getter_name: str, expected_attr: str): + """ + Tests that all getter methods return the correct attribute values + + :param server_process_config_data: A pytest fixture of test data to pass to the ProcessConfig class + :param getter_name: Name of the getter method to test. This is pulled from the parametrized list defined above this test. + :param expected_attr: Name of the corresponding attribute. This is pulled from the parametrized list defined above this test. + """ + config = ProcessConfig(server_process_config_data) + getter = getattr(config, getter_name) + assert getter() == server_process_config_data[expected_attr] + + +class TestServerConfig: + """Tests for the ServerConfig class.""" + + def test_init_with_complete_data(self, server_server_config: Dict[str, str]): + """ + Tests that __init__ populates attributes correctly with complete data + + :param server_server_config: A pytest fixture of test data to pass to the ServerConfig class + """ + config = ServerConfig(server_server_config) + assert config.container == ContainerConfig(server_server_config["container"]) + assert config.process == ProcessConfig(server_server_config["process"]) + assert config.container_format == ContainerFormatConfig(server_server_config["docker"]) + + def test_init_with_missing_data(self, server_process_config_data: Dict[str, str]): + """ + Tests that __init__ uses None for missing data + + :param server_process_config_data: A pytest fixture of test data to pass to the ContainerConfig class + """ + incomplete_data = {"process": server_process_config_data} + config = ServerConfig(incomplete_data) + assert config.process == ProcessConfig(server_process_config_data) + assert config.container is None + assert config.container_format is None + + +# class TestRedisConfig: +# """Tests for the RedisConfig class.""" + +# def test_parse(self, server_redis_conf_file): +# raise ValueError From e1f667ddea170cbcc251182c4ad041aa699730ea Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Thu, 23 May 2024 08:00:48 -0700 Subject: [PATCH 32/44] start work on tests for RedisConfig --- tests/fixtures/server.py | 84 +++++++++++++++++++++------ tests/unit/server/test_server_util.py | 62 ++++++++++++++++++-- 2 files changed, 125 insertions(+), 21 deletions(-) diff --git a/tests/fixtures/server.py b/tests/fixtures/server.py index 35efdcd65..04c858f46 100644 --- a/tests/fixtures/server.py +++ b/tests/fixtures/server.py @@ -1,16 +1,17 @@ """ Fixtures specifically for help testing the modules in the server/ directory. """ +import os import pytest -import shutil from typing import Dict @pytest.fixture(scope="class") -def server_container_config_data(temp_output_dir: str): +def server_container_config_data(temp_output_dir: str) -> Dict[str, str]: """ Fixture to provide sample data for ContainerConfig tests :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + :returns: A dict containing the necessary key/values for the ContainerConfig object """ return { "format": "docker", @@ -25,9 +26,11 @@ def server_container_config_data(temp_output_dir: str): } @pytest.fixture(scope="class") -def server_container_format_config_data(): +def server_container_format_config_data() -> Dict[str, str]: """ Fixture to provide sample data for ContainerFormatConfig tests + + :returns: A dict containing the necessary key/values for the ContainerFormatConfig object """ return { "command": "docker", @@ -37,9 +40,11 @@ def server_container_format_config_data(): } @pytest.fixture(scope="class") -def server_process_config_data(): +def server_process_config_data() -> Dict[str, str]: """ Fixture to provide sample data for ProcessConfig tests + + :returns: A dict containing the necessary key/values for the ProcessConfig object """ return { "status": "status {pid}", @@ -51,13 +56,14 @@ def server_server_config( server_container_config_data: Dict[str, str], server_process_config_data: Dict[str, str], server_container_format_config_data: Dict[str, str], -): +) -> Dict[str, Dict[str, str]]: """ Fixture to provide sample data for ServerConfig tests :param server_container_config_data: A pytest fixture of test data to pass to the ContainerConfig class :param server_process_config_data: A pytest fixture of test data to pass to the ProcessConfig class :param server_container_format_config_data: A pytest fixture of test data to pass to the ContainerFormatConfig class + :returns: A dictionary containing each of the configuration dicts we'll need """ return { "container": server_container_config_data, @@ -66,19 +72,63 @@ def server_server_config( } -@pytest.fixture(scope="class") -def server_redis_conf_file(temp_output_dir: str): +@pytest.fixture(scope="session") +def server_testing_dir(temp_output_dir: str) -> str: """ - Fixture to copy the redis.conf file from the merlin/server/ directory to the - temporary output directory and provide the path to the copied file + Fixture to create a temporary output directory for tests related to the server functionality. :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + :returns: The path to the temporary testing directory for server tests + """ + testing_dir = f"{temp_output_dir}/server_testing/" + if not os.path.exists(testing_dir): + os.mkdir(testing_dir) + + return testing_dir + + +@pytest.fixture(scope="session") +def server_redis_conf_file(server_testing_dir: str) -> str: + """ + Fixture to copy the redis.conf file from the merlin/server/ directory to the + temporary output directory and provide the path to the copied file. + + If a test will modify this file with a file write, you should make a copy of + this file to modify instead. + + :param server_testing_dir: A pytest fixture that defines a path to the the output directory we'll write to + :returns: The path to the redis configuration file we'll use for testing """ - # TODO - # - will probably have to do more than just copying over the conf file - # - likely want to create our own test conf file with the settings that - # can be modified by RedisConf instead - path_to_redis_conf = f"{os.path.dirname(os.path.abspath(__file__))}/../../merlin/server/redis.conf" - path_to_copied_redis = f"{temp_output_dir}/redis.conf" - shutil.copy(path_to_redis_conf, path_to_copied_redis) - return path_to_copied_redis \ No newline at end of file + redis_conf_file = f"{server_testing_dir}/redis.conf" + file_contents = """ + # ip address + bind 127.0.0.1 + + # port + port 6379 + + # password + requirepass merlin_password + + # directory + dir ./ + + # snapshot + save 300 100 + + # db file + dbfilename dump.rdb + + # append mode + appendfsync everysec + + # append file + appendfilename appendonly.aof + + # dummy trailing comment + """.strip().replace(" ", "") + + with open(redis_conf_file, "w") as rcf: + rcf.write(file_contents) + + return redis_conf_file diff --git a/tests/unit/server/test_server_util.py b/tests/unit/server/test_server_util.py index 384e1ea37..c71e854eb 100644 --- a/tests/unit/server/test_server_util.py +++ b/tests/unit/server/test_server_util.py @@ -1,8 +1,10 @@ """ Tests for the `server_util.py` module. """ +import filecmp import os import pytest +import shutil from typing import Callable, Dict, Union from merlin.server.server_util import ( @@ -288,8 +290,60 @@ def test_init_with_missing_data(self, server_process_config_data: Dict[str, str] assert config.container_format is None -# class TestRedisConfig: -# """Tests for the RedisConfig class.""" +class TestRedisConfig: + """Tests for the RedisConfig class.""" + + def test_initialization(self, server_redis_conf_file: str): + """ + Using a dummy redis configuration file, test that the initialization + of the RedisConfig class behaves as expected. + + :param server_redis_conf_file: The path to a dummy redis configuration file + """ + expected_entries = { + "bind": "127.0.0.1", + "port": "6379", + "requirepass": "merlin_password", + "dir": "./", + "save": "300 100", + "dbfilename": "dump.rdb", + "appendfsync": "everysec", + "appendfilename": "appendonly.aof", + } + expected_comments = { + "bind": "# ip address\n", + "port": "\n# port\n", + "requirepass": "\n# password\n", + "dir": "\n# directory\n", + "save": "\n# snapshot\n", + "dbfilename": "\n# db file\n", + "appendfsync": "\n# append mode\n", + "appendfilename": "\n# append file\n", + } + expected_trailing_comment = "\n# dummy trailing comment" + expected_entry_order = list(expected_entries.keys()) + redis_config = RedisConfig(server_redis_conf_file) + assert redis_config.filename == server_redis_conf_file + assert not redis_config.changed + assert redis_config.entries == expected_entries + assert redis_config.entry_order == expected_entry_order + assert redis_config.comments == expected_comments + assert redis_config.trailing_comments == expected_trailing_comment + + def test_write(self, server_redis_conf_file: str, server_testing_dir: str): + """ + """ + copy_redis_conf_file = f"{server_testing_dir}/redis_copy.conf" + + # Create a RedisConf object with the basic redis conf file + redis_config = RedisConfig(server_redis_conf_file) + + # Change the filepath of the redis config file to be the copy that we'll write to + redis_config.filename = copy_redis_conf_file + + # Run the test + redis_config.write() + + # Check that the contents of the copied file match the contents of the basic file + assert filecmp.cmp(server_redis_conf_file, copy_redis_conf_file) -# def test_parse(self, server_redis_conf_file): -# raise ValueError From 9997d8e8f442349b4cc94c29a99cec3f4710b3f8 Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Tue, 4 Jun 2024 15:07:05 -0700 Subject: [PATCH 33/44] add tests for RedisConfig object --- merlin/server/server_commands.py | 4 +- merlin/server/server_util.py | 79 ++-- tests/unit/server/test_RedisConfig.py | 538 ++++++++++++++++++++++++++ tests/unit/server/test_server_util.py | 64 +-- 4 files changed, 577 insertions(+), 108 deletions(-) create mode 100644 tests/unit/server/test_RedisConfig.py diff --git a/merlin/server/server_commands.py b/merlin/server/server_commands.py index 65d17c42b..40f2689d0 100644 --- a/merlin/server/server_commands.py +++ b/merlin/server/server_commands.py @@ -98,9 +98,7 @@ def config_server(args: Namespace) -> None: # pylint: disable=R0912 redis_config.set_directory(args.directory) - redis_config.set_snapshot_seconds(args.snapshot_seconds) - - redis_config.set_snapshot_changes(args.snapshot_changes) + redis_config.set_snapshot(seconds=args.snapshot_seconds, changes=args.snapshot_changes) redis_config.set_snapshot_file(args.snapshot_file) diff --git a/merlin/server/server_util.py b/merlin/server/server_util.py index aff641d4d..27a83376d 100644 --- a/merlin/server/server_util.py +++ b/merlin/server/server_util.py @@ -304,16 +304,14 @@ class RedisConfig: to write those changes into a redis readable config file. """ - filename = "" - entry_order = [] - entries = {} - comments = {} - trailing_comments = "" - changed = False - def __init__(self, filename) -> None: self.filename = filename self.changed = False + self.entry_order = [] + self.entries = {} + self.comments = {} + self.trailing_comments = "" + self.changed = False self.parse() def parse(self) -> None: @@ -393,7 +391,7 @@ def get_port(self) -> str: """Getter method to get the port from the redis config""" return self.get_config_value("port") - def set_port(self, port: str) -> bool: + def set_port(self, port: int) -> bool: """Validates and sets a given port""" if port is None: return False @@ -428,59 +426,56 @@ def set_directory(self, directory: str) -> bool: """ if directory is None: return False + # Create the directory if it doesn't exist if not os.path.exists(directory): os.mkdir(directory) LOG.info(f"Created directory {directory}") - # Validate the directory input - if os.path.exists(directory): - # Set the save directory to the redis config - if not self.set_config_value("dir", directory): - LOG.error("Unable to set directory for redis config") - return False - else: - LOG.error(f"Directory {directory} given does not exist and could not be created.") + # Set the save directory to the redis config + if not self.set_config_value("dir", directory): + LOG.error("Unable to set directory for redis config") return False LOG.info(f"Directory is set to {directory}") return True - def set_snapshot_seconds(self, seconds: int) -> bool: - """Sets the snapshot wait time""" - if seconds is None: - return False - # Set the snapshot second in the redis config - value = self.get_config_value("save") - if value is None: - LOG.error("Unable to get exisiting parameter values for snapshot") - return False + def set_snapshot(self, seconds: int = None, changes: int = None) -> bool: + """ + Sets the 'seconds' and/or 'changes' values of the snapshot setting, + depending on what the user requests. + + :param seconds: The first value of snapshot to change. If we're leaving it the + same this will be None. + :param changes: The second value of snapshot to change. If we're leaving it the + same this will be None. + :returns: True if successful, False otherwise. + """ - value = value.split() - value[0] = str(seconds) - value = " ".join(value) - if not self.set_config_value("save", value): - LOG.error("Unable to set snapshot value seconds") + # If both values are None, this method is doing nothing + if seconds is None and changes is None: return False - LOG.info(f"Snapshot wait time is set to {seconds} seconds") - return True - - def set_snapshot_changes(self, changes: int) -> bool: - """Sets the snapshot threshold""" - if changes is None: - return False - # Set the snapshot changes into the redis config + # Grab the snapshot value from the redis config value = self.get_config_value("save") if value is None: LOG.error("Unable to get exisiting parameter values for snapshot") return False + # Update the snapshot value value = value.split() - value[1] = str(changes) + log_msg = "" + if seconds is not None: + value[0] = str(seconds) + log_msg += f"Snapshot wait time is set to {seconds} seconds. " + if changes is not None: + value[1] = str(changes) + log_msg += f"Snapshot threshold is set to {changes} changes." value = " ".join(value) + + # Set the new snapshot value if not self.set_config_value("save", value): - LOG.error("Unable to set snapshot value seconds") + LOG.error("Unable to set snapshot value") return False - LOG.info(f"Snapshot threshold is set to {changes} changes") + LOG.info(log_msg) return True def set_snapshot_file(self, file: str) -> bool: @@ -508,7 +503,7 @@ def set_append_mode(self, mode: str) -> bool: LOG.error("Unable to set append_mode in redis config") return False else: - LOG.error("Not a valid append_mode(Only valid modes are always, everysec, no)") + LOG.error("Not a valid append_mode (Only valid modes are always, everysec, no)") return False LOG.info(f"Append mode is set to {mode}") diff --git a/tests/unit/server/test_RedisConfig.py b/tests/unit/server/test_RedisConfig.py new file mode 100644 index 000000000..12880d4d6 --- /dev/null +++ b/tests/unit/server/test_RedisConfig.py @@ -0,0 +1,538 @@ +""" +Tests for the RedisConfig class of the `server_util.py` module. + +This class is especially large so that's why these tests have been +moved to their own file. +""" +import filecmp +import logging +import pytest +from typing import Any + +from merlin.server.server_util import RedisConfig + +class TestRedisConfig: + """Tests for the RedisConfig class.""" + + def test_initialization(self, server_redis_conf_file: str): + """ + Using a dummy redis configuration file, test that the initialization + of the RedisConfig class behaves as expected. + + :param server_redis_conf_file: The path to a dummy redis configuration file + """ + expected_entries = { + "bind": "127.0.0.1", + "port": "6379", + "requirepass": "merlin_password", + "dir": "./", + "save": "300 100", + "dbfilename": "dump.rdb", + "appendfsync": "everysec", + "appendfilename": "appendonly.aof", + } + expected_comments = { + "bind": "# ip address\n", + "port": "\n# port\n", + "requirepass": "\n# password\n", + "dir": "\n# directory\n", + "save": "\n# snapshot\n", + "dbfilename": "\n# db file\n", + "appendfsync": "\n# append mode\n", + "appendfilename": "\n# append file\n", + } + expected_trailing_comment = "\n# dummy trailing comment" + expected_entry_order = list(expected_entries.keys()) + redis_config = RedisConfig(server_redis_conf_file) + assert redis_config.filename == server_redis_conf_file + assert not redis_config.changed + assert redis_config.entries == expected_entries + assert redis_config.entry_order == expected_entry_order + assert redis_config.comments == expected_comments + assert redis_config.trailing_comments == expected_trailing_comment + + def test_write(self, server_redis_conf_file: str, server_testing_dir: str): + """ + Test that the write functionality works by writing the contents of a dummy + configuration file to a blank configuration file. + + :param server_redis_conf_file: The path to a dummy redis configuration file + :param server_testing_dir: The path to the the temp output directory for server tests + """ + copy_redis_conf_file = f"{server_testing_dir}/redis_copy.conf" + + # Create a RedisConf object with the basic redis conf file + redis_config = RedisConfig(server_redis_conf_file) + + # Change the filepath of the redis config file to be the copy that we'll write to + redis_config.set_filename(copy_redis_conf_file) + + # Run the test + redis_config.write() + + # Check that the contents of the copied file match the contents of the basic file + assert filecmp.cmp(server_redis_conf_file, copy_redis_conf_file) + + @pytest.mark.parametrize("key, val, expected_return", [ + ("port", 1234, True), + ("invalid_key", "dummy_val", False) + ]) + def test_set_config_value(self, server_redis_conf_file: str, key: str, val: Any, expected_return: bool): + """ + Test the `set_config_value` method with valid and invalid keys. + + :param server_redis_conf_file: The path to a dummy redis configuration file + :param key: The key value to modify with `set_config_value` + :param val: The value to set `key` to + :param expected_return: The expected return from `set_config_value` + """ + redis_config = RedisConfig(server_redis_conf_file) + actual_return = redis_config.set_config_value(key, val) + assert actual_return == expected_return + if expected_return: + assert redis_config.entries[key] == val + assert redis_config.changes_made() + else: + assert not redis_config.changes_made() + + @pytest.mark.parametrize("key, expected_val", [ + ("bind", "127.0.0.1"), + ("port", "6379"), + ("requirepass", "merlin_password"), + ("dir", "./"), + ("save", "300 100"), + ("dbfilename", "dump.rdb"), + ("appendfsync", "everysec"), + ("appendfilename", "appendonly.aof"), + ("invalid_key", None) + ]) + def test_get_config_value(self, server_redis_conf_file: str, key: str, expected_val: str): + """ + Test the `get_config_value` method with valid and invalid keys. + + :param server_redis_conf_file: The path to a dummy redis configuration file + :param key: The key value to modify with `set_config_value` + :param expected_val: The value we're expecting to get by querying `key` + """ + redis_conf = RedisConfig(server_redis_conf_file) + assert redis_conf.get_config_value(key) == expected_val + + @pytest.mark.parametrize("ip_to_set", [ + "127.0.0.1", # Most common IP + "0.0.0.0", # Edge case (low) + "255.255.255.255", # Edge case (high) + "123.222.199.20", # Random valid IP + ]) + def test_set_ip_address_valid( + self, + caplog: "Fixture", # noqa: F821 + server_redis_conf_file: str, + ip_to_set: str + ): + """ + Test the `set_ip_address` method with valid ips. These should all return True + and set the 'bind' value to whatever `ip_to_set` is. + + :param caplog: A built-in fixture from the pytest library to capture logs + :param server_redis_conf_file: The path to a dummy redis configuration file + :param ip_to_set: The ip address to set + """ + caplog.set_level(logging.INFO) + redis_config = RedisConfig(server_redis_conf_file) + assert redis_config.set_ip_address(ip_to_set) + assert f"Ipaddress is set to {ip_to_set}" in caplog.text, "Missing expected log message" + assert redis_config.get_ip_address() == ip_to_set + + @pytest.mark.parametrize("ip_to_set, expected_log", [ + (None, None), # No IP + ("0.0.0", "Invalid IPv4 address given."), # Invalid IPv4 + ("bind-unset", "Unable to set ip address for redis config"), # Special invalid case where bind doesn't exist + ]) + def test_set_ip_address_invalid( + self, + caplog: "Fixture", # noqa: F821 + server_redis_conf_file: str, + ip_to_set: str, + expected_log: str, + ): + """ + Test the `set_ip_address` method with invalid ips. These should all return False. + and not modify the 'bind' setting. + + :param caplog: A built-in fixture from the pytest library to capture logs + :param server_redis_conf_file: The path to a dummy redis configuration file + :param ip_to_set: The ip address to set + :param expected_log: The string we're expecting the logger to log + """ + redis_config = RedisConfig(server_redis_conf_file) + # For the test where bind is unset, delete bind from dict and set new ip val to a valid value + if ip_to_set == "bind-unset": + del redis_config.entries["bind"] + ip_to_set = "127.0.0.1" + assert not redis_config.set_ip_address(ip_to_set) + assert redis_config.get_ip_address() != ip_to_set + if expected_log is not None: + assert expected_log in caplog.text, "Missing expected log message" + + @pytest.mark.parametrize("port_to_set", [ + 6379, # Most common port + 1, # Edge case (low) + 65535, # Edge case (high) + 12345, # Random valid port + ]) + def test_set_port_valid( + self, + caplog: "Fixture", # noqa: F821 + server_redis_conf_file: str, + port_to_set: str, + ): + """ + Test the `set_port` method with valid ports. These should all return True + and set the 'port' value to whatever `port_to_set` is. + + :param caplog: A built-in fixture from the pytest library to capture logs + :param server_redis_conf_file: The path to a dummy redis configuration file + :param port_to_set: The port to set + """ + caplog.set_level(logging.INFO) + redis_config = RedisConfig(server_redis_conf_file) + assert redis_config.set_port(port_to_set) + assert redis_config.get_port() == port_to_set + assert f"Port is set to {port_to_set}" in caplog.text, "Missing expected log message" + + @pytest.mark.parametrize("port_to_set, expected_log", [ + (None, None), # No port + (0, "Invalid port given."), # Edge case (low) + (65536, "Invalid port given."), # Edge case (high) + ("port-unset", "Unable to set port for redis config"), # Special invalid case where port doesn't exist + ]) + def test_set_port_invalid( + self, + caplog: "Fixture", # noqa: F821 + server_redis_conf_file: str, + port_to_set: str, + expected_log: str, + ): + """ + Test the `set_port` method with invalid inputs. These should all return False + and not modify the 'port' setting. + + :param caplog: A built-in fixture from the pytest library to capture logs + :param server_redis_conf_file: The path to a dummy redis configuration file + :param port_to_set: The port to set + :param expected_log: The string we're expecting the logger to log + """ + redis_config = RedisConfig(server_redis_conf_file) + # For the test where port is unset, delete port from dict and set port val to a valid value + if port_to_set == "port-unset": + del redis_config.entries["port"] + port_to_set = 5 + assert not redis_config.set_port(port_to_set) + assert redis_config.get_port() != port_to_set + if expected_log is not None: + assert expected_log in caplog.text, "Missing expected log message" + + @pytest.mark.parametrize("pass_to_set, expected_return", [ + ("valid_password", True), # Valid password + (None, False), # Invalid password + ]) + def test_set_password( + self, + caplog: "Fixture", # noqa: F821 + server_redis_conf_file: str, + pass_to_set: str, + expected_return: bool, + ): + """ + Test the `set_password` method with both valid and invalid input. + + :param caplog: A built-in fixture from the pytest library to capture logs + :param server_redis_conf_file: The path to a dummy redis configuration file + :param pass_to_set: The password to set + :param expected_return: The expected return value + """ + caplog.set_level(logging.INFO) + redis_conf = RedisConfig(server_redis_conf_file) + assert redis_conf.set_password(pass_to_set) == expected_return + if expected_return: + assert redis_conf.get_password() == pass_to_set + assert "New password set" in caplog.text, "Missing expected log message" + + def test_set_directory_valid( + self, + caplog: "Fixture", # noqa: F821 + server_redis_conf_file: str, + server_testing_dir: str, + ): + """ + Test the `set_directory` method with valid input. This should return True, modify the + 'dir' value, and log some messages about creating/setting the directory. + + :param caplog: A built-in fixture from the pytest library to capture logs + :param server_redis_conf_file: The path to a dummy redis configuration file + :param server_testing_dir: The path to the the temp output directory for server tests + """ + caplog.set_level(logging.INFO) + redis_config = RedisConfig(server_redis_conf_file) + dir_to_set = f"{server_testing_dir}/dummy_dir" + assert redis_config.set_directory(dir_to_set) + assert redis_config.get_config_value("dir") == dir_to_set + assert f"Created directory {dir_to_set}" in caplog.text, "Missing created log message" + assert f"Directory is set to {dir_to_set}" in caplog.text, "Missing set log message" + + def test_set_directory_none(self, server_redis_conf_file: str): + """ + Test the `set_directory` method with None as the input. This should return False + and not modify the 'dir' setting. + + :param server_redis_conf_file: The path to a dummy redis configuration file + """ + redis_config = RedisConfig(server_redis_conf_file) + assert not redis_config.set_directory(None) + assert redis_config.get_config_value("dir") != None + + def test_set_directory_dir_unset( + self, + caplog: "Fixture", # noqa: F821 + server_redis_conf_file: str, + server_testing_dir: str, + ): + """ + Test the `set_directory` method with the 'dir' setting not existing. This should + return False and log an error message. + + :param caplog: A built-in fixture from the pytest library to capture logs + :param server_redis_conf_file: The path to a dummy redis configuration file + :param server_testing_dir: The path to the the temp output directory for server tests + """ + redis_config = RedisConfig(server_redis_conf_file) + del redis_config.entries["dir"] + dir_to_set = f"{server_testing_dir}/dummy_dir" + assert not redis_config.set_directory(dir_to_set) + assert "Unable to set directory for redis config" in caplog.text, "Missing expected log message" + + def test_set_snapshot_valid(self, caplog: "Fixture", server_redis_conf_file: str): # noqa: F821 + """ + Test the `set_snapshot` method with a valid input for 'seconds' and 'changes'. + This should return True and modify both values of 'save'. + + :param caplog: A built-in fixture from the pytest library to capture logs + :param server_redis_conf_file: The path to a dummy redis configuration file + """ + caplog.set_level(logging.INFO) + redis_conf = RedisConfig(server_redis_conf_file) + snap_sec_to_set = 20 + snap_changes_to_set = 30 + assert redis_conf.set_snapshot(seconds=snap_sec_to_set, changes=snap_changes_to_set) + save_val = redis_conf.get_config_value("save").split() + assert save_val[0] == str(snap_sec_to_set) + assert save_val[1] == str(snap_changes_to_set) + expected_log = f"Snapshot wait time is set to {snap_sec_to_set} seconds. " \ + f"Snapshot threshold is set to {snap_changes_to_set} changes" + assert expected_log in caplog.text, "Missing expected log message" + + def test_set_snapshot_just_seconds(self, caplog: "Fixture", server_redis_conf_file: str): # noqa: F821 + """ + Test the `set_snapshot` method with a valid input for 'seconds'. This should + return True and modify the first value of 'save'. + + :param caplog: A built-in fixture from the pytest library to capture logs + :param server_redis_conf_file: The path to a dummy redis configuration file + """ + caplog.set_level(logging.INFO) + redis_conf = RedisConfig(server_redis_conf_file) + orig_save = redis_conf.get_config_value("save").split() + snap_sec_to_set = 20 + assert redis_conf.set_snapshot(seconds=snap_sec_to_set) + save_val = redis_conf.get_config_value("save").split() + assert save_val[0] == str(snap_sec_to_set) + assert save_val[1] == orig_save[1] + expected_log = f"Snapshot wait time is set to {snap_sec_to_set} seconds. " + assert expected_log in caplog.text, "Missing expected log message" + + def test_set_snapshot_just_changes(self, caplog: "Fixture", server_redis_conf_file: str): # noqa: F821 + """ + Test the `set_snapshot` method with a valid input for 'changes'. This should + return True and modify the second value of 'save'. + + :param caplog: A built-in fixture from the pytest library to capture logs + :param server_redis_conf_file: The path to a dummy redis configuration file + """ + caplog.set_level(logging.INFO) + redis_conf = RedisConfig(server_redis_conf_file) + orig_save = redis_conf.get_config_value("save").split() + snap_changes_to_set = 30 + assert redis_conf.set_snapshot(changes=snap_changes_to_set) + save_val = redis_conf.get_config_value("save").split() + assert save_val[0] == orig_save[0] + assert save_val[1] == str(snap_changes_to_set) + expected_log = f"Snapshot threshold is set to {snap_changes_to_set} changes" + assert expected_log in caplog.text, "Missing expected log message" + + def test_set_snapshot_none(self, server_redis_conf_file: str): + """ + Test the `set_snapshot` method with None as the input for both seconds + and changes. This should return False. + + :param server_redis_conf_file: The path to a dummy redis configuration file + """ + redis_conf = RedisConfig(server_redis_conf_file) + assert not redis_conf.set_snapshot(seconds=None, changes=None) + + def test_set_snapshot_save_unset(self, caplog: "Fixture", server_redis_conf_file: str): # noqa: F821 + """ + Test the `set_snapshot` method with the 'save' setting not existing. This should + return False and log an error message. + + :param caplog: A built-in fixture from the pytest library to capture logs + :param server_redis_conf_file: The path to a dummy redis configuration file + """ + redis_conf = RedisConfig(server_redis_conf_file) + del redis_conf.entries["save"] + assert not redis_conf.set_snapshot(seconds=20) + assert "Unable to get exisiting parameter values for snapshot" in caplog.text, "Missing expected log message" + + def test_set_snapshot_file_valid(self, caplog: "Fixture", server_redis_conf_file: str): # noqa: F821 + """ + Test the `set_snapshot_file` method with a valid input. This should + return True and modify the value of 'dbfilename'. + + :param caplog: A built-in fixture from the pytest library to capture logs + :param server_redis_conf_file: The path to a dummy redis configuration file + """ + caplog.set_level(logging.INFO) + redis_conf = RedisConfig(server_redis_conf_file) + filename = "dummy_file.rdb" + assert redis_conf.set_snapshot_file(filename) + assert redis_conf.get_config_value("dbfilename") == filename + assert f"Snapshot file is set to {filename}" in caplog.text, "Missing expected log message" + + def test_set_snapshot_file_none(self, server_redis_conf_file: str): + """ + Test the `set_snapshot_file` method with None as the input. + This should return False. + + :param server_redis_conf_file: The path to a dummy redis configuration file + """ + redis_conf = RedisConfig(server_redis_conf_file) + assert not redis_conf.set_snapshot_file(None) + + def test_set_snapshot_file_dbfilename_unset(self, caplog: "Fixture", server_redis_conf_file: str): # noqa: F821 + """ + Test the `set_snapshot` method with the 'dbfilename' setting not existing. This should + return False and log an error message. + + :param caplog: A built-in fixture from the pytest library to capture logs + :param server_redis_conf_file: The path to a dummy redis configuration file + """ + redis_conf = RedisConfig(server_redis_conf_file) + del redis_conf.entries["dbfilename"] + filename = "dummy_file.rdb" + assert not redis_conf.set_snapshot_file(filename) + assert redis_conf.get_config_value("dbfilename") != filename + assert "Unable to set snapshot_file name" in caplog.text, "Missing expected log message" + + @pytest.mark.parametrize("mode_to_set", [ + "always", + "everysec", + "no", + ]) + def test_set_append_mode_valid( + self, + caplog: "Fixture", # noqa: F821 + server_redis_conf_file: str, + mode_to_set: str, + ): + """ + Test the `set_append_mode` method with valid modes. These should all return True + and modify the value of 'appendfsync'. + + :param caplog: A built-in fixture from the pytest library to capture logs + :param server_redis_conf_file: The path to a dummy redis configuration file + :param mode_to_set: The mode to set + """ + caplog.set_level(logging.INFO) + redis_conf = RedisConfig(server_redis_conf_file) + assert redis_conf.set_append_mode(mode_to_set) + assert redis_conf.get_config_value("appendfsync") == mode_to_set + assert f"Append mode is set to {mode_to_set}" in caplog.text, "Missing expected log message" + + def test_set_append_mode_invalid(self, caplog: "Fixture", server_redis_conf_file: str): # noqa: F821 + """ + Test the `set_append_mode` method with an invalid mode. This should return False + and log an error message. + + :param caplog: A built-in fixture from the pytest library to capture logs + :param server_redis_conf_file: The path to a dummy redis configuration file + """ + redis_conf = RedisConfig(server_redis_conf_file) + invalid_mode = "invalid" + assert not redis_conf.set_append_mode(invalid_mode) + assert redis_conf.get_config_value("appendfsync") != invalid_mode + expected_log = "Not a valid append_mode (Only valid modes are always, everysec, no)" + assert expected_log in caplog.text, "Missing expected log message" + + def test_set_append_mode_none(self, server_redis_conf_file: str): + """ + Test the `set_append_mode` method with None as the input. + This should return False. + + :param server_redis_conf_file: The path to a dummy redis configuration file + """ + redis_conf = RedisConfig(server_redis_conf_file) + assert not redis_conf.set_append_mode(None) + + def test_set_append_mode_appendfsync_unset(self, caplog: "Fixture", server_redis_conf_file: str): # noqa: F821 + """ + Test the `set_append_mode` method with the 'appendfsync' setting not existing. This should + return False and log an error message. + + :param caplog: A built-in fixture from the pytest library to capture logs + :param server_redis_conf_file: The path to a dummy redis configuration file + """ + redis_conf = RedisConfig(server_redis_conf_file) + del redis_conf.entries["appendfsync"] + mode = "no" + assert not redis_conf.set_append_mode(mode) + assert redis_conf.get_config_value("appendfsync") != mode + assert "Unable to set append_mode in redis config" in caplog.text, "Missing expected log message" + + def test_set_append_file_valid(self, caplog: "Fixture", server_redis_conf_file: str): # noqa: F821 + """ + Test the `set_append_file` method with a valid file. This should return True + and modify the value of 'appendfilename'. + + :param caplog: A built-in fixture from the pytest library to capture logs + :param server_redis_conf_file: The path to a dummy redis configuration file + """ + caplog.set_level(logging.INFO) + redis_conf = RedisConfig(server_redis_conf_file) + valid_file = "valid" + assert redis_conf.set_append_file(valid_file) + assert redis_conf.get_config_value("appendfilename") == f'"{valid_file}"' + assert f"Append file is set to {valid_file}" in caplog.text, "Missing expected log message" + + def test_set_append_file_none(self, server_redis_conf_file: str): + """ + Test the `set_append_file` method with None as the input. + This should return False. + + :param server_redis_conf_file: The path to a dummy redis configuration file + """ + redis_conf = RedisConfig(server_redis_conf_file) + assert not redis_conf.set_append_file(None) + + def test_set_append_file_appendfilename_unset(self, caplog: "Fixture", server_redis_conf_file: str): # noqa: F821 + """ + Test the `set_append_file` method with the 'appendfilename' setting not existing. This should + return False and log an error message. + + :param caplog: A built-in fixture from the pytest library to capture logs + :param server_redis_conf_file: The path to a dummy redis configuration file + """ + redis_conf = RedisConfig(server_redis_conf_file) + del redis_conf.entries["appendfilename"] + filename = "valid_filename" + assert not redis_conf.set_append_file(filename) + assert redis_conf.get_config_value("appendfilename") != filename + assert "Unable to set append filename." in caplog.text, "Missing expected log message" diff --git a/tests/unit/server/test_server_util.py b/tests/unit/server/test_server_util.py index c71e854eb..0332be944 100644 --- a/tests/unit/server/test_server_util.py +++ b/tests/unit/server/test_server_util.py @@ -1,18 +1,15 @@ """ Tests for the `server_util.py` module. """ -import filecmp import os import pytest -import shutil -from typing import Callable, Dict, Union +from typing import Dict, Union from merlin.server.server_util import ( AppYaml, ContainerConfig, ContainerFormatConfig, ProcessConfig, - RedisConfig, RedisUsers, ServerConfig, valid_ipv4, @@ -288,62 +285,3 @@ def test_init_with_missing_data(self, server_process_config_data: Dict[str, str] assert config.process == ProcessConfig(server_process_config_data) assert config.container is None assert config.container_format is None - - -class TestRedisConfig: - """Tests for the RedisConfig class.""" - - def test_initialization(self, server_redis_conf_file: str): - """ - Using a dummy redis configuration file, test that the initialization - of the RedisConfig class behaves as expected. - - :param server_redis_conf_file: The path to a dummy redis configuration file - """ - expected_entries = { - "bind": "127.0.0.1", - "port": "6379", - "requirepass": "merlin_password", - "dir": "./", - "save": "300 100", - "dbfilename": "dump.rdb", - "appendfsync": "everysec", - "appendfilename": "appendonly.aof", - } - expected_comments = { - "bind": "# ip address\n", - "port": "\n# port\n", - "requirepass": "\n# password\n", - "dir": "\n# directory\n", - "save": "\n# snapshot\n", - "dbfilename": "\n# db file\n", - "appendfsync": "\n# append mode\n", - "appendfilename": "\n# append file\n", - } - expected_trailing_comment = "\n# dummy trailing comment" - expected_entry_order = list(expected_entries.keys()) - redis_config = RedisConfig(server_redis_conf_file) - assert redis_config.filename == server_redis_conf_file - assert not redis_config.changed - assert redis_config.entries == expected_entries - assert redis_config.entry_order == expected_entry_order - assert redis_config.comments == expected_comments - assert redis_config.trailing_comments == expected_trailing_comment - - def test_write(self, server_redis_conf_file: str, server_testing_dir: str): - """ - """ - copy_redis_conf_file = f"{server_testing_dir}/redis_copy.conf" - - # Create a RedisConf object with the basic redis conf file - redis_config = RedisConfig(server_redis_conf_file) - - # Change the filepath of the redis config file to be the copy that we'll write to - redis_config.filename = copy_redis_conf_file - - # Run the test - redis_config.write() - - # Check that the contents of the copied file match the contents of the basic file - assert filecmp.cmp(server_redis_conf_file, copy_redis_conf_file) - From 52213f245a2d5e11d94f2fc58bb4ccfa8f5cbf56 Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Tue, 4 Jun 2024 16:41:12 -0700 Subject: [PATCH 34/44] add tests for RedisUsers class --- merlin/server/server_util.py | 2 +- tests/fixtures/server.py | 48 +++++++- tests/unit/server/test_server_util.py | 167 ++++++++++++++++++++++++++ 3 files changed, 214 insertions(+), 3 deletions(-) diff --git a/merlin/server/server_util.py b/merlin/server/server_util.py index 27a83376d..9b7233097 100644 --- a/merlin/server/server_util.py +++ b/merlin/server/server_util.py @@ -623,7 +623,7 @@ def set_password(self, user: str, password: str): self.users[user].set_password(password) return True - def remove_user(self, user) -> bool: + def remove_user(self, user: str) -> bool: """Remove a user from the dict of users""" if user in self.users: del self.users[user] diff --git a/tests/fixtures/server.py b/tests/fixtures/server.py index 04c858f46..5c5dea102 100644 --- a/tests/fixtures/server.py +++ b/tests/fixtures/server.py @@ -3,6 +3,7 @@ """ import os import pytest +import yaml from typing import Dict @pytest.fixture(scope="class") @@ -90,8 +91,7 @@ def server_testing_dir(temp_output_dir: str) -> str: @pytest.fixture(scope="session") def server_redis_conf_file(server_testing_dir: str) -> str: """ - Fixture to copy the redis.conf file from the merlin/server/ directory to the - temporary output directory and provide the path to the copied file. + Fixture to write a redis.conf file to the temporary output directory. If a test will modify this file with a file write, you should make a copy of this file to modify instead. @@ -132,3 +132,47 @@ def server_redis_conf_file(server_testing_dir: str) -> str: rcf.write(file_contents) return redis_conf_file + +@pytest.fixture(scope="session") +def server_users() -> dict: + """ + Create a dictionary of two test users with identical configuration settings. + + :returns: A dict containing the two test users and their settings + """ + users = { + "default": { + "channels": '*', + "commands": '@all', + "hash_password": '1ba9249af0c73dacb0f9a70567126624076b5bee40de811e65f57eabcdaf490a', + "keys": '*', + "status": 'on', + }, + "test_user": { + "channels": '*', + "commands": '@all', + "hash_password": '1ba9249af0c73dacb0f9a70567126624076b5bee40de811e65f57eabcdaf490a', + "keys": '*', + "status": 'on', + } + } + return users + +@pytest.fixture(scope="session") +def server_redis_users_file(server_testing_dir: str, server_users: dict) -> str: + """ + Fixture to write a redis.users file to the temporary output directory. + + If a test will modify this file with a file write, you should make a copy of + this file to modify instead. + + :param server_testing_dir: A pytest fixture that defines a path to the the output directory we'll write to + :param server_users: A dict of test user configurations + :returns: The path to the redis user configuration file we'll use for testing + """ + redis_users_file = f"{server_testing_dir}/redis.users" + + with open(redis_users_file, "w") as ruf: + yaml.dump(server_users, ruf) + + return redis_users_file \ No newline at end of file diff --git a/tests/unit/server/test_server_util.py b/tests/unit/server/test_server_util.py index 0332be944..61f29293c 100644 --- a/tests/unit/server/test_server_util.py +++ b/tests/unit/server/test_server_util.py @@ -1,6 +1,8 @@ """ Tests for the `server_util.py` module. """ +import filecmp +import hashlib import os import pytest from typing import Dict, Union @@ -285,3 +287,168 @@ def test_init_with_missing_data(self, server_process_config_data: Dict[str, str] assert config.process == ProcessConfig(server_process_config_data) assert config.container is None assert config.container_format is None + +class TestRedisUsers: + """ + Tests for the RedisUsers class. + + TODO add integration test(s) for `apply_to_redis` method of this class. + """ + + class TestUser: + """Tests for the RedisUsers.User class""" + + def test_initializaiton(self): + """Test the initialization process of the User class.""" + user = RedisUsers.User() + assert user.status == "on" + assert user.hash_password == hashlib.sha256(b"password").hexdigest() + assert user.keys == "*" + assert user.channels == "*" + assert user.commands == "@all" + + def test_parse_dict(self): + """Test the `parse_dict` method of the User class.""" + test_dict = { + "status": "test_status", + "hash_password": "test_password", + "keys": "test_keys", + "channels": "test_channels", + "commands": "test_commands", + } + user = RedisUsers.User() + user.parse_dict(test_dict) + assert user.status == test_dict["status"] + assert user.hash_password == test_dict["hash_password"] + assert user.keys == test_dict["keys"] + assert user.channels == test_dict["channels"] + assert user.commands == test_dict["commands"] + + def test_get_user_dict(self): + """Test the `get_user_dict` method of the User class.""" + test_dict = { + "status": "test_status", + "hash_password": "test_password", + "keys": "test_keys", + "channels": "test_channels", + "commands": "test_commands", + "invalid_key": "invalid_val", + } + user = RedisUsers.User() + user.parse_dict(test_dict) # Set the test values + actual_dict = user.get_user_dict() + assert "invalid_key" not in actual_dict # Check that the invalid key isn't parsed + + # Check that the values are as expected + for key, val in actual_dict.items(): + if key == "status": + assert val == "on" + else: + assert val == test_dict[key] + + def test_set_password(self): + """Test the `set_password` method of the User class.""" + user = RedisUsers.User() + pass_to_set = "dummy_password" + user.set_password(pass_to_set) + assert user.hash_password == hashlib.sha256(bytes(pass_to_set, "utf-8")).hexdigest() + + def test_initialization(self, server_redis_users_file: str, server_users: dict): + """ + Test the initialization process of the RedisUsers class. + + :param server_redis_users_file: The path to a dummy redis users file + :param server_users: A dict of test user configurations + """ + redis_users = RedisUsers(server_redis_users_file) + assert redis_users.filename == server_redis_users_file + assert len(redis_users.users) == len(server_users) + + def test_write(self, server_redis_users_file: str, server_testing_dir: str): + """ + Test that the write functionality works by writing the contents of a dummy + users file to a blank users file. + + :param server_redis_users_file: The path to a dummy redis users file + :param server_testing_dir: The path to the the temp output directory for server tests + """ + copy_redis_users_file = f"{server_testing_dir}/redis_copy.users" + + # Create a RedisUsers object with the basic redis users file + redis_users = RedisUsers(server_redis_users_file) + + # Change the filepath of the redis users file to be the copy that we'll write to + redis_users.filename = copy_redis_users_file + + # Run the test + redis_users.write() + + # Check that the contents of the copied file match the contents of the basic file + assert filecmp.cmp(server_redis_users_file, copy_redis_users_file) + + def test_add_user_nonexistent(self, server_redis_users_file: str): + """ + Test the `add_user` method with a user that doesn't exists. + This should return True and add the user to the list of users. + + :param server_redis_users_file: The path to a dummy redis users file + """ + redis_users = RedisUsers(server_redis_users_file) + num_users_before = len(redis_users.users) + assert redis_users.add_user("new_user") + assert len(redis_users.users) == num_users_before + 1 + + def test_add_user_exists(self, server_redis_users_file: str): + """ + Test the `add_user` method with a user that already exists. + This should return False. + + :param server_redis_users_file: The path to a dummy redis users file + """ + redis_users = RedisUsers(server_redis_users_file) + assert not redis_users.add_user("test_user") + + def test_set_password_valid(self, server_redis_users_file: str): + """ + Test the `set_password` method with a user that exists. + This should return True and change the password for the user. + + :param server_redis_users_file: The path to a dummy redis users file + """ + redis_users = RedisUsers(server_redis_users_file) + pass_to_set = "new_password" + assert redis_users.set_password("test_user", pass_to_set) + expected_hash_pass = hashlib.sha256(bytes(pass_to_set, "utf-8")).hexdigest() + assert redis_users.users["test_user"].hash_password == expected_hash_pass + + def test_set_password_invalid(self, server_redis_users_file: str): + """ + Test the `set_password` method with a user that doesn't exist. + This should return False. + + :param server_redis_users_file: The path to a dummy redis users file + """ + redis_users = RedisUsers(server_redis_users_file) + assert not redis_users.set_password("nonexistent_user", "new_password") + + def test_remove_user_valid(self, server_redis_users_file: str): + """ + Test the `remove_user` method with a user that exists. + This should return True and remove the user from the list of users. + + :param server_redis_users_file: The path to a dummy redis users file + """ + redis_users = RedisUsers(server_redis_users_file) + num_users_before = len(redis_users.users) + assert redis_users.remove_user("test_user") + assert len(redis_users.users) == num_users_before - 1 + + def test_remove_user_invalid(self, server_redis_users_file: str): + """ + Test the `remove_user` method with a user that doesn't exist. + This should return False and not modify the user list. + + :param server_redis_users_file: The path to a dummy redis users file + """ + redis_users = RedisUsers(server_redis_users_file) + assert not redis_users.remove_user("nonexistent_user") \ No newline at end of file From a59243ffc65a77c9ce3dc2b67efd62a8dce06fc2 Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Thu, 6 Jun 2024 10:01:03 -0700 Subject: [PATCH 35/44] change server fixtures to use redis config files --- tests/fixtures/server.py | 174 +++++++++++++++----------- tests/unit/server/test_server_util.py | 6 +- 2 files changed, 107 insertions(+), 73 deletions(-) diff --git a/tests/fixtures/server.py b/tests/fixtures/server.py index 5c5dea102..ae3a966c8 100644 --- a/tests/fixtures/server.py +++ b/tests/fixtures/server.py @@ -6,72 +6,6 @@ import yaml from typing import Dict -@pytest.fixture(scope="class") -def server_container_config_data(temp_output_dir: str) -> Dict[str, str]: - """ - Fixture to provide sample data for ContainerConfig tests - - :param temp_output_dir: The path to the temporary output directory we'll be using for this test run - :returns: A dict containing the necessary key/values for the ContainerConfig object - """ - return { - "format": "docker", - "image_type": "postgres", - "image": "postgres:latest", - "url": "postgres://localhost", - "config": "postgres.conf", - "config_dir": "/path/to/config", - "pfile": "merlin_server_postgres.pf", - "pass_file": f"{temp_output_dir}/postgres.pass", - "user_file": "postgres.users", - } - -@pytest.fixture(scope="class") -def server_container_format_config_data() -> Dict[str, str]: - """ - Fixture to provide sample data for ContainerFormatConfig tests - - :returns: A dict containing the necessary key/values for the ContainerFormatConfig object - """ - return { - "command": "docker", - "run_command": "{command} run --name {name} -d {image}", - "stop_command": "{command} stop {name}", - "pull_command": "{command} pull {url}", - } - -@pytest.fixture(scope="class") -def server_process_config_data() -> Dict[str, str]: - """ - Fixture to provide sample data for ProcessConfig tests - - :returns: A dict containing the necessary key/values for the ProcessConfig object - """ - return { - "status": "status {pid}", - "kill": "terminate {pid}", - } - -@pytest.fixture(scope="class") -def server_server_config( - server_container_config_data: Dict[str, str], - server_process_config_data: Dict[str, str], - server_container_format_config_data: Dict[str, str], -) -> Dict[str, Dict[str, str]]: - """ - Fixture to provide sample data for ServerConfig tests - - :param server_container_config_data: A pytest fixture of test data to pass to the ContainerConfig class - :param server_process_config_data: A pytest fixture of test data to pass to the ProcessConfig class - :param server_container_format_config_data: A pytest fixture of test data to pass to the ContainerFormatConfig class - :returns: A dictionary containing each of the configuration dicts we'll need - """ - return { - "container": server_container_config_data, - "process": server_process_config_data, - "docker": server_container_format_config_data, - } - @pytest.fixture(scope="session") def server_testing_dir(temp_output_dir: str) -> str: @@ -81,7 +15,7 @@ def server_testing_dir(temp_output_dir: str) -> str: :param temp_output_dir: The path to the temporary output directory we'll be using for this test run :returns: The path to the temporary testing directory for server tests """ - testing_dir = f"{temp_output_dir}/server_testing/" + testing_dir = f"{temp_output_dir}/server_testing" if not os.path.exists(testing_dir): os.mkdir(testing_dir) @@ -96,7 +30,7 @@ def server_redis_conf_file(server_testing_dir: str) -> str: If a test will modify this file with a file write, you should make a copy of this file to modify instead. - :param server_testing_dir: A pytest fixture that defines a path to the the output directory we'll write to + :param server_testing_dir: A pytest fixture that defines a path to the output directory we'll write to :returns: The path to the redis configuration file we'll use for testing """ redis_conf_file = f"{server_testing_dir}/redis.conf" @@ -133,6 +67,26 @@ def server_redis_conf_file(server_testing_dir: str) -> str: return redis_conf_file + +@pytest.fixture(scope="session") +def server_redis_pass_file(server_testing_dir: str) -> str: + """ + Fixture to create a redis password file in the temporary output directory. + + If a test will modify this file with a file write, you should make a copy of + this file to modify instead. + + :param server_testing_dir: A pytest fixture that defines a path to the output directory we'll write to + :returns: The path to the redis password file + """ + redis_pass_file = f"{server_testing_dir}/redis.pass" + + with open(redis_pass_file, "w") as rpf: + rpf.write("server-tests-password") + + return redis_pass_file + + @pytest.fixture(scope="session") def server_users() -> dict: """ @@ -158,6 +112,7 @@ def server_users() -> dict: } return users + @pytest.fixture(scope="session") def server_redis_users_file(server_testing_dir: str, server_users: dict) -> str: """ @@ -166,7 +121,7 @@ def server_redis_users_file(server_testing_dir: str, server_users: dict) -> str: If a test will modify this file with a file write, you should make a copy of this file to modify instead. - :param server_testing_dir: A pytest fixture that defines a path to the the output directory we'll write to + :param server_testing_dir: A pytest fixture that defines a path to the output directory we'll write to :param server_users: A dict of test user configurations :returns: The path to the redis user configuration file we'll use for testing """ @@ -175,4 +130,83 @@ def server_redis_users_file(server_testing_dir: str, server_users: dict) -> str: with open(redis_users_file, "w") as ruf: yaml.dump(server_users, ruf) - return redis_users_file \ No newline at end of file + return redis_users_file + + +@pytest.fixture(scope="class") +def server_container_config_data( + server_testing_dir: str, + server_redis_conf_file: str, + server_redis_pass_file: str, + server_redis_users_file: str, +) -> Dict[str, str]: + """ + Fixture to provide sample data for ContainerConfig tests. + + :param server_testing_dir: A pytest fixture that defines a path to the output directory we'll write to + :param server_redis_conf_file: A pytest fixture that defines a path to a redis configuration file + :param server_redis_pass_file: A pytest fixture that defines a path to a redis password file + :param server_redis_users_file: A pytest fixture that defines a path to a redis users file + :returns: A dict containing the necessary key/values for the ContainerConfig object + """ + + return { + "format": "singularity", + "image_type": "redis", + "image": "redis_latest.sif", + "url": "docker://redis", + "config": server_redis_conf_file.split("/")[-1], + "config_dir": server_testing_dir, + "pfile": "merlin_server.pf", + "pass_file": server_redis_pass_file.split("/")[-1], + "user_file": server_redis_users_file.split("/")[-1], + } + + +@pytest.fixture(scope="class") +def server_container_format_config_data() -> Dict[str, str]: + """ + Fixture to provide sample data for ContainerFormatConfig tests + + :returns: A dict containing the necessary key/values for the ContainerFormatConfig object + """ + return { + "command": "singularity", + "run_command": "{command} run -H {home_dir} {image} {config}", + "stop_command": "kill", + "pull_command": "{command} pull {image} {url}", + } + + +@pytest.fixture(scope="class") +def server_process_config_data() -> Dict[str, str]: + """ + Fixture to provide sample data for ProcessConfig tests + + :returns: A dict containing the necessary key/values for the ProcessConfig object + """ + return { + "status": "pgrep -P {pid}", + "kill": "kill {pid}", + } + + +@pytest.fixture(scope="class") +def server_server_config( + server_container_config_data: Dict[str, str], + server_process_config_data: Dict[str, str], + server_container_format_config_data: Dict[str, str], +) -> Dict[str, Dict[str, str]]: + """ + Fixture to provide sample data for ServerConfig tests + + :param server_container_config_data: A pytest fixture of test data to pass to the ContainerConfig class + :param server_process_config_data: A pytest fixture of test data to pass to the ProcessConfig class + :param server_container_format_config_data: A pytest fixture of test data to pass to the ContainerFormatConfig class + :returns: A dictionary containing each of the configuration dicts we'll need + """ + return { + "container": server_container_config_data, + "process": server_process_config_data, + "singularity": server_container_format_config_data, + } diff --git a/tests/unit/server/test_server_util.py b/tests/unit/server/test_server_util.py index 61f29293c..2986e22de 100644 --- a/tests/unit/server/test_server_util.py +++ b/tests/unit/server/test_server_util.py @@ -169,7 +169,7 @@ def test_get_container_password(self, server_container_config_data: Dict[str, st :param server_container_config_data: A pytest fixture of test data to pass to the ContainerConfig class """ # Write a fake password to the password file - test_password = "super-secret-password" + test_password = "server-tests-password" with open(server_container_config_data["pass_file"], "w") as pass_file: pass_file.write(test_password) @@ -274,7 +274,7 @@ def test_init_with_complete_data(self, server_server_config: Dict[str, str]): config = ServerConfig(server_server_config) assert config.container == ContainerConfig(server_server_config["container"]) assert config.process == ProcessConfig(server_server_config["process"]) - assert config.container_format == ContainerFormatConfig(server_server_config["docker"]) + assert config.container_format == ContainerFormatConfig(server_server_config["singularity"]) def test_init_with_missing_data(self, server_process_config_data: Dict[str, str]): """ @@ -451,4 +451,4 @@ def test_remove_user_invalid(self, server_redis_users_file: str): :param server_redis_users_file: The path to a dummy redis users file """ redis_users = RedisUsers(server_redis_users_file) - assert not redis_users.remove_user("nonexistent_user") \ No newline at end of file + assert not redis_users.remove_user("nonexistent_user") From 0ef586e4d60ab7d0ca228ba9c311ef4d3924211b Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Thu, 6 Jun 2024 10:55:35 -0700 Subject: [PATCH 36/44] add tests for AppYaml class --- tests/fixtures/server.py | 62 ++++++++++++++++++++++++ tests/unit/server/test_server_util.py | 70 +++++++++++++++++++++++++++ 2 files changed, 132 insertions(+) diff --git a/tests/fixtures/server.py b/tests/fixtures/server.py index ae3a966c8..284084e2c 100644 --- a/tests/fixtures/server.py +++ b/tests/fixtures/server.py @@ -210,3 +210,65 @@ def server_server_config( "process": server_process_config_data, "singularity": server_container_format_config_data, } + + +@pytest.fixture(scope="function") +def server_app_yaml_contents( + server_redis_pass_file: str, + server_container_config_data: Dict[str, str], + server_process_config_data: Dict[str, str], +) -> Dict[str, str]: + """ + Fixture to create the contents of an app.yaml file. + + :param server_redis_pass_file: A pytest fixture that defines a path to a redis password file + :param server_container_config_data: A pytest fixture of test data to pass to the ContainerConfig class + :param server_process_config_data: A pytest fixture of test data to pass to the ProcessConfig class + :returns: A dict with typical app.yaml contents + """ + contents = { + "broker": { + "cert_reqs": "none", + "name": "redis", + "password": server_redis_pass_file, + "port": 6379, + "server": "127.0.0.1", + "username": "default", + "vhost": "testhost", + }, + "container": server_container_config_data, + "process": server_process_config_data, + "results_backend": { + "cert_reqs": "none", + "db_num": 0, + "name": "redis", + "password": server_redis_pass_file, + "port": 6379, + "server": "127.0.0.1", + "username": "default", + } + } + return contents + + +@pytest.fixture(scope="function") +def server_app_yaml(server_testing_dir: str, server_app_yaml_contents: dict) -> str: + """ + Fixture to create an app.yaml file in the temporary output directory. + + If a test will modify this file with a file write, you should make a copy of + this file to modify instead. + + NOTE this must be function scoped since server_app_yaml_contents is function scoped. + + :param server_testing_dir: A pytest fixture that defines a path to the output directory we'll write to + :param server_app_yaml_contents: A pytest fixture that creates a dict of contents for an app.yaml file + :returns: The path to the app.yaml file + """ + app_yaml_file = f"{server_testing_dir}/app.yaml" + + if not os.path.exists(app_yaml_file): + with open(app_yaml_file, "w") as ayf: + yaml.dump(server_app_yaml_contents, ayf) + + return app_yaml_file \ No newline at end of file diff --git a/tests/unit/server/test_server_util.py b/tests/unit/server/test_server_util.py index 2986e22de..20ff922bb 100644 --- a/tests/unit/server/test_server_util.py +++ b/tests/unit/server/test_server_util.py @@ -12,6 +12,7 @@ ContainerConfig, ContainerFormatConfig, ProcessConfig, + RedisConfig, RedisUsers, ServerConfig, valid_ipv4, @@ -288,6 +289,7 @@ def test_init_with_missing_data(self, server_process_config_data: Dict[str, str] assert config.container is None assert config.container_format is None + class TestRedisUsers: """ Tests for the RedisUsers class. @@ -452,3 +454,71 @@ def test_remove_user_invalid(self, server_redis_users_file: str): """ redis_users = RedisUsers(server_redis_users_file) assert not redis_users.remove_user("nonexistent_user") + + +class TestAppYaml: + """Tests for the AppYaml class.""" + + def test_initialization(self, server_app_yaml: str, server_app_yaml_contents: dict): + """ + Test the initialization process of the AppYaml class. + + :param server_app_yaml: The path to an app.yaml file + :param server_app_yaml_contents: A dict of app.yaml configurations + """ + app_yaml = AppYaml(server_app_yaml) + assert app_yaml.get_data() == server_app_yaml_contents + + def test_apply_server_config(self, server_app_yaml: str, server_server_config: Dict[str, str]): + """ + Test the `apply_server_config` method. This should update the data attribute. + + :param server_app_yaml: The path to an app.yaml file + :param server_server_config: A pytest fixture of test data to pass to the ServerConfig class + """ + app_yaml = AppYaml(server_app_yaml) + server_config = ServerConfig(server_server_config) + redis_config = RedisConfig(server_config.container.get_config_path()) + app_yaml.apply_server_config(server_config) + + assert app_yaml.data[app_yaml.broker_name]["name"] == server_config.container.get_image_type() + assert app_yaml.data[app_yaml.broker_name]["username"] == "default" + assert app_yaml.data[app_yaml.broker_name]["password"] == server_config.container.get_pass_file_path() + assert app_yaml.data[app_yaml.broker_name]["server"] == redis_config.get_ip_address() + assert app_yaml.data[app_yaml.broker_name]["port"] == redis_config.get_port() + + assert app_yaml.data[app_yaml.results_name]["name"] == server_config.container.get_image_type() + assert app_yaml.data[app_yaml.results_name]["username"] == "default" + assert app_yaml.data[app_yaml.results_name]["password"] == server_config.container.get_pass_file_path() + assert app_yaml.data[app_yaml.results_name]["server"] == redis_config.get_ip_address() + assert app_yaml.data[app_yaml.results_name]["port"] == redis_config.get_port() + + def test_update_data(self, server_app_yaml: str): + """ + Test the `update_data` method. This should update the data attribute. + + :param server_app_yaml: The path to an app.yaml file + """ + app_yaml = AppYaml(server_app_yaml) + new_data = {app_yaml.broker_name: {"username": "new_user"}} + app_yaml.update_data(new_data) + + assert app_yaml.data[app_yaml.broker_name]["username"] == "new_user" + + def test_write(self, server_app_yaml: str, server_testing_dir: str): + """ + Test the `write` method. This should write data to a file. + + :param server_app_yaml: The path to an app.yaml file + :param server_testing_dir: The path to the the temp output directory for server tests + """ + copy_app_yaml = f"{server_testing_dir}/app_copy.yaml" + + # Create a AppYaml object with the basic app.yaml file + app_yaml = AppYaml(server_app_yaml) + + # Run the test + app_yaml.write(copy_app_yaml) + + # Check that the contents of the copied file match the contents of the basic file + assert filecmp.cmp(server_app_yaml, copy_app_yaml) From bde90797c0938ea6844fd74ed4e23a28eab341dd Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Thu, 6 Jun 2024 12:06:09 -0700 Subject: [PATCH 37/44] final cleanup of server_utils --- tests/fixtures/server.py | 9 ++-- tests/unit/server/test_server_util.py | 67 ++++++++++++++++----------- 2 files changed, 42 insertions(+), 34 deletions(-) diff --git a/tests/fixtures/server.py b/tests/fixtures/server.py index 284084e2c..01db7bd56 100644 --- a/tests/fixtures/server.py +++ b/tests/fixtures/server.py @@ -4,7 +4,7 @@ import os import pytest import yaml -from typing import Dict +from typing import Dict, Union @pytest.fixture(scope="session") @@ -88,7 +88,7 @@ def server_redis_pass_file(server_testing_dir: str) -> str: @pytest.fixture(scope="session") -def server_users() -> dict: +def server_users() -> Dict[str, Dict[str, str]]: """ Create a dictionary of two test users with identical configuration settings. @@ -217,7 +217,7 @@ def server_app_yaml_contents( server_redis_pass_file: str, server_container_config_data: Dict[str, str], server_process_config_data: Dict[str, str], -) -> Dict[str, str]: +) -> Dict[str, Union[str, int]]: """ Fixture to create the contents of an app.yaml file. @@ -256,9 +256,6 @@ def server_app_yaml(server_testing_dir: str, server_app_yaml_contents: dict) -> """ Fixture to create an app.yaml file in the temporary output directory. - If a test will modify this file with a file write, you should make a copy of - this file to modify instead. - NOTE this must be function scoped since server_app_yaml_contents is function scoped. :param server_testing_dir: A pytest fixture that defines a path to the output directory we'll write to diff --git a/tests/unit/server/test_server_util.py b/tests/unit/server/test_server_util.py index 20ff922bb..c9b59e83e 100644 --- a/tests/unit/server/test_server_util.py +++ b/tests/unit/server/test_server_util.py @@ -30,8 +30,8 @@ def test_valid_ipv4_valid_ip(valid_ip: str): Test the `valid_ipv4` function with valid IPs. This should return True. - :param valid_ip: A valid port to test. - These are pulled from the parametrized list defined above this test. + :param valid_ip: A valid port to test. These are pulled from the parametrized + list defined above this test. """ assert valid_ipv4(valid_ip) @@ -47,8 +47,8 @@ def test_valid_ipv4_invalid_ip(invalid_ip: Union[str, None]): An IP is valid if every integer separated by the '.' delimiter are between 0 and 255. This should return False for both IPs tested here. - :param invalid_ip: An invalid port to test. - These are pulled from the parametrized list defined above this test. + :param invalid_ip: An invalid port to test. These are pulled from the parametrized + list defined above this test. """ assert not valid_ipv4(invalid_ip) @@ -63,8 +63,8 @@ def test_valid_port_valid_input(valid_input: int): Valid ports are ports between 1 and 65535. This should return True. - :param valid_input: A valid input value to test. - These are pulled from the parametrized list defined above this test. + :param valid_input: A valid input value to test. These are pulled from the parametrized + list defined above this test. """ assert valid_port(valid_input) @@ -79,8 +79,8 @@ def test_valid_port_invalid_input(invalid_input: int): Valid ports are ports between 1 and 65535. This should return False for each invalid input tested. - :param invalid_input: An invalid input value to test. - These are pulled from the parametrized list defined above this test. + :param invalid_input: An invalid input value to test. These are pulled from the parametrized + list defined above this test. """ assert not valid_port(invalid_input) @@ -90,7 +90,7 @@ class TestContainerConfig: def test_init_with_complete_data(self, server_container_config_data: Dict[str, str]): """ - Tests that __init__ populates attributes correctly with complete data + Tests that __init__ populates attributes correctly with complete data. :param server_container_config_data: A pytest fixture of test data to pass to the ContainerConfig class """ @@ -107,7 +107,7 @@ def test_init_with_complete_data(self, server_container_config_data: Dict[str, s def test_init_with_missing_data(self): """ - Tests that __init__ uses defaults for missing data + Tests that __init__ uses defaults for missing data. """ incomplete_data = {"format": "docker"} config = ContainerConfig(incomplete_data) @@ -130,7 +130,7 @@ def test_init_with_missing_data(self): ]) def test_get_path_methods(self, server_container_config_data: Dict[str, str], attr_name: str): """ - Tests that get_*_path methods construct the correct path + Tests that get_*_path methods construct the correct path. :param server_container_config_data: A pytest fixture of test data to pass to the ContainerConfig class :param attr_name: Name of the attribute to be tested. These are pulled from the parametrized list defined above this test. @@ -153,7 +153,7 @@ def test_get_path_methods(self, server_container_config_data: Dict[str, str], at ]) def test_getter_methods(self, server_container_config_data: Dict[str, str], getter_name: str, expected_attr: str): """ - Tests that all getter methods return the correct attribute values + Tests that all getter methods return the correct attribute values. :param server_container_config_data: A pytest fixture of test data to pass to the ContainerConfig class :param getter_name: Name of the getter method to test. This is pulled from the parametrized list defined above this test. @@ -163,20 +163,31 @@ def test_getter_methods(self, server_container_config_data: Dict[str, str], gett getter = getattr(config, getter_name) assert getter() == server_container_config_data[expected_attr] - def test_get_container_password(self, server_container_config_data: Dict[str, str]): + def test_get_container_password(self, server_testing_dir: str, server_container_config_data: Dict[str, str]): """ - Test that the get_container_password is reading the password file properly + Test that the `get_container_password` method is reading the password file properly. + :param server_testing_dir: The path to the the temp output directory for server tests :param server_container_config_data: A pytest fixture of test data to pass to the ContainerConfig class """ # Write a fake password to the password file - test_password = "server-tests-password" - with open(server_container_config_data["pass_file"], "w") as pass_file: + test_password = "super-secret-password" + temp_pass_file = f"{server_testing_dir}/temp.pass" + with open(temp_pass_file, "w") as pass_file: pass_file.write(test_password) - # Run the test - config = ContainerConfig(server_container_config_data) - assert config.get_container_password() == test_password + # Use temp pass file + orig_pass_file = server_container_config_data["pass_file"] + server_container_config_data["pass_file"] = temp_pass_file + + try: + # Run the test + config = ContainerConfig(server_container_config_data) + assert config.get_container_password() == test_password + except Exception as exc: + # If there was a problem, reset to the original password file + server_container_config_data["pass_file"] = orig_pass_file + raise exc class TestContainerFormatConfig: @@ -184,7 +195,7 @@ class TestContainerFormatConfig: def test_init_with_complete_data(self, server_container_format_config_data: Dict[str, str]): """ - Tests that __init__ populates attributes correctly with complete data + Tests that __init__ populates attributes correctly with complete data. :param server_container_format_config_data: A pytest fixture of test data to pass to the ContainerFormatConfig class """ @@ -196,7 +207,7 @@ def test_init_with_complete_data(self, server_container_format_config_data: Dict def test_init_with_missing_data(self): """ - Tests that __init__ uses defaults for missing data + Tests that __init__ uses defaults for missing data. """ incomplete_data = {"command": "docker"} config = ContainerFormatConfig(incomplete_data) @@ -213,7 +224,7 @@ def test_init_with_missing_data(self): ]) def test_getter_methods(self, server_container_format_config_data: Dict[str, str], getter_name: str, expected_attr: str): """ - Tests that all getter methods return the correct attribute values + Tests that all getter methods return the correct attribute values. :param server_container_format_config_data: A pytest fixture of test data to pass to the ContainerFormatConfig class :param getter_name: Name of the getter method to test. This is pulled from the parametrized list defined above this test. @@ -229,7 +240,7 @@ class TestProcessConfig: def test_init_with_complete_data(self, server_process_config_data: Dict[str, str]): """ - Tests that __init__ populates attributes correctly with complete data + Tests that __init__ populates attributes correctly with complete data. :param server_process_config_data: A pytest fixture of test data to pass to the ProcessConfig class """ @@ -239,7 +250,7 @@ def test_init_with_complete_data(self, server_process_config_data: Dict[str, str def test_init_with_missing_data(self): """ - Tests that __init__ uses defaults for missing data + Tests that __init__ uses defaults for missing data. """ incomplete_data = {"status": "status {pid}"} config = ProcessConfig(incomplete_data) @@ -252,7 +263,7 @@ def test_init_with_missing_data(self): ]) def test_getter_methods(self, server_process_config_data: Dict[str, str], getter_name: str, expected_attr: str): """ - Tests that all getter methods return the correct attribute values + Tests that all getter methods return the correct attribute values. :param server_process_config_data: A pytest fixture of test data to pass to the ProcessConfig class :param getter_name: Name of the getter method to test. This is pulled from the parametrized list defined above this test. @@ -268,7 +279,7 @@ class TestServerConfig: def test_init_with_complete_data(self, server_server_config: Dict[str, str]): """ - Tests that __init__ populates attributes correctly with complete data + Tests that __init__ populates attributes correctly with complete data. :param server_server_config: A pytest fixture of test data to pass to the ServerConfig class """ @@ -279,7 +290,7 @@ def test_init_with_complete_data(self, server_server_config: Dict[str, str]): def test_init_with_missing_data(self, server_process_config_data: Dict[str, str]): """ - Tests that __init__ uses None for missing data + Tests that __init__ uses None for missing data. :param server_process_config_data: A pytest fixture of test data to pass to the ContainerConfig class """ @@ -298,7 +309,7 @@ class TestRedisUsers: """ class TestUser: - """Tests for the RedisUsers.User class""" + """Tests for the RedisUsers.User class.""" def test_initializaiton(self): """Test the initialization process of the User class.""" From da94020e72e5aa66544a30efe18969dbd04ee704 Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Thu, 6 Jun 2024 12:31:00 -0700 Subject: [PATCH 38/44] fix lint issues --- merlin/examples/generator.py | 2 +- merlin/server/server_util.py | 8 +- tests/conftest.py | 15 ++- tests/context_managers/server_manager.py | 1 + tests/fixtures/server.py | 37 +++--- tests/fixtures/status.py | 5 +- tests/unit/common/test_dumper.py | 34 ++++-- tests/unit/common/test_encryption.py | 1 + tests/unit/common/test_sample_index.py | 1 + tests/unit/common/test_util_sampling.py | 1 + tests/unit/config/test_broker.py | 1 + tests/unit/config/test_config_object.py | 1 + tests/unit/config/test_configfile.py | 1 + tests/unit/config/test_results_backend.py | 1 + tests/unit/server/test_RedisConfig.py | 132 ++++++++++++--------- tests/unit/server/test_server_util.py | 138 +++++++++++++--------- tests/unit/test_examples_generator.py | 1 + tests/utils.py | 1 + 18 files changed, 232 insertions(+), 149 deletions(-) diff --git a/merlin/examples/generator.py b/merlin/examples/generator.py index d05f5c234..285b946d8 100644 --- a/merlin/examples/generator.py +++ b/merlin/examples/generator.py @@ -146,5 +146,5 @@ def setup_example(name, outdir): LOG.info(f"Copying example '{name}' to {outdir}") write_example(src_path, outdir) - print(f'example: {example}') + print(f"example: {example}") return example diff --git a/merlin/server/server_util.py b/merlin/server/server_util.py index 9b7233097..741cdb832 100644 --- a/merlin/server/server_util.py +++ b/merlin/server/server_util.py @@ -124,7 +124,7 @@ def __init__(self, data: dict) -> None: def __eq__(self, other: "ContainerFormatConfig"): """ Equality magic method used for testing this class - + :param other: Another ContainerFormatConfig object to check if they're the same """ variables = ("format", "image_type", "image", "url", "config", "config_dir", "pfile", "pass_file", "user_file") @@ -220,7 +220,7 @@ def __init__(self, data: dict) -> None: def __eq__(self, other: "ContainerFormatConfig"): """ Equality magic method used for testing this class - + :param other: Another ContainerFormatConfig object to check if they're the same """ variables = ("command", "run_command", "stop_command", "pull_command") @@ -263,7 +263,7 @@ def __init__(self, data: dict) -> None: def __eq__(self, other: "ProcessConfig"): """ Equality magic method used for testing this class - + :param other: Another ProcessConfig object to check if they're the same """ variables = ("status", "kill") @@ -441,7 +441,7 @@ def set_snapshot(self, seconds: int = None, changes: int = None) -> bool: """ Sets the 'seconds' and/or 'changes' values of the snapshot setting, depending on what the user requests. - + :param seconds: The first value of snapshot to change. If we're leaving it the same this will be None. :param changes: The second value of snapshot to change. If we're leaving it the diff --git a/tests/conftest.py b/tests/conftest.py index ce5cf7571..6ddfb8474 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -49,6 +49,9 @@ from tests.utils import create_cert_files, create_pass_file +# pylint: disable=redefined-outer-name + + ####################################### # Loading in Module Specific Fixtures # ####################################### @@ -117,7 +120,7 @@ def temp_output_dir(tmp_path_factory: TempPathFactory) -> str: @pytest.fixture(scope="session") -def merlin_server_dir(temp_output_dir: str) -> str: # pylint: disable=redefined-outer-name +def merlin_server_dir(temp_output_dir: str) -> str: """ The path to the merlin_server directory that will be created by the `redis_server` fixture. @@ -131,7 +134,7 @@ def merlin_server_dir(temp_output_dir: str) -> str: # pylint: disable=redefined @pytest.fixture(scope="session") -def redis_server(merlin_server_dir: str, test_encryption_key: bytes) -> str: # pylint: disable=redefined-outer-name +def redis_server(merlin_server_dir: str, test_encryption_key: bytes) -> str: """ Start a redis server instance that runs on localhost:6379. This will yield the redis server uri that can be used to create a connection with celery. @@ -152,7 +155,7 @@ def redis_server(merlin_server_dir: str, test_encryption_key: bytes) -> str: # @pytest.fixture(scope="session") -def celery_app(redis_server: str) -> Celery: # pylint: disable=redefined-outer-name +def celery_app(redis_server: str) -> Celery: """ Create the celery app to be used throughout our integration tests. @@ -163,7 +166,7 @@ def celery_app(redis_server: str) -> Celery: # pylint: disable=redefined-outer- @pytest.fixture(scope="session") -def sleep_sig(celery_app: Celery) -> Signature: # pylint: disable=redefined-outer-name +def sleep_sig(celery_app: Celery) -> Signature: """ Create a task registered to our celery app and return a signature for it. Once requested by a test, you can set the queue you'd like to send this to @@ -195,7 +198,7 @@ def worker_queue_map() -> Dict[str, str]: @pytest.fixture(scope="class") -def launch_workers(celery_app: Celery, worker_queue_map: Dict[str, str]): # pylint: disable=redefined-outer-name +def launch_workers(celery_app: Celery, worker_queue_map: Dict[str, str]): """ Launch the workers on the celery app fixture using the worker and queue names defined in the worker_queue_map fixture. @@ -238,7 +241,7 @@ def test_encryption_key() -> bytes: @pytest.fixture(scope="function") -def config(merlin_server_dir: str, test_encryption_key: bytes): # pylint: disable=redefined-outer-name +def config(merlin_server_dir: str, test_encryption_key: bytes): """ DO NOT USE THIS FIXTURE IN A TEST, USE `redis_config` OR `rabbit_config` INSTEAD. This fixture is intended to be used strictly by the `redis_config` and `rabbit_config` diff --git a/tests/context_managers/server_manager.py b/tests/context_managers/server_manager.py index ea6a731ff..c88948772 100644 --- a/tests/context_managers/server_manager.py +++ b/tests/context_managers/server_manager.py @@ -2,6 +2,7 @@ Module to define functionality for managing the containerized server used for testing. """ + import os import signal import subprocess diff --git a/tests/fixtures/server.py b/tests/fixtures/server.py index 01db7bd56..156a374d7 100644 --- a/tests/fixtures/server.py +++ b/tests/fixtures/server.py @@ -1,10 +1,15 @@ """ Fixtures specifically for help testing the modules in the server/ directory. """ + import os +from typing import Dict, Union + import pytest import yaml -from typing import Dict, Union + + +# pylint: disable=redefined-outer-name @pytest.fixture(scope="session") @@ -60,7 +65,9 @@ def server_redis_conf_file(server_testing_dir: str) -> str: appendfilename appendonly.aof # dummy trailing comment - """.strip().replace(" ", "") + """.strip().replace( + " ", "" + ) with open(redis_conf_file, "w") as rcf: rcf.write(file_contents) @@ -96,19 +103,19 @@ def server_users() -> Dict[str, Dict[str, str]]: """ users = { "default": { - "channels": '*', - "commands": '@all', - "hash_password": '1ba9249af0c73dacb0f9a70567126624076b5bee40de811e65f57eabcdaf490a', - "keys": '*', - "status": 'on', + "channels": "*", + "commands": "@all", + "hash_password": "1ba9249af0c73dacb0f9a70567126624076b5bee40de811e65f57eabcdaf490a", + "keys": "*", + "status": "on", }, "test_user": { - "channels": '*', - "commands": '@all', - "hash_password": '1ba9249af0c73dacb0f9a70567126624076b5bee40de811e65f57eabcdaf490a', - "keys": '*', - "status": 'on', - } + "channels": "*", + "commands": "@all", + "hash_password": "1ba9249af0c73dacb0f9a70567126624076b5bee40de811e65f57eabcdaf490a", + "keys": "*", + "status": "on", + }, } return users @@ -246,7 +253,7 @@ def server_app_yaml_contents( "port": 6379, "server": "127.0.0.1", "username": "default", - } + }, } return contents @@ -268,4 +275,4 @@ def server_app_yaml(server_testing_dir: str, server_app_yaml_contents: dict) -> with open(app_yaml_file, "w") as ayf: yaml.dump(server_app_yaml_contents, ayf) - return app_yaml_file \ No newline at end of file + return app_yaml_file diff --git a/tests/fixtures/status.py b/tests/fixtures/status.py index ab0de5d1e..3ae8dcaa7 100644 --- a/tests/fixtures/status.py +++ b/tests/fixtures/status.py @@ -9,6 +9,9 @@ import pytest +# pylint: disable=redefined-outer-name + + @pytest.fixture(scope="class") def status_testing_dir(temp_output_dir: str) -> str: """ @@ -24,7 +27,7 @@ def status_testing_dir(temp_output_dir: str) -> str: @pytest.fixture(scope="class") -def status_empty_file(status_testing_dir: str) -> str: # pylint: disable=W0621 +def status_empty_file(status_testing_dir: str) -> str: """ A pytest fixture to create an empty status file. diff --git a/tests/unit/common/test_dumper.py b/tests/unit/common/test_dumper.py index 7c437fde9..c52e9fe90 100644 --- a/tests/unit/common/test_dumper.py +++ b/tests/unit/common/test_dumper.py @@ -1,21 +1,27 @@ """ Tests for the `dumper.py` file. """ + import csv import json import os -import pytest - from datetime import datetime from time import sleep +import pytest + from merlin.common.dumper import dump_handler + NUM_ROWS = 5 -CSV_INFO_TO_DUMP = {"row_num": [i for i in range(1, NUM_ROWS+1)], "other_info": [f"test_info_{i}" for i in range(1, NUM_ROWS+1)]} -JSON_INFO_TO_DUMP = {str(i): {f"other_info_{i}": f"test_info_{i}"} for i in range(1, NUM_ROWS+1)} +CSV_INFO_TO_DUMP = { + "row_num": [i for i in range(1, NUM_ROWS + 1)], + "other_info": [f"test_info_{i}" for i in range(1, NUM_ROWS + 1)], +} +JSON_INFO_TO_DUMP = {str(i): {f"other_info_{i}": f"test_info_{i}"} for i in range(1, NUM_ROWS + 1)} DUMP_HANDLER_DIR = "{temp_output_dir}/dump_handler" + def test_dump_handler_invalid_dump_file(): """ This is really testing the initialization of the Dumper class with an invalid file type. @@ -25,6 +31,7 @@ def test_dump_handler_invalid_dump_file(): dump_handler("bad_file.txt", CSV_INFO_TO_DUMP) assert "Invalid file type for bad_file.txt. Supported file types are: ['csv', 'json']" in str(excinfo.value) + def get_output_file(temp_dir: str, file_name: str): """ Helper function to get a full path to the temporary output file. @@ -38,6 +45,7 @@ def get_output_file(temp_dir: str, file_name: str): dump_file = f"{dump_dir}/{file_name}" return dump_file + def run_csv_dump_test(dump_file: str, fmode: str): """ Run the test for csv dump. @@ -52,16 +60,17 @@ def run_csv_dump_test(dump_file: str, fmode: str): reader = csv.reader(df) written_data = list(reader) - expected_rows = NUM_ROWS*2 if fmode == "a" else NUM_ROWS - assert len(written_data) == expected_rows+1 # Adding one because of the header row + expected_rows = NUM_ROWS * 2 if fmode == "a" else NUM_ROWS + assert len(written_data) == expected_rows + 1 # Adding one because of the header row for i, row in enumerate(written_data): assert len(row) == 2 # Check number of columns if i == 0: # Checking the header row assert row[0] == "row_num" assert row[1] == "other_info" else: # Checking the data rows - assert row[0] == str(CSV_INFO_TO_DUMP["row_num"][(i%NUM_ROWS)-1]) - assert row[1] == str(CSV_INFO_TO_DUMP["other_info"][(i%NUM_ROWS)-1]) + assert row[0] == str(CSV_INFO_TO_DUMP["row_num"][(i % NUM_ROWS) - 1]) + assert row[1] == str(CSV_INFO_TO_DUMP["other_info"][(i % NUM_ROWS) - 1]) + def test_dump_handler_csv_write(temp_output_dir: str): """ @@ -80,6 +89,7 @@ def test_dump_handler_csv_write(temp_output_dir: str): # Assert that everything ran properly run_csv_dump_test(dump_file, "w") + def test_dump_handler_csv_append(temp_output_dir: str): """ This is really testing the write method of the Dumper class with the file write mode set to append. @@ -93,13 +103,14 @@ def test_dump_handler_csv_append(temp_output_dir: str): # Run the first call to create the csv file dump_handler(dump_file, CSV_INFO_TO_DUMP) - + # Run the second call to append to the csv file dump_handler(dump_file, CSV_INFO_TO_DUMP) # Assert that everything ran properly run_csv_dump_test(dump_file, "a") + def test_dump_handler_json_write(temp_output_dir: str): """ This is really testing the write method of the Dumper class. @@ -120,6 +131,7 @@ def test_dump_handler_json_write(temp_output_dir: str): contents = json.load(df) assert contents == JSON_INFO_TO_DUMP + def test_dump_handler_json_append(temp_output_dir: str): """ This is really testing the write method of the Dumper class with the file write mode set to append. @@ -137,7 +149,7 @@ def test_dump_handler_json_append(temp_output_dir: str): dump_handler(dump_file, first_dump) # Sleep so we don't accidentally get the same timestamp - sleep(.5) + sleep(0.5) # Run the second call to append to the file timestamp_2 = str(datetime.now()) @@ -153,4 +165,4 @@ def test_dump_handler_json_append(temp_output_dir: str): assert timestamp_1 in keys assert timestamp_2 in keys assert contents[timestamp_1] == JSON_INFO_TO_DUMP - assert contents[timestamp_2] == JSON_INFO_TO_DUMP \ No newline at end of file + assert contents[timestamp_2] == JSON_INFO_TO_DUMP diff --git a/tests/unit/common/test_encryption.py b/tests/unit/common/test_encryption.py index d797f68c0..3e37cef84 100644 --- a/tests/unit/common/test_encryption.py +++ b/tests/unit/common/test_encryption.py @@ -1,6 +1,7 @@ """ Tests for the `encrypt.py` and `encrypt_backend_traffic.py` files. """ + import os import celery diff --git a/tests/unit/common/test_sample_index.py b/tests/unit/common/test_sample_index.py index cdb5b2f4f..d857b7ce5 100644 --- a/tests/unit/common/test_sample_index.py +++ b/tests/unit/common/test_sample_index.py @@ -1,6 +1,7 @@ """ Tests for the `sample_index.py` and `sample_index_factory.py` files. """ + import os import pytest diff --git a/tests/unit/common/test_util_sampling.py b/tests/unit/common/test_util_sampling.py index c957ac105..b4cc252d5 100644 --- a/tests/unit/common/test_util_sampling.py +++ b/tests/unit/common/test_util_sampling.py @@ -1,6 +1,7 @@ """ Tests for the `util_sampling.py` file. """ + import numpy as np import pytest diff --git a/tests/unit/config/test_broker.py b/tests/unit/config/test_broker.py index 8af1dda75..581b19488 100644 --- a/tests/unit/config/test_broker.py +++ b/tests/unit/config/test_broker.py @@ -1,6 +1,7 @@ """ Tests for the `broker.py` file. """ + import os from ssl import CERT_NONE from typing import Any, Dict diff --git a/tests/unit/config/test_config_object.py b/tests/unit/config/test_config_object.py index bd658bc66..64e56b7d9 100644 --- a/tests/unit/config/test_config_object.py +++ b/tests/unit/config/test_config_object.py @@ -1,6 +1,7 @@ """ Test the functionality of the Config object. """ + from copy import copy, deepcopy from types import SimpleNamespace diff --git a/tests/unit/config/test_configfile.py b/tests/unit/config/test_configfile.py index aeb1da941..975e19ee4 100644 --- a/tests/unit/config/test_configfile.py +++ b/tests/unit/config/test_configfile.py @@ -1,6 +1,7 @@ """ Tests for the configfile.py module. """ + import getpass import os import shutil diff --git a/tests/unit/config/test_results_backend.py b/tests/unit/config/test_results_backend.py index 314df6ce7..f49e3e897 100644 --- a/tests/unit/config/test_results_backend.py +++ b/tests/unit/config/test_results_backend.py @@ -1,6 +1,7 @@ """ Tests for the `results_backend.py` file. """ + import os from ssl import CERT_NONE from typing import Any, Dict diff --git a/tests/unit/server/test_RedisConfig.py b/tests/unit/server/test_RedisConfig.py index 12880d4d6..321d2f38a 100644 --- a/tests/unit/server/test_RedisConfig.py +++ b/tests/unit/server/test_RedisConfig.py @@ -4,13 +4,16 @@ This class is especially large so that's why these tests have been moved to their own file. """ + import filecmp import logging -import pytest from typing import Any +import pytest + from merlin.server.server_util import RedisConfig + class TestRedisConfig: """Tests for the RedisConfig class.""" @@ -73,10 +76,7 @@ def test_write(self, server_redis_conf_file: str, server_testing_dir: str): # Check that the contents of the copied file match the contents of the basic file assert filecmp.cmp(server_redis_conf_file, copy_redis_conf_file) - @pytest.mark.parametrize("key, val, expected_return", [ - ("port", 1234, True), - ("invalid_key", "dummy_val", False) - ]) + @pytest.mark.parametrize("key, val, expected_return", [("port", 1234, True), ("invalid_key", "dummy_val", False)]) def test_set_config_value(self, server_redis_conf_file: str, key: str, val: Any, expected_return: bool): """ Test the `set_config_value` method with valid and invalid keys. @@ -95,17 +95,20 @@ def test_set_config_value(self, server_redis_conf_file: str, key: str, val: Any, else: assert not redis_config.changes_made() - @pytest.mark.parametrize("key, expected_val", [ - ("bind", "127.0.0.1"), - ("port", "6379"), - ("requirepass", "merlin_password"), - ("dir", "./"), - ("save", "300 100"), - ("dbfilename", "dump.rdb"), - ("appendfsync", "everysec"), - ("appendfilename", "appendonly.aof"), - ("invalid_key", None) - ]) + @pytest.mark.parametrize( + "key, expected_val", + [ + ("bind", "127.0.0.1"), + ("port", "6379"), + ("requirepass", "merlin_password"), + ("dir", "./"), + ("save", "300 100"), + ("dbfilename", "dump.rdb"), + ("appendfsync", "everysec"), + ("appendfilename", "appendonly.aof"), + ("invalid_key", None), + ], + ) def test_get_config_value(self, server_redis_conf_file: str, key: str, expected_val: str): """ Test the `get_config_value` method with valid and invalid keys. @@ -117,18 +120,16 @@ def test_get_config_value(self, server_redis_conf_file: str, key: str, expected_ redis_conf = RedisConfig(server_redis_conf_file) assert redis_conf.get_config_value(key) == expected_val - @pytest.mark.parametrize("ip_to_set", [ - "127.0.0.1", # Most common IP - "0.0.0.0", # Edge case (low) - "255.255.255.255", # Edge case (high) - "123.222.199.20", # Random valid IP - ]) - def test_set_ip_address_valid( - self, - caplog: "Fixture", # noqa: F821 - server_redis_conf_file: str, - ip_to_set: str - ): + @pytest.mark.parametrize( + "ip_to_set", + [ + "127.0.0.1", # Most common IP + "0.0.0.0", # Edge case (low) + "255.255.255.255", # Edge case (high) + "123.222.199.20", # Random valid IP + ], + ) + def test_set_ip_address_valid(self, caplog: "Fixture", server_redis_conf_file: str, ip_to_set: str): # noqa: F821 """ Test the `set_ip_address` method with valid ips. These should all return True and set the 'bind' value to whatever `ip_to_set` is. @@ -143,11 +144,14 @@ def test_set_ip_address_valid( assert f"Ipaddress is set to {ip_to_set}" in caplog.text, "Missing expected log message" assert redis_config.get_ip_address() == ip_to_set - @pytest.mark.parametrize("ip_to_set, expected_log", [ - (None, None), # No IP - ("0.0.0", "Invalid IPv4 address given."), # Invalid IPv4 - ("bind-unset", "Unable to set ip address for redis config"), # Special invalid case where bind doesn't exist - ]) + @pytest.mark.parametrize( + "ip_to_set, expected_log", + [ + (None, None), # No IP + ("0.0.0", "Invalid IPv4 address given."), # Invalid IPv4 + ("bind-unset", "Unable to set ip address for redis config"), # Special invalid case where bind doesn't exist + ], + ) def test_set_ip_address_invalid( self, caplog: "Fixture", # noqa: F821 @@ -174,12 +178,15 @@ def test_set_ip_address_invalid( if expected_log is not None: assert expected_log in caplog.text, "Missing expected log message" - @pytest.mark.parametrize("port_to_set", [ - 6379, # Most common port - 1, # Edge case (low) - 65535, # Edge case (high) - 12345, # Random valid port - ]) + @pytest.mark.parametrize( + "port_to_set", + [ + 6379, # Most common port + 1, # Edge case (low) + 65535, # Edge case (high) + 12345, # Random valid port + ], + ) def test_set_port_valid( self, caplog: "Fixture", # noqa: F821 @@ -200,12 +207,15 @@ def test_set_port_valid( assert redis_config.get_port() == port_to_set assert f"Port is set to {port_to_set}" in caplog.text, "Missing expected log message" - @pytest.mark.parametrize("port_to_set, expected_log", [ - (None, None), # No port - (0, "Invalid port given."), # Edge case (low) - (65536, "Invalid port given."), # Edge case (high) - ("port-unset", "Unable to set port for redis config"), # Special invalid case where port doesn't exist - ]) + @pytest.mark.parametrize( + "port_to_set, expected_log", + [ + (None, None), # No port + (0, "Invalid port given."), # Edge case (low) + (65536, "Invalid port given."), # Edge case (high) + ("port-unset", "Unable to set port for redis config"), # Special invalid case where port doesn't exist + ], + ) def test_set_port_invalid( self, caplog: "Fixture", # noqa: F821 @@ -232,10 +242,13 @@ def test_set_port_invalid( if expected_log is not None: assert expected_log in caplog.text, "Missing expected log message" - @pytest.mark.parametrize("pass_to_set, expected_return", [ - ("valid_password", True), # Valid password - (None, False), # Invalid password - ]) + @pytest.mark.parametrize( + "pass_to_set, expected_return", + [ + ("valid_password", True), # Valid password + (None, False), # Invalid password + ], + ) def test_set_password( self, caplog: "Fixture", # noqa: F821 @@ -289,7 +302,7 @@ def test_set_directory_none(self, server_redis_conf_file: str): """ redis_config = RedisConfig(server_redis_conf_file) assert not redis_config.set_directory(None) - assert redis_config.get_config_value("dir") != None + assert redis_config.get_config_value("dir") is not None def test_set_directory_dir_unset( self, @@ -327,8 +340,10 @@ def test_set_snapshot_valid(self, caplog: "Fixture", server_redis_conf_file: str save_val = redis_conf.get_config_value("save").split() assert save_val[0] == str(snap_sec_to_set) assert save_val[1] == str(snap_changes_to_set) - expected_log = f"Snapshot wait time is set to {snap_sec_to_set} seconds. " \ - f"Snapshot threshold is set to {snap_changes_to_set} changes" + expected_log = ( + f"Snapshot wait time is set to {snap_sec_to_set} seconds. " + f"Snapshot threshold is set to {snap_changes_to_set} changes" + ) assert expected_log in caplog.text, "Missing expected log message" def test_set_snapshot_just_seconds(self, caplog: "Fixture", server_redis_conf_file: str): # noqa: F821 @@ -432,11 +447,14 @@ def test_set_snapshot_file_dbfilename_unset(self, caplog: "Fixture", server_redi assert redis_conf.get_config_value("dbfilename") != filename assert "Unable to set snapshot_file name" in caplog.text, "Missing expected log message" - @pytest.mark.parametrize("mode_to_set", [ - "always", - "everysec", - "no", - ]) + @pytest.mark.parametrize( + "mode_to_set", + [ + "always", + "everysec", + "no", + ], + ) def test_set_append_mode_valid( self, caplog: "Fixture", # noqa: F821 diff --git a/tests/unit/server/test_server_util.py b/tests/unit/server/test_server_util.py index c9b59e83e..909cb7cdf 100644 --- a/tests/unit/server/test_server_util.py +++ b/tests/unit/server/test_server_util.py @@ -1,12 +1,14 @@ """ Tests for the `server_util.py` module. """ + import filecmp import hashlib import os -import pytest from typing import Dict, Union +import pytest + from merlin.server.server_util import ( AppYaml, ContainerConfig, @@ -16,15 +18,19 @@ RedisUsers, ServerConfig, valid_ipv4, - valid_port + valid_port, ) -@pytest.mark.parametrize("valid_ip", [ - "0.0.0.0", - "127.0.0.1", - "14.105.200.58", - "255.255.255.255", -]) + +@pytest.mark.parametrize( + "valid_ip", + [ + "0.0.0.0", + "127.0.0.1", + "14.105.200.58", + "255.255.255.255", + ], +) def test_valid_ipv4_valid_ip(valid_ip: str): """ Test the `valid_ipv4` function with valid IPs. @@ -35,12 +41,16 @@ def test_valid_ipv4_valid_ip(valid_ip: str): """ assert valid_ipv4(valid_ip) -@pytest.mark.parametrize("invalid_ip", [ - "256.0.0.1", - "-1.0.0.1", - None, - "127.0.01", -]) + +@pytest.mark.parametrize( + "invalid_ip", + [ + "256.0.0.1", + "-1.0.0.1", + None, + "127.0.01", + ], +) def test_valid_ipv4_invalid_ip(invalid_ip: Union[str, None]): """ Test the `valid_ipv4` function with invalid IPs. @@ -52,11 +62,15 @@ def test_valid_ipv4_invalid_ip(invalid_ip: Union[str, None]): """ assert not valid_ipv4(invalid_ip) -@pytest.mark.parametrize("valid_input", [ - 1, - 433, - 65535, -]) + +@pytest.mark.parametrize( + "valid_input", + [ + 1, + 433, + 65535, + ], +) def test_valid_port_valid_input(valid_input: int): """ Test the `valid_port` function with valid port numbers. @@ -68,11 +82,15 @@ def test_valid_port_valid_input(valid_input: int): """ assert valid_port(valid_input) -@pytest.mark.parametrize("invalid_input", [ - -1, - 0, - 65536, -]) + +@pytest.mark.parametrize( + "invalid_input", + [ + -1, + 0, + 65536, + ], +) def test_valid_port_invalid_input(invalid_input: int): """ Test the `valid_port` function with invalid inputs. @@ -121,13 +139,16 @@ def test_init_with_missing_data(self): assert config.pass_file == ContainerConfig.PASSWORD_FILE assert config.user_file == ContainerConfig.USERS_FILE - @pytest.mark.parametrize("attr_name", [ - "image", - "config", - "pfile", - "pass_file", - "user_file", - ]) + @pytest.mark.parametrize( + "attr_name", + [ + "image", + "config", + "pfile", + "pass_file", + "user_file", + ], + ) def test_get_path_methods(self, server_container_config_data: Dict[str, str], attr_name: str): """ Tests that get_*_path methods construct the correct path. @@ -140,17 +161,20 @@ def test_get_path_methods(self, server_container_config_data: Dict[str, str], at expected_path = os.path.join(server_container_config_data["config_dir"], server_container_config_data[attr_name]) assert get_path_method() == expected_path - @pytest.mark.parametrize("getter_name, expected_attr", [ - ("get_format", "format"), - ("get_image_type", "image_type"), - ("get_image_name", "image"), - ("get_image_url", "url"), - ("get_config_name", "config"), - ("get_config_dir", "config_dir"), - ("get_pfile_name", "pfile"), - ("get_pass_file_name", "pass_file"), - ("get_user_file_name", "user_file"), - ]) + @pytest.mark.parametrize( + "getter_name, expected_attr", + [ + ("get_format", "format"), + ("get_image_type", "image_type"), + ("get_image_name", "image"), + ("get_image_url", "url"), + ("get_config_name", "config"), + ("get_config_dir", "config_dir"), + ("get_pfile_name", "pfile"), + ("get_pass_file_name", "pass_file"), + ("get_user_file_name", "user_file"), + ], + ) def test_getter_methods(self, server_container_config_data: Dict[str, str], getter_name: str, expected_attr: str): """ Tests that all getter methods return the correct attribute values. @@ -216,12 +240,15 @@ def test_init_with_missing_data(self): assert config.stop_command == config.STOP_COMMAND assert config.pull_command == config.PULL_COMMAND - @pytest.mark.parametrize("getter_name, expected_attr", [ - ("get_command", "command"), - ("get_run_command", "run_command"), - ("get_stop_command", "stop_command"), - ("get_pull_command", "pull_command"), - ]) + @pytest.mark.parametrize( + "getter_name, expected_attr", + [ + ("get_command", "command"), + ("get_run_command", "run_command"), + ("get_stop_command", "stop_command"), + ("get_pull_command", "pull_command"), + ], + ) def test_getter_methods(self, server_container_format_config_data: Dict[str, str], getter_name: str, expected_attr: str): """ Tests that all getter methods return the correct attribute values. @@ -257,10 +284,13 @@ def test_init_with_missing_data(self): assert config.status == incomplete_data["status"] assert config.kill == config.KILL_COMMAND - @pytest.mark.parametrize("getter_name, expected_attr", [ - ("get_status_command", "status"), - ("get_kill_command", "kill"), - ]) + @pytest.mark.parametrize( + "getter_name, expected_attr", + [ + ("get_status_command", "status"), + ("get_kill_command", "kill"), + ], + ) def test_getter_methods(self, server_process_config_data: Dict[str, str], getter_name: str, expected_attr: str): """ Tests that all getter methods return the correct attribute values. @@ -304,7 +334,7 @@ def test_init_with_missing_data(self, server_process_config_data: Dict[str, str] class TestRedisUsers: """ Tests for the RedisUsers class. - + TODO add integration test(s) for `apply_to_redis` method of this class. """ @@ -358,7 +388,7 @@ def test_get_user_dict(self): assert val == "on" else: assert val == test_dict[key] - + def test_set_password(self): """Test the `set_password` method of the User class.""" user = RedisUsers.User() diff --git a/tests/unit/test_examples_generator.py b/tests/unit/test_examples_generator.py index 5a05e3599..fe7378540 100644 --- a/tests/unit/test_examples_generator.py +++ b/tests/unit/test_examples_generator.py @@ -1,6 +1,7 @@ """ Tests for the `merlin/examples/generator.py` module. """ + import os from typing import List diff --git a/tests/utils.py b/tests/utils.py index d883b83cd..0b408db54 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,6 +1,7 @@ """ Utility functions for our test suite. """ + import os from typing import Dict From 2997de6624d1cebeaa98929e3f315948ee0721f9 Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Thu, 6 Jun 2024 14:09:18 -0700 Subject: [PATCH 39/44] parametrize setup examples tests --- tests/fixtures/examples.py | 20 + tests/unit/test_examples_generator.py | 572 ++++++++++---------------- 2 files changed, 248 insertions(+), 344 deletions(-) create mode 100644 tests/fixtures/examples.py diff --git a/tests/fixtures/examples.py b/tests/fixtures/examples.py new file mode 100644 index 000000000..16a2f576d --- /dev/null +++ b/tests/fixtures/examples.py @@ -0,0 +1,20 @@ +""" +Fixtures specifically for help testing the modules in the examples/ directory. +""" + +import os +import pytest + +@pytest.fixture(scope="session") +def examples_testing_dir(temp_output_dir: str) -> str: + """ + Fixture to create a temporary output directory for tests related to the examples functionality. + + :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + :returns: The path to the temporary testing directory for examples tests + """ + testing_dir = f"{temp_output_dir}/examples_testing" + if not os.path.exists(testing_dir): + os.mkdir(testing_dir) + + return testing_dir \ No newline at end of file diff --git a/tests/unit/test_examples_generator.py b/tests/unit/test_examples_generator.py index fe7378540..3f0f2df9d 100644 --- a/tests/unit/test_examples_generator.py +++ b/tests/unit/test_examples_generator.py @@ -3,6 +3,7 @@ """ import os +import pytest from typing import List from tabulate import tabulate @@ -83,32 +84,27 @@ def test_gather_all_examples(): assert sorted(actual) == sorted(expected) -def test_write_example_dir(temp_output_dir: str): +def test_write_example_dir(examples_testing_dir: str): """ Test the `write_example` function with the src_path as a directory. - :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + :param examples_testing_dir: The path to the the temp output directory for examples tests """ - generator_dir = EXAMPLES_GENERATOR_DIR.format(temp_output_dir=temp_output_dir) dir_to_copy = f"{EXAMPLES_DIR}/feature_demo/" + dst_dir = f"{examples_testing_dir}/write_example_dir" + write_example(dir_to_copy, dst_dir) + assert sorted(os.listdir(dir_to_copy)) == sorted(os.listdir(dst_dir)) - write_example(dir_to_copy, generator_dir) - assert sorted(os.listdir(dir_to_copy)) == sorted(os.listdir(generator_dir)) - -def test_write_example_file(temp_output_dir: str): +def test_write_example_file(examples_testing_dir: str): """ Test the `write_example` function with the src_path as a file. - :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + :param examples_testing_dir: The path to the the temp output directory for examples tests """ - generator_dir = EXAMPLES_GENERATOR_DIR.format(temp_output_dir=temp_output_dir) - create_dir(generator_dir) - - dst_path = f"{generator_dir}/flux_par.yaml" file_to_copy = f"{EXAMPLES_DIR}/flux/flux_par.yaml" - - write_example(file_to_copy, generator_dir) + dst_path = f"{examples_testing_dir}/flux_par.yaml" + write_example(file_to_copy, dst_path) assert os.path.exists(dst_path) @@ -174,6 +170,8 @@ def test_list_examples(): ] expected = "\n" + tabulate(expected_rows, expected_headers) + "\n" actual = list_examples() + print(f"expected:\n{expected}") + print(f"actual:\n{actual}") assert actual == expected @@ -185,7 +183,7 @@ def test_setup_example_invalid_name(): assert setup_example("invalid_example_name", None) is None -def test_setup_example_no_outdir(temp_output_dir: str): +def test_setup_example_no_outdir(examples_testing_dir: str): """ Test the `setup_example` function with an invalid example name. This should create a directory with the example name (in this case hello) @@ -194,14 +192,12 @@ def test_setup_example_no_outdir(temp_output_dir: str): the `setup_example` function creates the hello/ subdirectory in a directory with the name of this test (setup_no_outdir). - :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + :param examples_testing_dir: The path to the the temp output directory for examples tests """ cwd = os.getcwd() # Create the temp path to store this setup and move into that directory - generator_dir = EXAMPLES_GENERATOR_DIR.format(temp_output_dir=temp_output_dir) - create_dir(generator_dir) - setup_example_dir = os.path.join(generator_dir, "setup_no_outdir") + setup_example_dir = os.path.join(examples_testing_dir, "setup_no_outdir") create_dir(setup_example_dir) os.chdir(setup_example_dir) @@ -229,37 +225,226 @@ def test_setup_example_no_outdir(temp_output_dir: str): raise AssertionError from exc -def test_setup_example_outdir_exists(temp_output_dir: str): +def test_setup_example_outdir_exists(examples_testing_dir: str): """ Test the `setup_example` function with an output directory that already exists. This should just return None. - :param temp_output_dir: The path to the temporary output directory we'll be using for this test run - """ - generator_dir = EXAMPLES_GENERATOR_DIR.format(temp_output_dir=temp_output_dir) - create_dir(generator_dir) - - assert setup_example("hello", generator_dir) is None - - -##################################### -# Tests for setting up each example # -##################################### - - -def run_setup_example(temp_output_dir: str, example_name: str, example_files: List[str], expected_return: str): + :param examples_testing_dir: The path to the the temp output directory for examples tests + """ + assert setup_example("hello", examples_testing_dir) is None + + +@pytest.mark.parametrize( + "example_name, example_files, expected_return", + [ + ( + "feature_demo", + [ + ".gitignore", + "feature_demo.yaml", + "requirements.txt", + "scripts/features.json", + "scripts/hello_world.py", + "scripts/pgen.py", + ], + "feature_demo", + ), + ( + "flux_local", + [ + "flux_local.yaml", + "flux_par_restart.yaml", + "flux_par.yaml", + "paper.yaml", + "requirements.txt", + "scripts/flux_info.py", + "scripts/hello_sleep.c", + "scripts/hello.c", + "scripts/make_samples.py", + "scripts/paper_workers.sbatch", + "scripts/test_workers.sbatch", + "scripts/workers.sbatch", + "scripts/workers.bsub", + ], + "flux", + ), + ( + "lsf_par", + [ + "lsf_par_srun.yaml", + "lsf_par.yaml", + "scripts/hello.c", + "scripts/make_samples.py", + ], + "lsf", + ), + ( + "slurm_par", + [ + "slurm_par.yaml", + "slurm_par_restart.yaml", + "requirements.txt", + "scripts/hello.c", + "scripts/make_samples.py", + "scripts/test_workers.sbatch", + "scripts/workers.sbatch", + ], + "slurm", + ), + ( + "hello", + [ + "hello_samples.yaml", + "hello.yaml", + "my_hello.yaml", + "requirements.txt", + "make_samples.py", + ], + "hello", + ), + ( + "hpc_demo", + [ + "hpc_demo.yaml", + "cumulative_sample_processor.py", + "faker_sample.py", + "sample_collector.py", + "sample_processor.py", + "requirements.txt", + ], + "hpc_demo", + ), + ( + "iterative_demo", + [ + "iterative_demo.yaml", + "cumulative_sample_processor.py", + "faker_sample.py", + "sample_collector.py", + "sample_processor.py", + "requirements.txt", + ], + "iterative_demo", + ), + ( + "null_spec", + [ + "null_spec.yaml", + "null_chain.yaml", + ".gitignore", + "Makefile", + "requirements.txt", + "scripts/aggregate_chain_output.sh", + "scripts/aggregate_output.sh", + "scripts/check_completion.sh", + "scripts/kill_all.sh", + "scripts/launch_chain_job.py", + "scripts/launch_jobs.py", + "scripts/make_samples.py", + "scripts/read_output_chain.py", + "scripts/read_output.py", + "scripts/search.sh", + "scripts/submit_chain.sbatch", + "scripts/submit.sbatch", + ], + "null_spec", + ), + ( + "openfoam_wf", + [ + "openfoam_wf.yaml", + "openfoam_wf_docker_template.yaml", + "README.md", + "requirements.txt", + "scripts/make_samples.py", + "scripts/blockMesh_template.txt", + "scripts/cavity_setup.sh", + "scripts/combine_outputs.py", + "scripts/learn.py", + "scripts/mesh_param_script.py", + "scripts/run_openfoam", + ], + "openfoam_wf", + ), + ( + "openfoam_wf_no_docker", + [ + "openfoam_wf_no_docker.yaml", + "openfoam_wf_no_docker_template.yaml", + "requirements.txt", + "scripts/make_samples.py", + "scripts/blockMesh_template.txt", + "scripts/cavity_setup.sh", + "scripts/combine_outputs.py", + "scripts/learn.py", + "scripts/mesh_param_script.py", + "scripts/run_openfoam", + ], + "openfoam_wf_no_docker", + ), + ( + "openfoam_wf_singularity", + [ + "openfoam_wf_singularity.yaml", + "openfoam_wf_singularity_template.yaml", + "requirements.txt", + "scripts/make_samples.py", + "scripts/blockMesh_template.txt", + "scripts/cavity_setup.sh", + "scripts/combine_outputs.py", + "scripts/learn.py", + "scripts/mesh_param_script.py", + "scripts/run_openfoam", + ], + "openfoam_wf_singularity", + ), + ( + "optimization_basic", + [ + "optimization_basic.yaml", + "requirements.txt", + "template_config.py", + "template_optimization.temp", + "scripts/collector.py", + "scripts/optimizer.py", + "scripts/test_functions.py", + "scripts/visualizer.py", + ], + "optimization", + ), + ( + "remote_feature_demo", + [ + ".gitignore", + "remote_feature_demo.yaml", + "requirements.txt", + "scripts/features.json", + "scripts/hello_world.py", + "scripts/pgen.py", + ], + "remote_feature_demo", + ), + ("restart", ["restart.yaml", "scripts/make_samples.py"], "restart"), + ("restart_delay", ["restart_delay.yaml", "scripts/make_samples.py"], "restart_delay"), + ], +) +def test_setup_example(examples_testing_dir: str, example_name: str, example_files: List[str], expected_return: str): """ - Helper function to run tests for the `setup_example` function. + Run tests for the `setup_example` function. + Each test will consist of: + 1. The name of the example to setup + 2. A list of files that we're expecting to be setup + 3. The expected return value + Each test is a tuple in the parametrize decorator above this test function. - :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + :param examples_testing_dir: The path to the the temp output directory for examples tests :param example_name: The name of the example to setup :param example_files: A list of filenames that should be copied by setup_example :param expected_return: The expected return value from `setup_example` """ # Create the temp path to store this setup - generator_dir = EXAMPLES_GENERATOR_DIR.format(temp_output_dir=temp_output_dir) - create_dir(generator_dir) - setup_example_dir = os.path.join(generator_dir, f"setup_{example_name}") + setup_example_dir = os.path.join(examples_testing_dir, f"setup_{example_name}") # Ensure that the example name is returned actual = setup_example(example_name, setup_example_dir) @@ -271,317 +456,16 @@ def run_setup_example(temp_output_dir: str, example_name: str, example_files: Li assert os.path.exists(file) -def test_setup_example_feature_demo(temp_output_dir: str): - """ - Test the `setup_example` function for the feature_demo example. - - :param temp_output_dir: The path to the temporary output directory we'll be using for this test run - """ - example_name = "feature_demo" - example_files = [ - ".gitignore", - "feature_demo.yaml", - "requirements.txt", - "scripts/features.json", - "scripts/hello_world.py", - "scripts/pgen.py", - ] - - run_setup_example(temp_output_dir, example_name, example_files, example_name) - - -def test_setup_example_flux(temp_output_dir: str): - """ - Test the `setup_example` function for the flux example. - - :param temp_output_dir: The path to the temporary output directory we'll be using for this test run - """ - example_files = [ - "flux_local.yaml", - "flux_par_restart.yaml", - "flux_par.yaml", - "paper.yaml", - "requirements.txt", - "scripts/flux_info.py", - "scripts/hello_sleep.c", - "scripts/hello.c", - "scripts/make_samples.py", - "scripts/paper_workers.sbatch", - "scripts/test_workers.sbatch", - "scripts/workers.sbatch", - "scripts/workers.bsub", - ] - - run_setup_example(temp_output_dir, "flux_local", example_files, "flux") - - -def test_setup_example_lsf(temp_output_dir: str): - """ - Test the `setup_example` function for the lsf example. - - :param temp_output_dir: The path to the temporary output directory we'll be using for this test run - """ - - # TODO should there be a workers.bsub for this example? - example_files = [ - "lsf_par_srun.yaml", - "lsf_par.yaml", - "scripts/hello.c", - "scripts/make_samples.py", - ] - - run_setup_example(temp_output_dir, "lsf_par", example_files, "lsf") - - -def test_setup_example_slurm(temp_output_dir: str): - """ - Test the `setup_example` function for the slurm example. - - :param temp_output_dir: The path to the temporary output directory we'll be using for this test run - """ - example_files = [ - "slurm_par.yaml", - "slurm_par_restart.yaml", - "requirements.txt", - "scripts/hello.c", - "scripts/make_samples.py", - "scripts/test_workers.sbatch", - "scripts/workers.sbatch", - ] - - run_setup_example(temp_output_dir, "slurm_par", example_files, "slurm") - - -def test_setup_example_hello(temp_output_dir: str): - """ - Test the `setup_example` function for the hello example. - - :param temp_output_dir: The path to the temporary output directory we'll be using for this test run - """ - example_name = "hello" - example_files = [ - "hello_samples.yaml", - "hello.yaml", - "my_hello.yaml", - "requirements.txt", - "make_samples.py", - ] - - run_setup_example(temp_output_dir, example_name, example_files, example_name) - - -def test_setup_example_hpc(temp_output_dir: str): - """ - Test the `setup_example` function for the hpc_demo example. - - :param temp_output_dir: The path to the temporary output directory we'll be using for this test run - """ - example_name = "hpc_demo" - example_files = [ - "hpc_demo.yaml", - "cumulative_sample_processor.py", - "faker_sample.py", - "sample_collector.py", - "sample_processor.py", - "requirements.txt", - ] - - run_setup_example(temp_output_dir, example_name, example_files, example_name) - - -def test_setup_example_iterative(temp_output_dir: str): - """ - Test the `setup_example` function for the iterative_demo example. - - :param temp_output_dir: The path to the temporary output directory we'll be using for this test run - """ - example_name = "iterative_demo" - example_files = [ - "iterative_demo.yaml", - "cumulative_sample_processor.py", - "faker_sample.py", - "sample_collector.py", - "sample_processor.py", - "requirements.txt", - ] - - run_setup_example(temp_output_dir, example_name, example_files, example_name) - - -def test_setup_example_null(temp_output_dir: str): - """ - Test the `setup_example` function for the null_spec example. - - :param temp_output_dir: The path to the temporary output directory we'll be using for this test run - """ - example_name = "null_spec" - example_files = [ - "null_spec.yaml", - "null_chain.yaml", - ".gitignore", - "Makefile", - "requirements.txt", - "scripts/aggregate_chain_output.sh", - "scripts/aggregate_output.sh", - "scripts/check_completion.sh", - "scripts/kill_all.sh", - "scripts/launch_chain_job.py", - "scripts/launch_jobs.py", - "scripts/make_samples.py", - "scripts/read_output_chain.py", - "scripts/read_output.py", - "scripts/search.sh", - "scripts/submit_chain.sbatch", - "scripts/submit.sbatch", - ] - - run_setup_example(temp_output_dir, example_name, example_files, example_name) - - -def test_setup_example_openfoam(temp_output_dir: str): - """ - Test the `setup_example` function for the openfoam_wf example. - - :param temp_output_dir: The path to the temporary output directory we'll be using for this test run - """ - example_name = "openfoam_wf" - example_files = [ - "openfoam_wf.yaml", - "openfoam_wf_docker_template.yaml", - "README.md", - "requirements.txt", - "scripts/make_samples.py", - "scripts/blockMesh_template.txt", - "scripts/cavity_setup.sh", - "scripts/combine_outputs.py", - "scripts/learn.py", - "scripts/mesh_param_script.py", - "scripts/run_openfoam", - ] - - run_setup_example(temp_output_dir, example_name, example_files, example_name) - - -def test_setup_example_openfoam_no_docker(temp_output_dir: str): - """ - Test the `setup_example` function for the openfoam_wf_no_docker example. - - :param temp_output_dir: The path to the temporary output directory we'll be using for this test run - """ - example_name = "openfoam_wf_no_docker" - example_files = [ - "openfoam_wf_no_docker.yaml", - "openfoam_wf_no_docker_template.yaml", - "requirements.txt", - "scripts/make_samples.py", - "scripts/blockMesh_template.txt", - "scripts/cavity_setup.sh", - "scripts/combine_outputs.py", - "scripts/learn.py", - "scripts/mesh_param_script.py", - "scripts/run_openfoam", - ] - - run_setup_example(temp_output_dir, example_name, example_files, example_name) - - -def test_setup_example_openfoam_singularity(temp_output_dir: str): - """ - Test the `setup_example` function for the openfoam_wf_singularity example. - - :param temp_output_dir: The path to the temporary output directory we'll be using for this test run - """ - example_name = "openfoam_wf_singularity" - example_files = [ - "openfoam_wf_singularity.yaml", - "openfoam_wf_singularity_template.yaml", - "requirements.txt", - "scripts/make_samples.py", - "scripts/blockMesh_template.txt", - "scripts/cavity_setup.sh", - "scripts/combine_outputs.py", - "scripts/learn.py", - "scripts/mesh_param_script.py", - "scripts/run_openfoam", - ] - - run_setup_example(temp_output_dir, example_name, example_files, example_name) - - -def test_setup_example_optimization(temp_output_dir: str): - """ - Test the `setup_example` function for the optimization example. - - :param temp_output_dir: The path to the temporary output directory we'll be using for this test run - """ - example_files = [ - "optimization_basic.yaml", - "requirements.txt", - "template_config.py", - "template_optimization.temp", - "scripts/collector.py", - "scripts/optimizer.py", - "scripts/test_functions.py", - "scripts/visualizer.py", - ] - - run_setup_example(temp_output_dir, "optimization_basic", example_files, "optimization") - - -def test_setup_example_remote_feature_demo(temp_output_dir: str): - """ - Test the `setup_example` function for the remote_feature_demo example. - - :param temp_output_dir: The path to the temporary output directory we'll be using for this test run - """ - example_name = "remote_feature_demo" - example_files = [ - ".gitignore", - "remote_feature_demo.yaml", - "requirements.txt", - "scripts/features.json", - "scripts/hello_world.py", - "scripts/pgen.py", - ] - - run_setup_example(temp_output_dir, example_name, example_files, example_name) - - -def test_setup_example_restart(temp_output_dir: str): - """ - Test the `setup_example` function for the restart example. - - :param temp_output_dir: The path to the temporary output directory we'll be using for this test run - """ - example_name = "restart" - example_files = ["restart.yaml", "scripts/make_samples.py"] - - run_setup_example(temp_output_dir, example_name, example_files, example_name) - - -def test_setup_example_restart_delay(temp_output_dir: str): - """ - Test the `setup_example` function for the restart_delay example. - - :param temp_output_dir: The path to the temporary output directory we'll be using for this test run - """ - example_name = "restart_delay" - example_files = ["restart_delay.yaml", "scripts/make_samples.py"] - - run_setup_example(temp_output_dir, example_name, example_files, example_name) - - -def test_setup_example_simple_chain(temp_output_dir: str): +def test_setup_example_simple_chain(examples_testing_dir: str): """ Test the `setup_example` function for the simple_chain example. + This example just writes a single file so we can't run it in the `test_setup_example` test. - :param temp_output_dir: The path to the temporary output directory we'll be using for this test run + :param examples_testing_dir: The path to the the temp output directory for examples tests """ # Create the temp path to store this setup - generator_dir = EXAMPLES_GENERATOR_DIR.format(temp_output_dir=temp_output_dir) - create_dir(generator_dir) - output_file = os.path.join(generator_dir, "simple_chain.yaml") + output_file = os.path.join(examples_testing_dir, "simple_chain.yaml") # Ensure that the example name is returned actual = setup_example("simple_chain", output_file) From 2f24577cdc53f40a270ab40b098f80e85cabb1fb Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Thu, 6 Jun 2024 14:24:51 -0700 Subject: [PATCH 40/44] sort example output --- merlin/examples/generator.py | 2 +- tests/unit/test_examples_generator.py | 68 +++++++++++++-------------- 2 files changed, 34 insertions(+), 36 deletions(-) diff --git a/merlin/examples/generator.py b/merlin/examples/generator.py index 285b946d8..63da74d78 100644 --- a/merlin/examples/generator.py +++ b/merlin/examples/generator.py @@ -60,7 +60,7 @@ def gather_example_dirs(): """Get all the example directories""" result = {} - for directory in os.listdir(EXAMPLES_DIR): + for directory in sorted(os.listdir(EXAMPLES_DIR)): result[directory] = directory return result diff --git a/tests/unit/test_examples_generator.py b/tests/unit/test_examples_generator.py index 3f0f2df9d..7548c8a49 100644 --- a/tests/unit/test_examples_generator.py +++ b/tests/unit/test_examples_generator.py @@ -112,44 +112,39 @@ def test_list_examples(): """Test the `list_examples` function to see if it gives us all of the examples that we want.""" expected_headers = ["name", "description"] expected_rows = [ - [ - "openfoam_wf_no_docker", - "A parameter study that includes initializing, running,\n" - "post-processing, collecting, learning and vizualizing OpenFOAM runs\n" - "without using docker.", - ], - [ - "optimization_basic", - "Design Optimization Template\n" - "To use,\n" - "1. Specify the first three variables here (N_DIMS, TEST_FUNCTION, DEBUG)\n" - "2. Run the template_config file in current directory using `python template_config.py`\n" - "3. Merlin run as usual (merlin run optimization.yaml)\n" - "* MAX_ITER and the N_SAMPLES options use default values unless using DEBUG mode\n" - "* BOUNDS_X and UNCERTS_X are configured using the template_config.py scripts", - ], ["feature_demo", "Run 10 hello worlds."], ["flux_local", "Run a scan through Merlin/Maestro"], ["flux_par", "A simple ensemble of parallel MPI jobs run by flux."], ["flux_par_restart", "A simple ensemble of parallel MPI jobs run by flux."], ["paper_flux", "Use flux to run single core MPI jobs and record timings."], - ["lsf_par", "A simple ensemble of parallel MPI jobs run by lsf (jsrun)."], - ["lsf_par_srun", "A simple ensemble of parallel MPI jobs run by lsf using the srun wrapper (srun)."], - ["restart", "A simple ensemble of with restarts."], - ["restart_delay", "A simple ensemble of with restart delay times."], - ["simple_chain", "test to see that chains are not run in parallel"], - ["slurm_par", "A simple ensemble of parallel MPI jobs run by slurm (srun)."], - ["slurm_par_restart", "A simple ensemble of parallel MPI jobs run by slurm (srun)."], - ["remote_feature_demo", "Run 10 hello worlds."], ["hello", "a very simple merlin workflow"], ["hello_samples", "a very simple merlin workflow, with samples"], ["hpc_demo", "Demo running a workflow on HPC machines"], + ["iterative_demo", "Demo of a workflow with self driven iteration/looping"], + ["lsf_par", "A simple ensemble of parallel MPI jobs run by lsf (jsrun)."], + ["lsf_par_srun", "A simple ensemble of parallel MPI jobs run by lsf using the srun wrapper (srun)."], + [ + "null_chain", + "Run N_SAMPLES steps of TIME seconds each at CONC concurrency.\n" + "May be used to measure overhead in merlin.\n" + "Iterates thru a chain of workflows.", + ], + [ + "null_spec", + "run N_SAMPLES null steps at CONC concurrency for TIME seconds each. May be used to measure overhead in merlin.", + ], [ "openfoam_wf", "A parameter study that includes initializing, running,\n" "post-processing, collecting, learning and visualizing OpenFOAM runs\n" "using docker.", ], + [ + "openfoam_wf_no_docker", + "A parameter study that includes initializing, running,\n" + "post-processing, collecting, learning and vizualizing OpenFOAM runs\n" + "without using docker.", + ], [ "openfoam_wf_singularity", "A parameter study that includes initializing, running,\n" @@ -157,21 +152,24 @@ def test_list_examples(): "using singularity.", ], [ - "null_chain", - "Run N_SAMPLES steps of TIME seconds each at CONC concurrency.\n" - "May be used to measure overhead in merlin.\n" - "Iterates thru a chain of workflows.", - ], - [ - "null_spec", - "run N_SAMPLES null steps at CONC concurrency for TIME seconds each. May be used to measure overhead in merlin.", + "optimization_basic", + "Design Optimization Template\n" + "To use,\n" + "1. Specify the first three variables here (N_DIMS, TEST_FUNCTION, DEBUG)\n" + "2. Run the template_config file in current directory using `python template_config.py`\n" + "3. Merlin run as usual (merlin run optimization.yaml)\n" + "* MAX_ITER and the N_SAMPLES options use default values unless using DEBUG mode\n" + "* BOUNDS_X and UNCERTS_X are configured using the template_config.py scripts", ], - ["iterative_demo", "Demo of a workflow with self driven iteration/looping"], + ["remote_feature_demo", "Run 10 hello worlds."], + ["restart", "A simple ensemble of with restarts."], + ["restart_delay", "A simple ensemble of with restart delay times."], + ["simple_chain", "test to see that chains are not run in parallel"], + ["slurm_par", "A simple ensemble of parallel MPI jobs run by slurm (srun)."], + ["slurm_par_restart", "A simple ensemble of parallel MPI jobs run by slurm (srun)."], ] expected = "\n" + tabulate(expected_rows, expected_headers) + "\n" actual = list_examples() - print(f"expected:\n{expected}") - print(f"actual:\n{actual}") assert actual == expected From 5e0a5f78903a371e5dcf93175a3a6ccd296bf016 Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Thu, 6 Jun 2024 15:27:09 -0700 Subject: [PATCH 41/44] ensure directory is changed back on no outdir test --- tests/unit/test_examples_generator.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/unit/test_examples_generator.py b/tests/unit/test_examples_generator.py index 7548c8a49..25432ffca 100644 --- a/tests/unit/test_examples_generator.py +++ b/tests/unit/test_examples_generator.py @@ -170,6 +170,8 @@ def test_list_examples(): ] expected = "\n" + tabulate(expected_rows, expected_headers) + "\n" actual = list_examples() + print(f"actual:\n{actual}") + print(f"expected:\n{expected}") assert actual == expected @@ -221,6 +223,8 @@ def test_setup_example_no_outdir(examples_testing_dir: str): except AssertionError as exc: os.chdir(cwd) raise AssertionError from exc + finally: + os.chdir(cwd) def test_setup_example_outdir_exists(examples_testing_dir: str): From d8fa77c268c25a240900ab020063d6d05803ce3e Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Thu, 6 Jun 2024 16:54:24 -0700 Subject: [PATCH 42/44] sort the specs in examples output --- merlin/examples/generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/merlin/examples/generator.py b/merlin/examples/generator.py index 63da74d78..c45cfb9ce 100644 --- a/merlin/examples/generator.py +++ b/merlin/examples/generator.py @@ -90,7 +90,7 @@ def list_examples(): for example_dir in gather_example_dirs(): directory = os.path.join(os.path.join(EXAMPLES_DIR, example_dir), "") specs = glob.glob(directory + "*.yaml") - for spec in specs: + for spec in sorted(specs): if "template" in spec: continue with open(spec) as f: # pylint: disable=C0103 From 8421d74ff33c094fdee94152ae5a7f4de6539304 Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Thu, 6 Jun 2024 16:56:45 -0700 Subject: [PATCH 43/44] fix lint issues --- tests/fixtures/examples.py | 4 +++- tests/unit/test_examples_generator.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/fixtures/examples.py b/tests/fixtures/examples.py index 16a2f576d..7c4626e3e 100644 --- a/tests/fixtures/examples.py +++ b/tests/fixtures/examples.py @@ -3,8 +3,10 @@ """ import os + import pytest + @pytest.fixture(scope="session") def examples_testing_dir(temp_output_dir: str) -> str: """ @@ -17,4 +19,4 @@ def examples_testing_dir(temp_output_dir: str) -> str: if not os.path.exists(testing_dir): os.mkdir(testing_dir) - return testing_dir \ No newline at end of file + return testing_dir diff --git a/tests/unit/test_examples_generator.py b/tests/unit/test_examples_generator.py index 25432ffca..7d4d879fb 100644 --- a/tests/unit/test_examples_generator.py +++ b/tests/unit/test_examples_generator.py @@ -3,9 +3,9 @@ """ import os -import pytest from typing import List +import pytest from tabulate import tabulate from merlin.examples.generator import ( From 0543ae48623128efaba0de00f83ed90a27c73b9b Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Mon, 10 Jun 2024 09:47:35 -0700 Subject: [PATCH 44/44] start writing tests for server config --- merlin/examples/generator.py | 1 - merlin/server/server_config.py | 4 +-- tests/unit/server/test_server_config.py | 43 +++++++++++++++++++++++++ 3 files changed, 45 insertions(+), 3 deletions(-) create mode 100644 tests/unit/server/test_server_config.py diff --git a/merlin/examples/generator.py b/merlin/examples/generator.py index c45cfb9ce..725448bec 100644 --- a/merlin/examples/generator.py +++ b/merlin/examples/generator.py @@ -146,5 +146,4 @@ def setup_example(name, outdir): LOG.info(f"Copying example '{name}' to {outdir}") write_example(src_path, outdir) - print(f"example: {example}") return example diff --git a/merlin/server/server_config.py b/merlin/server/server_config.py index f4d5d5174..b0b91f892 100644 --- a/merlin/server/server_config.py +++ b/merlin/server/server_config.py @@ -92,8 +92,8 @@ def generate_password(length, pass_command: str = None) -> str: :return:: string value with given length """ if pass_command: - process = subprocess.run(pass_command.split(), shell=True, stdout=subprocess.PIPE) - return process.stdout + process = subprocess.run(pass_command, shell=True, capture_output=True, text=True) + return process.stdout.strip() characters = list(string.ascii_letters + string.digits + "!@#$%^&*()") diff --git a/tests/unit/server/test_server_config.py b/tests/unit/server/test_server_config.py new file mode 100644 index 000000000..058e77fcf --- /dev/null +++ b/tests/unit/server/test_server_config.py @@ -0,0 +1,43 @@ +""" +Tests for the `server_config.py` module. +""" + +import string + +from merlin.server.server_config import ( + PASSWORD_LENGTH, + check_process_file_format, + config_merlin_server, + create_server_config, + dump_process_file, + generate_password, + get_server_status, + parse_redis_output, + pull_process_file, + pull_server_config, + pull_server_image, +) + + +def test_generate_password_no_pass_command(): + """ + Test the `generate_password` function with no password command. + This should generate a password of 256 (PASSWORD_LENGTH) random ASCII characters. + """ + generated_password = generate_password(PASSWORD_LENGTH) + assert len(generated_password) == PASSWORD_LENGTH + valid_ascii_chars = string.ascii_letters + string.digits + "!@#$%^&*()" + for ch in generated_password: + assert ch in valid_ascii_chars + + +def test_generate_password_with_pass_command(): + """ + Test the `generate_password` function with no password command. + This should generate a password of 256 (PASSWORD_LENGTH) random ASCII characters. + """ + test_pass = "test-password" + generated_password = generate_password(0, pass_command=f"echo {test_pass}") + assert generated_password == test_pass + +