83 lines
3.1 KiB
Python
83 lines
3.1 KiB
Python
import cv2
|
||
import time
|
||
import asyncio
|
||
from ultralytics import YOLO
|
||
from deep_sort_realtime.deepsort_tracker import DeepSort
|
||
|
||
from utils.websocket_server import room_manager
|
||
|
||
|
||
class DeepSortModel:
|
||
|
||
def __init__(self, pt_url):
|
||
self.model = YOLO(pt_url)
|
||
self.tracker = DeepSort()
|
||
|
||
def sort_video(self, url, detect_id: int, idx_to_class: {}):
|
||
"""
|
||
对文件夹中的视频或rtsp的视频流,进行目标追踪
|
||
"""
|
||
room_name = 'deep_sort_' + str(detect_id)
|
||
room_count = 'deep_sort_count_' + str(detect_id)
|
||
count_result = {}
|
||
for key in idx_to_class.keys():
|
||
count_result[key] = set()
|
||
cap = cv2.VideoCapture(url)
|
||
start_time = time.time()
|
||
while cap.isOpened():
|
||
# 检查是否已经超过10分钟(600秒)
|
||
elapsed_time = time.time() - start_time
|
||
if elapsed_time > 600: # 600 seconds = 10 minutes
|
||
break
|
||
ret, frame = cap.read()
|
||
if not ret:
|
||
break
|
||
# YOLO 推理(GPU 加速)
|
||
results = self.model(frame, device=0, conf=0.6, verbose=False)
|
||
# 获取检测框数据
|
||
detections = results[0].boxes.data.cpu().numpy()
|
||
# DeepSORT 格式转换:[(bbox_x, y, w, h), confidence, class_id]
|
||
tracker_inputs = []
|
||
for det in detections:
|
||
x1, y1, x2, y2, conf, cls = det
|
||
bbox = [x1, y1, x2 - x1, y2 - y1] # (x, y, w, h)
|
||
tracker_inputs.append((bbox, conf, int(cls)))
|
||
# 更新跟踪器
|
||
tracks = self.tracker.update_tracks(tracker_inputs, frame=frame)
|
||
# 获取所有被确认过的追踪目标
|
||
active_tracks = []
|
||
# 绘制跟踪结果
|
||
for track in tracks:
|
||
if not track.is_confirmed():
|
||
active_tracks.append(track)
|
||
continue
|
||
track_id = track.track_id
|
||
track_cls = str(track.det_class)
|
||
ltrb = track.to_ltrb()
|
||
x1, y1, x2, y2 = map(int, ltrb)
|
||
# 绘制矩形框和ID标签
|
||
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||
cv2.putText(
|
||
frame,
|
||
f"{idx_to_class[track_cls]} {track_id}",
|
||
(x1, y1 - 10),
|
||
cv2.FONT_HERSHEY_SIMPLEX,
|
||
0.5,
|
||
(0, 255, 0),
|
||
2,
|
||
)
|
||
for tark in active_tracks:
|
||
class_id = str(tark.det_class)
|
||
count_result[class_id].add(tark.track_id)
|
||
# 对应每个label进行计数
|
||
result = {}
|
||
for key in count_result.keys():
|
||
result[idx_to_class[key]] = len(count_result[key])
|
||
# 将帧编码为 JPEG
|
||
ret, jpeg = cv2.imencode('.jpg', frame)
|
||
if ret:
|
||
jpeg_bytes = jpeg.tobytes()
|
||
asyncio.run(room_manager.send_stream_to_room(room_name, jpeg_bytes))
|
||
asyncio.run(room_manager.send_to_room(room_count, str(result)))
|
||
|