完成线程启动训练并以websocket的方式进行发送shell

This commit is contained in:
2025-03-10 17:42:56 +08:00
parent b4b1085403
commit 7d736c4ac4
9 changed files with 97 additions and 36 deletions

View File

@ -1,6 +1,7 @@
from app.model.bussiness_model import ProjectImage, ProjectInfo, ProjectImgLeafer, ProjectImgLabel, ProjectTrain
from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoOut
from app.model.schemas.project_image_schemas import ProjectImgLeaferLabel, ProjectImgLeaferOut
from app.model.schemas.project_train_schemas import ProjectTrainIn
from app.model.crud import project_info_crud as pic
from app.model.crud import project_image_crud as pimc
from app.model.crud import project_label_crud as plc
@ -9,12 +10,14 @@ from app.model.crud import project_img_leafer_label_crud as pillc
from app.util import os_utils as os
from app.util import random_utils as ru
from app.config.config_reader import datasets_url, runs_url, images_url, yolo_url
from app.websocket.web_socket_server import room_manager
from sqlalchemy.orm import Session
from typing import List
from fastapi import UploadFile
import yaml
import subprocess
import asyncio
def add_project(info: ProjectInfoIn, session: Session, user_id: int):
@ -119,9 +122,10 @@ def get_img_leafer(image_id: int, session: Session):
return img_leafer_out
def run_train_yolo(project_info: ProjectInfoOut, session: Session):
def run_train_yolo(project_info: ProjectInfoOut, train_in: ProjectTrainIn, session: Session):
"""
yolov5执行训练任务
:param train_in: 训练参数
:param project_info: 项目信息
:param session: 数据库session
:return:
@ -168,29 +172,44 @@ def run_train_yolo(project_info: ProjectInfoOut, session: Session):
# 打包完成开始训练,训练前,更改项目的训练状态
pic.update_project_status(project_info.id, '1', session)
# 开始训练
# 开始执行异步训练
data = yaml_file
project = os.file_path(runs_url, project_info.project_no, 'train')
project = os.file_path(runs_url, project_info.project_no)
name = version_path
# thread_train = threading.Thread(target=ps.run_commend, args=(data, project, name, train_in.epochs,
# train_in.patience, train_in.weights_id,
# train_in.project_id, session,))
# thread_train.start();
return data, project, name
def run_commend(data: str, project: str,
name: str, epochs: int,
async def run_commend(data: str, project: str,
name: str, epochs: int, patience: int, weights: str,
project_id: int, session: Session):
"""
执行训练
:param data: 训练数据集
:param project: 训练结果的项目目录
:param name: 实验名称
:param epochs: 训练轮数
:param patience: 早停耐心值
:param weights: 权重文件
:param project_id: 项目id
:param session:
:return:
"""
yolo_path = os.file_path(yolo_url, 'train.py')
yield f"stdout: 模型训练开始,请稍等。。。\n"
room = 'train_' + str(project_id)
await room_manager.send_to_room(room, f"stdout: 模型训练开始,请稍等。。。\n")
commend = ["python", '-u', yolo_path, "--data=" + data, "--project=" + project, "--name=" + name,
"--epochs=" + str(epochs), "--batch-size=4", "--exist-ok", "--patience=" + str(patience)]
if weights != None and weights != '':
train_info = ptc.get_train(weights, session)
if train_info != None:
commend.append("--weights=" + train_info.best_pt)
# 启动子进程
with subprocess.Popen(
["python", '-u', yolo_path,
"--data=" + data,
"--project=" + project,
"--name=" + name,
"--epochs=" + str(epochs),
"--batch-size=4",
"--exist-ok"],
commend,
bufsize=1, # bufsize=0时为不缓存bufsize=1时按行缓存bufsize为其他正整数时为按照近似该正整数的字节数缓存
shell=False,
stdout=subprocess.PIPE,
@ -202,13 +221,14 @@ def run_commend(data: str, project: str,
line = process.stdout.readline()
process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死
if line != '\n':
yield line
await room_manager.send_to_room(room, line)
# 等待进程结束并获取返回码
return_code = process.wait()
if return_code != 0:
pic.update_project_status(project_id, '-1', session)
else:
await room_manager.send_to_room(room, 'success')
pic.update_project_status(project_id, '2', session)
# 然后保存版本训练信息
train = ProjectTrain()
@ -218,6 +238,11 @@ def run_commend(data: str, project: str,
last_pt_path = os.file_path(project, name, 'weights', 'last.pt')
train.best_pt = bast_pt_path
train.last_pt = last_pt_path
if weights != None and weights != '':
train.weights_id = weights
train.weights_name = train_info.train_version
train.patience = patience
train.epochs = epochs
ptc.add_train(train, session)