deep_sort 提升版本

This commit is contained in:
2025-06-11 16:19:38 +08:00
parent 9e99b08d13
commit ca245e4cec
47 changed files with 114 additions and 3644 deletions

View File

@ -1,131 +1,15 @@
from utils.websocket_server import room_manager
import time
import torch
from deep_sort.deep_sort import DeepSort
from deep_sort.utils.draw import draw_boxes
from algo import DeepSortModel
def yolov5_to_deepsort_format(pred):
"""
将YOLOv5的预测结果转换为Deep SORT所需的格式
:param pred: YOLOv5的预测结果
:return: 转换后的bbox_xywh, confs, class_ids
"""
# pred[:, :4] = xyxy2xywh(pred[:, :4])
# xywh = pred[:, :4].cpu().numpy()
# conf = pred[:, 4].cpu().numpy()
# cls = pred[:, 5].cpu().numpy()
# return xywh, conf, cls
async def run_deepsort(
def run_deepsort(
detect_id: int,
weights_pt: str,
data: str,
idx_to_class: {},
sort_type: str = 'video',
video_path: str = None,
rtsp_url: str = None
url
):
"""
deep_sort追踪先经过yolov5对目标进行识别,
deep_sort追踪先经过yolo对目标进行识别
再调用deepsort对目标进行追踪
"""
# room = 'deep_sort_' + str(detect_id)
#
# room_count = 'deep_sort_count_' + str(detect_id)
#
# # 选择设备CPU 或 GPU
# device = select_device('cuda:0')
#
# model = DetectMultiBackend(weights_pt, device=device, dnn=False, data=data, fp16=False)
#
# deepsort = DeepSort(
# model_path="deep_sort/deep/checkpoint/ckpt.t7", # ReID 模型路径
# max_dist=0.2, # 外观特征匹配阈值(越小越严格)
# max_iou_distance=0.7, # 最大IoU距离阈值
# max_age=70, # 目标最大存活帧数(未匹配时保留的帧数)
# n_init=3 # 初始确认帧数连续匹配到n_init次后确认跟踪
# )
# stride, names, pt = model.stride, model.names, model.pt
# img_sz = check_img_size((640, 640), s=stride) # check image size
# if sort_type == 'video':
# dataset = LoadImages(video_path, img_size=img_sz, stride=stride, auto=pt, vid_stride=1)
# else:
# dataset = LoadStreams(rtsp_url, img_size=img_sz, stride=stride, auto=pt, vid_stride=1)
# bs = len(dataset)
#
# count_result = {}
# for key in idx_to_class.keys():
# count_result[key] = set()
#
# model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *img_sz))
#
# time.sleep(3) # 等待3s等待websocket进入
#
# start_time = time.time()
#
# for path, im, im0s, vid_cap, s in dataset:
# # 检查是否已经超过10分钟600秒
# elapsed_time = time.time() - start_time
# if elapsed_time > 600: # 600 seconds = 10 minutes
# print(room, "已达到最大执行时间,结束推理。")
# break
# if room_manager.rooms.get(room):
# im0 = im0s[0]
# im = torch.from_numpy(im).to(model.device)
# im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
# im /= 255 # 0 - 255 to 0.0 - 1.0
# if len(im.shape) == 3:
# im = im[None] # expand for batch dim
# pred = model(im, augment=False, visualize=False)
# # NMS
# pred = non_max_suppression(pred, 0.25, 0.45, None, False, max_det=1000)[0]
#
# pred[:, :4] = scale_coords(im.shape[2:], pred[:, :4], im0.shape).round()
#
# # 使用YOLOv5进行检测后得到的pred
# bbox_xywh, cls_conf, cls_ids = yolov5_to_deepsort_format(pred)
#
# mask = cls_ids == 0
#
# bbox_xywh = bbox_xywh[mask]
# bbox_xywh[:, 2:] *= 1.2
# cls_conf = cls_conf[mask]
# cls_ids = cls_ids[mask]
#
# # 调用Deep SORT更新方法
# outputs, _ = deepsort.update(bbox_xywh, cls_conf, cls_ids, im0)
#
# if len(outputs) > 0:
# bbox_xyxy = outputs[:, :4]
# identities = outputs[:, -1]
# cls = outputs[:, -2]
# names = [idx_to_class[str(label)] for label in cls]
# # 开始画框
# ori_img = draw_boxes(im0, bbox_xyxy, names, identities, None)
#
# # 获取所有被确认过的追踪目标
# active_tracks = [
# track for track in deepsort.tracker.tracks
# if track.is_confirmed()
# ]
#
# for tark in active_tracks:
# class_id = str(tark.cls)
# count_result[class_id].add(tark.track_id)
# # 对应每个label进行计数
# result = {}
# for key in count_result.keys():
# result[idx_to_class[key]] = len(count_result[key])
# # 将帧编码为 JPEG
# ret, jpeg = cv2.imencode('.jpg', ori_img)
# if ret:
# jpeg_bytes = jpeg.tobytes()
# await room_manager.send_stream_to_room(room, jpeg_bytes)
# await room_manager.send_to_room(room_count, str(result))
# else:
# print(room, '结束追踪')
# break
deep_sort = DeepSortModel(weights_pt)
deep_sort.sort_video(url=url, detect_id=detect_id, idx_to_class=idx_to_class)

View File

@ -51,21 +51,6 @@ async def before_detect(
return detect_log
def run_img_loop(
weights: str,
source: str,
project: str,
name: str,
detect_id: int,
is_gpu: str):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 运行异步函数
loop.run_until_complete(run_detect_folder(weights, source, project, name, detect_id, is_gpu))
# 可选: 关闭循环
loop.close()
def run_detect_folder(
weights: str,
source: str,
@ -111,121 +96,13 @@ async def update_sql(db: AsyncSession, detect_id: int, log_id: int, project, nam
await crud.ProjectDetectLogFileDal(db).create_models(detect_log_files)
async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id: int, is_gpu: str):
def run_detect_rtsp(weights_pt: str, rtsp_url: str, room_name: str):
"""
rtsp 视频流推理
:param detect_id: 训练集的id
:param room_name: websocket链接名称
:param weights_pt: 权重文件
:param rtsp_url: 视频流地址
:param data: yaml文件
:param is_gpu: 是否启用加速
:return:
"""
# room = 'detect_rtsp_' + str(detect_id)
# # 选择设备CPU 或 GPU
# device = select_device('cpu')
# # 判断是否存在cuda版本
# if is_gpu == 'True':
# device = select_device('cuda:0')
#
# # 加载模型
# model = DetectMultiBackend(weights_pt, device=device, dnn=False, data=data, fp16=False)
#
# stride, names, pt = model.stride, model.names, model.pt
# img_sz = check_img_size((640, 640), s=stride) # check image size
#
# dataset = LoadStreams(rtsp_url, img_size=img_sz, stride=stride, auto=pt, vid_stride=1)
# bs = len(dataset)
#
# model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *img_sz))
#
# time.sleep(3) # 等待3s等待websocket进入
#
# start_time = time.time()
#
# for path, im, im0s, vid_cap, s in dataset:
# # 检查是否已经超过10分钟600秒
# elapsed_time = time.time() - start_time
# if elapsed_time > 600: # 600 seconds = 10 minutes
# print(room, "已达到最大执行时间,结束推理。")
# break
# if room_manager.rooms.get(room):
# im = torch.from_numpy(im).to(model.device)
# im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
# im /= 255 # 0 - 255 to 0.0 - 1.0
# if len(im.shape) == 3:
# im = im[None] # expand for batch dim
#
# # Inference
# pred = model(im, augment=False, visualize=False)
# # NMS
# pred = non_max_suppression(pred, 0.25, 0.45, None, False, max_det=1000)
#
# # Process predictions
# for i, det in enumerate(pred): # per image
# p, im0, frame = path[i], im0s[i].copy(), dataset.count
# annotator = Annotator(im0, line_width=3, example=str(names))
# if len(det):
# # Rescale boxes from img_size to im0 size
# det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
#
# # Write results
# for *xyxy, conf, cls in reversed(det):
# c = int(cls) # integer class
# label = None if False else (names[c] if False else f"{names[c]} {conf:.2f}")
# annotator.box_label(xyxy, label, color=colors(c, True))
#
# # Stream results
# im0 = annotator.result()
# # 将帧编码为 JPEG
# ret, jpeg = cv2.imencode('.jpg', im0)
# if ret:
# frame_data = jpeg.tobytes()
# await room_manager.send_stream_to_room(room, frame_data)
# else:
# print(room, '结束推理')
# break
def run_rtsp_loop(weights_pt: str, rtsp_url: str, data: str, detect_id: int, is_gpu: str):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 运行异步函数
loop.run_until_complete(
run_detect_rtsp(
weights_pt,
rtsp_url,
data,
detect_id,
is_gpu
)
)
# 可选: 关闭循环
loop.close()
def run_deepsort_loop(
detect_id: int,
weights_pt: str,
data: str,
idx_to_class: {},
sort_type: str = 'video',
video_path: str = None,
rtsp_url: str = None
):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 运行异步函数
loop.run_until_complete(
deepsort_service.run_deepsort(
detect_id,
weights_pt,
data,
idx_to_class,
sort_type,
video_path,
rtsp_url
)
)
# 可选: 关闭循环
loop.close()
model = YoloModel(weights_pt)
model.predict_rtsp(rtsp_url, room_name)

View File

@ -11,10 +11,11 @@ from core.dependencies import IdList
from utils.websocket_server import room_manager
from . import schemas, crud, params, service, models
from apps.business.train.crud import ProjectTrainDal
from apps.business.project.crud import ProjectInfoDal, ProjectLabelDal
from apps.vadmin.auth.utils.current import AllUserAuth
from apps.vadmin.auth.utils.validation.auth import Auth
from utils.response import SuccessResponse, ErrorResponse
from apps.business.deepsort import service as deep_sort_service
from apps.business.project.crud import ProjectInfoDal, ProjectLabelDal
import os
import shutil
@ -147,13 +148,11 @@ async def run_detect_yolo(
else:
weights_pt = train.last_pt
thread_train = threading.Thread(
target=service.run_rtsp_loop,
target=service.run_detect_rtsp,
args=(
weights_pt,
detect.rtsp_url,
train.train_data,
detect.id,
None
room
)
)
thread_train.start()
@ -164,22 +163,16 @@ async def run_detect_yolo(
# 查询项目所属标签,返回两个 idname一一对应的数组
label_id_list, label_name_list = await ProjectLabelDal(auth.db).get_label_for_train(project_info.id)
idx_to_class = {str(i): name for i, name in enumerate(label_name_list)}
# if detect_log_in.pt_type == 'best':
# weights_pt = train.best_pt
# else:
# weights_pt = train.last_pt
if detect.file_type == 'rtsp':
threading_main = threading.Thread(
target=service.run_deepsort_loop,
args=(detect.id, train.best_pt, train.train_data,
idx_to_class, 'rtsp', None, detect.rtsp_url))
threading_main.start()
elif detect.file_type == 'video':
threading_main = threading.Thread(
target=service.run_deepsort_loop,
args=(detect.id, train.best_pt, train.train_data,
idx_to_class, 'video', detect.folder_url, None))
threading_main.start()
threading_main = threading.Thread(
target=deep_sort_service.run_deepsort,
args=(
detect.id,
train.best_pt,
idx_to_class,
detect.rtsp_url
)
)
threading_main.start()
return SuccessResponse(msg="执行成功")