200 lines
7.9 KiB
Python
200 lines
7.9 KiB
Python
#!/usr/bin/python
|
|
# -*- coding: utf-8 -*-
|
|
# @version : 1.0
|
|
# @Create Time : 2025/04/03 10:30
|
|
# @File : views.py
|
|
# @IDE : PyCharm
|
|
# @desc : 路由,视图文件
|
|
|
|
from utils import os_utils as osu
|
|
from core.dependencies import IdList
|
|
from core.database import redis_getter
|
|
from . import schemas, crud, params, service, models
|
|
from utils.websocket_server import room_manager
|
|
from apps.business.train.crud import ProjectTrainDal
|
|
from apps.vadmin.auth.utils.current import AllUserAuth
|
|
from apps.vadmin.auth.utils.validation.auth import Auth
|
|
from utils.response import SuccessResponse, ErrorResponse
|
|
|
|
import os
|
|
import shutil
|
|
import zipfile
|
|
import tempfile
|
|
import threading
|
|
from redis.asyncio import Redis
|
|
from fastapi.responses import FileResponse
|
|
from fastapi import Depends, APIRouter, Form, UploadFile, BackgroundTasks
|
|
|
|
|
|
app = APIRouter()
|
|
|
|
|
|
###########################################################
|
|
# 项目推理集合信息
|
|
###########################################################
|
|
@app.get("/list/{proj_id}", summary="获取推理集合列表")
|
|
async def detect_list(
|
|
proj_id: int,
|
|
auth: Auth = Depends(AllUserAuth())):
|
|
datas = await crud.ProjectDetectDal(auth.db).get_datas(
|
|
limit=0,
|
|
v_where=[models.ProjectDetect.project_id == proj_id],
|
|
v_order="asc",
|
|
v_order_field="id",
|
|
v_return_count=False)
|
|
return SuccessResponse(datas)
|
|
|
|
|
|
@app.post("/", summary="创建推理集合")
|
|
async def add_detect(
|
|
data: schemas.ProjectDetectIn,
|
|
auth: Auth = Depends(AllUserAuth())):
|
|
detect_dal = crud.ProjectDetectDal(auth.db)
|
|
if await detect_dal.check_name(data.detect_name, data.project_id):
|
|
return ErrorResponse(msg="该项目中存在相同名称的集合")
|
|
await detect_dal.add_detect(data=data, user_id=auth.user.id)
|
|
return SuccessResponse(msg="保存成功")
|
|
|
|
|
|
@app.delete("/", summary="删除项目推理集合信息")
|
|
async def delete_detect(
|
|
ids: IdList = Depends(),
|
|
auth: Auth = Depends(AllUserAuth())):
|
|
await crud.ProjectDetectDal(auth.db).delete_detects(ids=ids.ids)
|
|
return SuccessResponse("删除成功")
|
|
|
|
|
|
###########################################################
|
|
# 项目推理集合文件信息
|
|
###########################################################
|
|
@app.get("/files", summary="获取推理集合文件列表")
|
|
async def file_list(
|
|
p: params.ProjectDetectFileParams = Depends(),
|
|
auth: Auth = Depends(AllUserAuth())):
|
|
if p.limit > 0:
|
|
datas, count = await crud.ProjectDetectFileDal(auth.db).get_datas(**p.dict(), v_return_count=True)
|
|
return SuccessResponse(datas, count=count)
|
|
else:
|
|
datas = await crud.ProjectDetectFileDal(auth.db).get_datas(**p.dict(), v_return_count=False)
|
|
return SuccessResponse(datas)
|
|
|
|
|
|
@app.post("/files", summary="上传项目推理集合文件")
|
|
async def upload_file(
|
|
detect_id: int = Form(...),
|
|
files: list[UploadFile] = Form(...),
|
|
auth: Auth = Depends(AllUserAuth())):
|
|
detect_dal = crud.ProjectDetectDal(auth.db)
|
|
file_dal = crud.ProjectDetectFileDal(auth.db)
|
|
detect_info = await detect_dal.get_data(data_id=detect_id)
|
|
if detect_info is None:
|
|
return ErrorResponse("训练集合查询失败,请刷新后再试")
|
|
if detect_info.file_type != 'rtsp':
|
|
if not any(osu.is_extensions(detect_info.file_type, file.filename) for file in files):
|
|
return ErrorResponse("上传的文件中存在不符合的文件类型")
|
|
await file_dal.add_files(detect_info, files, auth.user.id)
|
|
return SuccessResponse(msg="上传成功")
|
|
|
|
|
|
@app.delete("/files", summary="删除推理集合文件")
|
|
async def delete_file(
|
|
ids: IdList = Depends(),
|
|
auth: Auth = Depends(AllUserAuth())):
|
|
await crud.ProjectDetectFileDal(auth.db).delete_files(ids=ids.ids)
|
|
return SuccessResponse("删除成功")
|
|
|
|
|
|
@app.post("/start", summary="开始推理")
|
|
async def run_detect_yolo(
|
|
detect_log_in: schemas.ProjectDetectLogIn,
|
|
auth: Auth = Depends(AllUserAuth()),
|
|
rd: Redis = Depends(redis_getter)):
|
|
detect_dal = crud.ProjectDetectDal(auth.db)
|
|
train_dal = ProjectTrainDal(auth.db)
|
|
detect = await detect_dal.get_data(detect_log_in.detect_id)
|
|
if detect is None:
|
|
return ErrorResponse(msg="训练集合不存在")
|
|
train = await train_dal.get_data(detect_log_in.train_id)
|
|
if train is None:
|
|
return ErrorResponse("训练权重不存在")
|
|
file_count = await crud.ProjectDetectFileDal(auth.db).file_count(detect_log_in.detect_id)
|
|
if file_count == 0 and detect.rtsp_url is None:
|
|
return ErrorResponse("推理集合中没有内容,请先到推理集合中上传图片")
|
|
is_gpu = await rd.get('is_gpu')
|
|
if detect.file_type == 'img' or detect.file_type == 'video':
|
|
detect_log = await service.before_detect(detect_log_in, detect, train, auth.db, auth.user.id)
|
|
thread_train = threading.Thread(target=service.run_img_loop,
|
|
args=(detect_log.pt_url, detect_log.folder_url,
|
|
detect_log.detect_folder_url, detect_log.detect_version,
|
|
detect_log.detect_id, is_gpu))
|
|
thread_train.start()
|
|
await service.update_sql(
|
|
auth.db, detect_log.detect_id,
|
|
detect_log.id, detect_log.detect_folder_url,
|
|
detect_log.detect_version)
|
|
elif detect.file_type == 'rtsp':
|
|
room = 'detect_rtsp_' + str(detect.id)
|
|
if not room_manager.rooms.get(room):
|
|
if detect_log_in.pt_type == 'best':
|
|
weights_pt = train.best_pt
|
|
else:
|
|
weights_pt = train.last_pt
|
|
thread_train = threading.Thread(target=service.run_rtsp_loop,
|
|
args=(weights_pt, detect.rtsp_url, train.train_data, detect.id, is_gpu,))
|
|
thread_train.start()
|
|
return SuccessResponse(msg="执行成功")
|
|
|
|
|
|
###########################################################
|
|
# 项目推理记录信息
|
|
###########################################################
|
|
@app.get("/logs", summary="获取推理记录列表")
|
|
async def logs(
|
|
p: params.ProjectDetectLogParams = Depends(),
|
|
auth: Auth = Depends(AllUserAuth())):
|
|
datas, count = await crud.ProjectDetectLogDal(auth.db).get_datas(**p.dict(), v_return_count=True)
|
|
return SuccessResponse(datas, count=count)
|
|
|
|
|
|
@app.get("/logs/download/{log_id}", summary="下载推理结果")
|
|
async def logs_download(
|
|
log_id: int,
|
|
background_tasks: BackgroundTasks,
|
|
auth: Auth = Depends(AllUserAuth())
|
|
):
|
|
# 获取日志记录
|
|
detect_log = await crud.ProjectDetectLogDal(auth.db).get_data(data_id=log_id)
|
|
if detect_log is None:
|
|
return ErrorResponse(msg="推理结果查询错误,请刷新页面再试")
|
|
|
|
# 检查源文件夹是否存在
|
|
folder_path = os.path.join(detect_log.detect_folder_url, detect_log.detect_version)
|
|
if not os.path.exists(folder_path):
|
|
return ErrorResponse(msg="推理结果文件夹不存在")
|
|
|
|
zip_filename = f"{detect_log.detect_name}_{detect_log.detect_version}.zip"
|
|
zip_file_path = osu.zip_folder(folder_path, zip_filename)
|
|
|
|
# 后台任务删掉这个
|
|
background_tasks.add_task(cleanup_temp_dir, zip_file_path)
|
|
|
|
# 直接返回文件响应
|
|
return FileResponse(path=zip_file_path, filename=os.path.basename(zip_file_path))
|
|
|
|
|
|
def cleanup_temp_dir(temp_dir: str):
|
|
"""清理临时目录的后台任务"""
|
|
try:
|
|
shutil.rmtree(temp_dir, ignore_errors=True)
|
|
except Exception as e:
|
|
print(f"清理临时目录失败: {e}")
|
|
|
|
|
|
@app.get("/logs/files", summary="获取项目推理记录文件列表")
|
|
async def logs_files(
|
|
p: params.ProjectDetectLogFileParams = Depends(),
|
|
auth: Auth = Depends(AllUserAuth())):
|
|
datas = await crud.ProjectDetectLogFileDal(auth.db).get_datas(**p.dict(), v_return_count=False)
|
|
return SuccessResponse(datas)
|
|
|