yolo中断训练
This commit is contained in:
@ -61,6 +61,8 @@ from app.yolov5.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel,
|
||||
smart_resume, torch_distributed_zero_first)
|
||||
from app.schemas.TrainResult import Report, ProcessValueList
|
||||
from app.controller.AlgorithmController import algorithm_process_value_websocket
|
||||
from app.controller.AlgorithmController import ifKillDict
|
||||
from app.utils.websocket_tool import manager
|
||||
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
||||
RANK = int(os.getenv('RANK', -1))
|
||||
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
|
||||
@ -304,6 +306,15 @@ def train(hyp, opt, device, data_list,id,callbacks): # hyp is path/to/hyp.yaml
|
||||
num_train_img=train_num,
|
||||
train_mod_savepath=best)
|
||||
|
||||
def kill_return():
|
||||
"""
|
||||
算法中断,返回
|
||||
"""
|
||||
id = report.id
|
||||
data = report.dict()
|
||||
data_res = {'code': 1, "type": 'kill', 'msg': 'fail', 'data': data}
|
||||
manager.send_message_proj_json(message=data_res, id=id)
|
||||
|
||||
@algorithm_process_value_websocket()
|
||||
def report_cellback(i, num_epochs, reportAccu):
|
||||
report.rate_of_progess = ((i + 1) / num_epochs) * 100
|
||||
@ -314,6 +325,11 @@ def train(hyp, opt, device, data_list,id,callbacks): # hyp is path/to/hyp.yaml
|
||||
###################结束#######################
|
||||
for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
|
||||
#callbacks.run('on_train_epoch_start')
|
||||
global ifKillDict
|
||||
ifkill = ifKillDict['id']
|
||||
if ifkill:
|
||||
kill_return()
|
||||
break
|
||||
model.train()
|
||||
|
||||
# Update image weights (optional, single-GPU only)
|
||||
|
Reference in New Issue
Block a user