From 7d736c4ac439c944afc3f6fcc7bc949e7b147fdb Mon Sep 17 00:00:00 2001 From: sunyugang Date: Mon, 10 Mar 2025 17:42:56 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E6=88=90=E7=BA=BF=E7=A8=8B=E5=90=AF?= =?UTF-8?q?=E5=8A=A8=E8=AE=AD=E7=BB=83=E5=B9=B6=E4=BB=A5websocket=E7=9A=84?= =?UTF-8?q?=E6=96=B9=E5=BC=8F=E8=BF=9B=E8=A1=8C=E5=8F=91=E9=80=81shell?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/business/project_train_api.py | 41 +++++++++++----- app/api/business/websocket_api.py | 9 ++-- app/application/app.py | 6 +-- app/common/reponse_code.py | 2 +- app/model/bussiness_model.py | 4 ++ app/model/schemas/project_train_schemas.py | 10 ++++ app/service/project_train_service.py | 57 ++++++++++++++++------ app/util/ps_util.py | 3 +- requirements.txt | 1 + 9 files changed, 97 insertions(+), 36 deletions(-) diff --git a/app/api/business/project_train_api.py b/app/api/business/project_train_api.py index 9ebc0b2..d3e1b9f 100644 --- a/app/api/business/project_train_api.py +++ b/app/api/business/project_train_api.py @@ -3,18 +3,20 @@ from app.model.crud import project_label_crud as plc from app.model.crud import project_info_crud as pic from app.model.crud import project_image_crud as pimc from app.model.crud import project_train_crud as ptnc -from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoPager -from app.model.schemas.project_label_schemas import ProjectLabel -from app.model.schemas.project_image_schemas import ProjectImgLeaferLabel, ProjectImagePager from app.model.bussiness_model import ProjectLabel as pl +from app.model.schemas.project_label_schemas import ProjectLabel +from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoPager +from app.model.schemas.project_image_schemas import ProjectImgLeaferLabel, ProjectImagePager +from app.model.schemas.project_train_schemas import ProjectTrainIn from app.common.jwt_check import get_user_id from app.common import reponse_code as rc from app.service import project_train_service as ps from app.db.db_session import get_db +import threading +import asyncio from typing import List from fastapi import APIRouter, Depends, Request, UploadFile, File, Form -from fastapi.responses import StreamingResponse from fastapi.encoders import jsonable_encoder from sqlalchemy.orm import Session @@ -182,7 +184,7 @@ def get_image_list(image: ProjectImagePager, session: Session = Depends(get_db)) :return: """ if image.pagerNum is None and image.pagerSize is None: - image_list = pimc.get_image_list(image.project_id, session) + image_list = pimc.get_image_list(image.project_id, image.img_type, session) result = jsonable_encoder(image_list) return rc.response_success(data=result) else: @@ -217,14 +219,15 @@ def get_img_leafer(image_id: int, session: Session = Depends(get_db)): return rc.response_success(data=img_leafer_out['leafer']) -@project.get("/run_train/{project_id}") -async def run_train(project_id: int, session: Session = Depends(get_db)): +@project.post("/run_train") +async def run_train(train_in: ProjectTrainIn, session: Session = Depends(get_db)): """ 执行项目训练方法 - :param project_id: + :param train_in: :param session: :return: """ + project_id = train_in.project_id project_info = pic.get_project_by_id(project_id, session) if project_info is None: return rc.response_error("项目查询错误") @@ -240,10 +243,24 @@ async def run_train(project_id: int, session: Session = Depends(get_db)): return rc.response_error("请先上传验证图片") if val_img_count < 5: return rc.response_error("验证图片少于5张,请继续上传验证图片") - data, project_name, name = ps.run_train_yolo(project_info, session) - return StreamingResponse( - ps.run_commend(data, project_name, name, 50, project_id, session), - media_type="text/plain") + data, project, name = ps.run_train_yolo(project_info, train_in, session) + thread_train = threading.Thread(target=run_event_loop, args=(data, project, name, train_in, + project_id, session,)) + thread_train.start(); + return rc.response_success(msg="执行成功") + + +def run_event_loop(data: str, project: str, + name: str, train_in: ProjectTrainIn, + project_id: int, session: Session): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + # 运行异步函数 + loop.run_until_complete(ps.run_commend(data, project, name, train_in.epochs, + train_in.patience, train_in.weights_id, + project_id, session)) + # 可选: 关闭循环 + loop.close() @project.get("/get_train_list/{project_id}") diff --git a/app/api/business/websocket_api.py b/app/api/business/websocket_api.py index 006ad0d..740f97f 100644 --- a/app/api/business/websocket_api.py +++ b/app/api/business/websocket_api.py @@ -1,4 +1,5 @@ from fastapi import APIRouter, WebSocket +from starlette.websockets import WebSocketState from app.websocket.web_socket_server import room_manager @@ -19,9 +20,11 @@ async def websocket_room(websocket: WebSocket, room: str): try: while True: data = await websocket.receive_text() - await room_manager.broadcast_to_room(room, data, exclude_websocket=websocket) + room_manager.broadcast_to_room(room, data, exclude_websocket=websocket) except Exception as e: - print(f"连接关闭: {e}") + if websocket.client_state != WebSocketState.DISCONNECTED: + await websocket.close(code=1000) finally: await room_manager.remove_from_room(room, websocket) - await websocket.close() + if websocket.client_state != WebSocketState.DISCONNECTED: + await websocket.close(code=1001) diff --git a/app/application/app.py b/app/application/app.py index 2cd6cf9..4f8c1aa 100644 --- a/app/application/app.py +++ b/app/application/app.py @@ -44,17 +44,17 @@ my_app.include_router(web_socket, prefix="/ws", tags=["websocket管理"]) # fastapi定时执行任务 -async def periodic_task(): +async def serve_info_task(): while True: server_json = get_server_json() await room_manager.send_to_room('server-info', server_json) - await asyncio.sleep(2) # 每 10 秒执行 + await asyncio.sleep(5) # 每 5 秒执行 @my_app.on_event("startup") async def start_periodic_task(): # 在后台启动异步任务 - asyncio.create_task(periodic_task()) + asyncio.create_task(serve_info_task()) diff --git a/app/common/reponse_code.py b/app/common/reponse_code.py index ebaa3ee..37adee5 100644 --- a/app/common/reponse_code.py +++ b/app/common/reponse_code.py @@ -3,7 +3,7 @@ from fastapi import status from app.db.page_util import Pager -def response_code_view(code: int,msg: str) -> Response: +def response_code_view(code: int, msg: str) -> Response: return JSONResponse( status_code=code, content={ diff --git a/app/model/bussiness_model.py b/app/model/bussiness_model.py index 512d2fa..771c302 100644 --- a/app/model/bussiness_model.py +++ b/app/model/bussiness_model.py @@ -81,6 +81,10 @@ class ProjectTrain(DbCommon): __tablename__ = "project_train" project_id: Mapped[int] = mapped_column(Integer, nullable=False) train_version: Mapped[str] = mapped_column(String(32), nullable=False) + weights_id: Mapped[int] = mapped_column(Integer) + weights_name: Mapped[str] = mapped_column(String(32)) + epochs: Mapped[int] = mapped_column(Integer) + patience: Mapped[int] = mapped_column(Integer) best_pt: Mapped[str] = mapped_column(String(255), nullable=False) last_pt: Mapped[str] = mapped_column(String(255), nullable=False) diff --git a/app/model/schemas/project_train_schemas.py b/app/model/schemas/project_train_schemas.py index e2a736a..7cf9678 100644 --- a/app/model/schemas/project_train_schemas.py +++ b/app/model/schemas/project_train_schemas.py @@ -3,10 +3,20 @@ from pydantic import BaseModel, Field from typing import Optional +class ProjectTrainIn(BaseModel): + project_id: Optional[int] = Field(..., description="项目id") + weights_id: Optional[int] = Field(None, description="权重文件") + epochs: Optional[int] = Field(50, description="训练轮数") + patience: Optional[int] = Field(20, description="早停的耐心值") + + class ProjectTrainOut(BaseModel): """项目训练版本信息表""" id: Optional[int] = Field(None, description="训练id") train_version: Optional[str] = Field(None, description="训练版本号") + weights_name: Optional[str] = Field(None, description="权重名称") + epochs: Optional[int] = Field(None, description="训练轮数") + patience: Optional[int] = Field(None, description="早停的耐心值") create_time: Optional[datetime] = Field(None, description="训练时间") class Config: diff --git a/app/service/project_train_service.py b/app/service/project_train_service.py index 9acaf02..7530617 100644 --- a/app/service/project_train_service.py +++ b/app/service/project_train_service.py @@ -1,6 +1,7 @@ from app.model.bussiness_model import ProjectImage, ProjectInfo, ProjectImgLeafer, ProjectImgLabel, ProjectTrain from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoOut from app.model.schemas.project_image_schemas import ProjectImgLeaferLabel, ProjectImgLeaferOut +from app.model.schemas.project_train_schemas import ProjectTrainIn from app.model.crud import project_info_crud as pic from app.model.crud import project_image_crud as pimc from app.model.crud import project_label_crud as plc @@ -9,12 +10,14 @@ from app.model.crud import project_img_leafer_label_crud as pillc from app.util import os_utils as os from app.util import random_utils as ru from app.config.config_reader import datasets_url, runs_url, images_url, yolo_url +from app.websocket.web_socket_server import room_manager from sqlalchemy.orm import Session from typing import List from fastapi import UploadFile import yaml import subprocess +import asyncio def add_project(info: ProjectInfoIn, session: Session, user_id: int): @@ -119,9 +122,10 @@ def get_img_leafer(image_id: int, session: Session): return img_leafer_out -def run_train_yolo(project_info: ProjectInfoOut, session: Session): +def run_train_yolo(project_info: ProjectInfoOut, train_in: ProjectTrainIn, session: Session): """ yolov5执行训练任务 + :param train_in: 训练参数 :param project_info: 项目信息 :param session: 数据库session :return: @@ -168,29 +172,44 @@ def run_train_yolo(project_info: ProjectInfoOut, session: Session): # 打包完成开始训练,训练前,更改项目的训练状态 pic.update_project_status(project_info.id, '1', session) - # 开始训练 + # 开始执行异步训练 data = yaml_file - project = os.file_path(runs_url, project_info.project_no, 'train') + project = os.file_path(runs_url, project_info.project_no) name = version_path - + # thread_train = threading.Thread(target=ps.run_commend, args=(data, project, name, train_in.epochs, + # train_in.patience, train_in.weights_id, + # train_in.project_id, session,)) + # thread_train.start(); return data, project, name -def run_commend(data: str, project: str, - name: str, epochs: int, +async def run_commend(data: str, project: str, + name: str, epochs: int, patience: int, weights: str, project_id: int, session: Session): + """ + 执行训练 + :param data: 训练数据集 + :param project: 训练结果的项目目录 + :param name: 实验名称 + :param epochs: 训练轮数 + :param patience: 早停耐心值 + :param weights: 权重文件 + :param project_id: 项目id + :param session: + :return: + """ yolo_path = os.file_path(yolo_url, 'train.py') - - yield f"stdout: 模型训练开始,请稍等。。。\n" + room = 'train_' + str(project_id) + await room_manager.send_to_room(room, f"stdout: 模型训练开始,请稍等。。。\n") + commend = ["python", '-u', yolo_path, "--data=" + data, "--project=" + project, "--name=" + name, + "--epochs=" + str(epochs), "--batch-size=4", "--exist-ok", "--patience=" + str(patience)] + if weights != None and weights != '': + train_info = ptc.get_train(weights, session) + if train_info != None: + commend.append("--weights=" + train_info.best_pt) # 启动子进程 with subprocess.Popen( - ["python", '-u', yolo_path, - "--data=" + data, - "--project=" + project, - "--name=" + name, - "--epochs=" + str(epochs), - "--batch-size=4", - "--exist-ok"], + commend, bufsize=1, # bufsize=0时,为不缓存;bufsize=1时,按行缓存;bufsize为其他正整数时,为按照近似该正整数的字节数缓存 shell=False, stdout=subprocess.PIPE, @@ -202,13 +221,14 @@ def run_commend(data: str, project: str, line = process.stdout.readline() process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死 if line != '\n': - yield line + await room_manager.send_to_room(room, line) # 等待进程结束并获取返回码 return_code = process.wait() if return_code != 0: pic.update_project_status(project_id, '-1', session) else: + await room_manager.send_to_room(room, 'success') pic.update_project_status(project_id, '2', session) # 然后保存版本训练信息 train = ProjectTrain() @@ -218,6 +238,11 @@ def run_commend(data: str, project: str, last_pt_path = os.file_path(project, name, 'weights', 'last.pt') train.best_pt = bast_pt_path train.last_pt = last_pt_path + if weights != None and weights != '': + train.weights_id = weights + train.weights_name = train_info.train_version + train.patience = patience + train.epochs = epochs ptc.add_train(train, session) diff --git a/app/util/ps_util.py b/app/util/ps_util.py index 64a2651..be17760 100644 --- a/app/util/ps_util.py +++ b/app/util/ps_util.py @@ -2,6 +2,7 @@ import psutil import platform import json from datetime import datetime +# 获取服务器运行状态 def get_server_info(): @@ -74,4 +75,4 @@ def get_server_info(): def get_server_json(): server_info = get_server_info() - return json.dumps(server_info, indent=2, ensure_ascii=False) \ No newline at end of file + return json.dumps(server_info, indent=2, ensure_ascii=False) diff --git a/requirements.txt b/requirements.txt index dca5744..42452b0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,6 +10,7 @@ python-multipart==0.0.5 redis~=4.1.4 SQLAlchemy~=2.0.34 uvicorn~=0.17.5 +uvicorn[standard] loguru~=0.6.0 xlrd~=2.0.1 bcrypt==3.2.0