|
10 | 10 | import optuna
|
11 | 11 | import copy
|
12 | 12 | import os
|
| 13 | +import torch.nn as nn |
13 | 14 | from datetime import datetime
|
14 | 15 | import wandb
|
15 | 16 | import torch
|
16 | 17 | from src.model import Model
|
17 | 18 | from src.utils.torch_utils import model_info, check_runtime
|
18 | 19 | from src.utils.common import read_yaml
|
19 |
| -from train import train |
| 20 | +from src.dataloader import TorchTrainer, create_dataloader |
20 | 21 | from src.trainer import count_model_params
|
21 | 22 | from typing import Any, Dict, List, Tuple
|
22 | 23 | import argparse
|
@@ -422,15 +423,32 @@ def objective(trial: optuna.trial.Trial, device) -> Tuple[float, int, float]:
|
422 | 423 | if params_nums >= 500000:
|
423 | 424 | print(f' trial: {trial.number}, This model has too many param:{params_nums}')
|
424 | 425 | raise optuna.structs.TrialPruned()
|
| 426 | + |
| 427 | + train_loader, val_loader, test_loader = create_dataloader(data_config) |
425 | 428 |
|
426 |
| - # train current model |
427 |
| - test_loss, test_f1, test_acc = train( |
428 |
| - model_config=model_config, |
429 |
| - data_config=data_config, |
430 |
| - log_dir=log_dir, |
431 |
| - fp16=data_config["FP16"], |
| 429 | + criterion = nn.CrossEntropyLoss() |
| 430 | + optimizer = torch.optim.AdamW( |
| 431 | + model.parameters(), lr=data_config["INIT_LR"] |
| 432 | + ) |
| 433 | + scheduler = torch.optim.lr_scheduler.OneCycleLR( |
| 434 | + optimizer, |
| 435 | + max_lr=0.1, |
| 436 | + steps_per_epoch=len(train_loader), |
| 437 | + epochs=data_config["EPOCHS"], |
| 438 | + pct_start=0.05, |
| 439 | + ) |
| 440 | + |
| 441 | + trainer = TorchTrainer( |
| 442 | + model, |
| 443 | + criterion, |
| 444 | + optimizer, |
| 445 | + scheduler, |
432 | 446 | device=device,
|
| 447 | + verbose=1, |
| 448 | + model_path=log_dir, |
433 | 449 | )
|
| 450 | + trainer.train(train_loader, data_config["EPOCHS"], val_dataloader=val_loader) |
| 451 | + loss, test_f1, acc_percent = trainer.test(model, test_dataloader=val_loader) |
434 | 452 |
|
435 | 453 | wandb.log({'f1':test_f1,'params_nums':params_nums, 'mean_time':mean_time})
|
436 | 454 |
|
|
0 commit comments