Skip to content

Commit 0c71f5a

Browse files
committed
updated torch.load params
1 parent 9615326 commit 0c71f5a

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

model_coverage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def mask_text(self, text_tokenized):
9191
return masked
9292

9393
def reload_model(self, model_file):
94-
print(self.model.load_state_dict(torch.load(model_file), strict=False))
94+
print(self.model.load_state_dict(torch.load(model_file, map_location=torch.device(self.device)), strict=False))
9595

9696
def save_model(self, model_file):
9797
torch.save(self.model.state_dict(), model_file)

model_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(self, max_output_length=25, max_input_length=300, device='cpu', tok
3636
self.mode = "train"
3737

3838
def reload(self, from_file):
39-
print(self.model.load_state_dict(torch.load(from_file), strict=False))
39+
print(self.model.load_state_dict(torch.load(from_file, map_location=torch.device(self.device)), strict=False))
4040

4141
def save(self, to_file):
4242
torch.save(self.model.state_dict(), to_file)

0 commit comments

Comments
 (0)