#!/usr/bin/python # -*- coding: utf-8 -*- # @version : 1.0 # @Create Time : 2025/04/03 10:30 # @File : views.py # @IDE : PyCharm # @desc : 路由,视图文件 from utils import os_utils as osu from core.dependencies import IdList from utils.websocket_server import room_manager from . import schemas, crud, params, service, models from apps.business.train.crud import ProjectTrainDal from apps.vadmin.auth.utils.current import AllUserAuth from apps.vadmin.auth.utils.validation.auth import Auth from utils.response import SuccessResponse, ErrorResponse from apps.business.deepsort import service as deep_sort_service from apps.business.project.crud import ProjectInfoDal, ProjectLabelDal import os import shutil import threading from fastapi.responses import FileResponse from fastapi import Depends, APIRouter, Form, UploadFile, BackgroundTasks app = APIRouter() ########################################################### # 项目推理集合信息 ########################################################### @app.get("/list/{proj_id}", summary="获取推理集合列表") async def detect_list( proj_id: int, auth: Auth = Depends(AllUserAuth())): datas = await crud.ProjectDetectDal(auth.db).get_datas( limit=0, v_where=[models.ProjectDetect.project_id == proj_id], v_order="asc", v_order_field="id", v_return_count=False) return SuccessResponse(datas) @app.post("/", summary="创建推理集合") async def add_detect( data: schemas.ProjectDetectIn, auth: Auth = Depends(AllUserAuth())): detect_dal = crud.ProjectDetectDal(auth.db) if await detect_dal.check_name(data.detect_name, data.project_id): return ErrorResponse(msg="该项目中存在相同名称的集合") await detect_dal.add_detect(data=data, user_id=auth.user.id) return SuccessResponse(msg="保存成功") @app.delete("/", summary="删除项目推理集合信息") async def delete_detect( ids: IdList = Depends(), auth: Auth = Depends(AllUserAuth())): await crud.ProjectDetectDal(auth.db).delete_detects(ids=ids.ids) return SuccessResponse("删除成功") ########################################################### # 项目推理集合文件信息 ########################################################### @app.get("/files", summary="获取推理集合文件列表") async def file_list( p: params.ProjectDetectFileParams = Depends(), auth: Auth = Depends(AllUserAuth())): if p.limit > 0: datas, count = await crud.ProjectDetectFileDal(auth.db).get_datas(**p.dict(), v_return_count=True) return SuccessResponse(datas, count=count) else: datas = await crud.ProjectDetectFileDal(auth.db).get_datas(**p.dict(), v_return_count=False) return SuccessResponse(datas) @app.post("/files", summary="上传项目推理集合文件") async def upload_file( detect_id: int = Form(...), files: list[UploadFile] = Form(...), auth: Auth = Depends(AllUserAuth())): detect_dal = crud.ProjectDetectDal(auth.db) file_dal = crud.ProjectDetectFileDal(auth.db) detect_info = await detect_dal.get_data(data_id=detect_id) if detect_info is None: return ErrorResponse("训练集合查询失败,请刷新后再试") if detect_info.file_type != 'rtsp': if not any(osu.is_extensions(detect_info.file_type, file.filename) for file in files): return ErrorResponse("上传的文件中存在不符合的文件类型") await file_dal.add_files(detect_info, files, auth.user.id) return SuccessResponse(msg="上传成功") @app.delete("/files", summary="删除推理集合文件") async def delete_file( ids: IdList = Depends(), auth: Auth = Depends(AllUserAuth())): await crud.ProjectDetectFileDal(auth.db).delete_files(ids=ids.ids) return SuccessResponse("删除成功") @app.post("/start", summary="开始推理") async def run_detect_yolo( detect_log_in: schemas.ProjectDetectLogIn, auth: Auth = Depends(AllUserAuth())): detect_dal = crud.ProjectDetectDal(auth.db) train_dal = ProjectTrainDal(auth.db) detect = await detect_dal.get_data(detect_log_in.detect_id) if detect is None: return ErrorResponse(msg="训练集合不存在") train = await train_dal.get_data(detect_log_in.train_id) if train is None: return ErrorResponse("训练权重不存在") file_count = await crud.ProjectDetectFileDal(auth.db).file_count(detect_log_in.detect_id) if file_count == 0 and detect.rtsp_url is None: return ErrorResponse("推理集合中没有文件,请先到推理集合中上传推理内容") # 判断一下是单纯的推理项目还是跟踪项目 project_info = await ProjectInfoDal(auth.db).get_data(data_id=detect.project_id) if project_info.type_code == 'yolo': if detect.file_type == 'img' or detect.file_type == 'video': detect_log = await service.before_detect(detect_log_in, detect, train, auth.db, auth.user.id) thread_train = threading.Thread( target=service.run_detect_folder, args=( detect_log.pt_url, detect_log.folder_url, detect_log.detect_folder_url, detect_log.detect_version ) ) thread_train.start() await service.update_sql( auth.db, detect_log.detect_id, detect_log.id, detect_log.detect_folder_url, detect_log.detect_version ) return SuccessResponse(msg="执行成功", data=file_count) elif detect.file_type == 'rtsp': room = 'detect_rtsp_' + str(detect.id) if not room_manager.rooms.get(room): if detect_log_in.pt_type == 'best': weights_pt = train.best_pt else: weights_pt = train.last_pt thread_train = threading.Thread( target=service.run_detect_rtsp, args=( weights_pt, detect.rtsp_url, room ) ) thread_train.start() return SuccessResponse(msg="执行成功") elif project_info.type_code == 'deepsort': room = 'deep_sort_' + str(detect.id) if not room_manager.rooms.get(room): # 查询项目所属标签,返回两个 id,name一一对应的数组 label_id_list, label_name_list = await ProjectLabelDal(auth.db).get_label_for_train(project_info.id) idx_to_class = {str(i): name for i, name in enumerate(label_name_list)} threading_main = threading.Thread( target=deep_sort_service.run_deepsort, args=( detect.id, train.best_pt, idx_to_class, detect.rtsp_url ) ) threading_main.start() return SuccessResponse(msg="执行成功") ########################################################### # 项目推理记录信息 ########################################################### @app.get("/logs", summary="获取推理记录列表") async def logs( p: params.ProjectDetectLogParams = Depends(), auth: Auth = Depends(AllUserAuth())): datas, count = await crud.ProjectDetectLogDal(auth.db).get_datas(**p.dict(), v_return_count=True) return SuccessResponse(datas, count=count) @app.get("/logs/download/{log_id}", summary="下载推理结果") async def logs_download( log_id: int, background_tasks: BackgroundTasks, auth: Auth = Depends(AllUserAuth()) ): # 获取日志记录 detect_log = await crud.ProjectDetectLogDal(auth.db).get_data(data_id=log_id) if detect_log is None: return ErrorResponse(msg="推理结果查询错误,请刷新页面再试") # 检查源文件夹是否存在 folder_path = os.path.join(detect_log.detect_folder_url, detect_log.detect_version) if not os.path.exists(folder_path): return ErrorResponse(msg="推理结果尚未生成,请稍后再下载") zip_filename = f"{detect_log.detect_name}_{detect_log.detect_version}.zip" zip_file_path = osu.zip_folder(folder_path, zip_filename) # 后台任务删掉这个 background_tasks.add_task(cleanup_temp_dir, zip_file_path) # 直接返回文件响应 return FileResponse(path=zip_file_path, filename=os.path.basename(zip_file_path)) def cleanup_temp_dir(temp_dir: str): """清理临时目录的后台任务""" try: shutil.rmtree(temp_dir, ignore_errors=True) except Exception as e: print(f"清理临时目录失败: {e}") @app.get("/logs/files", summary="获取项目推理记录文件列表") async def logs_files( p: params.ProjectDetectLogFileParams = Depends(), auth: Auth = Depends(AllUserAuth())): datas = await crud.ProjectDetectLogFileDal(auth.db).get_datas(**p.dict(), v_return_count=False) return SuccessResponse(datas)