共享传参

This commit is contained in:
2022-11-25 17:07:45 +08:00
parent 9a4b11b054
commit 00dc9b02eb
2 changed files with 8 additions and 8 deletions

View File

@ -89,7 +89,7 @@ def yaml_rewrite(file='data.yaml',data_list=[]):
with open(file, 'w') as f:
yaml.safe_dump(coco_dict, f, sort_keys=False)
def train(hyp, opt, device, data_list,id,callbacks): # hyp is path/to/hyp.yaml or hyp dictionary
def train(hyp, opt, device, data_list,id,getsomething,callbacks): # hyp is path/to/hyp.yaml or hyp dictionary
#将数据路径写到yaml文件中
#data_list = file_tool.get_file(proj_no=pro)
# print(data_list)
@ -574,7 +574,7 @@ def parse_opt(weights,savemodel,epoches,img_size,batch_size,device,known=False):
return parser.parse_known_args()[0] if known else parser.parse_args()
def main(opt,data_list,id,callbacks=Callbacks()):
def main(opt,data_list,id,getsomething,callbacks=Callbacks()):
# Checks
if RANK in {-1, 0}:
print_args(vars(opt))
@ -624,7 +624,7 @@ def main(opt,data_list,id,callbacks=Callbacks()):
# Train
if not opt.evolve:
train(opt.hyp, opt, device, data_list,id,callbacks)
train(opt.hyp, opt, device, data_list,id,getsomething,callbacks)
# Evolve hyperparameters (optional)
else:
@ -705,7 +705,7 @@ def main(opt,data_list,id,callbacks=Callbacks()):
hyp[k] = round(hyp[k], 5) # significant digits
# Train mutation
results = train(hyp.copy(), opt, device, callbacks)
results = train(hyp.copy(), opt, device, getsomething=getsomething,callbacks=callbacks)
callbacks = Callbacks()
# Write mutation results
print_mutation(results, hyp.copy(), save_dir, opt.bucket)
@ -725,9 +725,9 @@ def run(**kwargs):
main(opt)
return opt
def train_start(weights,savemodel,epoches,img_size,batch_size,device,data_list,id):
def train_start(weights,savemodel,epoches,img_size,batch_size,device,data_list,id,getsomething):
opt = parse_opt(weights,savemodel,epoches,img_size,batch_size,device)
main(opt,data_list,id)
main(opt,data_list,id,getsomething)
if __name__ == "__main__":