aicheckv2-api/utils/deepsort.py

181 lines
7.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import torch
from yolov5.models.common import DetectMultiBackend
from yolov5.utils.torch_utils import select_device
from yolov5.utils.dataloaders import LoadImages
from yolov5.utils.general import check_img_size, non_max_suppression, cv2, scale_coords, xyxy2xywh
from deep_sort.deep_sort import DeepSort
class VideoTracker(object):
def __init__(self, weights_pt, data, video_path, save_path, idx_to_class):
self.video_path = video_path
self.save_path = save_path
self.idx_to_class = idx_to_class
# 选择设备CPU 或 GPU
device = select_device('cpu')
self.vdo = cv2.VideoCapture()
self.detector = DetectMultiBackend(weights_pt, device=device, dnn=False, data=data, fp16=False)
self.deepsort = DeepSort(
model_path="deep_sort/deep/checkpoint/ckpt.t7", # ReID 模型路径
max_dist=0.2, # 外观特征匹配阈值(越小越严格)
max_iou_distance=0.7, # 最大IoU距离阈值
max_age=70, # 目标最大存活帧数(未匹配时保留的帧数)
n_init=3 # 初始确认帧数连续匹配到n_init次后确认跟踪
)
self.class_names = self.detector.class_names
def __enter__(self):
self.vdo.open(self.video_path)
self.im_width = int(self.vdo.get(cv2.CAP_PROP_FRAME_WIDTH))
self.im_height = int(self.vdo.get(cv2.CAP_PROP_FRAME_HEIGHT))
assert self.vdo.isOpened()
if self.save_path:
os.makedirs(self.args.save_path, exist_ok=True)
# path of saved video and results
self.save_video_path = os.path.join(self.save_path, "results.avi")
self.save_results_path = os.path.join(self.save_path, "results.txt")
# create video writer
fourcc = cv2.VideoWriter_fourcc(*'MJPG')
self.writer = cv2.VideoWriter(self.save_video_path, fourcc, 20, (self.im_width, self.im_height))
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
if exc_type:
print(exc_type, exc_value, exc_traceback)
def run(self):
stride, names, pt = self.model.stride, self.model.names, self.model.pt
imgsz = check_img_size((640, 640), s=stride) # check image size
dataset = LoadImages(self.video_path, img_size=imgsz, stride=stride, auto=pt, vid_stride=1)
bs = len(dataset)
self.model.warmup(imgsz=(1 if pt or self.model.triton else bs, 3, *imgsz))
for path, im, im0s, vid_cap, s in dataset:
im = torch.from_numpy(im).to(self.model.device)
im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32
im /= 255 # 0 - 255 to 0.0 - 1.0
if len(im.shape) == 3:
im = im[None] # expand for batch dim
if self.model.xml and im.shape[0] > 1:
ims = torch.chunk(im, im.shape[0], 0)
# Inference
if self.model.xml and im.shape[0] > 1:
pred = None
for image in ims:
if pred is None:
pred = self.model(image, augment=False, visualize=False).unsqueeze(0)
else:
pred = torch.cat(
(pred, self.model(image, augment=False, visualize=False).unsqueeze(0)),
dim=0
)
pred = [pred, None]
else:
pred = self.model(im, augment=False, visualize=False)
# NMS
pred = non_max_suppression(pred, 0.40, 0.45, None, False, max_det=1000)[0]
image = im0s[0]
pred[:, :4] = scale_coords(im.shape[2:], pred[:, :4], image.shape).round()
# 使用YOLOv5进行检测后得到的pred
bbox_xywh, cls_conf, cls_ids = yolov5_to_deepsort_format(pred)
# select person class
mask = cls_ids == 0
bbox_xywh = bbox_xywh[mask]
# bbox dilation just in case bbox too small, delete this line if using a better pedestrian detector
bbox_xywh[:, 2:] *= 1.2
cls_conf = cls_conf[mask]
cls_ids = cls_ids[mask]
# 调用Deep SORT更新方法
outputs, _ = self.deepsort.update(bbox_xywh, cls_conf, cls_ids, image)
count_result = {}
for key in self.idx_to_class.keys():
count_result[key] = set()
# draw boxes for visualization
if len(outputs) > 0:
bbox_xyxy = outputs[:, :4] # 这个是检测所在框的坐标的数组
identities = outputs[:, -1] # 这个是每个元素的计数的数组
cls = outputs[:, -2] # 这个是标签数组id的数组
names = [self.idx_to_class[str(label)] for label in cls]
image = draw_boxes(image, bbox_xyxy, names, identities)
for i in range(len(cls)):
count_result[str(cls[i])].add(identities[i])
def draw_boxes(img, bbox, names=None, identities=None, offset=(0, 0)):
for i, box in enumerate(bbox):
x1, y1, x2, y2 = [int(i) for i in box]
x1 += offset[0]
x2 += offset[0]
y1 += offset[1]
y2 += offset[1]
# box text and bar
id = int(identities[i]) if identities is not None else 0
color = compute_color_for_labels(id)
label = '{:}{:d}'.format(names[i], id)
t_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_PLAIN, 2, 2)[0]
cv2.rectangle(img, (x1, y1), (x2, y2), color, 3)
cv2.rectangle(img, (x1, y1), (x1 + t_size[0] + 3, y1 + t_size[1] + 4), color, -1)
cv2.putText(img, label, (x1, y1 + t_size[1] + 4), cv2.FONT_HERSHEY_PLAIN, 2, [255, 255, 255], 2)
return img
def compute_color_for_labels(label):
"""
Simple function that adds fixed color depending on the class
"""
color = [int((p * (label ** 2 - label + 1)) % 255) for p in palette]
return tuple(color)
def yolov5_to_deepsort_format(pred):
"""
将YOLOv5的预测结果转换为Deep SORT所需的格式
:param pred: YOLOv5的预测结果
:return: 转换后的bbox_xywh, confs, class_ids
"""
pred[:, :4] = xyxy2xywh(pred[:, :4])
xywh = pred[:, :4].cpu().numpy()
conf = pred[:, 4].cpu().numpy()
cls = pred[:, 5].cpu().numpy()
return xywh, conf, cls
if __name__ == "__main__":
args = parse_args()
cfg = get_config()
if args.segment:
cfg.USE_SEGMENT = True
else:
cfg.USE_SEGMENT = False
if args.mmdet:
cfg.merge_from_file(args.config_mmdetection)
cfg.USE_MMDET = True
else:
cfg.merge_from_file(args.config_detection)
cfg.USE_MMDET = False
cfg.merge_from_file(args.config_deepsort)
if args.fastreid:
cfg.merge_from_file(args.config_fastreid)
cfg.USE_FASTREID = True
else:
cfg.USE_FASTREID = False
with VideoTracker(cfg, args, video_path=args.VIDEO_PATH) as vdo_trk:
vdo_trk.run()