完成线程启动训练并以websocket的方式进行发送shell

This commit is contained in:
2025-03-10 17:42:56 +08:00
parent b4b1085403
commit 7d736c4ac4
9 changed files with 97 additions and 36 deletions

View File

@ -3,18 +3,20 @@ from app.model.crud import project_label_crud as plc
from app.model.crud import project_info_crud as pic
from app.model.crud import project_image_crud as pimc
from app.model.crud import project_train_crud as ptnc
from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoPager
from app.model.schemas.project_label_schemas import ProjectLabel
from app.model.schemas.project_image_schemas import ProjectImgLeaferLabel, ProjectImagePager
from app.model.bussiness_model import ProjectLabel as pl
from app.model.schemas.project_label_schemas import ProjectLabel
from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoPager
from app.model.schemas.project_image_schemas import ProjectImgLeaferLabel, ProjectImagePager
from app.model.schemas.project_train_schemas import ProjectTrainIn
from app.common.jwt_check import get_user_id
from app.common import reponse_code as rc
from app.service import project_train_service as ps
from app.db.db_session import get_db
import threading
import asyncio
from typing import List
from fastapi import APIRouter, Depends, Request, UploadFile, File, Form
from fastapi.responses import StreamingResponse
from fastapi.encoders import jsonable_encoder
from sqlalchemy.orm import Session
@ -182,7 +184,7 @@ def get_image_list(image: ProjectImagePager, session: Session = Depends(get_db))
:return:
"""
if image.pagerNum is None and image.pagerSize is None:
image_list = pimc.get_image_list(image.project_id, session)
image_list = pimc.get_image_list(image.project_id, image.img_type, session)
result = jsonable_encoder(image_list)
return rc.response_success(data=result)
else:
@ -217,14 +219,15 @@ def get_img_leafer(image_id: int, session: Session = Depends(get_db)):
return rc.response_success(data=img_leafer_out['leafer'])
@project.get("/run_train/{project_id}")
async def run_train(project_id: int, session: Session = Depends(get_db)):
@project.post("/run_train")
async def run_train(train_in: ProjectTrainIn, session: Session = Depends(get_db)):
"""
执行项目训练方法
:param project_id:
:param train_in:
:param session:
:return:
"""
project_id = train_in.project_id
project_info = pic.get_project_by_id(project_id, session)
if project_info is None:
return rc.response_error("项目查询错误")
@ -240,10 +243,24 @@ async def run_train(project_id: int, session: Session = Depends(get_db)):
return rc.response_error("请先上传验证图片")
if val_img_count < 5:
return rc.response_error("验证图片少于5张请继续上传验证图片")
data, project_name, name = ps.run_train_yolo(project_info, session)
return StreamingResponse(
ps.run_commend(data, project_name, name, 50, project_id, session),
media_type="text/plain")
data, project, name = ps.run_train_yolo(project_info, train_in, session)
thread_train = threading.Thread(target=run_event_loop, args=(data, project, name, train_in,
project_id, session,))
thread_train.start();
return rc.response_success(msg="执行成功")
def run_event_loop(data: str, project: str,
name: str, train_in: ProjectTrainIn,
project_id: int, session: Session):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 运行异步函数
loop.run_until_complete(ps.run_commend(data, project, name, train_in.epochs,
train_in.patience, train_in.weights_id,
project_id, session))
# 可选: 关闭循环
loop.close()
@project.get("/get_train_list/{project_id}")

View File

@ -1,4 +1,5 @@
from fastapi import APIRouter, WebSocket
from starlette.websockets import WebSocketState
from app.websocket.web_socket_server import room_manager
@ -19,9 +20,11 @@ async def websocket_room(websocket: WebSocket, room: str):
try:
while True:
data = await websocket.receive_text()
await room_manager.broadcast_to_room(room, data, exclude_websocket=websocket)
room_manager.broadcast_to_room(room, data, exclude_websocket=websocket)
except Exception as e:
print(f"连接关闭: {e}")
if websocket.client_state != WebSocketState.DISCONNECTED:
await websocket.close(code=1000)
finally:
await room_manager.remove_from_room(room, websocket)
await websocket.close()
if websocket.client_state != WebSocketState.DISCONNECTED:
await websocket.close(code=1001)

View File

@ -44,17 +44,17 @@ my_app.include_router(web_socket, prefix="/ws", tags=["websocket管理"])
# fastapi定时执行任务
async def periodic_task():
async def serve_info_task():
while True:
server_json = get_server_json()
await room_manager.send_to_room('server-info', server_json)
await asyncio.sleep(2) # 每 10 秒执行
await asyncio.sleep(5) # 每 5 秒执行
@my_app.on_event("startup")
async def start_periodic_task():
# 在后台启动异步任务
asyncio.create_task(periodic_task())
asyncio.create_task(serve_info_task())

View File

@ -81,6 +81,10 @@ class ProjectTrain(DbCommon):
__tablename__ = "project_train"
project_id: Mapped[int] = mapped_column(Integer, nullable=False)
train_version: Mapped[str] = mapped_column(String(32), nullable=False)
weights_id: Mapped[int] = mapped_column(Integer)
weights_name: Mapped[str] = mapped_column(String(32))
epochs: Mapped[int] = mapped_column(Integer)
patience: Mapped[int] = mapped_column(Integer)
best_pt: Mapped[str] = mapped_column(String(255), nullable=False)
last_pt: Mapped[str] = mapped_column(String(255), nullable=False)

View File

@ -3,10 +3,20 @@ from pydantic import BaseModel, Field
from typing import Optional
class ProjectTrainIn(BaseModel):
project_id: Optional[int] = Field(..., description="项目id")
weights_id: Optional[int] = Field(None, description="权重文件")
epochs: Optional[int] = Field(50, description="训练轮数")
patience: Optional[int] = Field(20, description="早停的耐心值")
class ProjectTrainOut(BaseModel):
"""项目训练版本信息表"""
id: Optional[int] = Field(None, description="训练id")
train_version: Optional[str] = Field(None, description="训练版本号")
weights_name: Optional[str] = Field(None, description="权重名称")
epochs: Optional[int] = Field(None, description="训练轮数")
patience: Optional[int] = Field(None, description="早停的耐心值")
create_time: Optional[datetime] = Field(None, description="训练时间")
class Config:

View File

@ -1,6 +1,7 @@
from app.model.bussiness_model import ProjectImage, ProjectInfo, ProjectImgLeafer, ProjectImgLabel, ProjectTrain
from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoOut
from app.model.schemas.project_image_schemas import ProjectImgLeaferLabel, ProjectImgLeaferOut
from app.model.schemas.project_train_schemas import ProjectTrainIn
from app.model.crud import project_info_crud as pic
from app.model.crud import project_image_crud as pimc
from app.model.crud import project_label_crud as plc
@ -9,12 +10,14 @@ from app.model.crud import project_img_leafer_label_crud as pillc
from app.util import os_utils as os
from app.util import random_utils as ru
from app.config.config_reader import datasets_url, runs_url, images_url, yolo_url
from app.websocket.web_socket_server import room_manager
from sqlalchemy.orm import Session
from typing import List
from fastapi import UploadFile
import yaml
import subprocess
import asyncio
def add_project(info: ProjectInfoIn, session: Session, user_id: int):
@ -119,9 +122,10 @@ def get_img_leafer(image_id: int, session: Session):
return img_leafer_out
def run_train_yolo(project_info: ProjectInfoOut, session: Session):
def run_train_yolo(project_info: ProjectInfoOut, train_in: ProjectTrainIn, session: Session):
"""
yolov5执行训练任务
:param train_in: 训练参数
:param project_info: 项目信息
:param session: 数据库session
:return:
@ -168,29 +172,44 @@ def run_train_yolo(project_info: ProjectInfoOut, session: Session):
# 打包完成开始训练,训练前,更改项目的训练状态
pic.update_project_status(project_info.id, '1', session)
# 开始训练
# 开始执行异步训练
data = yaml_file
project = os.file_path(runs_url, project_info.project_no, 'train')
project = os.file_path(runs_url, project_info.project_no)
name = version_path
# thread_train = threading.Thread(target=ps.run_commend, args=(data, project, name, train_in.epochs,
# train_in.patience, train_in.weights_id,
# train_in.project_id, session,))
# thread_train.start();
return data, project, name
def run_commend(data: str, project: str,
name: str, epochs: int,
async def run_commend(data: str, project: str,
name: str, epochs: int, patience: int, weights: str,
project_id: int, session: Session):
"""
执行训练
:param data: 训练数据集
:param project: 训练结果的项目目录
:param name: 实验名称
:param epochs: 训练轮数
:param patience: 早停耐心值
:param weights: 权重文件
:param project_id: 项目id
:param session:
:return:
"""
yolo_path = os.file_path(yolo_url, 'train.py')
yield f"stdout: 模型训练开始,请稍等。。。\n"
room = 'train_' + str(project_id)
await room_manager.send_to_room(room, f"stdout: 模型训练开始,请稍等。。。\n")
commend = ["python", '-u', yolo_path, "--data=" + data, "--project=" + project, "--name=" + name,
"--epochs=" + str(epochs), "--batch-size=4", "--exist-ok", "--patience=" + str(patience)]
if weights != None and weights != '':
train_info = ptc.get_train(weights, session)
if train_info != None:
commend.append("--weights=" + train_info.best_pt)
# 启动子进程
with subprocess.Popen(
["python", '-u', yolo_path,
"--data=" + data,
"--project=" + project,
"--name=" + name,
"--epochs=" + str(epochs),
"--batch-size=4",
"--exist-ok"],
commend,
bufsize=1, # bufsize=0时为不缓存bufsize=1时按行缓存bufsize为其他正整数时为按照近似该正整数的字节数缓存
shell=False,
stdout=subprocess.PIPE,
@ -202,13 +221,14 @@ def run_commend(data: str, project: str,
line = process.stdout.readline()
process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死
if line != '\n':
yield line
await room_manager.send_to_room(room, line)
# 等待进程结束并获取返回码
return_code = process.wait()
if return_code != 0:
pic.update_project_status(project_id, '-1', session)
else:
await room_manager.send_to_room(room, 'success')
pic.update_project_status(project_id, '2', session)
# 然后保存版本训练信息
train = ProjectTrain()
@ -218,6 +238,11 @@ def run_commend(data: str, project: str,
last_pt_path = os.file_path(project, name, 'weights', 'last.pt')
train.best_pt = bast_pt_path
train.last_pt = last_pt_path
if weights != None and weights != '':
train.weights_id = weights
train.weights_name = train_info.train_version
train.patience = patience
train.epochs = epochs
ptc.add_train(train, session)

View File

@ -2,6 +2,7 @@ import psutil
import platform
import json
from datetime import datetime
# 获取服务器运行状态
def get_server_info():

View File

@ -10,6 +10,7 @@ python-multipart==0.0.5
redis~=4.1.4
SQLAlchemy~=2.0.34
uvicorn~=0.17.5
uvicorn[standard]
loguru~=0.6.0
xlrd~=2.0.1
bcrypt==3.2.0