完成训练模块的转移

This commit is contained in:
2025-04-17 11:03:05 +08:00
parent 4439687870
commit 74e8f0d415
188 changed files with 32931 additions and 70 deletions

View File

@@ -3,49 +3,70 @@
# @version : 1.0
# @Create Time : 2025/04/03 10:32
# @File : views.py
# @IDE : PyCharm
# @desc : 路由,视图文件
from sqlalchemy.ext.asyncio import AsyncSession
from fastapi import APIRouter, Depends
from . import models, schemas, crud, params
from core.dependencies import IdList
from . import models, schemas, crud
from apps.business.project.crud import ProjectInfoDal, ProjectImageDal
from utils.response import SuccessResponse, ErrorResponse
from apps.vadmin.auth.utils.current import AllUserAuth
from utils.response import SuccessResponse
from apps.vadmin.auth.utils.validation.auth import Auth
from core.database import db_getter
import service
import threading
from fastapi import APIRouter, Depends
app = APIRouter()
###########################################################
# 项目巡逻片信息
# 项目训练信息
###########################################################
@app.get("/project/train", summary="获取项目巡逻片信息列表", tags=["项目巡逻片信息"])
async def get_project_train_list(p: params.ProjectTrainParams = Depends(), auth: Auth = Depends(AllUserAuth())):
datas, count = await crud.ProjectTrainDal(auth.db).get_datas(**p.dict(), v_return_count=True)
return SuccessResponse(datas, count=count)
@app.post("/", summary="执行训练")
async def run_train(
train_in: schemas.ProjectTrainIn,
auth: Auth = Depends(AllUserAuth())):
proj_id = train_in.project_id
proj_dal = ProjectInfoDal(auth.db)
proj_img_dal = ProjectImageDal(auth.db)
proj_info = await proj_dal.get_data(proj_id)
if proj_info is None:
return ErrorResponse(msg="项目信息查询错误")
train_count, val_count = await proj_img_dal.get_img_count(proj_id)
if train_count == 0:
return ErrorResponse("请先上传训练图片")
if train_count < 10:
return ErrorResponse("训练图片少于10张请继续上传训练图片")
if val_count == 0:
return ErrorResponse("请先上传验证图片")
if val_count < 5:
return ErrorResponse("验证图片少于5张请继续上传验证图片")
train_label_count, val_label_count = await proj_img_dal.check_image_label(proj_id)
if train_label_count > 0:
return ErrorResponse("训练图片中存在未标注的图片")
if val_label_count > 0:
return ErrorResponse("验证图片中存在未标注的图片")
data, project, name = service.before_train(proj_info, auth.db)
# 异步执行操作操作过程通过websocket进行同步
thread_train = threading.Thread(
target=service.run_event_loop,
args=(data, project, name, train_in, proj_id, auth.db,))
thread_train.start()
return SuccessResponse(msg="执行成功")
@app.post("/project/train", summary="创建项目巡逻片信息", tags=["项目巡逻片信息"])
async def create_project_train(data: schemas.ProjectTrain, auth: Auth = Depends(AllUserAuth())):
return SuccessResponse(await crud.ProjectTrainDal(auth.db).create_data(data=data))
@app.get("/{proj_id}", summary="查询训练列表")
async def train_list(
proj_id: int,
auth: Auth = Depends(AllUserAuth())):
datas = await crud.ProjectTrainDal(auth.db).get_datas(
v_where=[models.ProjectTrain.project_id == proj_id],
v_schema=schemas.ProjectTrainOut,
v_order="asc",
v_order_field="id",v_return_count=False)
return SuccessResponse(data=datas)
@app.delete("/project/train", summary="删除项目巡逻片信息", description="硬删除", tags=["项目巡逻片信息"])
async def delete_project_train_list(ids: IdList = Depends(), auth: Auth = Depends(AllUserAuth())):
await crud.ProjectTrainDal(auth.db).delete_datas(ids=ids.ids, v_soft=False)
return SuccessResponse("删除成功")
@app.put("/project/train/{data_id}", summary="更新项目巡逻片信息", tags=["项目巡逻片信息"])
async def put_project_train(data_id: int, data: schemas.ProjectTrain, auth: Auth = Depends(AllUserAuth())):
return SuccessResponse(await crud.ProjectTrainDal(auth.db).put_data(data_id, data))
@app.get("/project/train/{data_id}", summary="获取项目巡逻片信息信息", tags=["项目巡逻片信息"])
async def get_project_train(data_id: int, db: AsyncSession = Depends(db_getter)):
schema = schemas.ProjectTrainSimpleOut
return SuccessResponse(await crud.ProjectTrainDal(db).get_data(data_id, v_schema=schema))
@app.get("/result/{proj_id}", summary="查询训练报告")
async def get_result(train_id:int, auth: Auth = Depends(AllUserAuth())):
result = await crud.ProjectTrainDal(auth.db).get_result(train_id)
return SuccessResponse(data=result)