Skip to content

Commit

Permalink
Update train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
elttaes authored Feb 26, 2023
1 parent 993fa02 commit 8714cdb
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions Function/metal/esm/1b/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import sys
sys.path.append("..")
import os
#os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3"
os.environ["CUDA_VISIBLE_DEVICES"] = ""
import model_down
from Bio import SeqIO
Expand Down Expand Up @@ -61,7 +60,7 @@ def collate_fn(self, batch: List[Tuple[Any, ...]]):
label.append(int(filename[0]))
return msa_batch_token,label

a3m_dir='/home/public/bigdata/my/datasets/metal/train'
a3m_dir='datasets/metal/train'
filenames = [
os.path.join(a3m_dir,name) for name in os.listdir(a3m_dir)
if os.path.splitext(name)[-1] == '.a3m'
Expand All @@ -75,11 +74,11 @@ def collate_fn(self, batch: List[Tuple[Any, ...]]):
num_workers=8,
collate_fn=dataset.collate_fn)

test_dir='/home/public/bigdata/my/datasets/metal/test'
test_dir='datasets/metal/test'
testnames = [
os.path.join(test_dir,name) for name in os.listdir(test_dir)
if os.path.splitext(name)[-1] == '.a3m'
] #选择指定目录下的.png图片
]
test_dataset = Dataset(testnames)
test_dataloader = torch.utils.data.DataLoader(test_dataset,
batch_size=batch_size,
Expand Down Expand Up @@ -159,4 +158,4 @@ def collate_fn(self, batch: List[Tuple[Any, ...]]):
# best_loss = val_loss
print(
"\nEpoch: {} / {} finish. Training Loss: {:.8f}. Validating Loss: {:.8f}.\n"
.format(epoch + 1, epochs, train_loss / train_step, val_loss))
.format(epoch + 1, epochs, train_loss / train_step, val_loss))

0 comments on commit 8714cdb

Please sign in to comment.