完成线程启动训练并以websocket的方式进行发送shell

This commit is contained in:
2025-03-10 17:42:56 +08:00
parent b4b1085403
commit 7d736c4ac4
9 changed files with 97 additions and 36 deletions

View File

@ -3,18 +3,20 @@ from app.model.crud import project_label_crud as plc
from app.model.crud import project_info_crud as pic from app.model.crud import project_info_crud as pic
from app.model.crud import project_image_crud as pimc from app.model.crud import project_image_crud as pimc
from app.model.crud import project_train_crud as ptnc from app.model.crud import project_train_crud as ptnc
from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoPager
from app.model.schemas.project_label_schemas import ProjectLabel
from app.model.schemas.project_image_schemas import ProjectImgLeaferLabel, ProjectImagePager
from app.model.bussiness_model import ProjectLabel as pl from app.model.bussiness_model import ProjectLabel as pl
from app.model.schemas.project_label_schemas import ProjectLabel
from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoPager
from app.model.schemas.project_image_schemas import ProjectImgLeaferLabel, ProjectImagePager
from app.model.schemas.project_train_schemas import ProjectTrainIn
from app.common.jwt_check import get_user_id from app.common.jwt_check import get_user_id
from app.common import reponse_code as rc from app.common import reponse_code as rc
from app.service import project_train_service as ps from app.service import project_train_service as ps
from app.db.db_session import get_db from app.db.db_session import get_db
import threading
import asyncio
from typing import List from typing import List
from fastapi import APIRouter, Depends, Request, UploadFile, File, Form from fastapi import APIRouter, Depends, Request, UploadFile, File, Form
from fastapi.responses import StreamingResponse
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -182,7 +184,7 @@ def get_image_list(image: ProjectImagePager, session: Session = Depends(get_db))
:return: :return:
""" """
if image.pagerNum is None and image.pagerSize is None: if image.pagerNum is None and image.pagerSize is None:
image_list = pimc.get_image_list(image.project_id, session) image_list = pimc.get_image_list(image.project_id, image.img_type, session)
result = jsonable_encoder(image_list) result = jsonable_encoder(image_list)
return rc.response_success(data=result) return rc.response_success(data=result)
else: else:
@ -217,14 +219,15 @@ def get_img_leafer(image_id: int, session: Session = Depends(get_db)):
return rc.response_success(data=img_leafer_out['leafer']) return rc.response_success(data=img_leafer_out['leafer'])
@project.get("/run_train/{project_id}") @project.post("/run_train")
async def run_train(project_id: int, session: Session = Depends(get_db)): async def run_train(train_in: ProjectTrainIn, session: Session = Depends(get_db)):
""" """
执行项目训练方法 执行项目训练方法
:param project_id: :param train_in:
:param session: :param session:
:return: :return:
""" """
project_id = train_in.project_id
project_info = pic.get_project_by_id(project_id, session) project_info = pic.get_project_by_id(project_id, session)
if project_info is None: if project_info is None:
return rc.response_error("项目查询错误") return rc.response_error("项目查询错误")
@ -240,10 +243,24 @@ async def run_train(project_id: int, session: Session = Depends(get_db)):
return rc.response_error("请先上传验证图片") return rc.response_error("请先上传验证图片")
if val_img_count < 5: if val_img_count < 5:
return rc.response_error("验证图片少于5张请继续上传验证图片") return rc.response_error("验证图片少于5张请继续上传验证图片")
data, project_name, name = ps.run_train_yolo(project_info, session) data, project, name = ps.run_train_yolo(project_info, train_in, session)
return StreamingResponse( thread_train = threading.Thread(target=run_event_loop, args=(data, project, name, train_in,
ps.run_commend(data, project_name, name, 50, project_id, session), project_id, session,))
media_type="text/plain") thread_train.start();
return rc.response_success(msg="执行成功")
def run_event_loop(data: str, project: str,
name: str, train_in: ProjectTrainIn,
project_id: int, session: Session):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 运行异步函数
loop.run_until_complete(ps.run_commend(data, project, name, train_in.epochs,
train_in.patience, train_in.weights_id,
project_id, session))
# 可选: 关闭循环
loop.close()
@project.get("/get_train_list/{project_id}") @project.get("/get_train_list/{project_id}")

View File

@ -1,4 +1,5 @@
from fastapi import APIRouter, WebSocket from fastapi import APIRouter, WebSocket
from starlette.websockets import WebSocketState
from app.websocket.web_socket_server import room_manager from app.websocket.web_socket_server import room_manager
@ -19,9 +20,11 @@ async def websocket_room(websocket: WebSocket, room: str):
try: try:
while True: while True:
data = await websocket.receive_text() data = await websocket.receive_text()
await room_manager.broadcast_to_room(room, data, exclude_websocket=websocket) room_manager.broadcast_to_room(room, data, exclude_websocket=websocket)
except Exception as e: except Exception as e:
print(f"连接关闭: {e}") if websocket.client_state != WebSocketState.DISCONNECTED:
await websocket.close(code=1000)
finally: finally:
await room_manager.remove_from_room(room, websocket) await room_manager.remove_from_room(room, websocket)
await websocket.close() if websocket.client_state != WebSocketState.DISCONNECTED:
await websocket.close(code=1001)

