优化训练过程
This commit is contained in:
@ -409,7 +409,7 @@ def train(hyp, opt, device, callbacks):
|
||||
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
|
||||
|
||||
# Forward
|
||||
with torch.cuda.amp.autocast(amp):
|
||||
with torch.amp.autocast(device_type='cuda', enabled=amp):
|
||||
pred = model(imgs) # forward
|
||||
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
|
||||
if RANK != -1:
|
||||
|
@ -513,7 +513,7 @@ def check_font(font=FONT, progress=False):
|
||||
font = Path(font)
|
||||
file = CONFIG_DIR / font.name
|
||||
if not font.exists() and not file.exists():
|
||||
url = f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{font.name}"
|
||||
url = f"https://ultralytics.com/assets/{font.name}"
|
||||
LOGGER.info(f"Downloading {url} to {file}...")
|
||||
torch.hub.download_url_to_file(url, str(file), progress=progress)
|
||||
|
||||
|
Reference in New Issue
Block a user