yolo中断训练功能添加

This commit is contained in:
jiakunhao 2022-11-15 17:59:06 +08:00
parent b70ba8a431
commit cdb081654a
4 changed files with 46 additions and 8 deletions

27
app/configs/global_var.py Normal file
View File

@ -0,0 +1,27 @@
"""
@Time 2022/11/15 10:13
@Auth
@File global_var.py
@IDE PyCharm
@MottoABC(Always Be Coding)
@Desc
"""
def _init(): # 初始化
global _global_dict
_global_dict = {}
def set_value(key, value):
# 定义一个全局变量
_global_dict[key] = value
def get_value(key):
# 获得一个全局变量,不存在则提示读取对应变量失败
try:
return _global_dict[key]
except:
print('读取' + key + '失败\r\n')

View File

@ -287,11 +287,11 @@ from app.schemas.TrainResult import DetectProcessValueDice, DetectReport
from app import file_tool from app import file_tool
def error_return(id: str): def error_return(id: str, data):
""" """
算法出错返回 算法出错返回
""" """
data_res = {'code': 2, "type": 'error', 'msg': 'fail', 'data': None} data_res = {'code': 2, "type": 'error', 'msg': 'fail', 'data': data}
manager.send_message_proj_json(message=data_res, id=id) manager.send_message_proj_json(message=data_res, id=id)
# 启动训练 # 启动训练
@ -310,8 +310,10 @@ def train_R0DY(params_str, id):
device = params.get('device').value device = params.get('device').value
try: try:
train_start(weights, savemodel, epoches, img_size, batch_size, device, data_list, id) train_start(weights, savemodel, epoches, img_size, batch_size, device, data_list, id)
except: print("train down!")
error_return(id=id) except Exception as e:
print(repr(e))
error_return(id=id,data=repr(e))
# 启动验证程序 # 启动验证程序

View File

@ -32,6 +32,7 @@ class Report(BaseModel):
train_mod_savepath: str = Field(..., description="模型保存路径") train_mod_savepath: str = Field(..., description="模型保存路径")
start_time: datetime.date = Field(datetime.datetime.now(), description="开始时间") start_time: datetime.date = Field(datetime.datetime.now(), description="开始时间")
end_time: datetime.date = Field(datetime.datetime.now(), description="结束时间") end_time: datetime.date = Field(datetime.datetime.now(), description="结束时间")
alg_code: str = Field(..., description="模型编码")
class ReportDict(BaseModel): class ReportDict(BaseModel):

View File

@ -61,7 +61,7 @@ from app.yolov5.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel,
smart_resume, torch_distributed_zero_first) smart_resume, torch_distributed_zero_first)
from app.schemas.TrainResult import Report, ProcessValueList from app.schemas.TrainResult import Report, ProcessValueList
from app.controller.AlgorithmController import algorithm_process_value_websocket from app.controller.AlgorithmController import algorithm_process_value_websocket
from app.controller.AlgorithmController import ifKillDict from app.configs import global_var
from app.utils.websocket_tool import manager from app.utils.websocket_tool import manager
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
RANK = int(os.getenv('RANK', -1)) RANK = int(os.getenv('RANK', -1))
@ -304,7 +304,8 @@ def train(hyp, opt, device, data_list,id,callbacks): # hyp is path/to/hyp.yaml
report = Report(rate_of_progess=0, precision=[process_value_list], report = Report(rate_of_progess=0, precision=[process_value_list],
id=id, sum=epochs, progress=0, id=id, sum=epochs, progress=0,
num_train_img=train_num, num_train_img=train_num,
train_mod_savepath=best) train_mod_savepath=best,
alg_code="R-ODY")
def kill_return(): def kill_return():
""" """
@ -325,8 +326,9 @@ def train(hyp, opt, device, data_list,id,callbacks): # hyp is path/to/hyp.yaml
###################结束####################### ###################结束#######################
for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------ for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
#callbacks.run('on_train_epoch_start') #callbacks.run('on_train_epoch_start')
global ifKillDict print("start get global_var")
ifkill = ifKillDict['id'] ifkill = global_var.get_value(report.id)
print("get global_var down:",ifkill)
if ifkill: if ifkill:
kill_return() kill_return()
break break
@ -350,6 +352,12 @@ def train(hyp, opt, device, data_list,id,callbacks): # hyp is path/to/hyp.yaml
optimizer.zero_grad() optimizer.zero_grad()
for i, (imgs, targets, paths, _) in pbar: # batch ------------------------------------------------------------- for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
#callbacks.run('on_train_batch_start') #callbacks.run('on_train_batch_start')
print("start get global_var")
ifkill = global_var.get_value(report.id)
print("get global_var down:",ifkill)
if ifkill:
kill_return()
break
if targets.shape[0] == 0: if targets.shape[0] == 0:
targets = [[0.00000, 5.00000, 0.97002, 0.24679, 0.05995, 0.05553], targets = [[0.00000, 5.00000, 0.97002, 0.24679, 0.05995, 0.05553],
[0.00000, 7.00000, 0.95097, 0.32007, 0.04188, 0.02549], [0.00000, 7.00000, 0.95097, 0.32007, 0.04188, 0.02549],