Skip to content

Commit

Permalink
rename variables
Browse files Browse the repository at this point in the history
  • Loading branch information
cheng-tan committed Nov 28, 2023
1 parent d75f88a commit 7f30c85
Showing 1 changed file with 12 additions and 13 deletions.
25 changes: 12 additions & 13 deletions tests/unit_tests/test_pytorch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,34 +102,33 @@ def test_save_load(remove_checkpoint):
sim1 = Simulator()
sim2 = Simulator()

fe = learn_to_pick.PyTorchFeatureEmbedder()
first_model_path = f"{CHECKPOINT_DIR}/first.checkpoint"

torch.manual_seed(0)
first_byom = learn_to_pick.PyTorchPolicy(feature_embedder=fe)
second_byom = learn_to_pick.PyTorchPolicy(feature_embedder=fe)
first_policy = learn_to_pick.PyTorchPolicy()
other_policy = learn_to_pick.PyTorchPolicy()

torch.manual_seed(0)

first_picker = learn_to_pick.PickBest.create(
policy=first_byom, selection_scorer=CustomSelectionScorer()
policy=first_policy, selection_scorer=CustomSelectionScorer()
)
sim1.run(first_picker, 5)
first_byom.save(first_model_path)
first_policy.save(first_model_path)

second_byom.load(first_model_path)
second_picker = learn_to_pick.PickBest.create(
policy=second_byom, selection_scorer=CustomSelectionScorer()
other_policy.load(first_model_path)
other_picker = learn_to_pick.PickBest.create(
policy=other_policy, selection_scorer=CustomSelectionScorer()
)
sim1.run(second_picker, 5)
sim1.run(other_picker, 5)

torch.manual_seed(0)
all_byom = learn_to_pick.PyTorchPolicy(feature_embedder=fe)
all_policy = learn_to_pick.PyTorchPolicy()
torch.manual_seed(0)
all_picker = learn_to_pick.PickBest.create(
policy=all_byom, selection_scorer=CustomSelectionScorer()
policy=all_policy, selection_scorer=CustomSelectionScorer()
)
sim2.run(all_picker, 10)

verify_same_models(second_byom.workspace, all_byom.workspace)
verify_same_optimizers(second_byom.workspace.optim, all_byom.workspace.optim)
verify_same_models(other_policy.workspace, all_policy.workspace)
verify_same_optimizers(other_policy.workspace.optim, all_policy.workspace.optim)

0 comments on commit 7f30c85

Please sign in to comment.