yolo中断训练功能添加
This commit is contained in:
parent
b70ba8a431
commit
cdb081654a
27
app/configs/global_var.py
Normal file
27
app/configs/global_var.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
"""
|
||||||
|
@Time : 2022/11/15 10:13
|
||||||
|
@Auth : 东
|
||||||
|
@File :global_var.py
|
||||||
|
@IDE :PyCharm
|
||||||
|
@Motto:ABC(Always Be Coding)
|
||||||
|
@Desc:
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _init(): # 初始化
|
||||||
|
global _global_dict
|
||||||
|
_global_dict = {}
|
||||||
|
|
||||||
|
|
||||||
|
def set_value(key, value):
|
||||||
|
# 定义一个全局变量
|
||||||
|
_global_dict[key] = value
|
||||||
|
|
||||||
|
|
||||||
|
def get_value(key):
|
||||||
|
# 获得一个全局变量,不存在则提示读取对应变量失败
|
||||||
|
try:
|
||||||
|
return _global_dict[key]
|
||||||
|
except:
|
||||||
|
print('读取' + key + '失败\r\n')
|
@ -287,11 +287,11 @@ from app.schemas.TrainResult import DetectProcessValueDice, DetectReport
|
|||||||
from app import file_tool
|
from app import file_tool
|
||||||
|
|
||||||
|
|
||||||
def error_return(id: str):
|
def error_return(id: str, data):
|
||||||
"""
|
"""
|
||||||
算法出错,返回
|
算法出错,返回
|
||||||
"""
|
"""
|
||||||
data_res = {'code': 2, "type": 'error', 'msg': 'fail', 'data': None}
|
data_res = {'code': 2, "type": 'error', 'msg': 'fail', 'data': data}
|
||||||
manager.send_message_proj_json(message=data_res, id=id)
|
manager.send_message_proj_json(message=data_res, id=id)
|
||||||
|
|
||||||
# 启动训练
|
# 启动训练
|
||||||
@ -310,8 +310,10 @@ def train_R0DY(params_str, id):
|
|||||||
device = params.get('device').value
|
device = params.get('device').value
|
||||||
try:
|
try:
|
||||||
train_start(weights, savemodel, epoches, img_size, batch_size, device, data_list, id)
|
train_start(weights, savemodel, epoches, img_size, batch_size, device, data_list, id)
|
||||||
except:
|
print("train down!")
|
||||||
error_return(id=id)
|
except Exception as e:
|
||||||
|
print(repr(e))
|
||||||
|
error_return(id=id,data=repr(e))
|
||||||
|
|
||||||
|
|
||||||
# 启动验证程序
|
# 启动验证程序
|
||||||
|
@ -32,6 +32,7 @@ class Report(BaseModel):
|
|||||||
train_mod_savepath: str = Field(..., description="模型保存路径")
|
train_mod_savepath: str = Field(..., description="模型保存路径")
|
||||||
start_time: datetime.date = Field(datetime.datetime.now(), description="开始时间")
|
start_time: datetime.date = Field(datetime.datetime.now(), description="开始时间")
|
||||||
end_time: datetime.date = Field(datetime.datetime.now(), description="结束时间")
|
end_time: datetime.date = Field(datetime.datetime.now(), description="结束时间")
|
||||||
|
alg_code: str = Field(..., description="模型编码")
|
||||||
|
|
||||||
|
|
||||||
class ReportDict(BaseModel):
|
class ReportDict(BaseModel):
|
||||||
|
@ -61,7 +61,7 @@ from app.yolov5.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel,
|
|||||||
smart_resume, torch_distributed_zero_first)
|
smart_resume, torch_distributed_zero_first)
|
||||||
from app.schemas.TrainResult import Report, ProcessValueList
|
from app.schemas.TrainResult import Report, ProcessValueList
|
||||||
from app.controller.AlgorithmController import algorithm_process_value_websocket
|
from app.controller.AlgorithmController import algorithm_process_value_websocket
|
||||||
from app.controller.AlgorithmController import ifKillDict
|
from app.configs import global_var
|
||||||
from app.utils.websocket_tool import manager
|
from app.utils.websocket_tool import manager
|
||||||
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
||||||
RANK = int(os.getenv('RANK', -1))
|
RANK = int(os.getenv('RANK', -1))
|
||||||
@ -304,7 +304,8 @@ def train(hyp, opt, device, data_list,id,callbacks): # hyp is path/to/hyp.yaml
|
|||||||
report = Report(rate_of_progess=0, precision=[process_value_list],
|
report = Report(rate_of_progess=0, precision=[process_value_list],
|
||||||
id=id, sum=epochs, progress=0,
|
id=id, sum=epochs, progress=0,
|
||||||
num_train_img=train_num,
|
num_train_img=train_num,
|
||||||
train_mod_savepath=best)
|
train_mod_savepath=best,
|
||||||
|
alg_code="R-ODY")
|
||||||
|
|
||||||
def kill_return():
|
def kill_return():
|
||||||
"""
|
"""
|
||||||
@ -325,8 +326,9 @@ def train(hyp, opt, device, data_list,id,callbacks): # hyp is path/to/hyp.yaml
|
|||||||
###################结束#######################
|
###################结束#######################
|
||||||
for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
|
for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
|
||||||
#callbacks.run('on_train_epoch_start')
|
#callbacks.run('on_train_epoch_start')
|
||||||
global ifKillDict
|
print("start get global_var")
|
||||||
ifkill = ifKillDict['id']
|
ifkill = global_var.get_value(report.id)
|
||||||
|
print("get global_var down:",ifkill)
|
||||||
if ifkill:
|
if ifkill:
|
||||||
kill_return()
|
kill_return()
|
||||||
break
|
break
|
||||||
@ -350,6 +352,12 @@ def train(hyp, opt, device, data_list,id,callbacks): # hyp is path/to/hyp.yaml
|
|||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
|
for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
|
||||||
#callbacks.run('on_train_batch_start')
|
#callbacks.run('on_train_batch_start')
|
||||||
|
print("start get global_var")
|
||||||
|
ifkill = global_var.get_value(report.id)
|
||||||
|
print("get global_var down:",ifkill)
|
||||||
|
if ifkill:
|
||||||
|
kill_return()
|
||||||
|
break
|
||||||
if targets.shape[0] == 0:
|
if targets.shape[0] == 0:
|
||||||
targets = [[0.00000, 5.00000, 0.97002, 0.24679, 0.05995, 0.05553],
|
targets = [[0.00000, 5.00000, 0.97002, 0.24679, 0.05995, 0.05553],
|
||||||
[0.00000, 7.00000, 0.95097, 0.32007, 0.04188, 0.02549],
|
[0.00000, 7.00000, 0.95097, 0.32007, 0.04188, 0.02549],
|
||||||
|
Loading…
x
Reference in New Issue
Block a user