完成线程启动训练并以websocket的方式进行发送shell

This commit is contained in:
2025-03-10 17:42:56 +08:00
parent b4b1085403
commit 7d736c4ac4
9 changed files with 97 additions and 36 deletions

View File

@ -3,18 +3,20 @@ from app.model.crud import project_label_crud as plc
from app.model.crud import project_info_crud as pic
from app.model.crud import project_image_crud as pimc
from app.model.crud import project_train_crud as ptnc
from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoPager
from app.model.schemas.project_label_schemas import ProjectLabel
from app.model.schemas.project_image_schemas import ProjectImgLeaferLabel, ProjectImagePager
from app.model.bussiness_model import ProjectLabel as pl
from app.model.schemas.project_label_schemas import ProjectLabel
from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoPager
from app.model.schemas.project_image_schemas import ProjectImgLeaferLabel, ProjectImagePager
from app.model.schemas.project_train_schemas import ProjectTrainIn
from app.common.jwt_check import get_user_id
from app.common import reponse_code as rc
from app.service import project_train_service as ps
from app.db.db_session import get_db
import threading
import asyncio
from typing import List
from fastapi import APIRouter, Depends, Request, UploadFile, File, Form
from fastapi.responses import StreamingResponse
from fastapi.encoders import jsonable_encoder
from sqlalchemy.orm import Session
@ -182,7 +184,7 @@ def get_image_list(image: ProjectImagePager, session: Session = Depends(get_db))
:return:
"""
if image.pagerNum is None and image.pagerSize is None:
image_list = pimc.get_image_list(image.project_id, session)
image_list = pimc.get_image_list(image.project_id, image.img_type, session)
result = jsonable_encoder(image_list)
return rc.response_success(data=result)
else:
@ -217,14 +219,15 @@ def get_img_leafer(image_id: int, session: Session = Depends(get_db)):
return rc.response_success(data=img_leafer_out['leafer'])
@project.get("/run_train/{project_id}")
async def run_train(project_id: int, session: Session = Depends(get_db)):
@project.post("/run_train")
async def run_train(train_in: ProjectTrainIn, session: Session = Depends(get_db)):
"""
执行项目训练方法
:param project_id:
:param train_in:
:param session:
:return:
"""
project_id = train_in.project_id
project_info = pic.get_project_by_id(project_id, session)
if project_info is None:
return rc.response_error("项目查询错误")
@ -240,10 +243,24 @@ async def run_train(project_id: int, session: Session = Depends(get_db)):
return rc.response_error("请先上传验证图片")
if val_img_count < 5:
return rc.response_error("验证图片少于5张请继续上传验证图片")
data, project_name, name = ps.run_train_yolo(project_info, session)
return StreamingResponse(
ps.run_commend(data, project_name, name, 50, project_id, session),
media_type="text/plain")
data, project, name = ps.run_train_yolo(project_info, train_in, session)
thread_train = threading.Thread(target=run_event_loop, args=(data, project, name, train_in,
project_id, session,))
thread_train.start();
return rc.response_success(msg="执行成功")
def run_event_loop(data: str, project: str,
name: str, train_in: ProjectTrainIn,
project_id: int, session: Session):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 运行异步函数
loop.run_until_complete(ps.run_commend(data, project, name, train_in.epochs,
train_in.patience, train_in.weights_id,
project_id, session))
# 可选: 关闭循环
loop.close()
@project.get("/get_train_list/{project_id}")

View File

@ -1,4 +1,5 @@
from fastapi import APIRouter, WebSocket
from starlette.websockets import WebSocketState
from app.websocket.web_socket_server import room_manager
@ -19,9 +20,11 @@ async def websocket_room(websocket: WebSocket, room: str):
try:
while True:
data = await websocket.receive_text()
await room_manager.broadcast_to_room(room, data, exclude_websocket=websocket)
room_manager.broadcast_to_room(room, data, exclude_websocket=websocket)
except Exception as e:
print(f"连接关闭: {e}")
if websocket.client_state != WebSocketState.DISCONNECTED:
await websocket.close(code=1000)
finally:
await room_manager.remove_from_room(room, websocket)
await websocket.close()
if websocket.client_state != WebSocketState.DISCONNECTED:
await websocket.close(code=1001)