yolo中断训练功能添加
This commit is contained in:
@ -61,7 +61,7 @@ from app.yolov5.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel,
|
||||
smart_resume, torch_distributed_zero_first)
|
||||
from app.schemas.TrainResult import Report, ProcessValueList
|
||||
from app.controller.AlgorithmController import algorithm_process_value_websocket
|
||||
from app.controller.AlgorithmController import ifKillDict
|
||||
from app.configs import global_var
|
||||
from app.utils.websocket_tool import manager
|
||||
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
||||
RANK = int(os.getenv('RANK', -1))
|
||||
@ -304,7 +304,8 @@ def train(hyp, opt, device, data_list,id,callbacks): # hyp is path/to/hyp.yaml
|
||||
report = Report(rate_of_progess=0, precision=[process_value_list],
|
||||
id=id, sum=epochs, progress=0,
|
||||
num_train_img=train_num,
|
||||
train_mod_savepath=best)
|
||||
train_mod_savepath=best,
|
||||
alg_code="R-ODY")
|
||||
|
||||
def kill_return():
|
||||
"""
|
||||
@ -325,8 +326,9 @@ def train(hyp, opt, device, data_list,id,callbacks): # hyp is path/to/hyp.yaml
|
||||
###################结束#######################
|
||||
for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
|
||||
#callbacks.run('on_train_epoch_start')
|
||||
global ifKillDict
|
||||
ifkill = ifKillDict['id']
|
||||
print("start get global_var")
|
||||
ifkill = global_var.get_value(report.id)
|
||||
print("get global_var down:",ifkill)
|
||||
if ifkill:
|
||||
kill_return()
|
||||
break
|
||||
@ -350,6 +352,12 @@ def train(hyp, opt, device, data_list,id,callbacks): # hyp is path/to/hyp.yaml
|
||||
optimizer.zero_grad()
|
||||
for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
|
||||
#callbacks.run('on_train_batch_start')
|
||||
print("start get global_var")
|
||||
ifkill = global_var.get_value(report.id)
|
||||
print("get global_var down:",ifkill)
|
||||
if ifkill:
|
||||
kill_return()
|
||||
break
|
||||
if targets.shape[0] == 0:
|
||||
targets = [[0.00000, 5.00000, 0.97002, 0.24679, 0.05995, 0.05553],
|
||||
[0.00000, 7.00000, 0.95097, 0.32007, 0.04188, 0.02549],
|
||||
|
Reference in New Issue
Block a user