deep_sort 提升版本
This commit is contained in:
@ -1 +1,2 @@
|
||||
from .yolov11 import YoloModel
|
||||
from .yolov11 import YoloModel
|
||||
from .deep_sort import DeepSortModel
|
82
algo/deep_sort.py
Normal file
82
algo/deep_sort.py
Normal file
@ -0,0 +1,82 @@
|
||||
import cv2
|
||||
import time
|
||||
import asyncio
|
||||
from ultralytics import YOLO
|
||||
from deep_sort_realtime.deepsort_tracker import DeepSort
|
||||
|
||||
from utils.websocket_server import room_manager
|
||||
|
||||
|
||||
class DeepSortModel:
|
||||
|
||||
def __init__(self, pt_url):
|
||||
self.model = YOLO(pt_url)
|
||||
self.tracker = DeepSort()
|
||||
|
||||
def sort_video(self, url, detect_id: int, idx_to_class: {}):
|
||||
"""
|
||||
对文件夹中的视频或rtsp的视频流,进行目标追踪
|
||||
"""
|
||||
room_name = 'deep_sort_' + str(detect_id)
|
||||
room_count = 'deep_sort_count_' + str(detect_id)
|
||||
count_result = {}
|
||||
for key in idx_to_class.keys():
|
||||
count_result[key] = set()
|
||||
cap = cv2.VideoCapture(url)
|
||||
start_time = time.time()
|
||||
while cap.isOpened():
|
||||
# 检查是否已经超过10分钟(600秒)
|
||||
elapsed_time = time.time() - start_time
|
||||
if elapsed_time > 600: # 600 seconds = 10 minutes
|
||||
break
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
# YOLO 推理(GPU 加速)
|
||||
results = self.model(frame, device=0, conf=0.6, verbose=False)
|
||||
# 获取检测框数据
|
||||
detections = results[0].boxes.data.cpu().numpy()
|
||||
# DeepSORT 格式转换:[(bbox_x, y, w, h), confidence, class_id]
|
||||
tracker_inputs = []
|
||||
for det in detections:
|
||||
x1, y1, x2, y2, conf, cls = det
|
||||
bbox = [x1, y1, x2 - x1, y2 - y1] # (x, y, w, h)
|
||||
tracker_inputs.append((bbox, conf, int(cls)))
|
||||
# 更新跟踪器
|
||||
tracks = self.tracker.update_tracks(tracker_inputs, frame=frame)
|
||||
# 获取所有被确认过的追踪目标
|
||||
active_tracks = []
|
||||
# 绘制跟踪结果
|
||||
for track in tracks:
|
||||
if not track.is_confirmed():
|
||||
active_tracks.append(track)
|
||||
continue
|
||||
track_id = track.track_id
|
||||
track_cls = str(track.det_class)
|
||||
ltrb = track.to_ltrb()
|
||||
x1, y1, x2, y2 = map(int, ltrb)
|
||||
# 绘制矩形框和ID标签
|
||||
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||||
cv2.putText(
|
||||
frame,
|
||||
f"{idx_to_class[track_cls]} {track_id}",
|
||||
(x1, y1 - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.5,
|
||||
(0, 255, 0),
|
||||
2,
|
||||
)
|
||||
for tark in active_tracks:
|
||||
class_id = str(tark.det_class)
|
||||
count_result[class_id].add(tark.track_id)
|
||||
# 对应每个label进行计数
|
||||
result = {}
|
||||
for key in count_result.keys():
|
||||
result[idx_to_class[key]] = len(count_result[key])
|
||||
# 将帧编码为 JPEG
|
||||
ret, jpeg = cv2.imencode('.jpg', frame)
|
||||
if ret:
|
||||
jpeg_bytes = jpeg.tobytes()
|
||||
asyncio.run(room_manager.send_stream_to_room(room_name, jpeg_bytes))
|
||||
asyncio.run(room_manager.send_to_room(room_count, str(result)))
|
||||
|
@ -1,3 +1,4 @@
|
||||
import cv2
|
||||
import time
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
@ -97,11 +98,10 @@ class YoloModel:
|
||||
conf=0.6
|
||||
)
|
||||
|
||||
def predict_rtsp(self, source, detect_id):
|
||||
def predict_rtsp(self, source, room_name):
|
||||
"""
|
||||
对rtsp视频流进行预测
|
||||
"""
|
||||
room_name = 'detect_rtsp_' + str(detect_id)
|
||||
start_time = time.time()
|
||||
for result in self.model.predict(source=source, stream=True, device='0', conf=0.6):
|
||||
# 检查是否已经超过10分钟(600秒)
|
||||
@ -109,5 +109,7 @@ class YoloModel:
|
||||
if elapsed_time > 600: # 600 seconds = 10 minutes
|
||||
break
|
||||
frame = result.plot()
|
||||
frame_data = frame.tobytes()
|
||||
asyncio.run(room_manager.send_stream_to_room(room_name, frame_data))
|
||||
ret, jpeg = cv2.imencode('.jpg', frame)
|
||||
if ret:
|
||||
frame_data = jpeg.tobytes()
|
||||
asyncio.run(room_manager.send_stream_to_room(room_name, frame_data))
|
Reference in New Issue
Block a user