200 lines
7.9 KiB
Python

#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version : 1.0
# @Create Time : 2025/04/03 10:30
# @File : views.py
# @IDE : PyCharm
# @desc : 路由,视图文件
from utils import os_utils as osu
from core.dependencies import IdList
from core.database import redis_getter
from . import schemas, crud, params, service, models
from utils.websocket_server import room_manager
from apps.business.train.crud import ProjectTrainDal
from apps.vadmin.auth.utils.current import AllUserAuth
from apps.vadmin.auth.utils.validation.auth import Auth
from utils.response import SuccessResponse, ErrorResponse
import os
import shutil
import zipfile
import tempfile
import threading
from redis.asyncio import Redis
from fastapi.responses import FileResponse
from fastapi import Depends, APIRouter, Form, UploadFile, BackgroundTasks
app = APIRouter()
###########################################################
# 项目推理集合信息
###########################################################
@app.get("/list/{proj_id}", summary="获取推理集合列表")
async def detect_list(
proj_id: int,
auth: Auth = Depends(AllUserAuth())):
datas = await crud.ProjectDetectDal(auth.db).get_datas(
limit=0,
v_where=[models.ProjectDetect.project_id == proj_id],
v_order="asc",
v_order_field="id",
v_return_count=False)
return SuccessResponse(datas)
@app.post("/", summary="创建推理集合")
async def add_detect(
data: schemas.ProjectDetectIn,
auth: Auth = Depends(AllUserAuth())):
detect_dal = crud.ProjectDetectDal(auth.db)
if await detect_dal.check_name(data.detect_name, data.project_id):
return ErrorResponse(msg="该项目中存在相同名称的集合")
await detect_dal.add_detect(data=data, user_id=auth.user.id)
return SuccessResponse(msg="保存成功")
@app.delete("/", summary="删除项目推理集合信息")
async def delete_detect(
ids: IdList = Depends(),
auth: Auth = Depends(AllUserAuth())):
await crud.ProjectDetectDal(auth.db).delete_detects(ids=ids.ids)
return SuccessResponse("删除成功")
###########################################################
# 项目推理集合文件信息
###########################################################
@app.get("/files", summary="获取推理集合文件列表")
async def file_list(
p: params.ProjectDetectFileParams = Depends(),
auth: Auth = Depends(AllUserAuth())):
if p.limit > 0:
datas, count = await crud.ProjectDetectFileDal(auth.db).get_datas(**p.dict(), v_return_count=True)
return SuccessResponse(datas, count=count)
else:
datas = await crud.ProjectDetectFileDal(auth.db).get_datas(**p.dict(), v_return_count=False)
return SuccessResponse(datas)
@app.post("/files", summary="上传项目推理集合文件")
async def upload_file(
detect_id: int = Form(...),
files: list[UploadFile] = Form(...),
auth: Auth = Depends(AllUserAuth())):
detect_dal = crud.ProjectDetectDal(auth.db)
file_dal = crud.ProjectDetectFileDal(auth.db)
detect_info = await detect_dal.get_data(data_id=detect_id)
if detect_info is None:
return ErrorResponse("训练集合查询失败,请刷新后再试")
if detect_info.file_type != 'rtsp':
if not any(osu.is_extensions(detect_info.file_type, file.filename) for file in files):
return ErrorResponse("上传的文件中存在不符合的文件类型")
await file_dal.add_files(detect_info, files, auth.user.id)
return SuccessResponse(msg="上传成功")
@app.delete("/files", summary="删除推理集合文件")
async def delete_file(
ids: IdList = Depends(),
auth: Auth = Depends(AllUserAuth())):
await crud.ProjectDetectFileDal(auth.db).delete_files(ids=ids.ids)
return SuccessResponse("删除成功")
@app.post("/start", summary="开始推理")
async def run_detect_yolo(
detect_log_in: schemas.ProjectDetectLogIn,
auth: Auth = Depends(AllUserAuth()),
rd: Redis = Depends(redis_getter)):
detect_dal = crud.ProjectDetectDal(auth.db)
train_dal = ProjectTrainDal(auth.db)
detect = await detect_dal.get_data(detect_log_in.detect_id)
if detect is None:
return ErrorResponse(msg="训练集合不存在")
train = await train_dal.get_data(detect_log_in.train_id)
if train is None:
return ErrorResponse("训练权重不存在")
file_count = await crud.ProjectDetectFileDal(auth.db).file_count(detect_log_in.detect_id)
if file_count == 0 and detect.rtsp_url is None:
return ErrorResponse("推理集合中没有内容,请先到推理集合中上传图片")
is_gpu = await rd.get('is_gpu')
if detect.file_type == 'img' or detect.file_type == 'video':
detect_log = await service.before_detect(detect_log_in, detect, train, auth.db, auth.user.id)
thread_train = threading.Thread(target=service.run_img_loop,
args=(detect_log.pt_url, detect_log.folder_url,
detect_log.detect_folder_url, detect_log.detect_version,
detect_log.detect_id, is_gpu))
thread_train.start()
await service.update_sql(
auth.db, detect_log.detect_id,
detect_log.id, detect_log.detect_folder_url,
detect_log.detect_version)
elif detect.file_type == 'rtsp':
room = 'detect_rtsp_' + str(detect.id)
if not room_manager.rooms.get(room):
if detect_log_in.pt_type == 'best':
weights_pt = train.best_pt
else:
weights_pt = train.last_pt
thread_train = threading.Thread(target=service.run_rtsp_loop,
args=(weights_pt, detect.rtsp_url, train.train_data, detect.id, is_gpu,))
thread_train.start()
return SuccessResponse(msg="执行成功")
###########################################################
# 项目推理记录信息
###########################################################
@app.get("/logs", summary="获取推理记录列表")
async def logs(
p: params.ProjectDetectLogParams = Depends(),
auth: Auth = Depends(AllUserAuth())):
datas, count = await crud.ProjectDetectLogDal(auth.db).get_datas(**p.dict(), v_return_count=True)
return SuccessResponse(datas, count=count)
@app.get("/logs/download/{log_id}", summary="下载推理结果")
async def logs_download(
log_id: int,
background_tasks: BackgroundTasks,
auth: Auth = Depends(AllUserAuth())
):
# 获取日志记录
detect_log = await crud.ProjectDetectLogDal(auth.db).get_data(data_id=log_id)
if detect_log is None:
return ErrorResponse(msg="推理结果查询错误,请刷新页面再试")
# 检查源文件夹是否存在
folder_path = os.path.join(detect_log.detect_folder_url, detect_log.detect_version)
if not os.path.exists(folder_path):
return ErrorResponse(msg="推理结果文件夹不存在")
zip_filename = f"{detect_log.detect_name}_{detect_log.detect_version}.zip"
zip_file_path = osu.zip_folder(folder_path, zip_filename)
# 后台任务删掉这个
background_tasks.add_task(cleanup_temp_dir, zip_file_path)
# 直接返回文件响应
return FileResponse(path=zip_file_path, filename=os.path.basename(zip_file_path))
def cleanup_temp_dir(temp_dir: str):
"""清理临时目录的后台任务"""
try:
shutil.rmtree(temp_dir, ignore_errors=True)
except Exception as e:
print(f"清理临时目录失败: {e}")
@app.get("/logs/files", summary="获取项目推理记录文件列表")
async def logs_files(
p: params.ProjectDetectLogFileParams = Depends(),
auth: Auth = Depends(AllUserAuth())):
datas = await crud.ProjectDetectLogFileDal(auth.db).get_datas(**p.dict(), v_return_count=False)
return SuccessResponse(datas)