Skip to content

Commit 21e0844

Browse files
authored
Fix: Add optimizer adamw
1 parent 472a012 commit 21e0844

File tree

1 file changed

+25
-7
lines changed

1 file changed

+25
-7
lines changed

AutoML_NAS.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,14 @@
1010
import optuna
1111
import copy
1212
import os
13+
import torch.nn as nn
1314
from datetime import datetime
1415
import wandb
1516
import torch
1617
from src.model import Model
1718
from src.utils.torch_utils import model_info, check_runtime
1819
from src.utils.common import read_yaml
19-
from train import train
20+
from src.dataloader import TorchTrainer, create_dataloader
2021
from src.trainer import count_model_params
2122
from typing import Any, Dict, List, Tuple
2223
import argparse
@@ -422,15 +423,32 @@ def objective(trial: optuna.trial.Trial, device) -> Tuple[float, int, float]:
422423
if params_nums >= 500000:
423424
print(f' trial: {trial.number}, This model has too many param:{params_nums}')
424425
raise optuna.structs.TrialPruned()
426+
427+
train_loader, val_loader, test_loader = create_dataloader(data_config)
425428

426-
# train current model
427-
test_loss, test_f1, test_acc = train(
428-
model_config=model_config,
429-
data_config=data_config,
430-
log_dir=log_dir,
431-
fp16=data_config["FP16"],
429+
criterion = nn.CrossEntropyLoss()
430+
optimizer = torch.optim.AdamW(
431+
model.parameters(), lr=data_config["INIT_LR"]
432+
)
433+
scheduler = torch.optim.lr_scheduler.OneCycleLR(
434+
optimizer,
435+
max_lr=0.1,
436+
steps_per_epoch=len(train_loader),
437+
epochs=data_config["EPOCHS"],
438+
pct_start=0.05,
439+
)
440+
441+
trainer = TorchTrainer(
442+
model,
443+
criterion,
444+
optimizer,
445+
scheduler,
432446
device=device,
447+
verbose=1,
448+
model_path=log_dir,
433449
)
450+
trainer.train(train_loader, data_config["EPOCHS"], val_dataloader=val_loader)
451+
loss, test_f1, acc_percent = trainer.test(model, test_dataloader=val_loader)
434452

435453
wandb.log({'f1':test_f1,'params_nums':params_nums, 'mean_time':mean_time})
436454

0 commit comments

Comments
 (0)