完成项目推理模块的接口测试
This commit is contained in:
@ -6,18 +6,24 @@
|
||||
# @IDE : PyCharm
|
||||
# @desc : 路由,视图文件
|
||||
|
||||
from utils import os_utils as osu
|
||||
from core.dependencies import IdList
|
||||
from core.database import redis_getter
|
||||
from . import schemas, crud, params, service
|
||||
from . import schemas, crud, params, service, models
|
||||
from utils.websocket_server import room_manager
|
||||
from apps.business.train.crud import ProjectTrainDal
|
||||
from apps.vadmin.auth.utils.current import AllUserAuth
|
||||
from apps.vadmin.auth.utils.validation.auth import Auth
|
||||
from utils.response import SuccessResponse, ErrorResponse
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import zipfile
|
||||
import tempfile
|
||||
import threading
|
||||
from redis.asyncio import Redis
|
||||
from fastapi import Depends, APIRouter, Form, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi import Depends, APIRouter, Form, UploadFile, BackgroundTasks
|
||||
|
||||
|
||||
app = APIRouter()
|
||||
@ -26,22 +32,27 @@ app = APIRouter()
|
||||
###########################################################
|
||||
# 项目推理集合信息
|
||||
###########################################################
|
||||
@app.get("/list", summary="获取项目推理集合信息列表")
|
||||
@app.get("/list/{proj_id}", summary="获取推理集合列表")
|
||||
async def detect_list(
|
||||
p: params.ProjectDetectParams = Depends(),
|
||||
proj_id: int,
|
||||
auth: Auth = Depends(AllUserAuth())):
|
||||
datas, count = await crud.ProjectDetectDal(auth.db).get_datas(**p.dict(), v_return_count=True)
|
||||
return SuccessResponse(datas, count=count)
|
||||
datas = await crud.ProjectDetectDal(auth.db).get_datas(
|
||||
limit=0,
|
||||
v_where=[models.ProjectDetect.project_id == proj_id],
|
||||
v_order="asc",
|
||||
v_order_field="id",
|
||||
v_return_count=False)
|
||||
return SuccessResponse(datas)
|
||||
|
||||
|
||||
@app.post("/", summary="创建项目推理集合信息")
|
||||
@app.post("/", summary="创建推理集合")
|
||||
async def add_detect(
|
||||
data: schemas.ProjectDetectIn,
|
||||
auth: Auth = Depends(AllUserAuth())):
|
||||
detect_dal = crud.ProjectDetectDal(auth.db)
|
||||
if await detect_dal.check_name(data.detect_name, data.project_id):
|
||||
return ErrorResponse(msg="该项目中存在相同名称的集合")
|
||||
await detect_dal.create_data(data=data)
|
||||
await detect_dal.add_detect(data=data, user_id=auth.user.id)
|
||||
return SuccessResponse(msg="保存成功")
|
||||
|
||||
|
||||
@ -49,18 +60,18 @@ async def add_detect(
|
||||
async def delete_detect(
|
||||
ids: IdList = Depends(),
|
||||
auth: Auth = Depends(AllUserAuth())):
|
||||
await crud.ProjectDetectDal(auth.db).delete_datas(ids=ids.ids, v_soft=False)
|
||||
await crud.ProjectDetectDal(auth.db).delete_detects(ids=ids.ids)
|
||||
return SuccessResponse("删除成功")
|
||||
|
||||
|
||||
###########################################################
|
||||
# 项目推理集合文件信息
|
||||
###########################################################
|
||||
@app.get("/file", summary="获取项目推理集合文件信息列表")
|
||||
@app.get("/files", summary="获取推理集合文件列表")
|
||||
async def file_list(
|
||||
p: params.ProjectDetectFileParams = Depends(),
|
||||
auth: Auth = Depends(AllUserAuth())):
|
||||
if p.limit:
|
||||
if p.limit > 0:
|
||||
datas, count = await crud.ProjectDetectFileDal(auth.db).get_datas(**p.dict(), v_return_count=True)
|
||||
return SuccessResponse(datas, count=count)
|
||||
else:
|
||||
@ -68,20 +79,24 @@ async def file_list(
|
||||
return SuccessResponse(datas)
|
||||
|
||||
|
||||
@app.post("/file", summary="上传项目推理集合文件")
|
||||
@app.post("/files", summary="上传项目推理集合文件")
|
||||
async def upload_file(
|
||||
detect_id: int = Form(...),
|
||||
files: list[UploadFile] = Form(...),
|
||||
auth: Auth = Depends(AllUserAuth())):
|
||||
detect_dal = crud.ProjectDetectDal(auth.db)
|
||||
file_dal = crud.ProjectDetectFileDal(auth.db)
|
||||
detect_out = file_dal.get_data(data_id=detect_id)
|
||||
if detect_out is None:
|
||||
detect_info = await detect_dal.get_data(data_id=detect_id)
|
||||
if detect_info is None:
|
||||
return ErrorResponse("训练集合查询失败,请刷新后再试")
|
||||
await file_dal.add_file(detect_out, files)
|
||||
if detect_info.file_type != 'rtsp':
|
||||
if not any(osu.is_extensions(detect_info.file_type, file.filename) for file in files):
|
||||
return ErrorResponse("上传的文件中存在不符合的文件类型")
|
||||
await file_dal.add_files(detect_info, files, auth.user.id)
|
||||
return SuccessResponse(msg="上传成功")
|
||||
|
||||
|
||||
@app.delete("/file", summary="删除项目推理集合文件信息")
|
||||
@app.delete("/files", summary="删除推理集合文件")
|
||||
async def delete_file(
|
||||
ids: IdList = Depends(),
|
||||
auth: Auth = Depends(AllUserAuth())):
|
||||
@ -89,29 +104,34 @@ async def delete_file(
|
||||
return SuccessResponse("删除成功")
|
||||
|
||||
|
||||
@app.post("/detect", summary="开始推理")
|
||||
def run_detect_yolo(
|
||||
@app.post("/start", summary="开始推理")
|
||||
async def run_detect_yolo(
|
||||
detect_log_in: schemas.ProjectDetectLogIn,
|
||||
auth: Auth = Depends(AllUserAuth()),
|
||||
rd: Redis = Depends(redis_getter)):
|
||||
detect_dal = crud.ProjectDetectDal(auth.db)
|
||||
train_dal = ProjectTrainDal(auth.db)
|
||||
detect = detect_dal.get_data(detect_log_in.detect_id)
|
||||
detect = await detect_dal.get_data(detect_log_in.detect_id)
|
||||
if detect is None:
|
||||
return ErrorResponse(msg="训练集合不存在")
|
||||
train = train_dal.get_data(detect_log_in.train_id)
|
||||
train = await train_dal.get_data(detect_log_in.train_id)
|
||||
if train is None:
|
||||
return ErrorResponse("训练权重不存在")
|
||||
file_count = crud.ProjectDetectFileDal(auth.db).file_count(detect_log_in.detect_id)
|
||||
file_count = await crud.ProjectDetectFileDal(auth.db).file_count(detect_log_in.detect_id)
|
||||
if file_count == 0 and detect.rtsp_url is None:
|
||||
return ErrorResponse("推理集合中没有内容,请先到推理集合中上传图片")
|
||||
is_gpu = await rd.get('is_gpu')
|
||||
if detect.file_type == 'img' or detect.file_type == 'video':
|
||||
detect_log = service.before_detect(detect_log_in, detect, train, auth.db)
|
||||
detect_log = await service.before_detect(detect_log_in, detect, train, auth.db, auth.user.id)
|
||||
thread_train = threading.Thread(target=service.run_img_loop,
|
||||
args=(detect_log.pt_url, detect_log.folder_url,
|
||||
detect_log.detect_folder_url, detect_log.detect_version,
|
||||
detect_log.id, detect_log.detect_id, auth.db,))
|
||||
detect_log.detect_id, is_gpu))
|
||||
thread_train.start()
|
||||
await service.update_sql(
|
||||
auth.db, detect_log.detect_id,
|
||||
detect_log.id, detect_log.detect_folder_url,
|
||||
detect_log.detect_version)
|
||||
elif detect.file_type == 'rtsp':
|
||||
room = 'detect_rtsp_' + str(detect.id)
|
||||
if not room_manager.rooms.get(room):
|
||||
@ -120,7 +140,7 @@ def run_detect_yolo(
|
||||
else:
|
||||
weights_pt = train.last_pt
|
||||
thread_train = threading.Thread(target=service.run_rtsp_loop,
|
||||
args=(weights_pt, detect.rtsp_url, train.train_data, detect.id, rd,))
|
||||
args=(weights_pt, detect.rtsp_url, train.train_data, detect.id, is_gpu,))
|
||||
thread_train.start()
|
||||
return SuccessResponse(msg="执行成功")
|
||||
|
||||
@ -128,16 +148,50 @@ def run_detect_yolo(
|
||||
###########################################################
|
||||
# 项目推理记录信息
|
||||
###########################################################
|
||||
@app.get("/log", summary="获取项目推理记录列表")
|
||||
async def log_pager(
|
||||
@app.get("/logs", summary="获取推理记录列表")
|
||||
async def logs(
|
||||
p: params.ProjectDetectLogParams = Depends(),
|
||||
auth: Auth = Depends(AllUserAuth())):
|
||||
datas, count = await crud.ProjectDetectLogDal(auth.db).get_datas(**p.dict(), v_return_count=True)
|
||||
return SuccessResponse(datas, count=count)
|
||||
|
||||
|
||||
@app.get("/log_files", summary="获取项目推理记录文件列表")
|
||||
async def log_files(
|
||||
@app.get("/logs/download/{log_id}", summary="下载推理结果")
|
||||
async def logs_download(
|
||||
log_id: int,
|
||||
background_tasks: BackgroundTasks,
|
||||
auth: Auth = Depends(AllUserAuth())
|
||||
):
|
||||
# 获取日志记录
|
||||
detect_log = await crud.ProjectDetectLogDal(auth.db).get_data(data_id=log_id)
|
||||
if detect_log is None:
|
||||
return ErrorResponse(msg="推理结果查询错误,请刷新页面再试")
|
||||
|
||||
# 检查源文件夹是否存在
|
||||
folder_path = os.path.join(detect_log.detect_folder_url, detect_log.detect_version)
|
||||
if not os.path.exists(folder_path):
|
||||
return ErrorResponse(msg="推理结果文件夹不存在")
|
||||
|
||||
zip_filename = f"{detect_log.detect_name}_{detect_log.detect_version}.zip"
|
||||
zip_file_path = osu.zip_folder(folder_path, zip_filename)
|
||||
|
||||
# 后台任务删掉这个
|
||||
background_tasks.add_task(cleanup_temp_dir, zip_file_path)
|
||||
|
||||
# 直接返回文件响应
|
||||
return FileResponse(path=zip_file_path, filename=os.path.basename(zip_file_path))
|
||||
|
||||
|
||||
def cleanup_temp_dir(temp_dir: str):
|
||||
"""清理临时目录的后台任务"""
|
||||
try:
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
except Exception as e:
|
||||
print(f"清理临时目录失败: {e}")
|
||||
|
||||
|
||||
@app.get("/logs/files", summary="获取项目推理记录文件列表")
|
||||
async def logs_files(
|
||||
p: params.ProjectDetectLogFileParams = Depends(),
|
||||
auth: Auth = Depends(AllUserAuth())):
|
||||
datas = await crud.ProjectDetectLogFileDal(auth.db).get_datas(**p.dict(), v_return_count=False)
|
||||
|
Reference in New Issue
Block a user