Skip to content

Commit

Permalink
Fix preloading options
Browse files Browse the repository at this point in the history
  • Loading branch information
lbugnon committed May 10, 2021
1 parent be66e13 commit cd0f997
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 12 deletions.
27 changes: 16 additions & 11 deletions miRe2e/mire2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,27 +51,32 @@ def __init__(self, device="cpu", pretrained="hsa", mfe_model_file=None,

self.preprocessor = Preprocessor(0, device)

if pretrained != "no":

if structure_model_file is None:
if structure_model_file is None:
if pretrained != "no":
state_dict = load_state_dict_from_url(
PRETRAINED[f"{pretrained}-structure"], map_location=device)
else:
state_dict = tr.load(structure_model_file, map_location=device)
self._structure.load_state_dict(state_dict)
else:
state_dict = tr.load(structure_model_file, map_location=device)
self._structure.load_state_dict(state_dict)

if mfe_model_file is None:
if mfe_model_file is None:
if pretrained != "no":
state_dict = load_state_dict_from_url(
PRETRAINED[f"{pretrained}-mfe"], map_location=device)
else:
state_dict = tr.load(mfe_model_file, map_location=device)
self._mfe.load_state_dict(state_dict)
else:
state_dict = tr.load(mfe_model_file, map_location=device)
self._mfe.load_state_dict(state_dict)

if predictor_model_file is None:

if predictor_model_file is None:
if pretrained != "no":
state_dict = load_state_dict_from_url(
PRETRAINED[f"{pretrained}-predictor"], map_location=device)
else:
state_dict = tr.load(predictor_model_file, map_location=device)
self._predictor.load_state_dict(state_dict)
else:
state_dict = tr.load(predictor_model_file, map_location=device)
self._predictor.load_state_dict(state_dict)

def _eval(self):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="miRe2e",
version="0.16",
version="0.17",
author="Jonathan Raad",
author_email="[email protected]",
description="An end-to-end deep neural network based on Transformers for "
Expand Down

0 comments on commit cd0f997

Please sign in to comment.