yolo中断训练
This commit is contained in:
@ -31,6 +31,7 @@ from pathlib import Path
|
||||
|
||||
bp = Blueprint('AlgorithmController', __name__)
|
||||
|
||||
ifKillDict = {}
|
||||
|
||||
def start_train_algorithm():
|
||||
"""
|
||||
@ -147,6 +148,42 @@ def algorithm_process_value_websocket():
|
||||
|
||||
return wrapTheFunction
|
||||
|
||||
def algorithm_kill_value_websocket():
|
||||
"""
|
||||
获取kill值, websocket发布
|
||||
"""
|
||||
|
||||
def wrapTheFunction(func):
|
||||
@wraps(func)
|
||||
def wrapped_function(*args, **kwargs):
|
||||
data = func(*args, **kwargs)
|
||||
id = data["id"]
|
||||
data_res = {'code': 1, "type": 'kill', 'msg': 'success', 'data': data}
|
||||
manager.send_message_proj_json(message=data_res, id=id)
|
||||
return data
|
||||
|
||||
return wrapped_function
|
||||
|
||||
return wrapTheFunction
|
||||
|
||||
|
||||
def algorithm_error_value_websocket():
|
||||
"""
|
||||
获取error值, websocket发布
|
||||
"""
|
||||
|
||||
def wrapTheFunction(func):
|
||||
@wraps(func)
|
||||
def wrapped_function(*args, **kwargs):
|
||||
data = func(*args, **kwargs)
|
||||
id = data["id"]
|
||||
data_res = {'code': 2, "type": 'error', 'msg': 'fail', 'data': data}
|
||||
manager.send_message_proj_json(message=data_res, id=id)
|
||||
return data
|
||||
|
||||
return wrapped_function
|
||||
|
||||
return wrapTheFunction
|
||||
|
||||
def obtain_train_param():
|
||||
"""
|
||||
@ -164,7 +201,6 @@ def obtain_train_param():
|
||||
|
||||
return wrapTheFunction
|
||||
|
||||
|
||||
def obtain_test_param():
|
||||
"""
|
||||
获取验证参数
|
||||
@ -215,6 +251,16 @@ def obtain_download_pt_param():
|
||||
|
||||
return wrapTheFunction
|
||||
|
||||
@bp.route('/change_ifKillDIct', methods=['get'])
|
||||
def change_ifKillDIct():
|
||||
"""
|
||||
修改全局变量
|
||||
"""
|
||||
id = request.args.get('id')
|
||||
type = request.args.get('type')
|
||||
global ifKillDict
|
||||
ifKillDict[id] = False
|
||||
return output_wrapped(0, 'success')
|
||||
|
||||
# @start_train_algorithm()
|
||||
# def start(param: str):
|
||||
@ -241,6 +287,13 @@ from app.schemas.TrainResult import DetectProcessValueDice, DetectReport
|
||||
from app import file_tool
|
||||
|
||||
|
||||
def error_return(id: str):
|
||||
"""
|
||||
算法出错,返回
|
||||
"""
|
||||
data_res = {'code': 2, "type": 'error', 'msg': 'fail', 'data': None}
|
||||
manager.send_message_proj_json(message=data_res, id=id)
|
||||
|
||||
# 启动训练
|
||||
@start_train_algorithm()
|
||||
def train_R0DY(params_str, id):
|
||||
@ -255,8 +308,10 @@ def train_R0DY(params_str, id):
|
||||
epoches = params.get('epochnum').value
|
||||
batch_size = params.get('batch_size').value
|
||||
device = params.get('device').value
|
||||
|
||||
train_start(weights, savemodel, epoches, img_size, batch_size, device, data_list, id)
|
||||
try:
|
||||
train_start(weights, savemodel, epoches, img_size, batch_size, device, data_list, id)
|
||||
except:
|
||||
error_return(id=id)
|
||||
|
||||
|
||||
# 启动验证程序
|
||||
|
Reference in New Issue
Block a user