View File

@ -44,17 +44,17 @@ my_app.include_router(web_socket, prefix="/ws", tags=["websocket管理"])
# fastapi定时执行任务 # fastapi定时执行任务
async def periodic_task(): async def serve_info_task():
while True: while True:
server_json = get_server_json() server_json = get_server_json()
await room_manager.send_to_room('server-info', server_json) await room_manager.send_to_room('server-info', server_json)
await asyncio.sleep(2) # 每 10 秒执行 await asyncio.sleep(5) # 每 5 秒执行
@my_app.on_event("startup") @my_app.on_event("startup")
async def start_periodic_task(): async def start_periodic_task():
# 在后台启动异步任务 # 在后台启动异步任务
asyncio.create_task(periodic_task()) asyncio.create_task(serve_info_task())

View File

@ -3,7 +3,7 @@ from fastapi import status
from app.db.page_util import Pager from app.db.page_util import Pager
def response_code_view(code: int,msg: str) -> Response: def response_code_view(code: int, msg: str) -> Response:
return JSONResponse( return JSONResponse(
status_code=code, status_code=code,
content={ content={

View File

@ -81,6 +81,10 @@ class ProjectTrain(DbCommon):
__tablename__ = "project_train" __tablename__ = "project_train"
project_id: Mapped[int] = mapped_column(Integer, nullable=False) project_id: Mapped[int] = mapped_column(Integer, nullable=False)
train_version: Mapped[str] = mapped_column(String(32), nullable=False) train_version: Mapped[str] = mapped_column(String(32), nullable=False)
weights_id: Mapped[int] = mapped_column(Integer)
weights_name: Mapped[str] = mapped_column(String(32))
epochs: Mapped[int] = mapped_column(Integer)
patience: Mapped[int] = mapped_column(Integer)
best_pt: Mapped[str] = mapped_column(String(255), nullable=False) best_pt: Mapped[str] = mapped_column(String(255), nullable=False)
last_pt: Mapped[str] = mapped_column(String(255), nullable=False) last_pt: Mapped[str] = mapped_column(String(255), nullable=False)

View File

@ -3,10 +3,20 @@ from pydantic import BaseModel, Field
from typing import Optional from typing import Optional
class ProjectTrainIn(BaseModel):
project_id: Optional[int] = Field(..., description="项目id")
weights_id: Optional[int] = Field(None, description="权重文件")
epochs: Optional[int] = Field(50, description="训练轮数")
patience: Optional[int] = Field(20, description="早停的耐心值")
class ProjectTrainOut(BaseModel): class ProjectTrainOut(BaseModel):
"""项目训练版本信息表""" """项目训练版本信息表"""
id: Optional[int] = Field(None, description="训练id") id: Optional[int] = Field(None, description="训练id")
train_version: Optional[str] = Field(None, description="训练版本号") train_version: Optional[str] = Field(None, description="训练版本号")
weights_name: Optional[str] = Field(None, description="权重名称")
epochs: Optional[int] = Field(None, description="训练轮数")
patience: Optional[int] = Field(None, description="早停的耐心值")
create_time: Optional[datetime] = Field(None, description="训练时间") create_time: Optional[datetime] = Field(None, description="训练时间")
class Config: class Config:

View File

@ -1,6 +1,7 @@
from app.model.bussiness_model import ProjectImage, ProjectInfo, ProjectImgLeafer, ProjectImgLabel, ProjectTrain from app.model.bussiness_model import ProjectImage, ProjectInfo, ProjectImgLeafer, ProjectImgLabel, ProjectTrain
from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoOut from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoOut
from app.model.schemas.project_image_schemas import ProjectImgLeaferLabel, ProjectImgLeaferOut from app.model.schemas.project_image_schemas import ProjectImgLeaferLabel, ProjectImgLeaferOut
from app.model.schemas.project_train_schemas import ProjectTrainIn
from app.model.crud import project_info_crud as pic from app.model.crud import project_info_crud as pic
from app.model.crud import project_image_crud as pimc from app.model.crud import project_image_crud as pimc
from app.model.crud import project_label_crud as plc from app.model.crud import project_label_crud as plc
@ -9,12 +10,14 @@ from app.model.crud import project_img_leafer_label_crud as pillc
from app.util import os_utils as os from app.util import os_utils as os
from app.util import random_utils as ru from app.util import random_utils as ru
from app.config.config_reader import datasets_url, runs_url, images_url, yolo_url from app.config.config_reader import datasets_url, runs_url, images_url, yolo_url
from app.websocket.web_socket_server import room_manager
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from typing import List from typing import List
from fastapi import UploadFile from fastapi import UploadFile
import yaml import yaml
import subprocess import subprocess
import asyncio
def add_project(info: ProjectInfoIn, session: Session, user_id: int): def add_project(info: ProjectInfoIn, session: Session, user_id: int):
@ -119,9 +122,10 @@ def get_img_leafer(image_id: int, session: Session):
return img_leafer_out return img_leafer_out
def run_train_yolo(project_info: ProjectInfoOut, session: Session): def run_train_yolo(project_info: ProjectInfoOut, train_in: ProjectTrainIn, session: Session):
""" """
yolov5执行训练任务 yolov5执行训练任务
:param train_in: 训练参数
:param project_info: 项目信息 :param project_info: 项目信息
:param session: 数据库session :param session: 数据库session
:return: :return:
@ -168,29 +172,44 @@ def run_train_yolo(project_info: ProjectInfoOut, session: Session):
# 打包完成开始训练,训练前,更改项目的训练状态 # 打包完成开始训练,训练前,更改项目的训练状态
pic.update_project_status(project_info.id, '1', session) pic.update_project_status(project_info.id, '1', session)
# 开始训练 # 开始执行异步训练
data = yaml_file data = yaml_file
project = os.file_path(runs_url, project_info.project_no, 'train') project = os.file_path(runs_url, project_info.project_no)
name = version_path name = version_path
# thread_train = threading.Thread(target=ps.run_commend, args=(data, project, name, train_in.epochs,
# train_in.patience, train_in.weights_id,
# train_in.project_id, session,))
# thread_train.start();
return data, project, name return data, project, name
def run_commend(data: str, project: str, async def run_commend(data: str, project: str,
name: str, epochs: int, name: str, epochs: int, patience: int, weights: str,
project_id: int, session: Session): project_id: int, session: Session):
"""
执行训练
:param data: 训练数据集
:param project: 训练结果的项目目录
:param name: 实验名称
:param epochs: 训练轮数
:param patience: 早停耐心值
:param weights: 权重文件
:param project_id: 项目id
:param session:
:return:
"""
yolo_path = os.file_path(yolo_url, 'train.py') yolo_path = os.file_path(yolo_url, 'train.py')
room = 'train_' + str(project_id)
yield f"stdout: 模型训练开始,请稍等。。。\n" await room_manager.send_to_room(room, f"stdout: 模型训练开始,请稍等。。。\n")
commend = ["python", '-u', yolo_path, "--data=" + data, "--project=" + project, "--name=" + name,
"--epochs=" + str(epochs), "--batch-size=4", "--exist-ok", "--patience=" + str(patience)]
if weights != None and weights != '':
train_info = ptc.get_train(weights, session)
if train_info != None:
commend.append("--weights=" + train_info.best_pt)
# 启动子进程 # 启动子进程
with subprocess.Popen( with subprocess.Popen(
["python", '-u', yolo_path, commend,
"--data=" + data,
"--project=" + project,
"--name=" + name,
"--epochs=" + str(epochs),
"--batch-size=4",
"--exist-ok"],
bufsize=1, # bufsize=0时为不缓存bufsize=1时按行缓存bufsize为其他正整数时为按照近似该正整数的字节数缓存 bufsize=1, # bufsize=0时为不缓存bufsize=1时按行缓存bufsize为其他正整数时为按照近似该正整数的字节数缓存
shell=False, shell=False,
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
@ -202,13 +221,14 @@ def run_commend(data: str, project: str,
line = process.stdout.readline() line = process.stdout.readline()
process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死 process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死
if line != '\n': if line != '\n':
yield line await room_manager.send_to_room(room, line)
# 等待进程结束并获取返回码 # 等待进程结束并获取返回码
return_code = process.wait() return_code = process.wait()
if return_code != 0: if return_code != 0:
pic.update_project_status(project_id, '-1', session) pic.update_project_status(project_id, '-1', session)
else: else:
await room_manager.send_to_room(room, 'success')
pic.update_project_status(project_id, '2', session) pic.update_project_status(project_id, '2', session)
# 然后保存版本训练信息 # 然后保存版本训练信息
train = ProjectTrain() train = ProjectTrain()
@ -218,6 +238,11 @@ def run_commend(data: str, project: str,
last_pt_path = os.file_path(project, name, 'weights', 'last.pt') last_pt_path = os.file_path(project, name, 'weights', 'last.pt')
train.best_pt = bast_pt_path train.best_pt = bast_pt_path
train.last_pt = last_pt_path train.last_pt = last_pt_path
if weights != None and weights != '':
train.weights_id = weights
train.weights_name = train_info.train_version
train.patience = patience
train.epochs = epochs
ptc.add_train(train, session) ptc.add_train(train, session)

View File

@ -2,6 +2,7 @@ import psutil
import platform import platform
import json import json
from datetime import datetime from datetime import datetime
# 获取服务器运行状态
def get_server_info(): def get_server_info():

View File

@ -10,6 +10,7 @@ python-multipart==0.0.5
redis~=4.1.4 redis~=4.1.4
SQLAlchemy~=2.0.34 SQLAlchemy~=2.0.34
uvicorn~=0.17.5 uvicorn~=0.17.5
uvicorn[standard]
loguru~=0.6.0 loguru~=0.6.0
xlrd~=2.0.1 xlrd~=2.0.1
bcrypt==3.2.0 bcrypt==3.2.0