完成项目训练模块的接口测试

This commit is contained in:
2025-04-22 10:11:44 +08:00
parent 7a9e571a96
commit 0033746fe1
11 changed files with 139 additions and 91 deletions

View File

@ -3,6 +3,8 @@
# @version : 1.0
# @Create Time : 2025/04/03 10:32
# @File : views.py
from core.database import redis_getter
from . import models, schemas, crud, service
from apps.vadmin.auth.utils.current import AllUserAuth
from apps.vadmin.auth.utils.validation.auth import Auth
@ -10,6 +12,7 @@ from utils.response import SuccessResponse, ErrorResponse
from apps.business.project.crud import ProjectInfoDal, ProjectImageDal
import threading
from redis.asyncio import Redis
from fastapi import APIRouter, Depends
@ -19,10 +22,11 @@ app = APIRouter()
###########################################################
# 项目训练信息
###########################################################
@app.post("/", summary="执行训练")
@app.post("/start", summary="执行训练")
async def run_train(
train_in: schemas.ProjectTrainIn,
auth: Auth = Depends(AllUserAuth())):
auth: Auth = Depends(AllUserAuth()),
rd: Redis = Depends(redis_getter)):
proj_id = train_in.project_id
proj_dal = ProjectInfoDal(auth.db)
proj_img_dal = ProjectImageDal(auth.db)
@ -43,12 +47,17 @@ async def run_train(
return ErrorResponse("训练图片中存在未标注的图片")
if val_label_count > 0:
return ErrorResponse("验证图片中存在未标注的图片")
data, project, name = service.before_train(proj_info, auth.db)
data, project, name = await service.before_train(proj_info, auth.db)
is_gpu = await rd.get('is_gpu')
train_info = None
if train_in.weights_id is not None:
train_info = await crud.ProjectTrainDal(auth.db).get_data(train_in.weights_id)
# 异步执行操作操作过程通过websocket进行同步
thread_train = threading.Thread(
target=service.run_event_loop,
args=(data, project, name, train_in, proj_id, auth.db,))
args=(data, project, name, train_in, proj_id, train_info, is_gpu))
thread_train.start()
await service.add_train(auth.db, proj_id, name, project, data, train_in, auth.user.id)
return SuccessResponse(msg="执行成功")
@ -57,15 +66,16 @@ async def train_list(
proj_id: int,
auth: Auth = Depends(AllUserAuth())):
datas = await crud.ProjectTrainDal(auth.db).get_datas(
limit=0,
v_where=[models.ProjectTrain.project_id == proj_id],
v_schema=schemas.ProjectTrainOut,
v_order="asc",
v_order_field="id",v_return_count=False)
v_order_field="id",
v_return_count=False)
return SuccessResponse(data=datas)
@app.get("/result/{proj_id}", summary="查询训练报告")
async def get_result(train_id:int, auth: Auth = Depends(AllUserAuth())):
@app.get("/result/{train_id}", summary="查询训练报告")
async def get_result(train_id: int, auth: Auth = Depends(AllUserAuth())):
result = await crud.ProjectTrainDal(auth.db).get_result(train_id)
return SuccessResponse(data=result)