完成目标追踪的开发和测试

This commit is contained in:
2025-04-28 09:46:47 +08:00
parent 5b38e91f61
commit da1f5a874e
5 changed files with 410 additions and 49 deletions

View File

@ -1,19 +1,19 @@
from application.settings import yolo_url, detect_url
from utils.websocket_server import room_manager
from utils import os_utils as os
from . import models, crud, schemas
from utils.websocket_server import room_manager
from application.settings import yolo_url, detect_url
from apps.business.train import models as train_models
from utils.yolov5.models.common import DetectMultiBackend
from utils.yolov5.utils.torch_utils import select_device
from utils.yolov5.utils.dataloaders import LoadStreams
from utils.yolov5.utils.general import check_img_size, Profile, non_max_suppression, cv2, scale_boxes
from utils.yolov5.utils.torch_utils import select_device
from ultralytics.utils.plotting import Annotator, colors
from utils.yolov5.models.common import DetectMultiBackend
from apps.business.deepsort import service as deepsort_service
from utils.yolov5.utils.general import check_img_size, Profile, non_max_suppression, cv2, scale_boxes
import time
import torch
import asyncio
import subprocess
from redis.asyncio import Redis
from sqlalchemy.ext.asyncio import AsyncSession
@ -169,16 +169,16 @@ async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id:
model = DetectMultiBackend(weights_pt, device=device, dnn=False, data=data, fp16=False)
stride, names, pt = model.stride, model.names, model.pt
imgsz = check_img_size((640, 640), s=stride) # check image size
img_sz = check_img_size((640, 640), s=stride) # check image size
dataset = LoadStreams(rtsp_url, img_size=imgsz, stride=stride, auto=pt, vid_stride=1)
dataset = LoadStreams(rtsp_url, img_size=img_sz, stride=stride, auto=pt, vid_stride=1)
bs = len(dataset)
model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz))
model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *img_sz))
seen, windows, dt = 0, [], (Profile(device=device), Profile(device=device), Profile(device=device))
dt = (Profile(device=device), Profile(device=device), Profile(device=device))
time.sleep(3) # 等待3s等待websocket进入
time.sleep(3) # 等待2s等待websocket进入
for path, im, im0s, vid_cap, s in dataset:
if room_manager.rooms.get(room):
@ -188,25 +188,13 @@ async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id:
im /= 255 # 0 - 255 to 0.0 - 1.0
if len(im.shape) == 3:
im = im[None] # expand for batch dim
if model.xml and im.shape[0] > 1:
ims = torch.chunk(im, im.shape[0], 0)
# Inference
with dt[1]:
if model.xml and im.shape[0] > 1:
pred = None
for image in ims:
if pred is None:
pred = model(image, augment=False, visualize=False).unsqueeze(0)
else:
pred = torch.cat((pred, model(image, augment=False, visualize=False).unsqueeze(0)),
dim=0)
pred = [pred, None]
else:
pred = model(im, augment=False, visualize=False)
pred = model(im, augment=False, visualize=False)
# NMS
with dt[2]:
pred = non_max_suppression(pred, 0.45, 0.45, None, False, max_det=1000)
pred = non_max_suppression(pred, 0.25, 0.45, None, False, max_det=1000)
# Process predictions
for i, det in enumerate(pred): # per image
@ -238,6 +226,41 @@ def run_rtsp_loop(weights_pt: str, rtsp_url: str, data: str, detect_id: int, is_
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 运行异步函数
loop.run_until_complete(run_detect_rtsp(weights_pt, rtsp_url, data, detect_id, is_gpu))
loop.run_until_complete(
run_detect_rtsp(
weights_pt,
rtsp_url,
data,
detect_id,
is_gpu
)
)
# 可选: 关闭循环
loop.close()
def run_deepsort_loop(
detect_id: int,
weights_pt: str,
data: str,
idx_to_class: {},
sort_type: str = 'video',
video_path: str = None,
rtsp_url: str = None
):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 运行异步函数
loop.run_until_complete(
deepsort_service.run_deepsort(
detect_id,
weights_pt,
data,
idx_to_class,
sort_type,
video_path,
rtsp_url
)
)
# 可选: 关闭循环
loop.close()

View File

@ -9,17 +9,16 @@
from utils import os_utils as osu
from core.dependencies import IdList
from core.database import redis_getter
from . import schemas, crud, params, service, models
from utils.websocket_server import room_manager
from . import schemas, crud, params, service, models
from apps.business.train.crud import ProjectTrainDal
from apps.business.project.crud import ProjectInfoDal, ProjectLabelDal
from apps.vadmin.auth.utils.current import AllUserAuth
from apps.vadmin.auth.utils.validation.auth import Auth
from utils.response import SuccessResponse, ErrorResponse
import os
import shutil
import zipfile
import tempfile
import threading
from redis.asyncio import Redis
from fastapi.responses import FileResponse
@ -121,27 +120,52 @@ async def run_detect_yolo(
if file_count == 0 and detect.rtsp_url is None:
return ErrorResponse("推理集合中没有内容,请先到推理集合中上传图片")
is_gpu = await rd.get('is_gpu')
if detect.file_type == 'img' or detect.file_type == 'video':
detect_log = await service.before_detect(detect_log_in, detect, train, auth.db, auth.user.id)
thread_train = threading.Thread(target=service.run_img_loop,
args=(detect_log.pt_url, detect_log.folder_url,
detect_log.detect_folder_url, detect_log.detect_version,
detect_log.detect_id, is_gpu))
thread_train.start()
await service.update_sql(
auth.db, detect_log.detect_id,
detect_log.id, detect_log.detect_folder_url,
detect_log.detect_version)
elif detect.file_type == 'rtsp':
room = 'detect_rtsp_' + str(detect.id)
if not room_manager.rooms.get(room):
if detect_log_in.pt_type == 'best':
weights_pt = train.best_pt
else:
weights_pt = train.last_pt
thread_train = threading.Thread(target=service.run_rtsp_loop,
args=(weights_pt, detect.rtsp_url, train.train_data, detect.id, is_gpu,))
# 判断一下是单纯的推理项目还是跟踪项目
project_info = await ProjectInfoDal(auth.db).get_data(data_id=detect.project_id)
if project_info.type_code == 'yolo':
if detect.file_type == 'img' or detect.file_type == 'video':
detect_log = await service.before_detect(detect_log_in, detect, train, auth.db, auth.user.id)
thread_train = threading.Thread(target=service.run_img_loop,
args=(detect_log.pt_url, detect_log.folder_url,
detect_log.detect_folder_url, detect_log.detect_version,
detect_log.detect_id, is_gpu))
thread_train.start()
await service.update_sql(
auth.db, detect_log.detect_id,
detect_log.id, detect_log.detect_folder_url,
detect_log.detect_version)
elif detect.file_type == 'rtsp':
room = 'detect_rtsp_' + str(detect.id)
if not room_manager.rooms.get(room):
if detect_log_in.pt_type == 'best':
weights_pt = train.best_pt
else:
weights_pt = train.last_pt
thread_train = threading.Thread(target=service.run_rtsp_loop,
args=(weights_pt, detect.rtsp_url, train.train_data, detect.id, is_gpu,))
thread_train.start()
elif project_info.type_code == 'deepsort':
room = 'deep_sort_' + str(detect.id)
if not room_manager.rooms.get(room):
# 查询项目所属标签,返回两个 idname一一对应的数组
label_id_list, label_name_list = await ProjectLabelDal(auth.db).get_label_for_train(project_info.id)
idx_to_class = {str(i): name for i, name in enumerate(label_name_list)}
# if detect_log_in.pt_type == 'best':
# weights_pt = train.best_pt
# else:
# weights_pt = train.last_pt
if detect.file_type == 'rtsp':
threading_main = threading.Thread(
target=service.run_deepsort_loop,
args=(detect.id, train.best_pt, train.train_data,
idx_to_class, 'rtsp', None, detect.rtsp_url))
threading_main.start()
elif detect.file_type == 'video':
threading_main = threading.Thread(
target=service.run_deepsort_loop,
args=(detect.id, train.best_pt, train.train_data,
idx_to_class, 'video', detect.folder_url, None))
threading_main.start()
return SuccessResponse(msg="执行成功")