#!/usr/bin/python # -*- coding: utf-8 -*- # @version : 1.0 # @Create Time : 2025/04/03 10:30 # @File : views.py # @IDE : PyCharm # @desc : 路由,视图文件 from utils import os_utils as osu from core.dependencies import IdList from core.database import redis_getter from . import schemas, crud, params, service, models from utils.websocket_server import room_manager from apps.business.train.crud import ProjectTrainDal from apps.vadmin.auth.utils.current import AllUserAuth from apps.vadmin.auth.utils.validation.auth import Auth from utils.response import SuccessResponse, ErrorResponse import os import shutil import zipfile import tempfile import threading from redis.asyncio import Redis from fastapi.responses import FileResponse from fastapi import Depends, APIRouter, Form, UploadFile, BackgroundTasks app = APIRouter() ########################################################### # 项目推理集合信息 ########################################################### @app.get("/list/{proj_id}", summary="获取推理集合列表") async def detect_list( proj_id: int, auth: Auth = Depends(AllUserAuth())): datas = await crud.ProjectDetectDal(auth.db).get_datas( limit=0, v_where=[models.ProjectDetect.project_id == proj_id], v_order="asc", v_order_field="id", v_return_count=False) return SuccessResponse(datas) @app.post("/", summary="创建推理集合") async def add_detect( data: schemas.ProjectDetectIn, auth: Auth = Depends(AllUserAuth())): detect_dal = crud.ProjectDetectDal(auth.db) if await detect_dal.check_name(data.detect_name, data.project_id): return ErrorResponse(msg="该项目中存在相同名称的集合") await detect_dal.add_detect(data=data, user_id=auth.user.id) return SuccessResponse(msg="保存成功") @app.delete("/", summary="删除项目推理集合信息") async def delete_detect( ids: IdList = Depends(), auth: Auth = Depends(AllUserAuth())): await crud.ProjectDetectDal(auth.db).delete_detects(ids=ids.ids) return SuccessResponse("删除成功") ########################################################### # 项目推理集合文件信息 ########################################################### @app.get("/files", summary="获取推理集合文件列表") async def file_list( p: params.ProjectDetectFileParams = Depends(), auth: Auth = Depends(AllUserAuth())): if p.limit > 0: datas, count = await crud.ProjectDetectFileDal(auth.db).get_datas(**p.dict(), v_return_count=True) return SuccessResponse(datas, count=count) else: datas = await crud.ProjectDetectFileDal(auth.db).get_datas(**p.dict(), v_return_count=False) return SuccessResponse(datas) @app.post("/files", summary="上传项目推理集合文件") async def upload_file( detect_id: int = Form(...), files: list[UploadFile] = Form(...), auth: Auth = Depends(AllUserAuth())): detect_dal = crud.ProjectDetectDal(auth.db) file_dal = crud.ProjectDetectFileDal(auth.db) detect_info = await detect_dal.get_data(data_id=detect_id) if detect_info is None: return ErrorResponse("训练集合查询失败,请刷新后再试") if detect_info.file_type != 'rtsp': if not any(osu.is_extensions(detect_info.file_type, file.filename) for file in files): return ErrorResponse("上传的文件中存在不符合的文件类型") await file_dal.add_files(detect_info, files, auth.user.id) return SuccessResponse(msg="上传成功") @app.delete("/files", summary="删除推理集合文件") async def delete_file( ids: IdList = Depends(), auth: Auth = Depends(AllUserAuth())): await crud.ProjectDetectFileDal(auth.db).delete_files(ids=ids.ids) return SuccessResponse("删除成功") @app.post("/start", summary="开始推理") async def run_detect_yolo( detect_log_in: schemas.ProjectDetectLogIn, auth: Auth = Depends(AllUserAuth()), rd: Redis = Depends(redis_getter)): detect_dal = crud.ProjectDetectDal(auth.db) train_dal = ProjectTrainDal(auth.db) detect = await detect_dal.get_data(detect_log_in.detect_id) if detect is None: return ErrorResponse(msg="训练集合不存在") train = await train_dal.get_data(detect_log_in.train_id) if train is None: return ErrorResponse("训练权重不存在") file_count = await crud.ProjectDetectFileDal(auth.db).file_count(detect_log_in.detect_id) if file_count == 0 and detect.rtsp_url is None: return ErrorResponse("推理集合中没有内容,请先到推理集合中上传图片") is_gpu = await rd.get('is_gpu') if detect.file_type == 'img' or detect.file_type == 'video': detect_log = await service.before_detect(detect_log_in, detect, train, auth.db, auth.user.id) thread_train = threading.Thread(target=service.run_img_loop, args=(detect_log.pt_url, detect_log.folder_url, detect_log.detect_folder_url, detect_log.detect_version, detect_log.detect_id, is_gpu)) thread_train.start() await service.update_sql( auth.db, detect_log.detect_id, detect_log.id, detect_log.detect_folder_url, detect_log.detect_version) elif detect.file_type == 'rtsp': room = 'detect_rtsp_' + str(detect.id) if not room_manager.rooms.get(room): if detect_log_in.pt_type == 'best': weights_pt = train.best_pt else: weights_pt = train.last_pt thread_train = threading.Thread(target=service.run_rtsp_loop, args=(weights_pt, detect.rtsp_url, train.train_data, detect.id, is_gpu,)) thread_train.start() return SuccessResponse(msg="执行成功") ########################################################### # 项目推理记录信息 ########################################################### @app.get("/logs", summary="获取推理记录列表") async def logs( p: params.ProjectDetectLogParams = Depends(), auth: Auth = Depends(AllUserAuth())): datas, count = await crud.ProjectDetectLogDal(auth.db).get_datas(**p.dict(), v_return_count=True) return SuccessResponse(datas, count=count) @app.get("/logs/download/{log_id}", summary="下载推理结果") async def logs_download( log_id: int, background_tasks: BackgroundTasks, auth: Auth = Depends(AllUserAuth()) ): # 获取日志记录 detect_log = await crud.ProjectDetectLogDal(auth.db).get_data(data_id=log_id) if detect_log is None: return ErrorResponse(msg="推理结果查询错误,请刷新页面再试") # 检查源文件夹是否存在 folder_path = os.path.join(detect_log.detect_folder_url, detect_log.detect_version) if not os.path.exists(folder_path): return ErrorResponse(msg="推理结果文件夹不存在") zip_filename = f"{detect_log.detect_name}_{detect_log.detect_version}.zip" zip_file_path = osu.zip_folder(folder_path, zip_filename) # 后台任务删掉这个 background_tasks.add_task(cleanup_temp_dir, zip_file_path) # 直接返回文件响应 return FileResponse(path=zip_file_path, filename=os.path.basename(zip_file_path)) def cleanup_temp_dir(temp_dir: str): """清理临时目录的后台任务""" try: shutil.rmtree(temp_dir, ignore_errors=True) except Exception as e: print(f"清理临时目录失败: {e}") @app.get("/logs/files", summary="获取项目推理记录文件列表") async def logs_files( p: params.ProjectDetectLogFileParams = Depends(), auth: Auth = Depends(AllUserAuth())): datas = await crud.ProjectDetectLogFileDal(auth.db).get_datas(**p.dict(), v_return_count=False) return SuccessResponse(datas)