From 7f30c8506136d60ece5127f0ecfd09ae9258684e Mon Sep 17 00:00:00 2001 From: cheng Date: Tue, 28 Nov 2023 21:51:53 +0000 Subject: [PATCH] rename variables --- tests/unit_tests/test_pytorch_model.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/tests/unit_tests/test_pytorch_model.py b/tests/unit_tests/test_pytorch_model.py index 0ebfa17..147ec00 100644 --- a/tests/unit_tests/test_pytorch_model.py +++ b/tests/unit_tests/test_pytorch_model.py @@ -102,34 +102,33 @@ def test_save_load(remove_checkpoint): sim1 = Simulator() sim2 = Simulator() - fe = learn_to_pick.PyTorchFeatureEmbedder() first_model_path = f"{CHECKPOINT_DIR}/first.checkpoint" torch.manual_seed(0) - first_byom = learn_to_pick.PyTorchPolicy(feature_embedder=fe) - second_byom = learn_to_pick.PyTorchPolicy(feature_embedder=fe) + first_policy = learn_to_pick.PyTorchPolicy() + other_policy = learn_to_pick.PyTorchPolicy() torch.manual_seed(0) first_picker = learn_to_pick.PickBest.create( - policy=first_byom, selection_scorer=CustomSelectionScorer() + policy=first_policy, selection_scorer=CustomSelectionScorer() ) sim1.run(first_picker, 5) - first_byom.save(first_model_path) + first_policy.save(first_model_path) - second_byom.load(first_model_path) - second_picker = learn_to_pick.PickBest.create( - policy=second_byom, selection_scorer=CustomSelectionScorer() + other_policy.load(first_model_path) + other_picker = learn_to_pick.PickBest.create( + policy=other_policy, selection_scorer=CustomSelectionScorer() ) - sim1.run(second_picker, 5) + sim1.run(other_picker, 5) torch.manual_seed(0) - all_byom = learn_to_pick.PyTorchPolicy(feature_embedder=fe) + all_policy = learn_to_pick.PyTorchPolicy() torch.manual_seed(0) all_picker = learn_to_pick.PickBest.create( - policy=all_byom, selection_scorer=CustomSelectionScorer() + policy=all_policy, selection_scorer=CustomSelectionScorer() ) sim2.run(all_picker, 10) - verify_same_models(second_byom.workspace, all_byom.workspace) - verify_same_optimizers(second_byom.workspace.optim, all_byom.workspace.optim) + verify_same_models(other_policy.workspace, all_policy.workspace) + verify_same_optimizers(other_policy.workspace.optim, all_policy.workspace.optim)