完成目标追踪的开发和测试

This commit is contained in:
sunyugang 2025-04-28 09:46:47 +08:00
parent 5b38e91f61
commit da1f5a874e
5 changed files with 410 additions and 49 deletions

View File

@ -0,0 +1,133 @@
from utils.websocket_server import room_manager
import time
import torch
from utils.yolov5.models.common import DetectMultiBackend
from utils.yolov5.utils.torch_utils import select_device
from utils.yolov5.utils.dataloaders import LoadImages, LoadStreams
from utils.yolov5.utils.general import check_img_size, non_max_suppression, cv2, scale_coords, xyxy2xywh, Profile
from deep_sort.deep_sort import DeepSort
from deep_sort.utils.draw import draw_boxes
def yolov5_to_deepsort_format(pred):
"""
将YOLOv5的预测结果转换为Deep SORT所需的格式
:param pred: YOLOv5的预测结果
:return: 转换后的bbox_xywh, confs, class_ids
"""
pred[:, :4] = xyxy2xywh(pred[:, :4])
xywh = pred[:, :4].cpu().numpy()
conf = pred[:, 4].cpu().numpy()
cls = pred[:, 5].cpu().numpy()
return xywh, conf, cls
async def run_deepsort(
detect_id: int,
weights_pt: str,
data: str,
idx_to_class: {},
sort_type: str = 'video',
video_path: str = None,
rtsp_url: str = None
):
"""
deep_sort追踪先经过yolov5对目标进行识别
再调用deepsort对目标进行追踪
"""
room = 'deep_sort_' + str(detect_id)
room_count = 'deep_sort_count_' + str(detect_id)
# 选择设备CPU 或 GPU
device = select_device('cuda:0')
model = DetectMultiBackend(weights_pt, device=device, dnn=False, data=data, fp16=False)
deepsort = DeepSort(
model_path="deep_sort/deep/checkpoint/ckpt.t7", # ReID 模型路径
max_dist=0.2, # 外观特征匹配阈值(越小越严格)
max_iou_distance=0.7, # 最大IoU距离阈值
max_age=70, # 目标最大存活帧数(未匹配时保留的帧数)
n_init=3 # 初始确认帧数连续匹配到n_init次后确认跟踪
)
stride, names, pt = model.stride, model.names, model.pt
img_sz = check_img_size((640, 640), s=stride) # check image size
if sort_type == 'video':
dataset = LoadImages(video_path, img_size=img_sz, stride=stride, auto=pt, vid_stride=1)
else:
dataset = LoadStreams(rtsp_url, img_size=img_sz, stride=stride, auto=pt, vid_stride=1)
bs = len(dataset)
count_result = {}
for key in idx_to_class.keys():
count_result[key] = set()
model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *img_sz))
time.sleep(2) # 等待2s等待websocket进入
dt = (Profile(device=device), Profile(device=device), Profile(device=device))
for path, im, im0s, vid_cap, s in dataset:
if room_manager.rooms.get(room):
with dt[0]:
im0 = im0s[0]
im = torch.from_numpy(im).to(model.device)
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
im /= 255 # 0 - 255 to 0.0 - 1.0
if len(im.shape) == 3:
im = im[None] # expand for batch dim
with dt[1]:
pred = model(im, augment=False, visualize=False)
with dt[2]:
# NMS
pred = non_max_suppression(pred, 0.25, 0.45, None, False, max_det=1000)[0]
pred[:, :4] = scale_coords(im.shape[2:], pred[:, :4], im0.shape).round()
# 使用YOLOv5进行检测后得到的pred
bbox_xywh, cls_conf, cls_ids = yolov5_to_deepsort_format(pred)
mask = cls_ids == 0
bbox_xywh = bbox_xywh[mask]
bbox_xywh[:, 2:] *= 1.2
cls_conf = cls_conf[mask]
cls_ids = cls_ids[mask]
# 调用Deep SORT更新方法
outputs, _ = deepsort.update(bbox_xywh, cls_conf, cls_ids, im0)
if len(outputs) > 0:
bbox_xyxy = outputs[:, :4]
identities = outputs[:, -1]
cls = outputs[:, -2]
names = [idx_to_class[str(label)] for label in cls]
# 开始画框
ori_img = draw_boxes(im0, bbox_xyxy, names, identities, None)
# 获取所有被确认过的追踪目标
active_tracks = [
track for track in deepsort.tracker.tracks
if track.is_confirmed()
]
for tark in active_tracks:
class_id = str(tark.cls)
count_result[class_id].add(tark.track_id)
# 对应每个label进行计数
result = {}
for key in count_result.keys():
result[idx_to_class[key]] = len(count_result[key])
# 将帧编码为 JPEG
ret, jpeg = cv2.imencode('.jpg', ori_img)
if ret:
jpeg_bytes = jpeg.tobytes()
await room_manager.send_stream_to_room(room, jpeg_bytes)
await room_manager.send_to_room(room_count, str(result))
else:
print(room, '结束追踪')
break

