Files
aicheckv2-api/apps/business/deepsort/service.py
2025-06-09 15:27:45 +08:00

132 lines
4.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from utils.websocket_server import room_manager
import time
import torch
from deep_sort.deep_sort import DeepSort
from deep_sort.utils.draw import draw_boxes
def yolov5_to_deepsort_format(pred):
"""
将YOLOv5的预测结果转换为Deep SORT所需的格式
:param pred: YOLOv5的预测结果
:return: 转换后的bbox_xywh, confs, class_ids
"""
# pred[:, :4] = xyxy2xywh(pred[:, :4])
# xywh = pred[:, :4].cpu().numpy()
# conf = pred[:, 4].cpu().numpy()
# cls = pred[:, 5].cpu().numpy()
# return xywh, conf, cls
async def run_deepsort(
detect_id: int,
weights_pt: str,
data: str,
idx_to_class: {},
sort_type: str = 'video',
video_path: str = None,
rtsp_url: str = None
):
"""
deep_sort追踪先经过yolov5对目标进行识别
再调用deepsort对目标进行追踪
"""
# room = 'deep_sort_' + str(detect_id)
#
# room_count = 'deep_sort_count_' + str(detect_id)
#
# # 选择设备CPU 或 GPU
# device = select_device('cuda:0')
#
# model = DetectMultiBackend(weights_pt, device=device, dnn=False, data=data, fp16=False)
#
# deepsort = DeepSort(
# model_path="deep_sort/deep/checkpoint/ckpt.t7", # ReID 模型路径
# max_dist=0.2, # 外观特征匹配阈值(越小越严格)
# max_iou_distance=0.7, # 最大IoU距离阈值
# max_age=70, # 目标最大存活帧数(未匹配时保留的帧数)
# n_init=3 # 初始确认帧数连续匹配到n_init次后确认跟踪
# )
# stride, names, pt = model.stride, model.names, model.pt
# img_sz = check_img_size((640, 640), s=stride) # check image size
# if sort_type == 'video':
# dataset = LoadImages(video_path, img_size=img_sz, stride=stride, auto=pt, vid_stride=1)
# else:
# dataset = LoadStreams(rtsp_url, img_size=img_sz, stride=stride, auto=pt, vid_stride=1)
# bs = len(dataset)
#
# count_result = {}
# for key in idx_to_class.keys():
# count_result[key] = set()
#
# model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *img_sz))
#
# time.sleep(3) # 等待3s等待websocket进入
#
# start_time = time.time()
#
# for path, im, im0s, vid_cap, s in dataset:
# # 检查是否已经超过10分钟600秒
# elapsed_time = time.time() - start_time
# if elapsed_time > 600: # 600 seconds = 10 minutes
# print(room, "已达到最大执行时间,结束推理。")
# break
# if room_manager.rooms.get(room):
# im0 = im0s[0]
# im = torch.from_numpy(im).to(model.device)
# im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
# im /= 255 # 0 - 255 to 0.0 - 1.0
# if len(im.shape) == 3:
# im = im[None] # expand for batch dim
# pred = model(im, augment=False, visualize=False)
# # NMS
# pred = non_max_suppression(pred, 0.25, 0.45, None, False, max_det=1000)[0]
#
# pred[:, :4] = scale_coords(im.shape[2:], pred[:, :4], im0.shape).round()
#
# # 使用YOLOv5进行检测后得到的pred
# bbox_xywh, cls_conf, cls_ids = yolov5_to_deepsort_format(pred)
#
# mask = cls_ids == 0
#
# bbox_xywh = bbox_xywh[mask]
# bbox_xywh[:, 2:] *= 1.2
# cls_conf = cls_conf[mask]
# cls_ids = cls_ids[mask]
#
# # 调用Deep SORT更新方法
# outputs, _ = deepsort.update(bbox_xywh, cls_conf, cls_ids, im0)
#
# if len(outputs) > 0:
# bbox_xyxy = outputs[:, :4]
# identities = outputs[:, -1]
# cls = outputs[:, -2]
# names = [idx_to_class[str(label)] for label in cls]
# # 开始画框
# ori_img = draw_boxes(im0, bbox_xyxy, names, identities, None)
#
# # 获取所有被确认过的追踪目标
# active_tracks = [
# track for track in deepsort.tracker.tracks
# if track.is_confirmed()
# ]
#
# for tark in active_tracks:
# class_id = str(tark.cls)
# count_result[class_id].add(tark.track_id)
# # 对应每个label进行计数
# result = {}
# for key in count_result.keys():
# result[idx_to_class[key]] = len(count_result[key])
# # 将帧编码为 JPEG
# ret, jpeg = cv2.imencode('.jpg', ori_img)
# if ret:
# jpeg_bytes = jpeg.tobytes()
# await room_manager.send_stream_to_room(room, jpeg_bytes)
# await room_manager.send_to_room(room_count, str(result))
# else:
# print(room, '结束追踪')
# break