This commit is contained in:
552068321@qq.com 2022-11-07 20:49:02 +08:00
parent 2af24e2f03
commit e0d7d37b9e

View File

@ -161,8 +161,8 @@ def train(hyp, opt, device, data_list,id,callbacks): # hyp is path/to/hyp.yaml
check_suffix(weights, '.pt') # check weights
pretrained = weights.endswith('.pt')
if pretrained:
with torch_distributed_zero_first(LOCAL_RANK):
weights = attempt_download(weights) # download if not found locally
# with torch_distributed_zero_first(LOCAL_RANK):
# weights = attempt_download(weights) # download if not found locally
ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak
model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys