Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Managed the import of torch.amp to be compatible with all pytorch versions #13487

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@
import torch.nn as nn
import yaml
from torch.optim import lr_scheduler

try:
import torch.amp as amp
except ImportError:
import torch.cuda.amp as amp

from tqdm import tqdm

FILE = Path(__file__).resolve()
Expand Down Expand Up @@ -221,7 +227,7 @@ def train(hyp, opt, device, callbacks):
LOGGER.info(f"Transferred {len(csd)}/{len(model.state_dict())} items from {weights}") # report
else:
model = Model(cfg, ch=3, nc=nc, anchors=hyp.get("anchors")).to(device) # create
amp = check_amp(model) # check AMP
use_amp = check_amp(model) # check AMP

# Freeze
freeze = [f"model.{x}." for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # layers to freeze
Expand All @@ -238,7 +244,7 @@ def train(hyp, opt, device, callbacks):

# Batch size
if RANK == -1 and batch_size == -1: # single-GPU only, estimate best batch size
batch_size = check_train_batch_size(model, imgsz, amp)
batch_size = check_train_batch_size(model, imgsz, use_amp)
loggers.on_params_update({"batch_size": batch_size})

# Optimizer
Expand Down Expand Up @@ -352,7 +358,8 @@ def lf(x):
maps = np.zeros(nc) # mAP per class
results = (0, 0, 0, 0, 0, 0, 0) # P, R, [email protected], [email protected], val_loss(box, obj, cls)
scheduler.last_epoch = start_epoch - 1 # do not move
scaler = torch.cuda.amp.GradScaler(enabled=amp)
# scaler = torch.cuda.amp.GradScaler(enabled=amp)
scaler = amp.GradScaler(enabled=use_amp)
stopper, stop = EarlyStopping(patience=opt.patience), False
compute_loss = ComputeLoss(model) # init loss class
callbacks.run("on_train_start")
Expand Down Expand Up @@ -409,7 +416,8 @@ def lf(x):
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)

# Forward
with torch.cuda.amp.autocast(amp):
# with torch.cuda.amp.autocast(amp):
with amp.autocast(enabled=use_amp, device_type=device.type):
pred = model(imgs) # forward
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
if RANK != -1:
Expand Down Expand Up @@ -458,7 +466,7 @@ def lf(x):
data_dict,
batch_size=batch_size // WORLD_SIZE * 2,
imgsz=imgsz,
half=amp,
half=use_amp,
model=ema.ema,
single_cls=single_cls,
dataloader=val_loader,
Expand Down
Loading