查询训练报告接口
This commit is contained in:
@ -1,16 +1,18 @@
|
||||
import threading
|
||||
import asyncio
|
||||
from typing import List
|
||||
from fastapi import APIRouter, Depends, UploadFile, File, Form
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db.db_session import get_db
|
||||
from app.common import reponse_code as rc
|
||||
from app.model.crud import project_detect_crud as pdc
|
||||
from app.service import project_detect_service as pds
|
||||
from app.model.crud.project_train_crud import get_train
|
||||
from app.model.schemas.project_detect_schemas import ProjectDetectPager, ProjectDetectIn,\
|
||||
ProjectDetectImgPager, ProjectDetectLogIn, ProjectDetectLogPager
|
||||
from app.db.db_session import get_db
|
||||
|
||||
detect = APIRouter()
|
||||
|
||||
@ -116,11 +118,22 @@ def run_detect_yolo(detect_log_in: ProjectDetectLogIn, session: Session = Depend
|
||||
if train is None:
|
||||
return rc.response_error("训练权重不存在")
|
||||
detect_log = pds.run_detect_yolo(detect_log_in, detect, train, session)
|
||||
return StreamingResponse(pds.run_commend(detect_log.pt_url,
|
||||
thread_train = threading.Thread(target=run_event_loop, args=(detect_log.pt_url,
|
||||
detect_log.folder_url,
|
||||
detect_log.detect_folder_url,
|
||||
detect_log.detect_version,
|
||||
detect_log.id, detect_log.detect_id, session), media_type="text/plain")
|
||||
detect_log.id, detect_log.detect_id, session,))
|
||||
thread_train.start()
|
||||
return rc.response_success(msg="执行成功")
|
||||
|
||||
|
||||
def run_event_loop(weights: str, source: str, project: str, name: str, log_id: int, detect_id: int, session: Session):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
# 运行异步函数
|
||||
loop.run_until_complete(pds.run_commend(weights, source, project, name, log_id, detect_id, session))
|
||||
# 可选: 关闭循环
|
||||
loop.close()
|
||||
|
||||
|
||||
@detect.post("/get_log_pager")
|
||||
|
@ -246,19 +246,17 @@ async def run_train(train_in: ProjectTrainIn, session: Session = Depends(get_db)
|
||||
data, project, name = ps.run_train_yolo(project_info, train_in, session)
|
||||
thread_train = threading.Thread(target=run_event_loop, args=(data, project, name, train_in,
|
||||
project_id, session,))
|
||||
thread_train.start();
|
||||
thread_train.start()
|
||||
return rc.response_success(msg="执行成功")
|
||||
|
||||
|
||||
def run_event_loop(data: str, project: str,
|
||||
name: str, train_in: ProjectTrainIn,
|
||||
project_id: int, session: Session):
|
||||
def run_event_loop(data: str, project: str, name: str, train_in: ProjectTrainIn,
|
||||
project_id: int, session: Session):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
# 运行异步函数
|
||||
loop.run_until_complete(ps.run_commend(data, project, name, train_in.epochs,
|
||||
train_in.patience, train_in.weights_id,
|
||||
project_id, session))
|
||||
train_in.patience, train_in.weights_id, project_id, session))
|
||||
# 可选: 关闭循环
|
||||
loop.close()
|
||||
|
||||
@ -276,4 +274,15 @@ def get_train_list(project_id: int, session: Session = Depends(get_db)):
|
||||
return rc.response_success(data=result)
|
||||
|
||||
|
||||
|
||||
@project.get("/get_train_report/{train_id}")
|
||||
def get_train_report(train_id: int, session: Session = Depends(get_db)):
|
||||
"""
|
||||
查询训练报告
|
||||
:param train_id:
|
||||
:param session:
|
||||
:return:
|
||||
"""
|
||||
result_row = ps.get_train_result(train_id, session)
|
||||
if result_row is None:
|
||||
return rc.response_error("查询失败")
|
||||
return rc.response_success(data=result_row)
|
||||
|
@ -20,7 +20,7 @@ async def websocket_room(websocket: WebSocket, room: str):
|
||||
try:
|
||||
while True:
|
||||
data = await websocket.receive_text()
|
||||
room_manager.broadcast_to_room(room, data, exclude_websocket=websocket)
|
||||
await room_manager.broadcast_to_room(room, data, exclude_websocket=websocket)
|
||||
except Exception as e:
|
||||
if websocket.client_state != WebSocketState.DISCONNECTED:
|
||||
await websocket.close(code=1000)
|
||||
|
Reference in New Issue
Block a user