完成线程启动训练并以websocket的方式进行发送shell
This commit is contained in:
@ -3,18 +3,20 @@ from app.model.crud import project_label_crud as plc
|
||||
from app.model.crud import project_info_crud as pic
|
||||
from app.model.crud import project_image_crud as pimc
|
||||
from app.model.crud import project_train_crud as ptnc
|
||||
from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoPager
|
||||
from app.model.schemas.project_label_schemas import ProjectLabel
|
||||
from app.model.schemas.project_image_schemas import ProjectImgLeaferLabel, ProjectImagePager
|
||||
from app.model.bussiness_model import ProjectLabel as pl
|
||||
from app.model.schemas.project_label_schemas import ProjectLabel
|
||||
from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoPager
|
||||
from app.model.schemas.project_image_schemas import ProjectImgLeaferLabel, ProjectImagePager
|
||||
from app.model.schemas.project_train_schemas import ProjectTrainIn
|
||||
from app.common.jwt_check import get_user_id
|
||||
from app.common import reponse_code as rc
|
||||
from app.service import project_train_service as ps
|
||||
from app.db.db_session import get_db
|
||||
|
||||
import threading
|
||||
import asyncio
|
||||
from typing import List
|
||||
from fastapi import APIRouter, Depends, Request, UploadFile, File, Form
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -182,7 +184,7 @@ def get_image_list(image: ProjectImagePager, session: Session = Depends(get_db))
|
||||
:return:
|
||||
"""
|
||||
if image.pagerNum is None and image.pagerSize is None:
|
||||
image_list = pimc.get_image_list(image.project_id, session)
|
||||
image_list = pimc.get_image_list(image.project_id, image.img_type, session)
|
||||
result = jsonable_encoder(image_list)
|
||||
return rc.response_success(data=result)
|
||||
else:
|
||||
@ -217,14 +219,15 @@ def get_img_leafer(image_id: int, session: Session = Depends(get_db)):
|
||||
return rc.response_success(data=img_leafer_out['leafer'])
|
||||
|
||||
|
||||
@project.get("/run_train/{project_id}")
|
||||
async def run_train(project_id: int, session: Session = Depends(get_db)):
|
||||
@project.post("/run_train")
|
||||
async def run_train(train_in: ProjectTrainIn, session: Session = Depends(get_db)):
|
||||
"""
|
||||
执行项目训练方法
|
||||
:param project_id:
|
||||
:param train_in:
|
||||
:param session:
|
||||
:return:
|
||||
"""
|
||||
project_id = train_in.project_id
|
||||
project_info = pic.get_project_by_id(project_id, session)
|
||||
if project_info is None:
|
||||
return rc.response_error("项目查询错误")
|
||||
@ -240,10 +243,24 @@ async def run_train(project_id: int, session: Session = Depends(get_db)):
|
||||
return rc.response_error("请先上传验证图片")
|
||||
if val_img_count < 5:
|
||||
return rc.response_error("验证图片少于5张,请继续上传验证图片")
|
||||
data, project_name, name = ps.run_train_yolo(project_info, session)
|
||||
return StreamingResponse(
|
||||
ps.run_commend(data, project_name, name, 50, project_id, session),
|
||||
media_type="text/plain")
|
||||
data, project, name = ps.run_train_yolo(project_info, train_in, session)
|
||||
thread_train = threading.Thread(target=run_event_loop, args=(data, project, name, train_in,
|
||||
project_id, session,))
|
||||
thread_train.start();
|
||||
return rc.response_success(msg="执行成功")
|
||||
|
||||
|
||||
def run_event_loop(data: str, project: str,
|
||||
name: str, train_in: ProjectTrainIn,
|
||||
project_id: int, session: Session):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
# 运行异步函数
|
||||
loop.run_until_complete(ps.run_commend(data, project, name, train_in.epochs,
|
||||
train_in.patience, train_in.weights_id,
|
||||
project_id, session))
|
||||
# 可选: 关闭循环
|
||||
loop.close()
|
||||
|
||||
|
||||
@project.get("/get_train_list/{project_id}")
|
||||
|
@ -1,4 +1,5 @@
|
||||
from fastapi import APIRouter, WebSocket
|
||||
from starlette.websockets import WebSocketState
|
||||
|
||||
from app.websocket.web_socket_server import room_manager
|
||||
|
||||
@ -19,9 +20,11 @@ async def websocket_room(websocket: WebSocket, room: str):
|
||||
try:
|
||||
while True:
|
||||
data = await websocket.receive_text()
|
||||
await room_manager.broadcast_to_room(room, data, exclude_websocket=websocket)
|
||||
room_manager.broadcast_to_room(room, data, exclude_websocket=websocket)
|
||||
except Exception as e:
|
||||
print(f"连接关闭: {e}")
|
||||
if websocket.client_state != WebSocketState.DISCONNECTED:
|
||||
await websocket.close(code=1000)
|
||||
finally:
|
||||
await room_manager.remove_from_room(room, websocket)
|
||||
await websocket.close()
|
||||
if websocket.client_state != WebSocketState.DISCONNECTED:
|
||||
await websocket.close(code=1001)
|
||||
|
Reference in New Issue
Block a user