Files
aicheckv2-api/algo/deep_sort.py
2025-06-11 16:19:38 +08:00

83 lines
3.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import time
import asyncio
from ultralytics import YOLO
from deep_sort_realtime.deepsort_tracker import DeepSort
from utils.websocket_server import room_manager
class DeepSortModel:
def __init__(self, pt_url):
self.model = YOLO(pt_url)
self.tracker = DeepSort()
def sort_video(self, url, detect_id: int, idx_to_class: {}):
"""
对文件夹中的视频或rtsp的视频流进行目标追踪
"""
room_name = 'deep_sort_' + str(detect_id)
room_count = 'deep_sort_count_' + str(detect_id)
count_result = {}
for key in idx_to_class.keys():
count_result[key] = set()
cap = cv2.VideoCapture(url)
start_time = time.time()
while cap.isOpened():
# 检查是否已经超过10分钟600秒
elapsed_time = time.time() - start_time
if elapsed_time > 600: # 600 seconds = 10 minutes
break
ret, frame = cap.read()
if not ret:
break
# YOLO 推理GPU 加速)
results = self.model(frame, device=0, conf=0.6, verbose=False)
# 获取检测框数据
detections = results[0].boxes.data.cpu().numpy()
# DeepSORT 格式转换:[(bbox_x, y, w, h), confidence, class_id]
tracker_inputs = []
for det in detections:
x1, y1, x2, y2, conf, cls = det
bbox = [x1, y1, x2 - x1, y2 - y1] # (x, y, w, h)
tracker_inputs.append((bbox, conf, int(cls)))
# 更新跟踪器
tracks = self.tracker.update_tracks(tracker_inputs, frame=frame)
# 获取所有被确认过的追踪目标
active_tracks = []
# 绘制跟踪结果
for track in tracks:
if not track.is_confirmed():
active_tracks.append(track)
continue
track_id = track.track_id
track_cls = str(track.det_class)
ltrb = track.to_ltrb()
x1, y1, x2, y2 = map(int, ltrb)
# 绘制矩形框和ID标签
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(
frame,
f"{idx_to_class[track_cls]} {track_id}",
(x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
(0, 255, 0),
2,
)
for tark in active_tracks:
class_id = str(tark.det_class)
count_result[class_id].add(tark.track_id)
# 对应每个label进行计数
result = {}
for key in count_result.keys():
result[idx_to_class[key]] = len(count_result[key])
# 将帧编码为 JPEG
ret, jpeg = cv2.imencode('.jpg', frame)
if ret:
jpeg_bytes = jpeg.tobytes()
asyncio.run(room_manager.send_stream_to_room(room_name, jpeg_bytes))
asyncio.run(room_manager.send_to_room(room_count, str(result)))