提高版本到yolov11

This commit is contained in:
2025-06-09 15:27:45 +08:00
parent 86c6669593
commit 9e99b08d13
223 changed files with 333 additions and 43191 deletions

View File

@ -8,7 +8,6 @@
from utils import os_utils as osu
from core.dependencies import IdList
from core.database import redis_getter
from utils.websocket_server import room_manager
from . import schemas, crud, params, service, models
from apps.business.train.crud import ProjectTrainDal
@ -20,7 +19,6 @@ from utils.response import SuccessResponse, ErrorResponse
import os
import shutil
import threading
from redis.asyncio import Redis
from fastapi.responses import FileResponse
from fastapi import Depends, APIRouter, Form, UploadFile, BackgroundTasks
@ -106,8 +104,7 @@ async def delete_file(
@app.post("/start", summary="开始推理")
async def run_detect_yolo(
detect_log_in: schemas.ProjectDetectLogIn,
auth: Auth = Depends(AllUserAuth()),
rd: Redis = Depends(redis_getter)):
auth: Auth = Depends(AllUserAuth())):
detect_dal = crud.ProjectDetectDal(auth.db)
train_dal = ProjectTrainDal(auth.db)
detect = await detect_dal.get_data(detect_log_in.detect_id)
@ -119,21 +116,29 @@ async def run_detect_yolo(
file_count = await crud.ProjectDetectFileDal(auth.db).file_count(detect_log_in.detect_id)
if file_count == 0 and detect.rtsp_url is None:
return ErrorResponse("推理集合中没有文件,请先到推理集合中上传推理内容")
is_gpu = await rd.get('is_gpu')
# 判断一下是单纯的推理项目还是跟踪项目
project_info = await ProjectInfoDal(auth.db).get_data(data_id=detect.project_id)
if project_info.type_code == 'yolo':
if detect.file_type == 'img' or detect.file_type == 'video':
detect_log = await service.before_detect(detect_log_in, detect, train, auth.db, auth.user.id)
thread_train = threading.Thread(target=service.run_img_loop,
args=(detect_log.pt_url, detect_log.folder_url,
detect_log.detect_folder_url, detect_log.detect_version,
detect_log.detect_id, is_gpu))
thread_train = threading.Thread(
target=service.run_detect_folder,
args=(
detect_log.pt_url,
detect_log.folder_url,
detect_log.detect_folder_url,
detect_log.detect_version
)
)
thread_train.start()
await service.update_sql(
auth.db, detect_log.detect_id,
detect_log.id, detect_log.detect_folder_url,
detect_log.detect_version)
auth.db,
detect_log.detect_id,
detect_log.id,
detect_log.detect_folder_url,
detect_log.detect_version
)
return SuccessResponse(msg="执行成功", data=file_count)
elif detect.file_type == 'rtsp':
room = 'detect_rtsp_' + str(detect.id)
if not room_manager.rooms.get(room):
@ -141,9 +146,18 @@ async def run_detect_yolo(
weights_pt = train.best_pt
else:
weights_pt = train.last_pt
thread_train = threading.Thread(target=service.run_rtsp_loop,
args=(weights_pt, detect.rtsp_url, train.train_data, detect.id, is_gpu,))
thread_train = threading.Thread(
target=service.run_rtsp_loop,
args=(
weights_pt,
detect.rtsp_url,
train.train_data,
detect.id,
None
)
)
thread_train.start()
return SuccessResponse(msg="执行成功")
elif project_info.type_code == 'deepsort':
room = 'deep_sort_' + str(detect.id)
if not room_manager.rooms.get(room):
@ -166,7 +180,7 @@ async def run_detect_yolo(
args=(detect.id, train.best_pt, train.train_data,
idx_to_class, 'video', detect.folder_url, None))
threading_main.start()
return SuccessResponse(msg="执行成功")
return SuccessResponse(msg="执行成功")
###########################################################