优化训练过程

This commit is contained in:
2025-03-20 16:41:51 +08:00
parent 358bb40a2a
commit bba39adcfc
7 changed files with 16 additions and 12 deletions

View File

@ -409,7 +409,7 @@ def train(hyp, opt, device, callbacks):
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
# Forward
with torch.cuda.amp.autocast(amp):
with torch.amp.autocast(device_type='cuda', enabled=amp):
pred = model(imgs) # forward
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
if RANK != -1: