Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import argparse

import pandas as pd

from tqdm.auto import tqdm

import transformers
import torch
import torchmetrics
import pytorch_lightning as pl


### TODO : path, filename 정해지면 수정
# from PATH import CUSTOMDATAMODULE as DataModule
# from PATH import CUSTOMMODEL as Model
# from PATH import CUSTOMTRAINER as Trainer
# from PATH import CUSTOMPREPROCESS as preprocess


# fix random seeds for reproducibility
SEED = 123
seed_everything(SEED, workers=True) # pl seed



def inference(args):
### TODO : preprocess 함수도 인자로 넘겼으면..
datamodule = DataModule(args.model_name,
args.batch_size, args.shuffle,
args.train_path, args.dev_path,
args.test_path, args.predict_path)

model = Model(args.model_name,
args.learning_rate)

trainer = Trainer(gpus=1,
max_epochs=arg.max_epoch,
log_every_n_steps=1)

# Inference
model = torch.load(args.best_model_path)
predictions = trainer.predict(model=model,
datamodule=datamodule)

predictions = list(round(float(i), 1) for i in torch.cat(predictions))

output = pd.read_csv('./data/sample_submission.csv')
output['target'] = predictions

output_name = args.best_model_path[:-3] + '.csv'
output.to_csv('output.csv', index=False)


if __name__ == '__main__':
### TODO : best model의 config값 불러와서 argument로 지정
parser = argparse.ArgumentParser()
parser.add_argument('--best_model_path', default='model.pt')

args = parser.parse_args(args=[])

inference(args)

8 changes: 8 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
pyYAML==5.4.1
pytorch-lightning==1.7.3
tqdm==4.64.1
torchmetrics==0.10.0
torch==1.12.1
transformers==4.21.2
numpy==1.20.3
pandas==1.5.0
88 changes: 88 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,24 @@
from trainer import Trainer
from utils import prepare_device

from datetime import datetime

import pytorch_lightning as pl
from pytorch_lightning import seed_everything

### TODO : path, filename 정해지면 수정
# from PATH import CUSTOMDATAMODULE as DataModule
# from PATH import CUSTOMMODEL as Model
# from PATH import CUSTOMTRAINER as Trainer
# from PATH import CUSTOMPREPROCESS as preprocess

### TODO : utils에 scatterplot 구현하기
# "일례로, 위 그림은 모두 0.82의 상관계수를 가지나 모두 생김새가 다릅니다. 그러므로, 결과 분석 시 꼭 산점도를 그려보며 모델의 취약점을 평가해야합니다."
# from utils import scatterplot

### TODO : utils에 save_config 구현하기
# from utils import save_config


# fix random seeds for reproducibility
SEED = 123
Expand All @@ -18,6 +36,75 @@
torch.backends.cudnn.benchmark = False
np.random.seed(SEED)

seed_everything(SEED, workers=True) # pl seed



def train(args):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

저희 베이스라인의 train.py를 조금 수정해서 넣었습니다.
모델을 저장할 때 ./save/model/.pt 로 저장되게 했고,
modelname = 모델성능 + 베이스모델이름 + 저장시각 + 돌린사람으로 설정했어요~

### TODO : preprocess 함수도 인자로 넘겼으면..
datamodule = DataModule(args.model_name,
args.batch_size, args.shuffle,
args.train_path, args.dev_path,
args.test_path, args.predict_path)

model = Model(args.model_name,
args.learning_rate)

trainer = Trainer(gpus=1,
max_epochs=arg.max_epoch,
log_every_n_steps=1)

trainer.fit(model=model, datamodule=datamodule)
test_dict = trainer.test(model=model, datamodule=datamodule) # val로 테스트함.


# 모델 이름 자동 생성
# 형식: val 성능 + 모델이름 + 저장시각 + 실험자.pt
saved_model_path = './save/model/'
test_pearson = test_dict[0]['test_pearson']

now = datetime.now()
now_str = now.strftime('%Y-%m-%d_%H:%M:%S')
saved_model_name = '+'.join([str(test_pearson),
args.model_name,
now_str,
args.expreimenter])

torch.save(model, saved_model_path + saved_model_name + '.pt')


### TODO : scatterplot 그리고 저장하는 함수
# scatterplot(data, './save/plot/' + saved_model_name + '.png')


### TODO : config 저장하는 함수
# save_config(args, './save/config/' + saved_model_name + '.config')


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', default='klue/roberta-small', type=str)
parser.add_argument('--batch_size', default=16, type=int)
parser.add_argument('--max_epoch', default=1, type=int)
parser.add_argument('--shuffle', default=True)
parser.add_argument('--learning_rate', default=1e-5, type=float)
parser.add_argument('--train_path', default='../data/train.csv')
parser.add_argument('--dev_path', default='../data/dev.csv')
parser.add_argument('--test_path', default='../data/dev.csv')
parser.add_argument('--predict_path', default='../data/test.csv')

### TODO : custom argument
parser.add_argument('--preprocess', default="")
parser.add_argument('--experimenter', default='bc-nlp-03')
parser.add_argument('--optimizer', default='AdamW')
parser.add_argument('--loss_func', default='L1Loss')
parser.add_argument('--num_workers', default=4)

args = parser.parse_args(args=[])

train(args)

"""
def main(config):
logger = config.get_logger('train')

Expand Down Expand Up @@ -71,3 +158,4 @@ def main(config):
]
config = ConfigParser.from_args(args, options)
main(config)
"""