Skip to content

Commit

Permalink
complete the evaluation and metric logging
Browse files Browse the repository at this point in the history
  • Loading branch information
sherryzyh committed Dec 10, 2022
1 parent 08cb29e commit e827eba
Show file tree
Hide file tree
Showing 3 changed files with 244 additions and 285 deletions.
30 changes: 13 additions & 17 deletions code/train_cl_discriminate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import argparse
from utils.utils import parse_hps, get_exp_name, get_exp_path, load_data, quick_tokenize, contrastive_tokenize, load_loss_function, \
evaluation, cl_evaluation, define_logger, save_model, save_metric_log
cl_evaluation, define_logger, save_model, save_metric_log
import random
import numpy as np
import torch
Expand All @@ -16,20 +16,20 @@
import logging


def CL_evaluate(model, dev_dataloader, patient, best_accuracy, loss_function, logger, hps, exp_name, mode="dev"):
def evaluate(model, dev_dataloader, patient, best_accuracy, loss_function, logger, hps, eval_step, exp_name, exp_path, mode="dev"):
model.eval()
stop_train = False

with torch.no_grad():
print('\n')
logger.info("[Dev Evaluation] Start Evaluation on Dev Set")
if hps.loss_func == 'CrossEntropy':
dev_accu, dev_exact_accu, dev_loss = cl_evaluation(hps, dev_dataloader, model, loss_function)
dev_accu, dev_exact_accu, dev_loss = cl_evaluation(hps, dev_dataloader, model, loss_function, eval_step, exp_path)
print('\n')
logger.info("[Dev Metrics] Dev Soft Accuracy: \t{}".format(dev_accu))
logger.info("[Dev Metrics] Dev Exact Accuracy: \t{}".format(dev_exact_accu))
else:
dev_accu, dev_loss = cl_evaluation(hps, dev_dataloader, model, loss_function)
dev_accu, dev_loss = cl_evaluation(hps, dev_dataloader, model, loss_function, eval_step, exp_path)
print('\n')
logger.info("[Dev Metrics] Dev Accuracy: \t{}".format(dev_accu))
logger.info("[Dev Metrics] Dev Loss: \t{}".format(dev_loss))
Expand All @@ -53,7 +53,7 @@ def CL_evaluate(model, dev_dataloader, patient, best_accuracy, loss_function, lo
return patient, stop_train, dev_accu, dev_loss


def CL_train(model, optimizer, train_dataloader, dev_dataloader, loss_function, logger, hps, exp_name):
def train(model, optimizer, train_dataloader, dev_dataloader, loss_function, logger, hps, exp_name, exp_path):
logger.info("[INFO] Start Training")
step = 0
patient = 0
Expand Down Expand Up @@ -87,24 +87,21 @@ def CL_train(model, optimizer, train_dataloader, dev_dataloader, loss_function,
optimizer.step()

if hps.evaluation_strategy == "step" and step % hps.evaluation_step == 0 and step != 0:
patient, stop_train, dev_accu, dev_loss = CL_evaluate(model, dev_dataloader, patient, best_accuracy,
loss_function, logger,
hps, exp_name)
patient, stop_train, dev_accu, dev_loss = evaluate(model, dev_dataloader, patient, best_accuracy, loss_function, logger, hps, step, exp_name, exp_path)
metric_log[f'step_{step}']['dev_accu'] = dev_accu
metric_log[f'step_{step}']['dev_loss'] = dev_loss
if stop_train:
return
step += 1

if hps.loss_func == 'BCE':
train_accu, train_loss = cl_evaluation(hps, train_dataloader, model, loss_function)
train_accu, train_loss = cl_evaluation(hps, dev_dataloader, model, loss_function, eval_step, exp_path, print_pred=False)
logger.info("[Train Metrics] Train Accuracy: \t{}".format(train_accu))
logger.info("[Train Metrics] Train Loss: \t{}".format(train_loss))
metric_log[f'epoch_{epoch}']['train_accu'] = train_accu
metric_log[f'epoch_{epoch}']['train_loss'] = train_loss

if hps.evaluation_strategy == "epoch":
patient, stop_train, dev_accu, dev_loss = CL_evaluate(model, dev_dataloader, patient, best_accuracy,
loss_function, logger, hps,
exp_name)
patient, stop_train, dev_accu, dev_loss = evaluate(model, dev_dataloader, patient, best_accuracy, loss_function, logger, hps, epoch, exp_name, exp_path)
metric_log[f'epoch_{epoch}']['dev_accu'] = dev_accu
metric_log[f'epoch_{epoch}']['dev_loss'] = dev_loss

Expand All @@ -118,6 +115,7 @@ def main():
# parse hyper parameters
hps = parse_hps()
exp_name = get_exp_name(hps, "discriminate")
exp_path = get_exp_path(hps, exp_name)

# fix random seed
if hps.set_seed:
Expand All @@ -128,16 +126,14 @@ def main():

# prepare logger
logger, formatter = define_logger()
log_path = os.path.join(hps.log_dir, exp_name + ".txt")
log_path = os.path.join(exp_path, exp_name + ".txt")

file_handler = logging.FileHandler(log_path)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)

# logging all the hyper parameters
logger.info(f"=== hps ===\n{hps}")

exp_path = get_exp_path(hps, exp_name)
logger.info(f"[INFO] Experiment Path: {exp_path}")

# load data
Expand Down Expand Up @@ -188,7 +184,7 @@ def main():
# model = nn.parallel.DistributedDataParallel(model, device_ids=gpu_ids)

# contrastive training
CL_train(model, optimizer, train_dataloader, dev_dataloader, loss_function, logger, hps, exp_name)
train(model, optimizer, train_dataloader, dev_dataloader, loss_function, logger, hps, exp_name, exp_path)


if __name__ == '__main__':
Expand Down
38 changes: 19 additions & 19 deletions code/train_discriminate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import argparse
from utils.utils import parse_hps, get_exp_name, get_exp_path, load_data, quick_tokenize, load_loss_function, \
evaluation, define_logger, save_model
cr_evaluation, define_logger, save_model
import random
import numpy as np
import torch
Expand All @@ -19,25 +19,20 @@
from utils.kb_dataset import MyDataLoader


def evaluate(model, dev_dataloader, patient, best_accuracy, loss_function, logger, hps, epoch, metric_log, exp_name, exp_path):
def evaluate(model, dev_dataloader, patient, best_accuracy, loss_function, logger, hps, eval_step, metric_log, exp_name, exp_path):
model.eval()
stop_train = False

with torch.no_grad():
print('\n')
logger.info("[Dev Evaluation] Strain Evaluation on Dev Set")
logger.info("[Dev Evaluation] Start Evaluation on Dev Set")
if hps.loss_func == 'CrossEntropy':
dev_accu, dev_exact_accu, dev_loss = evaluation(hps, dev_dataloader, model, loss_function, epoch, exp_path)
metric_log[f'epoch_{epoch}']['dev_accuarcy'] = dev_accu
metric_log[f'epoch_{epoch}']['dev_exact_accuracy'] = dev_exact_accu
metric_log[f'epoch_{epoch}']['dev_loss'] = dev_loss
dev_accu, dev_exact_accu, dev_loss = cr_evaluation(hps, dev_dataloader, model, loss_function, eval_step, exp_path)
print('\n')
logger.info("[Dev Metrics] Dev Soft Accuracy: \t{}".format(dev_accu))
logger.info("[Dev Metrics] Dev Exact Accuracy: \t{}".format(dev_exact_accu))
else:
dev_accu, dev_loss = evaluation(hps, dev_dataloader, model, loss_function, epoch, exp_path)
metric_log[f'epoch_{epoch}']['dev_accuarcy'] = dev_accu
metric_log[f'epoch_{epoch}']['dev_loss'] = dev_loss
dev_accu, dev_loss = cr_evaluation(hps, dev_dataloader, model, loss_function, eval_step, exp_path)
print('\n')
logger.info("[Dev Metrics] Dev Accuracy: \t{}".format(dev_accu))
logger.info("[Dev Metrics] Dev Loss: \t{}".format(dev_loss))
Expand All @@ -58,7 +53,7 @@ def evaluate(model, dev_dataloader, patient, best_accuracy, loss_function, logge
if patient >= hps.patient:
logger.info("[INFO] Stopping Training by Early Stopping")
stop_train = True
return patient, stop_train
return patient, stop_train, dev_accu, dev_loss


def train(model, optimizer, train_dataloader, dev_dataloader, loss_function, logger, hps, exp_name, exp_path):
Expand Down Expand Up @@ -99,29 +94,34 @@ def train(model, optimizer, train_dataloader, dev_dataloader, loss_function, log
optimizer.step()

if hps.evaluation_strategy == "step" and step % hps.evaluation_step == 0 and step != 0:
patient, stop_train = evaluate(model, dev_dataloader, patient, best_accuracy, loss_function, logger,
hps, epoch, metric_log, exp_name, exp_path)
patient, stop_train, dev_accu, dev_loss = evaluate(model, dev_dataloader, patient, best_accuracy, loss_function, logger,
hps, step, metric_log, exp_name, exp_path)
metric_log[f'step_{step}']['dev_accu'] = dev_accu
metric_log[f'step_{step}']['dev_loss'] = dev_loss
if stop_train:
return
step += 1

train_loss = total_loss / (epoch_step * hps.batch_size) * 100
train_accu, train_loss = cr_evaluation(hps, train_dataloader, model, loss_function, eval_step, exp_path, print_pred=False)
logger.info("[Train Metrics] Train Accuracy: \t{}".format(train_accu))
logger.info("[Train Metrics] Train Loss: \t{}".format(train_loss))
metric_log[f'epoch_{epoch}']['train_accu'] = train_accu
metric_log[f'epoch_{epoch}']['train_loss'] = train_loss

if hps.evaluation_strategy == "epoch":
patient, stop_train = evaluate(model, dev_dataloader, patient, best_accuracy, loss_function, logger, hps,
patient, stop_train, dev_accu, dev_loss = evaluate(model, dev_dataloader, patient, best_accuracy, loss_function, logger, hps,
epoch, metric_log, exp_name, exp_path)
if stop_train:
return
metric_log[f'epoch_{epoch}']['dev_accu'] = dev_accu
metric_log[f'epoch_{epoch}']['dev_loss'] = dev_loss
if stop_train:
return


def main():
# parse hyper parameters
hps = parse_hps()
exp_name = get_exp_name(hps, "discriminate")
exp_path = get_exp_path(hps, exp_name)
if not os.path.exists(exp_path):
os.mkdir(exp_path)

# fix random seed
if hps.set_seed:
Expand Down
Loading

0 comments on commit e827eba

Please sign in to comment.