From b70ba8a4313359fcfd0afcaa242e3d6b7445005f Mon Sep 17 00:00:00 2001 From: jiakunhao Date: Mon, 14 Nov 2022 17:44:14 +0800 Subject: [PATCH] =?UTF-8?q?yolo=E4=B8=AD=E6=96=AD=E8=AE=AD=E7=BB=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/controller/AlgorithmController.py | 61 +++++++++++++++++++++++++-- app/yolov5/train_server.py | 16 +++++++ 2 files changed, 74 insertions(+), 3 deletions(-) diff --git a/app/controller/AlgorithmController.py b/app/controller/AlgorithmController.py index 121afbc..ddabca1 100644 --- a/app/controller/AlgorithmController.py +++ b/app/controller/AlgorithmController.py @@ -31,6 +31,7 @@ from pathlib import Path bp = Blueprint('AlgorithmController', __name__) +ifKillDict = {} def start_train_algorithm(): """ @@ -147,6 +148,42 @@ def algorithm_process_value_websocket(): return wrapTheFunction +def algorithm_kill_value_websocket(): + """ + 获取kill值, websocket发布 + """ + + def wrapTheFunction(func): + @wraps(func) + def wrapped_function(*args, **kwargs): + data = func(*args, **kwargs) + id = data["id"] + data_res = {'code': 1, "type": 'kill', 'msg': 'success', 'data': data} + manager.send_message_proj_json(message=data_res, id=id) + return data + + return wrapped_function + + return wrapTheFunction + + +def algorithm_error_value_websocket(): + """ + 获取error值, websocket发布 + """ + + def wrapTheFunction(func): + @wraps(func) + def wrapped_function(*args, **kwargs): + data = func(*args, **kwargs) + id = data["id"] + data_res = {'code': 2, "type": 'error', 'msg': 'fail', 'data': data} + manager.send_message_proj_json(message=data_res, id=id) + return data + + return wrapped_function + + return wrapTheFunction def obtain_train_param(): """ @@ -164,7 +201,6 @@ def obtain_train_param(): return wrapTheFunction - def obtain_test_param(): """ 获取验证参数 @@ -215,6 +251,16 @@ def obtain_download_pt_param(): return wrapTheFunction +@bp.route('/change_ifKillDIct', methods=['get']) +def change_ifKillDIct(): + """ + 修改全局变量 + """ + id = request.args.get('id') + type = request.args.get('type') + global ifKillDict + ifKillDict[id] = False + return output_wrapped(0, 'success') # @start_train_algorithm() # def start(param: str): @@ -241,6 +287,13 @@ from app.schemas.TrainResult import DetectProcessValueDice, DetectReport from app import file_tool +def error_return(id: str): + """ + 算法出错,返回 + """ + data_res = {'code': 2, "type": 'error', 'msg': 'fail', 'data': None} + manager.send_message_proj_json(message=data_res, id=id) + # 启动训练 @start_train_algorithm() def train_R0DY(params_str, id): @@ -255,8 +308,10 @@ def train_R0DY(params_str, id): epoches = params.get('epochnum').value batch_size = params.get('batch_size').value device = params.get('device').value - - train_start(weights, savemodel, epoches, img_size, batch_size, device, data_list, id) + try: + train_start(weights, savemodel, epoches, img_size, batch_size, device, data_list, id) + except: + error_return(id=id) # 启动验证程序 diff --git a/app/yolov5/train_server.py b/app/yolov5/train_server.py index c122f9d..aafb52f 100644 --- a/app/yolov5/train_server.py +++ b/app/yolov5/train_server.py @@ -61,6 +61,8 @@ from app.yolov5.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, smart_resume, torch_distributed_zero_first) from app.schemas.TrainResult import Report, ProcessValueList from app.controller.AlgorithmController import algorithm_process_value_websocket +from app.controller.AlgorithmController import ifKillDict +from app.utils.websocket_tool import manager LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html RANK = int(os.getenv('RANK', -1)) WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1)) @@ -304,6 +306,15 @@ def train(hyp, opt, device, data_list,id,callbacks): # hyp is path/to/hyp.yaml num_train_img=train_num, train_mod_savepath=best) + def kill_return(): + """ + 算法中断,返回 + """ + id = report.id + data = report.dict() + data_res = {'code': 1, "type": 'kill', 'msg': 'fail', 'data': data} + manager.send_message_proj_json(message=data_res, id=id) + @algorithm_process_value_websocket() def report_cellback(i, num_epochs, reportAccu): report.rate_of_progess = ((i + 1) / num_epochs) * 100 @@ -314,6 +325,11 @@ def train(hyp, opt, device, data_list,id,callbacks): # hyp is path/to/hyp.yaml ###################结束####################### for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------ #callbacks.run('on_train_epoch_start') + global ifKillDict + ifkill = ifKillDict['id'] + if ifkill: + kill_return() + break model.train() # Update image weights (optional, single-GPU only)