Skip to content

Commit

Permalink
Improve instructions and assertions
Browse files Browse the repository at this point in the history
  • Loading branch information
lbugnon committed Jun 14, 2022
1 parent dd5b04d commit 1cd41fa
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 3 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,19 +81,22 @@ new_model = MiRe2e(mfe_model_file='trained_mfe_predictor.pkl',
```
These model files are optional and you can specify any of them. The ones that are not specified are loaded from pre-trained defaults.

Check code documentation for advanced options.

## Training the models

Training the models may take several hours and requires GPU processing
capabilities beyond the ones provided freely by Google Colab. In the
following, there are instructions for training each stage of miRe2e. New models are saved as pickle files (*.pkl).

Training scripts were made for a 12GB GPU. You can adjust batch_size according to your hardware setup.


### Structure prediction model

To train the Structure prediction model, run:
```python
model.fit_structure('hairpin_examples.fa')
model.fit_structure('hairpin_examples.fa', batch_size=512)
```
The fasta file should contain sequences of hairpins and it's secondary structure. For example, the file [hairpin_examples.fa](https://sourceforge.net/projects/sourcesinc/files/mire2e/data/hairpin_examples.zip/download) can be used. The new model is saved in the root directory with the name “trained_structure_predictor.pkl”

Expand Down
3 changes: 3 additions & 0 deletions miRe2e/mfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def fit(self, input_fasta, structure_model, batch_size=512,
if verbose:
print("Loading sequences...")
seq_fasta, _, mfe_fasta = load_seq_struct_mfe(input_fasta)

assert len(seq_fasta)>=10*batch_size, f"batch_size should be between 1 and 1/10 the number of sequences. batch_size={batch_size} was given for {len(seq_fasta)} sequences"

if verbose:
print(f"Done ({len(seq_fasta)} sequences)")
structure_model.eval()
Expand Down
1 change: 1 addition & 0 deletions miRe2e/mire2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(self, device="cpu", pretrained="hsa", mfe_model_file=None,

if structure_model_file is None:
if pretrained != "no":
assert f"{pretrained}-structure" in PRETRAINED, f'pretrained model {pretrained} is not recognized'
state_dict = load_state_dict_from_url(
PRETRAINED[f"{pretrained}-structure"], map_location=device)
self._structure.load_state_dict(state_dict)
Expand Down
2 changes: 2 additions & 0 deletions miRe2e/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ def fit(self, structure_model, mfe_model, pos_fname, neg_fname,
train_seq, train_labels, valid_seq, valid_labels = \
load_train_valid_data(pos_fname, neg_fname, val_pos_fname,
val_neg_fname, length=length)

assert len(train_seq)>=10*batch_size, f"batch_size should be between 1 and 1/10 the number of sequences. batch_size={batch_size} was given for {len(train_seq)} sequences"

if verbose:
print(f"Training sequences {len(train_seq)} ("
Expand Down
6 changes: 4 additions & 2 deletions miRe2e/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def fit(self, input_fasta, batch_size=512, max_epochs=200,
if verbose:
print(f"Done ({len(sequence)} sequences)")

assert len(sequence)>=10*batch_size, f"batch_size should be between 1 and 1/10 the number of sequences. batch_size={batch_size} was given for {len(sequence)} sequences"

ind = np.arange(len(sequence))
np.random.shuffle(ind)
L = int(len(ind)*.8)
Expand All @@ -87,11 +89,11 @@ def fit(self, input_fasta, batch_size=512, max_epochs=200,
tr.utils.data.BatchSampler(
tr.utils.data.RandomSampler(range(len(train_ind)),
replacement=True), batch_size,
drop_last=True))
drop_last=False))

sampler_test = list(tr.utils.data.BatchSampler(
tr.utils.data.SequentialSampler(range(len(valid_ind))), batch_size,
drop_last=True))
drop_last=False))

optimizer = optim.SGD(self.parameters(), lr=1)
scheduler = tr.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)
Expand Down

0 comments on commit 1cd41fa

Please sign in to comment.