完成线程启动训练并以websocket的方式进行发送shell
This commit is contained in:
@ -3,18 +3,20 @@ from app.model.crud import project_label_crud as plc
|
||||
from app.model.crud import project_info_crud as pic
|
||||
from app.model.crud import project_image_crud as pimc
|
||||
from app.model.crud import project_train_crud as ptnc
|
||||
from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoPager
|
||||
from app.model.schemas.project_label_schemas import ProjectLabel
|
||||
from app.model.schemas.project_image_schemas import ProjectImgLeaferLabel, ProjectImagePager
|
||||
from app.model.bussiness_model import ProjectLabel as pl
|
||||
from app.model.schemas.project_label_schemas import ProjectLabel
|
||||
from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoPager
|
||||
from app.model.schemas.project_image_schemas import ProjectImgLeaferLabel, ProjectImagePager
|
||||
from app.model.schemas.project_train_schemas import ProjectTrainIn
|
||||
from app.common.jwt_check import get_user_id
|
||||
from app.common import reponse_code as rc
|
||||
from app.service import project_train_service as ps
|
||||
from app.db.db_session import get_db
|
||||
|
||||
import threading
|
||||
import asyncio
|
||||
from typing import List
|
||||
from fastapi import APIRouter, Depends, Request, UploadFile, File, Form
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -182,7 +184,7 @@ def get_image_list(image: ProjectImagePager, session: Session = Depends(get_db))
|
||||
:return:
|
||||
"""
|
||||
if image.pagerNum is None and image.pagerSize is None:
|
||||
image_list = pimc.get_image_list(image.project_id, session)
|
||||
image_list = pimc.get_image_list(image.project_id, image.img_type, session)
|
||||
result = jsonable_encoder(image_list)
|
||||
return rc.response_success(data=result)
|
||||
else:
|
||||
@ -217,14 +219,15 @@ def get_img_leafer(image_id: int, session: Session = Depends(get_db)):
|
||||
return rc.response_success(data=img_leafer_out['leafer'])
|
||||
|
||||
|
||||
@project.get("/run_train/{project_id}")
|
||||
async def run_train(project_id: int, session: Session = Depends(get_db)):
|
||||
@project.post("/run_train")
|
||||
async def run_train(train_in: ProjectTrainIn, session: Session = Depends(get_db)):
|
||||
"""
|
||||
执行项目训练方法
|
||||
:param project_id:
|
||||
:param train_in:
|
||||
:param session:
|
||||
:return:
|
||||
"""
|
||||
project_id = train_in.project_id
|
||||
project_info = pic.get_project_by_id(project_id, session)
|
||||
if project_info is None:
|
||||
return rc.response_error("项目查询错误")
|
||||
@ -240,10 +243,24 @@ async def run_train(project_id: int, session: Session = Depends(get_db)):
|
||||
return rc.response_error("请先上传验证图片")
|
||||
if val_img_count < 5:
|
||||
return rc.response_error("验证图片少于5张,请继续上传验证图片")
|
||||
data, project_name, name = ps.run_train_yolo(project_info, session)
|
||||
return StreamingResponse(
|
||||
ps.run_commend(data, project_name, name, 50, project_id, session),
|
||||
media_type="text/plain")
|
||||
data, project, name = ps.run_train_yolo(project_info, train_in, session)
|
||||
thread_train = threading.Thread(target=run_event_loop, args=(data, project, name, train_in,
|
||||
project_id, session,))
|
||||
thread_train.start();
|
||||
return rc.response_success(msg="执行成功")
|
||||
|
||||
|
||||
def run_event_loop(data: str, project: str,
|
||||
name: str, train_in: ProjectTrainIn,
|
||||
project_id: int, session: Session):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
# 运行异步函数
|
||||
loop.run_until_complete(ps.run_commend(data, project, name, train_in.epochs,
|
||||
train_in.patience, train_in.weights_id,
|
||||
project_id, session))
|
||||
# 可选: 关闭循环
|
||||
loop.close()
|
||||
|
||||
|
||||
@project.get("/get_train_list/{project_id}")
|
||||
|
@ -1,4 +1,5 @@
|
||||
from fastapi import APIRouter, WebSocket
|
||||
from starlette.websockets import WebSocketState
|
||||
|
||||
from app.websocket.web_socket_server import room_manager
|
||||
|
||||
@ -19,9 +20,11 @@ async def websocket_room(websocket: WebSocket, room: str):
|
||||
try:
|
||||
while True:
|
||||
data = await websocket.receive_text()
|
||||
await room_manager.broadcast_to_room(room, data, exclude_websocket=websocket)
|
||||
room_manager.broadcast_to_room(room, data, exclude_websocket=websocket)
|
||||
except Exception as e:
|
||||
print(f"连接关闭: {e}")
|
||||
if websocket.client_state != WebSocketState.DISCONNECTED:
|
||||
await websocket.close(code=1000)
|
||||
finally:
|
||||
await room_manager.remove_from_room(room, websocket)
|
||||
await websocket.close()
|
||||
if websocket.client_state != WebSocketState.DISCONNECTED:
|
||||
await websocket.close(code=1001)
|
||||
|
@ -44,17 +44,17 @@ my_app.include_router(web_socket, prefix="/ws", tags=["websocket管理"])
|
||||
|
||||
|
||||
# fastapi定时执行任务
|
||||
async def periodic_task():
|
||||
async def serve_info_task():
|
||||
while True:
|
||||
server_json = get_server_json()
|
||||
await room_manager.send_to_room('server-info', server_json)
|
||||
await asyncio.sleep(2) # 每 10 秒执行
|
||||
await asyncio.sleep(5) # 每 5 秒执行
|
||||
|
||||
|
||||
@my_app.on_event("startup")
|
||||
async def start_periodic_task():
|
||||
# 在后台启动异步任务
|
||||
asyncio.create_task(periodic_task())
|
||||
asyncio.create_task(serve_info_task())
|
||||
|
||||
|
||||
|
||||
|
@ -3,7 +3,7 @@ from fastapi import status
|
||||
from app.db.page_util import Pager
|
||||
|
||||
|
||||
def response_code_view(code: int,msg: str) -> Response:
|
||||
def response_code_view(code: int, msg: str) -> Response:
|
||||
return JSONResponse(
|
||||
status_code=code,
|
||||
content={
|
||||
|
@ -81,6 +81,10 @@ class ProjectTrain(DbCommon):
|
||||
__tablename__ = "project_train"
|
||||
project_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
train_version: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||
weights_id: Mapped[int] = mapped_column(Integer)
|
||||
weights_name: Mapped[str] = mapped_column(String(32))
|
||||
epochs: Mapped[int] = mapped_column(Integer)
|
||||
patience: Mapped[int] = mapped_column(Integer)
|
||||
best_pt: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
last_pt: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
|
||||
|
@ -3,10 +3,20 @@ from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class ProjectTrainIn(BaseModel):
|
||||
project_id: Optional[int] = Field(..., description="项目id")
|
||||
weights_id: Optional[int] = Field(None, description="权重文件")
|
||||
epochs: Optional[int] = Field(50, description="训练轮数")
|
||||
patience: Optional[int] = Field(20, description="早停的耐心值")
|
||||
|
||||
|
||||
class ProjectTrainOut(BaseModel):
|
||||
"""项目训练版本信息表"""
|
||||
id: Optional[int] = Field(None, description="训练id")
|
||||
train_version: Optional[str] = Field(None, description="训练版本号")
|
||||
weights_name: Optional[str] = Field(None, description="权重名称")
|
||||
epochs: Optional[int] = Field(None, description="训练轮数")
|
||||
patience: Optional[int] = Field(None, description="早停的耐心值")
|
||||
create_time: Optional[datetime] = Field(None, description="训练时间")
|
||||
|
||||
class Config:
|
||||
|
@ -1,6 +1,7 @@
|
||||
from app.model.bussiness_model import ProjectImage, ProjectInfo, ProjectImgLeafer, ProjectImgLabel, ProjectTrain
|
||||
from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoOut
|
||||
from app.model.schemas.project_image_schemas import ProjectImgLeaferLabel, ProjectImgLeaferOut
|
||||
from app.model.schemas.project_train_schemas import ProjectTrainIn
|
||||
from app.model.crud import project_info_crud as pic
|
||||
from app.model.crud import project_image_crud as pimc
|
||||
from app.model.crud import project_label_crud as plc
|
||||
@ -9,12 +10,14 @@ from app.model.crud import project_img_leafer_label_crud as pillc
|
||||
from app.util import os_utils as os
|
||||
from app.util import random_utils as ru
|
||||
from app.config.config_reader import datasets_url, runs_url, images_url, yolo_url
|
||||
from app.websocket.web_socket_server import room_manager
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
from fastapi import UploadFile
|
||||
import yaml
|
||||
import subprocess
|
||||
import asyncio
|
||||
|
||||
|
||||
def add_project(info: ProjectInfoIn, session: Session, user_id: int):
|
||||
@ -119,9 +122,10 @@ def get_img_leafer(image_id: int, session: Session):
|
||||
return img_leafer_out
|
||||
|
||||
|
||||
def run_train_yolo(project_info: ProjectInfoOut, session: Session):
|
||||
def run_train_yolo(project_info: ProjectInfoOut, train_in: ProjectTrainIn, session: Session):
|
||||
"""
|
||||
yolov5执行训练任务
|
||||
:param train_in: 训练参数
|
||||
:param project_info: 项目信息
|
||||
:param session: 数据库session
|
||||
:return:
|
||||
@ -168,29 +172,44 @@ def run_train_yolo(project_info: ProjectInfoOut, session: Session):
|
||||
# 打包完成开始训练,训练前,更改项目的训练状态
|
||||
pic.update_project_status(project_info.id, '1', session)
|
||||
|
||||
# 开始训练
|
||||
# 开始执行异步训练
|
||||
data = yaml_file
|
||||
project = os.file_path(runs_url, project_info.project_no, 'train')
|
||||
project = os.file_path(runs_url, project_info.project_no)
|
||||
name = version_path
|
||||
|
||||
# thread_train = threading.Thread(target=ps.run_commend, args=(data, project, name, train_in.epochs,
|
||||
# train_in.patience, train_in.weights_id,
|
||||
# train_in.project_id, session,))
|
||||
# thread_train.start();
|
||||
return data, project, name
|
||||
|
||||
|
||||
def run_commend(data: str, project: str,
|
||||
name: str, epochs: int,
|
||||
async def run_commend(data: str, project: str,
|
||||
name: str, epochs: int, patience: int, weights: str,
|
||||
project_id: int, session: Session):
|
||||
"""
|
||||
执行训练
|
||||
:param data: 训练数据集
|
||||
:param project: 训练结果的项目目录
|
||||
:param name: 实验名称
|
||||
:param epochs: 训练轮数
|
||||
:param patience: 早停耐心值
|
||||
:param weights: 权重文件
|
||||
:param project_id: 项目id
|
||||
:param session:
|
||||
:return:
|
||||
"""
|
||||
yolo_path = os.file_path(yolo_url, 'train.py')
|
||||
|
||||
yield f"stdout: 模型训练开始,请稍等。。。\n"
|
||||
room = 'train_' + str(project_id)
|
||||
await room_manager.send_to_room(room, f"stdout: 模型训练开始,请稍等。。。\n")
|
||||
commend = ["python", '-u', yolo_path, "--data=" + data, "--project=" + project, "--name=" + name,
|
||||
"--epochs=" + str(epochs), "--batch-size=4", "--exist-ok", "--patience=" + str(patience)]
|
||||
if weights != None and weights != '':
|
||||
train_info = ptc.get_train(weights, session)
|
||||
if train_info != None:
|
||||
commend.append("--weights=" + train_info.best_pt)
|
||||
# 启动子进程
|
||||
with subprocess.Popen(
|
||||
["python", '-u', yolo_path,
|
||||
"--data=" + data,
|
||||
"--project=" + project,
|
||||
"--name=" + name,
|
||||
"--epochs=" + str(epochs),
|
||||
"--batch-size=4",
|
||||
"--exist-ok"],
|
||||
commend,
|
||||
bufsize=1, # bufsize=0时,为不缓存;bufsize=1时,按行缓存;bufsize为其他正整数时,为按照近似该正整数的字节数缓存
|
||||
shell=False,
|
||||
stdout=subprocess.PIPE,
|
||||
@ -202,13 +221,14 @@ def run_commend(data: str, project: str,
|
||||
line = process.stdout.readline()
|
||||
process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死
|
||||
if line != '\n':
|
||||
yield line
|
||||
await room_manager.send_to_room(room, line)
|
||||
|
||||
# 等待进程结束并获取返回码
|
||||
return_code = process.wait()
|
||||
if return_code != 0:
|
||||
pic.update_project_status(project_id, '-1', session)
|
||||
else:
|
||||
await room_manager.send_to_room(room, 'success')
|
||||
pic.update_project_status(project_id, '2', session)
|
||||
# 然后保存版本训练信息
|
||||
train = ProjectTrain()
|
||||
@ -218,6 +238,11 @@ def run_commend(data: str, project: str,
|
||||
last_pt_path = os.file_path(project, name, 'weights', 'last.pt')
|
||||
train.best_pt = bast_pt_path
|
||||
train.last_pt = last_pt_path
|
||||
if weights != None and weights != '':
|
||||
train.weights_id = weights
|
||||
train.weights_name = train_info.train_version
|
||||
train.patience = patience
|
||||
train.epochs = epochs
|
||||
ptc.add_train(train, session)
|
||||
|
||||
|
||||
|
@ -2,6 +2,7 @@ import psutil
|
||||
import platform
|
||||
import json
|
||||
from datetime import datetime
|
||||
# 获取服务器运行状态
|
||||
|
||||
|
||||
def get_server_info():
|
||||
@ -74,4 +75,4 @@ def get_server_info():
|
||||
|
||||
def get_server_json():
|
||||
server_info = get_server_info()
|
||||
return json.dumps(server_info, indent=2, ensure_ascii=False)
|
||||
return json.dumps(server_info, indent=2, ensure_ascii=False)
|
||||
|
@ -10,6 +10,7 @@ python-multipart==0.0.5
|
||||
redis~=4.1.4
|
||||
SQLAlchemy~=2.0.34
|
||||
uvicorn~=0.17.5
|
||||
uvicorn[standard]
|
||||
loguru~=0.6.0
|
||||
xlrd~=2.0.1
|
||||
bcrypt==3.2.0
|
||||
|
Reference in New Issue
Block a user