完成线程启动训练并以websocket的方式进行发送shell
This commit is contained in:
@ -1,6 +1,7 @@
|
||||
from app.model.bussiness_model import ProjectImage, ProjectInfo, ProjectImgLeafer, ProjectImgLabel, ProjectTrain
|
||||
from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoOut
|
||||
from app.model.schemas.project_image_schemas import ProjectImgLeaferLabel, ProjectImgLeaferOut
|
||||
from app.model.schemas.project_train_schemas import ProjectTrainIn
|
||||
from app.model.crud import project_info_crud as pic
|
||||
from app.model.crud import project_image_crud as pimc
|
||||
from app.model.crud import project_label_crud as plc
|
||||
@ -9,12 +10,14 @@ from app.model.crud import project_img_leafer_label_crud as pillc
|
||||
from app.util import os_utils as os
|
||||
from app.util import random_utils as ru
|
||||
from app.config.config_reader import datasets_url, runs_url, images_url, yolo_url
|
||||
from app.websocket.web_socket_server import room_manager
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
from fastapi import UploadFile
|
||||
import yaml
|
||||
import subprocess
|
||||
import asyncio
|
||||
|
||||
|
||||
def add_project(info: ProjectInfoIn, session: Session, user_id: int):
|
||||
@ -119,9 +122,10 @@ def get_img_leafer(image_id: int, session: Session):
|
||||
return img_leafer_out
|
||||
|
||||
|
||||
def run_train_yolo(project_info: ProjectInfoOut, session: Session):
|
||||
def run_train_yolo(project_info: ProjectInfoOut, train_in: ProjectTrainIn, session: Session):
|
||||
"""
|
||||
yolov5执行训练任务
|
||||
:param train_in: 训练参数
|
||||
:param project_info: 项目信息
|
||||
:param session: 数据库session
|
||||
:return:
|
||||
@ -168,29 +172,44 @@ def run_train_yolo(project_info: ProjectInfoOut, session: Session):
|
||||
# 打包完成开始训练,训练前,更改项目的训练状态
|
||||
pic.update_project_status(project_info.id, '1', session)
|
||||
|
||||
# 开始训练
|
||||
# 开始执行异步训练
|
||||
data = yaml_file
|
||||
project = os.file_path(runs_url, project_info.project_no, 'train')
|
||||
project = os.file_path(runs_url, project_info.project_no)
|
||||
name = version_path
|
||||
|
||||
# thread_train = threading.Thread(target=ps.run_commend, args=(data, project, name, train_in.epochs,
|
||||
# train_in.patience, train_in.weights_id,
|
||||
# train_in.project_id, session,))
|
||||
# thread_train.start();
|
||||
return data, project, name
|
||||
|
||||
|
||||
def run_commend(data: str, project: str,
|
||||
name: str, epochs: int,
|
||||
async def run_commend(data: str, project: str,
|
||||
name: str, epochs: int, patience: int, weights: str,
|
||||
project_id: int, session: Session):
|
||||
"""
|
||||
执行训练
|
||||
:param data: 训练数据集
|
||||
:param project: 训练结果的项目目录
|
||||
:param name: 实验名称
|
||||
:param epochs: 训练轮数
|
||||
:param patience: 早停耐心值
|
||||
:param weights: 权重文件
|
||||
:param project_id: 项目id
|
||||
:param session:
|
||||
:return:
|
||||
"""
|
||||
yolo_path = os.file_path(yolo_url, 'train.py')
|
||||
|
||||
yield f"stdout: 模型训练开始,请稍等。。。\n"
|
||||
room = 'train_' + str(project_id)
|
||||
await room_manager.send_to_room(room, f"stdout: 模型训练开始,请稍等。。。\n")
|
||||
commend = ["python", '-u', yolo_path, "--data=" + data, "--project=" + project, "--name=" + name,
|
||||
"--epochs=" + str(epochs), "--batch-size=4", "--exist-ok", "--patience=" + str(patience)]
|
||||
if weights != None and weights != '':
|
||||
train_info = ptc.get_train(weights, session)
|
||||
if train_info != None:
|
||||
commend.append("--weights=" + train_info.best_pt)
|
||||
# 启动子进程
|
||||
with subprocess.Popen(
|
||||
["python", '-u', yolo_path,
|
||||
"--data=" + data,
|
||||
"--project=" + project,
|
||||
"--name=" + name,
|
||||
"--epochs=" + str(epochs),
|
||||
"--batch-size=4",
|
||||
"--exist-ok"],
|
||||
commend,
|
||||
bufsize=1, # bufsize=0时,为不缓存;bufsize=1时,按行缓存;bufsize为其他正整数时,为按照近似该正整数的字节数缓存
|
||||
shell=False,
|
||||
stdout=subprocess.PIPE,
|
||||
@ -202,13 +221,14 @@ def run_commend(data: str, project: str,
|
||||
line = process.stdout.readline()
|
||||
process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死
|
||||
if line != '\n':
|
||||
yield line
|
||||
await room_manager.send_to_room(room, line)
|
||||
|
||||
# 等待进程结束并获取返回码
|
||||
return_code = process.wait()
|
||||
if return_code != 0:
|
||||
pic.update_project_status(project_id, '-1', session)
|
||||
else:
|
||||
await room_manager.send_to_room(room, 'success')
|
||||
pic.update_project_status(project_id, '2', session)
|
||||
# 然后保存版本训练信息
|
||||
train = ProjectTrain()
|
||||
@ -218,6 +238,11 @@ def run_commend(data: str, project: str,
|
||||
last_pt_path = os.file_path(project, name, 'weights', 'last.pt')
|
||||
train.best_pt = bast_pt_path
|
||||
train.last_pt = last_pt_path
|
||||
if weights != None and weights != '':
|
||||
train.weights_id = weights
|
||||
train.weights_name = train_info.train_version
|
||||
train.patience = patience
|
||||
train.epochs = epochs
|
||||
ptc.add_train(train, session)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user