diff --git a/tests/unit_tests/test_pytorch_model.py b/tests/unit_tests/test_pytorch_model.py index 0ebfa17..147ec00 100644 --- a/tests/unit_tests/test_pytorch_model.py +++ b/tests/unit_tests/test_pytorch_model.py @@ -102,34 +102,33 @@ def test_save_load(remove_checkpoint): sim1 = Simulator() sim2 = Simulator() - fe = learn_to_pick.PyTorchFeatureEmbedder() first_model_path = f"{CHECKPOINT_DIR}/first.checkpoint" torch.manual_seed(0) - first_byom = learn_to_pick.PyTorchPolicy(feature_embedder=fe) - second_byom = learn_to_pick.PyTorchPolicy(feature_embedder=fe) + first_policy = learn_to_pick.PyTorchPolicy() + other_policy = learn_to_pick.PyTorchPolicy() torch.manual_seed(0) first_picker = learn_to_pick.PickBest.create( - policy=first_byom, selection_scorer=CustomSelectionScorer() + policy=first_policy, selection_scorer=CustomSelectionScorer() ) sim1.run(first_picker, 5) - first_byom.save(first_model_path) + first_policy.save(first_model_path) - second_byom.load(first_model_path) - second_picker = learn_to_pick.PickBest.create( - policy=second_byom, selection_scorer=CustomSelectionScorer() + other_policy.load(first_model_path) + other_picker = learn_to_pick.PickBest.create( + policy=other_policy, selection_scorer=CustomSelectionScorer() ) - sim1.run(second_picker, 5) + sim1.run(other_picker, 5) torch.manual_seed(0) - all_byom = learn_to_pick.PyTorchPolicy(feature_embedder=fe) + all_policy = learn_to_pick.PyTorchPolicy() torch.manual_seed(0) all_picker = learn_to_pick.PickBest.create( - policy=all_byom, selection_scorer=CustomSelectionScorer() + policy=all_policy, selection_scorer=CustomSelectionScorer() ) sim2.run(all_picker, 10) - verify_same_models(second_byom.workspace, all_byom.workspace) - verify_same_optimizers(second_byom.workspace.optim, all_byom.workspace.optim) + verify_same_models(other_policy.workspace, all_policy.workspace) + verify_same_optimizers(other_policy.workspace.optim, all_policy.workspace.optim)