训练卡住debug

This commit is contained in:
2022-11-18 15:44:10 +08:00
parent 3dad06f3d0
commit 6a8a491d3a
2 changed files with 11 additions and 6 deletions

View File

@ -93,6 +93,7 @@ def train(hyp, opt, device, data_list,id,callbacks): # hyp is path/to/hyp.yaml
#将数据路径写到yaml文件中
#data_list = file_tool.get_file(proj_no=pro)
# print(data_list)
print("get in train()")
yaml_rewrite(file=opt.data, data_list=data_list)
save_dir, epochs,batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze = \
@ -209,6 +210,7 @@ def train(hyp, opt, device, data_list,id,callbacks): # hyp is path/to/hyp.yaml
ema = ModelEMA(model) if RANK in {-1, 0} else None
# Resume
print("Resume")
best_fitness, start_epoch = 0.0, 0
if pretrained:
if resume:
@ -226,6 +228,7 @@ def train(hyp, opt, device, data_list,id,callbacks): # hyp is path/to/hyp.yaml
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
LOGGER.info('Using SyncBatchNorm()')
print("Trainloader")
# Trainloader
train_loader, dataset = create_dataloader(train_path,
imgsz,
@ -283,6 +286,7 @@ def train(hyp, opt, device, data_list,id,callbacks): # hyp is path/to/hyp.yaml
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
model.names = names
print("Start training")
# Start training
t0 = time.time()
nb = len(train_loader) # number of batches