From 00dc9b02eb293ffa1cf5d1bc96cef066d1e2903b Mon Sep 17 00:00:00 2001 From: JIAKUNHAO Date: Fri, 25 Nov 2022 17:07:45 +0800 Subject: [PATCH] =?UTF-8?q?=E5=85=B1=E4=BA=AB=E4=BC=A0=E5=8F=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/controller/AlgorithmController.py | 4 ++-- app/yolov5/train_server.py | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/app/controller/AlgorithmController.py b/app/controller/AlgorithmController.py index ef2a4a6..34ebe9b 100644 --- a/app/controller/AlgorithmController.py +++ b/app/controller/AlgorithmController.py @@ -301,7 +301,7 @@ def error_return(id: str, data): # 启动训练 @start_train_algorithm() -def train_R0DY(params_str, id): +def train_R0DY(params_str, id, getsomething): print('**********************************') print(params_str) print('**********************************') @@ -317,7 +317,7 @@ def train_R0DY(params_str, id): batch_size = params.get('batch_size').value device = params.get('device').value #try: - train_start(weights, savemodel, epoches, img_size, batch_size, device, data_list, id) + train_start(weights, savemodel, epoches, img_size, batch_size, device, data_list, id, getsomething) print("train down!") # except Exception as e: # print(repr(e)) diff --git a/app/yolov5/train_server.py b/app/yolov5/train_server.py index 0d50d5f..9358eb2 100644 --- a/app/yolov5/train_server.py +++ b/app/yolov5/train_server.py @@ -89,7 +89,7 @@ def yaml_rewrite(file='data.yaml',data_list=[]): with open(file, 'w') as f: yaml.safe_dump(coco_dict, f, sort_keys=False) -def train(hyp, opt, device, data_list,id,callbacks): # hyp is path/to/hyp.yaml or hyp dictionary +def train(hyp, opt, device, data_list,id,getsomething,callbacks): # hyp is path/to/hyp.yaml or hyp dictionary #将数据路径写到yaml文件中 #data_list = file_tool.get_file(proj_no=pro) # print(data_list) @@ -574,7 +574,7 @@ def parse_opt(weights,savemodel,epoches,img_size,batch_size,device,known=False): return parser.parse_known_args()[0] if known else parser.parse_args() -def main(opt,data_list,id,callbacks=Callbacks()): +def main(opt,data_list,id,getsomething,callbacks=Callbacks()): # Checks if RANK in {-1, 0}: print_args(vars(opt)) @@ -624,7 +624,7 @@ def main(opt,data_list,id,callbacks=Callbacks()): # Train if not opt.evolve: - train(opt.hyp, opt, device, data_list,id,callbacks) + train(opt.hyp, opt, device, data_list,id,getsomething,callbacks) # Evolve hyperparameters (optional) else: @@ -705,7 +705,7 @@ def main(opt,data_list,id,callbacks=Callbacks()): hyp[k] = round(hyp[k], 5) # significant digits # Train mutation - results = train(hyp.copy(), opt, device, callbacks) + results = train(hyp.copy(), opt, device, getsomething=getsomething,callbacks=callbacks) callbacks = Callbacks() # Write mutation results print_mutation(results, hyp.copy(), save_dir, opt.bucket) @@ -725,9 +725,9 @@ def run(**kwargs): main(opt) return opt -def train_start(weights,savemodel,epoches,img_size,batch_size,device,data_list,id): +def train_start(weights,savemodel,epoches,img_size,batch_size,device,data_list,id,getsomething): opt = parse_opt(weights,savemodel,epoches,img_size,batch_size,device) - main(opt,data_list,id) + main(opt,data_list,id,getsomething) if __name__ == "__main__":