View File

@ -1,19 +1,19 @@
from application.settings import yolo_url, detect_url
from utils.websocket_server import room_manager
from utils import os_utils as os
from . import models, crud, schemas
from utils.websocket_server import room_manager
from application.settings import yolo_url, detect_url
from apps.business.train import models as train_models
from utils.yolov5.models.common import DetectMultiBackend
from utils.yolov5.utils.torch_utils import select_device
from utils.yolov5.utils.dataloaders import LoadStreams
from utils.yolov5.utils.general import check_img_size, Profile, non_max_suppression, cv2, scale_boxes
from utils.yolov5.utils.torch_utils import select_device
from ultralytics.utils.plotting import Annotator, colors
from utils.yolov5.models.common import DetectMultiBackend
from apps.business.deepsort import service as deepsort_service
from utils.yolov5.utils.general import check_img_size, Profile, non_max_suppression, cv2, scale_boxes
import time
import torch
import asyncio
import subprocess
from redis.asyncio import Redis
from sqlalchemy.ext.asyncio import AsyncSession
@ -169,16 +169,16 @@ async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id:
model = DetectMultiBackend(weights_pt, device=device, dnn=False, data=data, fp16=False)
stride, names, pt = model.stride, model.names, model.pt
imgsz = check_img_size((640, 640), s=stride) # check image size
img_sz = check_img_size((640, 640), s=stride) # check image size
dataset = LoadStreams(rtsp_url, img_size=imgsz, stride=stride, auto=pt, vid_stride=1)
dataset = LoadStreams(rtsp_url, img_size=img_sz, stride=stride, auto=pt, vid_stride=1)
bs = len(dataset)
model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz))
model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *img_sz))
seen, windows, dt = 0, [], (Profile(device=device), Profile(device=device), Profile(device=device))
dt = (Profile(device=device), Profile(device=device), Profile(device=device))
time.sleep(3) # 等待3s等待websocket进入
time.sleep(3) # 等待2s等待websocket进入
for path, im, im0s, vid_cap, s in dataset:
if room_manager.rooms.get(room):
@ -188,25 +188,13 @@ async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id:
im /= 255 # 0 - 255 to 0.0 - 1.0
if len(im.shape) == 3:
im = im[None] # expand for batch dim
if model.xml and im.shape[0] > 1:
ims = torch.chunk(im, im.shape[0], 0)
# Inference
with dt[1]:
if model.xml and im.shape[0] > 1:
pred = None
for image in ims:
if pred is None:
pred = model(image, augment=False, visualize=False).unsqueeze(0)
else:
pred = torch.cat((pred, model(image, augment=False, visualize=False).unsqueeze(0)),
dim=0)
pred = [pred, None]
else:
pred = model(im, augment=False, visualize=False)
pred = model(im, augment=False, visualize=False)
# NMS
with dt[2]:
pred = non_max_suppression(pred, 0.45, 0.45, None, False, max_det=1000)
pred = non_max_suppression(pred, 0.25, 0.45, None, False, max_det=1000)
# Process predictions
for i, det in enumerate(pred): # per image
@ -238,6 +226,41 @@ def run_rtsp_loop(weights_pt: str, rtsp_url: str, data: str, detect_id: int, is_
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 运行异步函数
loop.run_until_complete(run_detect_rtsp(weights_pt, rtsp_url, data, detect_id, is_gpu))
loop.run_until_complete(
run_detect_rtsp(
weights_pt,
rtsp_url,
data,
detect_id,
is_gpu
)
)
# 可选: 关闭循环
loop.close()
def run_deepsort_loop(
detect_id: int,
weights_pt: str,
data: str,
idx_to_class: {},
sort_type: str = 'video',
video_path: str = None,
rtsp_url: str = None
):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 运行异步函数
loop.run_until_complete(
deepsort_service.run_deepsort(
detect_id,
weights_pt,
data,
idx_to_class,
sort_type,
video_path,
rtsp_url
)
)
# 可选: 关闭循环
loop.close()

View File

@ -9,17 +9,16 @@
from utils import os_utils as osu
from core.dependencies import IdList
from core.database import redis_getter
from . import schemas, crud, params, service, models
from utils.websocket_server import room_manager
from . import schemas, crud, params, service, models
from apps.business.train.crud import ProjectTrainDal
from apps.business.project.crud import ProjectInfoDal, ProjectLabelDal
from apps.vadmin.auth.utils.current import AllUserAuth
from apps.vadmin.auth.utils.validation.auth import Auth
from utils.response import SuccessResponse, ErrorResponse
import os
import shutil
import zipfile
import tempfile
import threading
from redis.asyncio import Redis
from fastapi.responses import FileResponse
@ -121,27 +120,52 @@ async def run_detect_yolo(
if file_count == 0 and detect.rtsp_url is None:
return ErrorResponse("推理集合中没有内容,请先到推理集合中上传图片")
is_gpu = await rd.get('is_gpu')
if detect.file_type == 'img' or detect.file_type == 'video':
detect_log = await service.before_detect(detect_log_in, detect, train, auth.db, auth.user.id)
thread_train = threading.Thread(target=service.run_img_loop,
args=(detect_log.pt_url, detect_log.folder_url,
detect_log.detect_folder_url, detect_log.detect_version,
detect_log.detect_id, is_gpu))
thread_train.start()
await service.update_sql(
auth.db, detect_log.detect_id,
detect_log.id, detect_log.detect_folder_url,
detect_log.detect_version)
elif detect.file_type == 'rtsp':
room = 'detect_rtsp_' + str(detect.id)
if not room_manager.rooms.get(room):
if detect_log_in.pt_type == 'best':
weights_pt = train.best_pt
else:
weights_pt = train.last_pt
thread_train = threading.Thread(target=service.run_rtsp_loop,
args=(weights_pt, detect.rtsp_url, train.train_data, detect.id, is_gpu,))
# 判断一下是单纯的推理项目还是跟踪项目
project_info = await ProjectInfoDal(auth.db).get_data(data_id=detect.project_id)
if project_info.type_code == 'yolo':
if detect.file_type == 'img' or detect.file_type == 'video':
detect_log = await service.before_detect(detect_log_in, detect, train, auth.db, auth.user.id)
thread_train = threading.Thread(target=service.run_img_loop,
args=(detect_log.pt_url, detect_log.folder_url,
detect_log.detect_folder_url, detect_log.detect_version,
detect_log.detect_id, is_gpu))
thread_train.start()
await service.update_sql(
auth.db, detect_log.detect_id,
detect_log.id, detect_log.detect_folder_url,
detect_log.detect_version)
elif detect.file_type == 'rtsp':
room = 'detect_rtsp_' + str(detect.id)
if not room_manager.rooms.get(room):
if detect_log_in.pt_type == 'best':
weights_pt = train.best_pt
else:
weights_pt = train.last_pt
thread_train = threading.Thread(target=service.run_rtsp_loop,
args=(weights_pt, detect.rtsp_url, train.train_data, detect.id, is_gpu,))
thread_train.start()
elif project_info.type_code == 'deepsort':
room = 'deep_sort_' + str(detect.id)
if not room_manager.rooms.get(room):
# 查询项目所属标签,返回两个 idname一一对应的数组
label_id_list, label_name_list = await ProjectLabelDal(auth.db).get_label_for_train(project_info.id)
idx_to_class = {str(i): name for i, name in enumerate(label_name_list)}
# if detect_log_in.pt_type == 'best':
# weights_pt = train.best_pt
# else:
# weights_pt = train.last_pt
if detect.file_type == 'rtsp':
threading_main = threading.Thread(
target=service.run_deepsort_loop,
args=(detect.id, train.best_pt, train.train_data,
idx_to_class, 'rtsp', None, detect.rtsp_url))
threading_main.start()
elif detect.file_type == 'video':
threading_main = threading.Thread(
target=service.run_deepsort_loop,
args=(detect.id, train.best_pt, train.train_data,
idx_to_class, 'video', detect.folder_url, None))
threading_main.start()
return SuccessResponse(msg="执行成功")

Binary file not shown.

181
utils/deepsort.py Normal file
View File

@ -0,0 +1,181 @@
import os
import torch
from yolov5.models.common import DetectMultiBackend
from yolov5.utils.torch_utils import select_device
from yolov5.utils.dataloaders import LoadImages
from yolov5.utils.general import check_img_size, non_max_suppression, cv2, scale_coords, xyxy2xywh
from deep_sort.deep_sort import DeepSort
class VideoTracker(object):
def __init__(self, weights_pt, data, video_path, save_path, idx_to_class):
self.video_path = video_path
self.save_path = save_path
self.idx_to_class = idx_to_class
# 选择设备CPU 或 GPU
device = select_device('cpu')
self.vdo = cv2.VideoCapture()
self.detector = DetectMultiBackend(weights_pt, device=device, dnn=False, data=data, fp16=False)
self.deepsort = DeepSort(
model_path="deep_sort/deep/checkpoint/ckpt.t7", # ReID 模型路径
max_dist=0.2, # 外观特征匹配阈值(越小越严格)
max_iou_distance=0.7, # 最大IoU距离阈值
max_age=70, # 目标最大存活帧数(未匹配时保留的帧数)
n_init=3 # 初始确认帧数连续匹配到n_init次后确认跟踪
)
self.class_names = self.detector.class_names
def __enter__(self):
self.vdo.open(self.video_path)
self.im_width = int(self.vdo.get(cv2.CAP_PROP_FRAME_WIDTH))
self.im_height = int(self.vdo.get(cv2.CAP_PROP_FRAME_HEIGHT))
assert self.vdo.isOpened()
if self.save_path:
os.makedirs(self.args.save_path, exist_ok=True)
# path of saved video and results
self.save_video_path = os.path.join(self.save_path, "results.avi")
self.save_results_path = os.path.join(self.save_path, "results.txt")
# create video writer
fourcc = cv2.VideoWriter_fourcc(*'MJPG')
self.writer = cv2.VideoWriter(self.save_video_path, fourcc, 20, (self.im_width, self.im_height))
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
if exc_type:
print(exc_type, exc_value, exc_traceback)
def run(self):
stride, names, pt = self.model.stride, self.model.names, self.model.pt
imgsz = check_img_size((640, 640), s=stride) # check image size
dataset = LoadImages(self.video_path, img_size=imgsz, stride=stride, auto=pt, vid_stride=1)
bs = len(dataset)
self.model.warmup(imgsz=(1 if pt or self.model.triton else bs, 3, *imgsz))
for path, im, im0s, vid_cap, s in dataset:
im = torch.from_numpy(im).to(self.model.device)
im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32
im /= 255 # 0 - 255 to 0.0 - 1.0
if len(im.shape) == 3:
im = im[None] # expand for batch dim
if self.model.xml and im.shape[0] > 1:
ims = torch.chunk(im, im.shape[0], 0)
# Inference
if self.model.xml and im.shape[0] > 1:
pred = None
for image in ims:
if pred is None:
pred = self.model(image, augment=False, visualize=False).unsqueeze(0)
else:
pred = torch.cat(
(pred, self.model(image, augment=False, visualize=False).unsqueeze(0)),
dim=0
)
pred = [pred, None]
else:
pred = self.model(im, augment=False, visualize=False)
# NMS
pred = non_max_suppression(pred, 0.40, 0.45, None, False, max_det=1000)[0]
image = im0s[0]
pred[:, :4] = scale_coords(im.shape[2:], pred[:, :4], image.shape).round()
# 使用YOLOv5进行检测后得到的pred
bbox_xywh, cls_conf, cls_ids = yolov5_to_deepsort_format(pred)
# select person class
mask = cls_ids == 0
bbox_xywh = bbox_xywh[mask]
# bbox dilation just in case bbox too small, delete this line if using a better pedestrian detector
bbox_xywh[:, 2:] *= 1.2
cls_conf = cls_conf[mask]
cls_ids = cls_ids[mask]
# 调用Deep SORT更新方法
outputs, _ = self.deepsort.update(bbox_xywh, cls_conf, cls_ids, image)
count_result = {}
for key in self.idx_to_class.keys():
count_result[key] = set()
# draw boxes for visualization
if len(outputs) > 0:
bbox_xyxy = outputs[:, :4] # 这个是检测所在框的坐标的数组
identities = outputs[:, -1] # 这个是每个元素的计数的数组
cls = outputs[:, -2] # 这个是标签数组id的数组
names = [self.idx_to_class[str(label)] for label in cls]
image = draw_boxes(image, bbox_xyxy, names, identities)
for i in range(len(cls)):
count_result[str(cls[i])].add(identities[i])
def draw_boxes(img, bbox, names=None, identities=None, offset=(0, 0)):
for i, box in enumerate(bbox):
x1, y1, x2, y2 = [int(i) for i in box]
x1 += offset[0]
x2 += offset[0]
y1 += offset[1]
y2 += offset[1]
# box text and bar
id = int(identities[i]) if identities is not None else 0
color = compute_color_for_labels(id)
label = '{:}{:d}'.format(names[i], id)
t_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_PLAIN, 2, 2)[0]
cv2.rectangle(img, (x1, y1), (x2, y2), color, 3)
cv2.rectangle(img, (x1, y1), (x1 + t_size[0] + 3, y1 + t_size[1] + 4), color, -1)
cv2.putText(img, label, (x1, y1 + t_size[1] + 4), cv2.FONT_HERSHEY_PLAIN, 2, [255, 255, 255], 2)
return img
def compute_color_for_labels(label):
"""
Simple function that adds fixed color depending on the class
"""
color = [int((p * (label ** 2 - label + 1)) % 255) for p in palette]
return tuple(color)
def yolov5_to_deepsort_format(pred):
"""
将YOLOv5的预测结果转换为Deep SORT所需的格式
:param pred: YOLOv5的预测结果
:return: 转换后的bbox_xywh, confs, class_ids
"""
pred[:, :4] = xyxy2xywh(pred[:, :4])
xywh = pred[:, :4].cpu().numpy()
conf = pred[:, 4].cpu().numpy()
cls = pred[:, 5].cpu().numpy()
return xywh, conf, cls
if __name__ == "__main__":
args = parse_args()
cfg = get_config()
if args.segment:
cfg.USE_SEGMENT = True
else:
cfg.USE_SEGMENT = False
if args.mmdet:
cfg.merge_from_file(args.config_mmdetection)
cfg.USE_MMDET = True
else:
cfg.merge_from_file(args.config_detection)
cfg.USE_MMDET = False
cfg.merge_from_file(args.config_deepsort)
if args.fastreid:
cfg.merge_from_file(args.config_fastreid)
cfg.USE_FASTREID = True
else:
cfg.USE_FASTREID = False
with VideoTracker(cfg, args, video_path=args.VIDEO_PATH) as vdo_trk:
vdo_trk.run()