deep_sort 提升版本
This commit is contained in:
@ -1 +1,2 @@
|
||||
from .yolov11 import YoloModel
|
||||
from .deep_sort import DeepSortModel
|
82
algo/deep_sort.py
Normal file
82
algo/deep_sort.py
Normal file
@ -0,0 +1,82 @@
|
||||
import cv2
|
||||
import time
|
||||
import asyncio
|
||||
from ultralytics import YOLO
|
||||
from deep_sort_realtime.deepsort_tracker import DeepSort
|
||||
|
||||
from utils.websocket_server import room_manager
|
||||
|
||||
|
||||
class DeepSortModel:
|
||||
|
||||
def __init__(self, pt_url):
|
||||
self.model = YOLO(pt_url)
|
||||
self.tracker = DeepSort()
|
||||
|
||||
def sort_video(self, url, detect_id: int, idx_to_class: {}):
|
||||
"""
|
||||
对文件夹中的视频或rtsp的视频流,进行目标追踪
|
||||
"""
|
||||
room_name = 'deep_sort_' + str(detect_id)
|
||||
room_count = 'deep_sort_count_' + str(detect_id)
|
||||
count_result = {}
|
||||
for key in idx_to_class.keys():
|
||||
count_result[key] = set()
|
||||
cap = cv2.VideoCapture(url)
|
||||
start_time = time.time()
|
||||
while cap.isOpened():
|
||||
# 检查是否已经超过10分钟(600秒)
|
||||
elapsed_time = time.time() - start_time
|
||||
if elapsed_time > 600: # 600 seconds = 10 minutes
|
||||
break
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
# YOLO 推理(GPU 加速)
|
||||
results = self.model(frame, device=0, conf=0.6, verbose=False)
|
||||
# 获取检测框数据
|
||||
detections = results[0].boxes.data.cpu().numpy()
|
||||
# DeepSORT 格式转换:[(bbox_x, y, w, h), confidence, class_id]
|
||||
tracker_inputs = []
|
||||
for det in detections:
|
||||
x1, y1, x2, y2, conf, cls = det
|
||||
bbox = [x1, y1, x2 - x1, y2 - y1] # (x, y, w, h)
|
||||
tracker_inputs.append((bbox, conf, int(cls)))
|
||||
# 更新跟踪器
|
||||
tracks = self.tracker.update_tracks(tracker_inputs, frame=frame)
|
||||
# 获取所有被确认过的追踪目标
|
||||
active_tracks = []
|
||||
# 绘制跟踪结果
|
||||
for track in tracks:
|
||||
if not track.is_confirmed():
|
||||
active_tracks.append(track)
|
||||
continue
|
||||
track_id = track.track_id
|
||||
track_cls = str(track.det_class)
|
||||
ltrb = track.to_ltrb()
|
||||
x1, y1, x2, y2 = map(int, ltrb)
|
||||
# 绘制矩形框和ID标签
|
||||
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||||
cv2.putText(
|
||||
frame,
|
||||
f"{idx_to_class[track_cls]} {track_id}",
|
||||
(x1, y1 - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.5,
|
||||
(0, 255, 0),
|
||||
2,
|
||||
)
|
||||
for tark in active_tracks:
|
||||
class_id = str(tark.det_class)
|
||||
count_result[class_id].add(tark.track_id)
|
||||
# 对应每个label进行计数
|
||||
result = {}
|
||||
for key in count_result.keys():
|
||||
result[idx_to_class[key]] = len(count_result[key])
|
||||
# 将帧编码为 JPEG
|
||||
ret, jpeg = cv2.imencode('.jpg', frame)
|
||||
if ret:
|
||||
jpeg_bytes = jpeg.tobytes()
|
||||
asyncio.run(room_manager.send_stream_to_room(room_name, jpeg_bytes))
|
||||
asyncio.run(room_manager.send_to_room(room_count, str(result)))
|
||||
|
@ -1,3 +1,4 @@
|
||||
import cv2
|
||||
import time
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
@ -97,11 +98,10 @@ class YoloModel:
|
||||
conf=0.6
|
||||
)
|
||||
|
||||
def predict_rtsp(self, source, detect_id):
|
||||
def predict_rtsp(self, source, room_name):
|
||||
"""
|
||||
对rtsp视频流进行预测
|
||||
"""
|
||||
room_name = 'detect_rtsp_' + str(detect_id)
|
||||
start_time = time.time()
|
||||
for result in self.model.predict(source=source, stream=True, device='0', conf=0.6):
|
||||
# 检查是否已经超过10分钟(600秒)
|
||||
@ -109,5 +109,7 @@ class YoloModel:
|
||||
if elapsed_time > 600: # 600 seconds = 10 minutes
|
||||
break
|
||||
frame = result.plot()
|
||||
frame_data = frame.tobytes()
|
||||
asyncio.run(room_manager.send_stream_to_room(room_name, frame_data))
|
||||
ret, jpeg = cv2.imencode('.jpg', frame)
|
||||
if ret:
|
||||
frame_data = jpeg.tobytes()
|
||||
asyncio.run(room_manager.send_stream_to_room(room_name, frame_data))
|
@ -1,131 +1,15 @@
|
||||
|
||||
from utils.websocket_server import room_manager
|
||||
|
||||
import time
|
||||
import torch
|
||||
from deep_sort.deep_sort import DeepSort
|
||||
from deep_sort.utils.draw import draw_boxes
|
||||
from algo import DeepSortModel
|
||||
|
||||
|
||||
def yolov5_to_deepsort_format(pred):
|
||||
"""
|
||||
将YOLOv5的预测结果转换为Deep SORT所需的格式
|
||||
:param pred: YOLOv5的预测结果
|
||||
:return: 转换后的bbox_xywh, confs, class_ids
|
||||
"""
|
||||
# pred[:, :4] = xyxy2xywh(pred[:, :4])
|
||||
# xywh = pred[:, :4].cpu().numpy()
|
||||
# conf = pred[:, 4].cpu().numpy()
|
||||
# cls = pred[:, 5].cpu().numpy()
|
||||
# return xywh, conf, cls
|
||||
|
||||
|
||||
async def run_deepsort(
|
||||
def run_deepsort(
|
||||
detect_id: int,
|
||||
weights_pt: str,
|
||||
data: str,
|
||||
idx_to_class: {},
|
||||
sort_type: str = 'video',
|
||||
video_path: str = None,
|
||||
rtsp_url: str = None
|
||||
url
|
||||
):
|
||||
"""
|
||||
deep_sort追踪,先经过yolov5对目标进行识别,
|
||||
deep_sort追踪,先经过yolo对目标进行识别,
|
||||
再调用deepsort对目标进行追踪
|
||||
"""
|
||||
# room = 'deep_sort_' + str(detect_id)
|
||||
#
|
||||
# room_count = 'deep_sort_count_' + str(detect_id)
|
||||
#
|
||||
# # 选择设备(CPU 或 GPU)
|
||||
# device = select_device('cuda:0')
|
||||
#
|
||||
# model = DetectMultiBackend(weights_pt, device=device, dnn=False, data=data, fp16=False)
|
||||
#
|
||||
# deepsort = DeepSort(
|
||||
# model_path="deep_sort/deep/checkpoint/ckpt.t7", # ReID 模型路径
|
||||
# max_dist=0.2, # 外观特征匹配阈值(越小越严格)
|
||||
# max_iou_distance=0.7, # 最大IoU距离阈值
|
||||
# max_age=70, # 目标最大存活帧数(未匹配时保留的帧数)
|
||||
# n_init=3 # 初始确认帧数(连续匹配到n_init次后确认跟踪)
|
||||
# )
|
||||
# stride, names, pt = model.stride, model.names, model.pt
|
||||
# img_sz = check_img_size((640, 640), s=stride) # check image size
|
||||
# if sort_type == 'video':
|
||||
# dataset = LoadImages(video_path, img_size=img_sz, stride=stride, auto=pt, vid_stride=1)
|
||||
# else:
|
||||
# dataset = LoadStreams(rtsp_url, img_size=img_sz, stride=stride, auto=pt, vid_stride=1)
|
||||
# bs = len(dataset)
|
||||
#
|
||||
# count_result = {}
|
||||
# for key in idx_to_class.keys():
|
||||
# count_result[key] = set()
|
||||
#
|
||||
# model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *img_sz))
|
||||
#
|
||||
# time.sleep(3) # 等待3s,等待websocket进入
|
||||
#
|
||||
# start_time = time.time()
|
||||
#
|
||||
# for path, im, im0s, vid_cap, s in dataset:
|
||||
# # 检查是否已经超过10分钟(600秒)
|
||||
# elapsed_time = time.time() - start_time
|
||||
# if elapsed_time > 600: # 600 seconds = 10 minutes
|
||||
# print(room, "已达到最大执行时间,结束推理。")
|
||||
# break
|
||||
# if room_manager.rooms.get(room):
|
||||
# im0 = im0s[0]
|
||||
# im = torch.from_numpy(im).to(model.device)
|
||||
# im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
|
||||
# im /= 255 # 0 - 255 to 0.0 - 1.0
|
||||
# if len(im.shape) == 3:
|
||||
# im = im[None] # expand for batch dim
|
||||
# pred = model(im, augment=False, visualize=False)
|
||||
# # NMS
|
||||
# pred = non_max_suppression(pred, 0.25, 0.45, None, False, max_det=1000)[0]
|
||||
#
|
||||
# pred[:, :4] = scale_coords(im.shape[2:], pred[:, :4], im0.shape).round()
|
||||
#
|
||||
# # 使用YOLOv5进行检测后得到的pred
|
||||
# bbox_xywh, cls_conf, cls_ids = yolov5_to_deepsort_format(pred)
|
||||
#
|
||||
# mask = cls_ids == 0
|
||||
#
|
||||
# bbox_xywh = bbox_xywh[mask]
|
||||
# bbox_xywh[:, 2:] *= 1.2
|
||||
# cls_conf = cls_conf[mask]
|
||||
# cls_ids = cls_ids[mask]
|
||||
#
|
||||
# # 调用Deep SORT更新方法
|
||||
# outputs, _ = deepsort.update(bbox_xywh, cls_conf, cls_ids, im0)
|
||||
#
|
||||
# if len(outputs) > 0:
|
||||
# bbox_xyxy = outputs[:, :4]
|
||||
# identities = outputs[:, -1]
|
||||
# cls = outputs[:, -2]
|
||||
# names = [idx_to_class[str(label)] for label in cls]
|
||||
# # 开始画框
|
||||
# ori_img = draw_boxes(im0, bbox_xyxy, names, identities, None)
|
||||
#
|
||||
# # 获取所有被确认过的追踪目标
|
||||
# active_tracks = [
|
||||
# track for track in deepsort.tracker.tracks
|
||||
# if track.is_confirmed()
|
||||
# ]
|
||||
#
|
||||
# for tark in active_tracks:
|
||||
# class_id = str(tark.cls)
|
||||
# count_result[class_id].add(tark.track_id)
|
||||
# # 对应每个label进行计数
|
||||
# result = {}
|
||||
# for key in count_result.keys():
|
||||
# result[idx_to_class[key]] = len(count_result[key])
|
||||
# # 将帧编码为 JPEG
|
||||
# ret, jpeg = cv2.imencode('.jpg', ori_img)
|
||||
# if ret:
|
||||
# jpeg_bytes = jpeg.tobytes()
|
||||
# await room_manager.send_stream_to_room(room, jpeg_bytes)
|
||||
# await room_manager.send_to_room(room_count, str(result))
|
||||
# else:
|
||||
# print(room, '结束追踪')
|
||||
# break
|
||||
deep_sort = DeepSortModel(weights_pt)
|
||||
deep_sort.sort_video(url=url, detect_id=detect_id, idx_to_class=idx_to_class)
|
||||
|
@ -51,21 +51,6 @@ async def before_detect(
|
||||
return detect_log
|
||||
|
||||
|
||||
def run_img_loop(
|
||||
weights: str,
|
||||
source: str,
|
||||
project: str,
|
||||
name: str,
|
||||
detect_id: int,
|
||||
is_gpu: str):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
# 运行异步函数
|
||||
loop.run_until_complete(run_detect_folder(weights, source, project, name, detect_id, is_gpu))
|
||||
# 可选: 关闭循环
|
||||
loop.close()
|
||||
|
||||
|
||||
def run_detect_folder(
|
||||
weights: str,
|
||||
source: str,
|
||||
@ -111,121 +96,13 @@ async def update_sql(db: AsyncSession, detect_id: int, log_id: int, project, nam
|
||||
await crud.ProjectDetectLogFileDal(db).create_models(detect_log_files)
|
||||
|
||||
|
||||
async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id: int, is_gpu: str):
|
||||
def run_detect_rtsp(weights_pt: str, rtsp_url: str, room_name: str):
|
||||
"""
|
||||
rtsp 视频流推理
|
||||
:param detect_id: 训练集的id
|
||||
:param room_name: websocket链接名称
|
||||
:param weights_pt: 权重文件
|
||||
:param rtsp_url: 视频流地址
|
||||
:param data: yaml文件
|
||||
:param is_gpu: 是否启用加速
|
||||
:return:
|
||||
"""
|
||||
# room = 'detect_rtsp_' + str(detect_id)
|
||||
# # 选择设备(CPU 或 GPU)
|
||||
# device = select_device('cpu')
|
||||
# # 判断是否存在cuda版本
|
||||
# if is_gpu == 'True':
|
||||
# device = select_device('cuda:0')
|
||||
#
|
||||
# # 加载模型
|
||||
# model = DetectMultiBackend(weights_pt, device=device, dnn=False, data=data, fp16=False)
|
||||
#
|
||||
# stride, names, pt = model.stride, model.names, model.pt
|
||||
# img_sz = check_img_size((640, 640), s=stride) # check image size
|
||||
#
|
||||
# dataset = LoadStreams(rtsp_url, img_size=img_sz, stride=stride, auto=pt, vid_stride=1)
|
||||
# bs = len(dataset)
|
||||
#
|
||||
# model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *img_sz))
|
||||
#
|
||||
# time.sleep(3) # 等待3s,等待websocket进入
|
||||
#
|
||||
# start_time = time.time()
|
||||
#
|
||||
# for path, im, im0s, vid_cap, s in dataset:
|
||||
# # 检查是否已经超过10分钟(600秒)
|
||||
# elapsed_time = time.time() - start_time
|
||||
# if elapsed_time > 600: # 600 seconds = 10 minutes
|
||||
# print(room, "已达到最大执行时间,结束推理。")
|
||||
# break
|
||||
# if room_manager.rooms.get(room):
|
||||
# im = torch.from_numpy(im).to(model.device)
|
||||
# im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
|
||||
# im /= 255 # 0 - 255 to 0.0 - 1.0
|
||||
# if len(im.shape) == 3:
|
||||
# im = im[None] # expand for batch dim
|
||||
#
|
||||
# # Inference
|
||||
# pred = model(im, augment=False, visualize=False)
|
||||
# # NMS
|
||||
# pred = non_max_suppression(pred, 0.25, 0.45, None, False, max_det=1000)
|
||||
#
|
||||
# # Process predictions
|
||||
# for i, det in enumerate(pred): # per image
|
||||
# p, im0, frame = path[i], im0s[i].copy(), dataset.count
|
||||
# annotator = Annotator(im0, line_width=3, example=str(names))
|
||||
# if len(det):
|
||||
# # Rescale boxes from img_size to im0 size
|
||||
# det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
|
||||
#
|
||||
# # Write results
|
||||
# for *xyxy, conf, cls in reversed(det):
|
||||
# c = int(cls) # integer class
|
||||
# label = None if False else (names[c] if False else f"{names[c]} {conf:.2f}")
|
||||
# annotator.box_label(xyxy, label, color=colors(c, True))
|
||||
#
|
||||
# # Stream results
|
||||
# im0 = annotator.result()
|
||||
# # 将帧编码为 JPEG
|
||||
# ret, jpeg = cv2.imencode('.jpg', im0)
|
||||
# if ret:
|
||||
# frame_data = jpeg.tobytes()
|
||||
# await room_manager.send_stream_to_room(room, frame_data)
|
||||
# else:
|
||||
# print(room, '结束推理')
|
||||
# break
|
||||
|
||||
|
||||
def run_rtsp_loop(weights_pt: str, rtsp_url: str, data: str, detect_id: int, is_gpu: str):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
# 运行异步函数
|
||||
loop.run_until_complete(
|
||||
run_detect_rtsp(
|
||||
weights_pt,
|
||||
rtsp_url,
|
||||
data,
|
||||
detect_id,
|
||||
is_gpu
|
||||
)
|
||||
)
|
||||
# 可选: 关闭循环
|
||||
loop.close()
|
||||
|
||||
|
||||
def run_deepsort_loop(
|
||||
detect_id: int,
|
||||
weights_pt: str,
|
||||
data: str,
|
||||
idx_to_class: {},
|
||||
sort_type: str = 'video',
|
||||
video_path: str = None,
|
||||
rtsp_url: str = None
|
||||
):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
# 运行异步函数
|
||||
loop.run_until_complete(
|
||||
deepsort_service.run_deepsort(
|
||||
detect_id,
|
||||
weights_pt,
|
||||
data,
|
||||
idx_to_class,
|
||||
sort_type,
|
||||
video_path,
|
||||
rtsp_url
|
||||
)
|
||||
)
|
||||
# 可选: 关闭循环
|
||||
loop.close()
|
||||
model = YoloModel(weights_pt)
|
||||
model.predict_rtsp(rtsp_url, room_name)
|
@ -11,10 +11,11 @@ from core.dependencies import IdList
|
||||
from utils.websocket_server import room_manager
|
||||
from . import schemas, crud, params, service, models
|
||||
from apps.business.train.crud import ProjectTrainDal
|
||||
from apps.business.project.crud import ProjectInfoDal, ProjectLabelDal
|
||||
from apps.vadmin.auth.utils.current import AllUserAuth
|
||||
from apps.vadmin.auth.utils.validation.auth import Auth
|
||||
from utils.response import SuccessResponse, ErrorResponse
|
||||
from apps.business.deepsort import service as deep_sort_service
|
||||
from apps.business.project.crud import ProjectInfoDal, ProjectLabelDal
|
||||
|
||||
import os
|
||||
import shutil
|
||||
@ -147,13 +148,11 @@ async def run_detect_yolo(
|
||||
else:
|
||||
weights_pt = train.last_pt
|
||||
thread_train = threading.Thread(
|
||||
target=service.run_rtsp_loop,
|
||||
target=service.run_detect_rtsp,
|
||||
args=(
|
||||
weights_pt,
|
||||
detect.rtsp_url,
|
||||
train.train_data,
|
||||
detect.id,
|
||||
None
|
||||
room
|
||||
)
|
||||
)
|
||||
thread_train.start()
|
||||
@ -164,22 +163,16 @@ async def run_detect_yolo(
|
||||
# 查询项目所属标签,返回两个 id,name一一对应的数组
|
||||
label_id_list, label_name_list = await ProjectLabelDal(auth.db).get_label_for_train(project_info.id)
|
||||
idx_to_class = {str(i): name for i, name in enumerate(label_name_list)}
|
||||
# if detect_log_in.pt_type == 'best':
|
||||
# weights_pt = train.best_pt
|
||||
# else:
|
||||
# weights_pt = train.last_pt
|
||||
if detect.file_type == 'rtsp':
|
||||
threading_main = threading.Thread(
|
||||
target=service.run_deepsort_loop,
|
||||
args=(detect.id, train.best_pt, train.train_data,
|
||||
idx_to_class, 'rtsp', None, detect.rtsp_url))
|
||||
threading_main.start()
|
||||
elif detect.file_type == 'video':
|
||||
threading_main = threading.Thread(
|
||||
target=service.run_deepsort_loop,
|
||||
args=(detect.id, train.best_pt, train.train_data,
|
||||
idx_to_class, 'video', detect.folder_url, None))
|
||||
threading_main.start()
|
||||
threading_main = threading.Thread(
|
||||
target=deep_sort_service.run_deepsort,
|
||||
args=(
|
||||
detect.id,
|
||||
train.best_pt,
|
||||
idx_to_class,
|
||||
detect.rtsp_url
|
||||
)
|
||||
)
|
||||
threading_main.start()
|
||||
return SuccessResponse(msg="执行成功")
|
||||
|
||||
|
||||
|
@ -1,19 +0,0 @@
|
||||
from .deep_sort import DeepSort
|
||||
|
||||
__all__ = ['DeepSort', 'build_tracker']
|
||||
|
||||
|
||||
def build_tracker(cfg, use_cuda):
|
||||
if cfg.USE_FASTREID:
|
||||
return DeepSort(model_path=cfg.FASTREID.CHECKPOINT, model_config=cfg.FASTREID.CFG,
|
||||
max_dist=cfg.DEEPSORT.MAX_DIST, min_confidence=cfg.DEEPSORT.MIN_CONFIDENCE,
|
||||
nms_max_overlap=cfg.DEEPSORT.NMS_MAX_OVERLAP, max_iou_distance=cfg.DEEPSORT.MAX_IOU_DISTANCE,
|
||||
max_age=cfg.DEEPSORT.MAX_AGE, n_init=cfg.DEEPSORT.N_INIT, nn_budget=cfg.DEEPSORT.NN_BUDGET,
|
||||
use_cuda=use_cuda)
|
||||
|
||||
else:
|
||||
return DeepSort(model_path=cfg.DEEPSORT.REID_CKPT,
|
||||
max_dist=cfg.DEEPSORT.MAX_DIST, min_confidence=cfg.DEEPSORT.MIN_CONFIDENCE,
|
||||
nms_max_overlap=cfg.DEEPSORT.NMS_MAX_OVERLAP, max_iou_distance=cfg.DEEPSORT.MAX_IOU_DISTANCE,
|
||||
max_age=cfg.DEEPSORT.MAX_AGE, n_init=cfg.DEEPSORT.N_INIT, nn_budget=cfg.DEEPSORT.NN_BUDGET,
|
||||
use_cuda=use_cuda)
|
@ -1,10 +0,0 @@
|
||||
DEEPSORT:
|
||||
REID_CKPT: "./deep_sort/deep/checkpoint/ckpt.t7"
|
||||
MAX_DIST: 0.2
|
||||
MIN_CONFIDENCE: 0.5
|
||||
NMS_MAX_OVERLAP: 0.5
|
||||
MAX_IOU_DISTANCE: 0.7
|
||||
MAX_AGE: 70
|
||||
N_INIT: 3
|
||||
NN_BUDGET: 100
|
||||
|
@ -1,3 +0,0 @@
|
||||
FASTREID:
|
||||
CFG: "thirdparty/fast-reid/configs/Market1501/bagtricks_R50.yml"
|
||||
CHECKPOINT: "deep_sort/deep/checkpoint/market_bot_R50.pth"
|
@ -1,6 +0,0 @@
|
||||
MASKRCNN:
|
||||
LABEL: "./coco_classes.json"
|
||||
WEIGHT: "./detector/Mask_RCNN/save_weights/maskrcnn_resnet50_fpn_coco.pth"
|
||||
|
||||
NUM_CLASSES: 90
|
||||
BOX_THRESH: 0.5
|
@ -1,5 +0,0 @@
|
||||
MMDET:
|
||||
CFG: "thirdparty/mmdetection/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py"
|
||||
CHECKPOINT: "detector/MMDet/weight/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth"
|
||||
|
||||
SCORE_THRESH: 0.5
|
@ -1,82 +0,0 @@
|
||||
In deepsort algorithm, appearance feature extraction network used to extract features from **image_crops** for matching purpose.The original model used in paper is in `model.py`, and its parameter here [ckpt.t7](https://drive.google.com/drive/folders/1xhG0kRH1EX5B9_Iz8gQJb7UNnn_riXi6). This repository also provides a `resnet.py` script and its pre-training weights on Imagenet here.
|
||||
|
||||
```
|
||||
# resnet18
|
||||
https://download.pytorch.org/models/resnet18-5c106cde.pth
|
||||
# resnet34
|
||||
https://download.pytorch.org/models/resnet34-333f7ec4.pth
|
||||
# resnet50
|
||||
https://download.pytorch.org/models/resnet50-19c8e357.pth
|
||||
# resnext50_32x4d
|
||||
https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth
|
||||
```
|
||||
|
||||
## Dataset PrePare
|
||||
|
||||
To train the model, first you need download [Market1501](http://www.liangzheng.com.cn/Project/project_reid.html) dataset or [Mars](http://www.liangzheng.com.cn/Project/project_mars.html) dataset.
|
||||
|
||||
If you want to train on your **own dataset**, assuming you have already downloaded the dataset.The dataset should be arranged in the following way.
|
||||
|
||||
```
|
||||
├── dataset_root: The root dir of the dataset.
|
||||
├── class1: Category 1 is located in the folder dir.
|
||||
├── xxx1.jpg: Image belonging to category 1.
|
||||
├── xxx2.jpg: Image belonging to category 1.
|
||||
├── class2: Category 2 is located in the folder dir.
|
||||
├── xxx3.jpg: Image belonging to category 2.
|
||||
├── xxx4.jpg: Image belonging to category 2.
|
||||
├── class3: Category 3 is located in the folder dir.
|
||||
...
|
||||
...
|
||||
```
|
||||
|
||||
## Training the RE-ID model
|
||||
|
||||
Assuming you have already prepare the dataset. Then you can use the following command to start your training progress.
|
||||
|
||||
#### training on a single GPU
|
||||
|
||||
```python
|
||||
usage: train.py [--data-dir]
|
||||
[--epochs]
|
||||
[--batch_size]
|
||||
[--lr]
|
||||
[--lrf]
|
||||
[--weights]
|
||||
[--freeze-layers]
|
||||
[--gpu_id]
|
||||
|
||||
# default use cuda:0, use Net in `model.py`
|
||||
python train.py --data-dir [dataset/root/path] --weights [(optional)pre-train/weight/path]
|
||||
# you can use `--freeze-layers` option to freeze full convolutional layer parameters except fc layers parameters
|
||||
python train.py --data-dir [dataset/root/path] --weights [(optional)pre-train/weight/path] --freeze-layers
|
||||
```
|
||||
|
||||
#### training on multiple GPU
|
||||
|
||||
```python
|
||||
usage: train_multiGPU.py [--data-dir]
|
||||
[--epochs]
|
||||
[--batch_size]
|
||||
[--lr]
|
||||
[--lrf]
|
||||
[--syncBN]
|
||||
[--weights]
|
||||
[--freeze-layers]
|
||||
# not change the following parameters, the system will automatically assignment
|
||||
[--device]
|
||||
[--world_size]
|
||||
[--dist_url]
|
||||
|
||||
# default use cuda:0, cuda:1, cuda:2, cuda:3, use resnet18 in `resnet.py`
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 train_multiGPU.py --data-dir [dataset/root/path] --weights [(optional)pre-train/weight/path]
|
||||
# you can use `--freeze-layers` option to freeze full convolutional layer parameters except fc layers parameters
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 train_multiGPU.py --data-dir [dataset/root/path] --weights [(optional)pre-train/weight/path] --freeze-layers
|
||||
```
|
||||
|
||||
An example of training progress is as follows:
|
||||
|
||||

|
||||
|
||||
The last, you can evaluate it using [test.py](deep_sort/deep/test.py) and [evaluate.py](deep_sort/deep/evalute.py).
|
||||
|
Binary file not shown.
@ -1,92 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
|
||||
import cv2
|
||||
from PIL import Image
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
class ClsDataset(Dataset):
|
||||
def __init__(self, images_path, images_labels, transform=None):
|
||||
self.images_path = images_path
|
||||
self.images_labels = images_labels
|
||||
self.transform = transform
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images_path)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img = cv2.imread(self.images_path[idx])
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
img = Image.fromarray(img)
|
||||
label = self.images_labels[idx]
|
||||
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
return img, label
|
||||
|
||||
@staticmethod
|
||||
def collate_fn(batch):
|
||||
images, labels = tuple(zip(*batch))
|
||||
images = torch.stack(images, dim=0)
|
||||
labels = torch.as_tensor(labels)
|
||||
return images, labels
|
||||
|
||||
|
||||
def read_split_data(root, valid_rate=0.2):
|
||||
assert os.path.exists(root), 'dataset root: {} does not exist.'.format(root)
|
||||
|
||||
class_names = [cls for cls in os.listdir(root) if os.path.isdir(os.path.join(root, cls))]
|
||||
class_names.sort()
|
||||
|
||||
class_indices = {name: i for i, name in enumerate(class_names)}
|
||||
json_str = json.dumps({v: k for k, v in class_indices.items()}, indent=4)
|
||||
with open('class_indices.json', 'w') as f:
|
||||
f.write(json_str)
|
||||
|
||||
train_images_path = []
|
||||
train_labels = []
|
||||
val_images_path = []
|
||||
val_labels = []
|
||||
per_class_num = []
|
||||
|
||||
supported = ['.jpg', '.JPG', '.png', '.PNG']
|
||||
for cls in class_names:
|
||||
cls_path = os.path.join(root, cls)
|
||||
images_path = [os.path.join(cls_path, i) for i in os.listdir(cls_path)
|
||||
if os.path.splitext(i)[-1] in supported]
|
||||
images_label = class_indices[cls]
|
||||
per_class_num.append(len(images_path))
|
||||
|
||||
val_path = random.sample(images_path, int(len(images_path) * valid_rate))
|
||||
for img_path in images_path:
|
||||
if img_path in val_path:
|
||||
val_images_path.append(img_path)
|
||||
val_labels.append(images_label)
|
||||
else:
|
||||
train_images_path.append(img_path)
|
||||
train_labels.append(images_label)
|
||||
|
||||
print("{} images were found in the dataset.".format(sum(per_class_num)))
|
||||
print("{} images for training.".format(len(train_images_path)))
|
||||
print("{} images for validation.".format(len(val_images_path)))
|
||||
|
||||
assert len(train_images_path) > 0, "number of training images must greater than zero"
|
||||
assert len(val_images_path) > 0, "number of validation images must greater than zero"
|
||||
|
||||
plot_distribution = False
|
||||
if plot_distribution:
|
||||
plt.bar(range(len(class_names)), per_class_num, align='center')
|
||||
plt.xticks(range(len(class_names)), class_names)
|
||||
|
||||
for i, v in enumerate(per_class_num):
|
||||
plt.text(x=i, y=v + 5, s=str(v), ha='center')
|
||||
|
||||
plt.xlabel('classes')
|
||||
plt.ylabel('numbers')
|
||||
plt.title('the distribution of dataset')
|
||||
plt.show()
|
||||
return [train_images_path, train_labels], [val_images_path, val_labels], len(class_names)
|
@ -1,15 +0,0 @@
|
||||
import torch
|
||||
|
||||
features = torch.load("features.pth")
|
||||
qf = features["qf"]
|
||||
ql = features["ql"]
|
||||
gf = features["gf"]
|
||||
gl = features["gl"]
|
||||
|
||||
scores = qf.mm(gf.t())
|
||||
res = scores.topk(5, dim=1)[1][:,0]
|
||||
top1correct = gl[res].eq(ql).sum().item()
|
||||
|
||||
print("Acc top1:{:.3f}".format(top1correct/ql.size(0)))
|
||||
|
||||
|
@ -1,93 +0,0 @@
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
import numpy as np
|
||||
import cv2
|
||||
import logging
|
||||
|
||||
from .model import Net
|
||||
from .resnet import resnet18
|
||||
# from fastreid.config import get_cfg
|
||||
# from fastreid.engine import DefaultTrainer
|
||||
# from fastreid.utils.checkpoint import Checkpointer
|
||||
|
||||
|
||||
class Extractor(object):
|
||||
def __init__(self, model_path, use_cuda=True):
|
||||
self.net = Net(reid=True)
|
||||
# self.net = resnet18(reid=True)
|
||||
self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
|
||||
state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
|
||||
self.net.load_state_dict(state_dict if 'net_dict' not in state_dict else state_dict['net_dict'], strict=False)
|
||||
logger = logging.getLogger("root.tracker")
|
||||
logger.info("Loading weights from {}... Done!".format(model_path))
|
||||
self.net.to(self.device)
|
||||
self.size = (64, 128)
|
||||
self.norm = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||
])
|
||||
|
||||
def _preprocess(self, im_crops):
|
||||
"""
|
||||
TODO:
|
||||
1. to float with scale from 0 to 1
|
||||
2. resize to (64, 128) as Market1501 dataset did
|
||||
3. concatenate to a numpy array
|
||||
3. to torch Tensor
|
||||
4. normalize
|
||||
"""
|
||||
|
||||
def _resize(im, size):
|
||||
return cv2.resize(im.astype(np.float32) / 255., size)
|
||||
|
||||
im_batch = torch.cat([self.norm(_resize(im, self.size)).unsqueeze(0) for im in im_crops], dim=0).float()
|
||||
return im_batch
|
||||
|
||||
def __call__(self, im_crops):
|
||||
im_batch = self._preprocess(im_crops)
|
||||
with torch.no_grad():
|
||||
im_batch = im_batch.to(self.device)
|
||||
features = self.net(im_batch)
|
||||
return features.cpu().numpy()
|
||||
|
||||
|
||||
class FastReIDExtractor(object):
|
||||
def __init__(self, model_config, model_path, use_cuda=True):
|
||||
cfg = get_cfg()
|
||||
cfg.merge_from_file(model_config)
|
||||
cfg.MODEL.BACKBONE.PRETRAIN = False
|
||||
self.net = DefaultTrainer.build_model(cfg)
|
||||
self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
|
||||
|
||||
Checkpointer(self.net).load(model_path)
|
||||
logger = logging.getLogger("root.tracker")
|
||||
logger.info("Loading weights from {}... Done!".format(model_path))
|
||||
self.net.to(self.device)
|
||||
self.net.eval()
|
||||
height, width = cfg.INPUT.SIZE_TEST
|
||||
self.size = (width, height)
|
||||
self.norm = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||
])
|
||||
|
||||
def _preprocess(self, im_crops):
|
||||
def _resize(im, size):
|
||||
return cv2.resize(im.astype(np.float32) / 255., size)
|
||||
|
||||
im_batch = torch.cat([self.norm(_resize(im, self.size)).unsqueeze(0) for im in im_crops], dim=0).float()
|
||||
return im_batch
|
||||
|
||||
def __call__(self, im_crops):
|
||||
im_batch = self._preprocess(im_crops)
|
||||
with torch.no_grad():
|
||||
im_batch = im_batch.to(self.device)
|
||||
features = self.net(im_batch)
|
||||
return features.cpu().numpy()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
img = cv2.imread("demo.jpg")[:, :, (2, 1, 0)]
|
||||
extr = Extractor("checkpoint/ckpt.t7")
|
||||
feature = extr(img)
|
||||
print(feature.shape)
|
@ -1,105 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
def __init__(self, c_in, c_out, is_downsample=False):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.is_downsample = is_downsample
|
||||
if is_downsample:
|
||||
self.conv1 = nn.Conv2d(c_in, c_out, 3, stride=2, padding=1, bias=False)
|
||||
else:
|
||||
self.conv1 = nn.Conv2d(c_in, c_out, 3, stride=1, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(c_out)
|
||||
self.relu = nn.ReLU(True)
|
||||
self.conv2 = nn.Conv2d(c_out, c_out, 3, stride=1, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(c_out)
|
||||
if is_downsample:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(c_in, c_out, 1, stride=2, bias=False),
|
||||
nn.BatchNorm2d(c_out)
|
||||
)
|
||||
elif c_in != c_out:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(c_in, c_out, 1, stride=1, bias=False),
|
||||
nn.BatchNorm2d(c_out)
|
||||
)
|
||||
self.is_downsample = True
|
||||
|
||||
def forward(self, x):
|
||||
y = self.conv1(x)
|
||||
y = self.bn1(y)
|
||||
y = self.relu(y)
|
||||
y = self.conv2(y)
|
||||
y = self.bn2(y)
|
||||
if self.is_downsample:
|
||||
x = self.downsample(x)
|
||||
return F.relu(x.add(y), True)
|
||||
|
||||
|
||||
def make_layers(c_in, c_out, repeat_times, is_downsample=False):
|
||||
blocks = []
|
||||
for i in range(repeat_times):
|
||||
if i == 0:
|
||||
blocks += [BasicBlock(c_in, c_out, is_downsample=is_downsample), ]
|
||||
else:
|
||||
blocks += [BasicBlock(c_out, c_out), ]
|
||||
return nn.Sequential(*blocks)
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self, num_classes=751, reid=False):
|
||||
super(Net, self).__init__()
|
||||
# 3 128 64
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(3, 64, 3, stride=1, padding=1),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(inplace=True),
|
||||
# nn.Conv2d(32,32,3,stride=1,padding=1),
|
||||
# nn.BatchNorm2d(32),
|
||||
# nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(3, 2, padding=1),
|
||||
)
|
||||
# 32 64 32
|
||||
self.layer1 = make_layers(64, 64, 2, False)
|
||||
# 32 64 32
|
||||
self.layer2 = make_layers(64, 128, 2, True)
|
||||
# 64 32 16
|
||||
self.layer3 = make_layers(128, 256, 2, True)
|
||||
# 128 16 8
|
||||
self.layer4 = make_layers(256, 512, 2, True)
|
||||
# 256 8 4
|
||||
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
||||
# 256 1 1
|
||||
self.reid = reid
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(512, 256),
|
||||
nn.BatchNorm1d(256),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(),
|
||||
nn.Linear(256, num_classes),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
x = self.avgpool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
# B x 128
|
||||
if self.reid:
|
||||
x = x.div(x.norm(p=2, dim=1, keepdim=True))
|
||||
return x
|
||||
# classifier
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
net = Net()
|
||||
x = torch.randn(4, 3, 128, 64)
|
||||
y = net(x)
|
||||
|
@ -1,67 +0,0 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def init_distributed_mode(args):
|
||||
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
||||
args.rank = int(os.environ['RANK'])
|
||||
args.world_size = int(os.environ['WORLD_SIZE'])
|
||||
args.gpu = int(os.environ['LOCAL_RANK'])
|
||||
elif 'SLURM_PROCID' in os.environ:
|
||||
args.rank = int(os.environ['SLURM_PROCID'])
|
||||
args.gpu = args.rank % torch.cuda.device_count()
|
||||
else:
|
||||
print("Not using distributed mode")
|
||||
args.distributed = False
|
||||
return
|
||||
|
||||
args.distributed = True
|
||||
|
||||
torch.cuda.set_device(args.gpu)
|
||||
args.dist_backend = 'nccl'
|
||||
print('| distributed init (rank {}): {}'.format(args.rank, args.dist_url), flush=True)
|
||||
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
||||
world_size=args.world_size, rank=args.rank)
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def cleanup():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def is_dist_avail_and_initialized():
|
||||
if not dist.is_available():
|
||||
return False
|
||||
if not dist.is_initialized():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_world_size():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 1
|
||||
return dist.get_world_size()
|
||||
|
||||
|
||||
def get_rank():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def is_main_process():
|
||||
return get_rank() == 0
|
||||
|
||||
|
||||
def reduce_value(value, average=True):
|
||||
world_size = get_world_size()
|
||||
if world_size < 2:
|
||||
return value
|
||||
with torch.no_grad():
|
||||
dist.all_reduce(value)
|
||||
if average:
|
||||
value /= world_size
|
||||
|
||||
return value
|
@ -1,90 +0,0 @@
|
||||
import sys
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
from .distributed_utils import reduce_value, is_main_process
|
||||
|
||||
|
||||
def load_model(state_dict, model_state_dict, model):
|
||||
for k in state_dict:
|
||||
if k in model_state_dict:
|
||||
if state_dict[k].shape != model_state_dict[k].shape:
|
||||
print('Skip loading parameter {}, required shape {}, ' \
|
||||
'loaded shape {}.'.format(
|
||||
k, model_state_dict[k].shape, state_dict[k].shape))
|
||||
state_dict[k] = model_state_dict[k]
|
||||
else:
|
||||
print('Drop parameter {}.'.format(k))
|
||||
for k in model_state_dict:
|
||||
if not (k in state_dict):
|
||||
print('No param {}.'.format(k))
|
||||
state_dict[k] = model_state_dict[k]
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
return model
|
||||
|
||||
|
||||
def train_one_epoch(model, optimizer, data_loader, device, epoch):
|
||||
model.train()
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
mean_loss = torch.zeros(1).to(device)
|
||||
sum_num = torch.zeros(1).to(device)
|
||||
optimizer.zero_grad()
|
||||
|
||||
if is_main_process():
|
||||
data_loader = tqdm(data_loader, file=sys.stdout)
|
||||
|
||||
for idx, (images, labels) in enumerate(data_loader):
|
||||
# forward
|
||||
images, labels = images.to(device), labels.to(device)
|
||||
outputs = model(images)
|
||||
loss = criterion(outputs, labels)
|
||||
|
||||
# backward
|
||||
loss.backward()
|
||||
loss = reduce_value(loss, average=True)
|
||||
mean_loss = (mean_loss * idx + loss.detach()) / (idx + 1)
|
||||
pred = torch.max(outputs, dim=1)[1]
|
||||
sum_num += torch.eq(pred, labels).sum()
|
||||
|
||||
if is_main_process():
|
||||
data_loader.desc = '[epoch {}] mean loss {}'.format(epoch, mean_loss.item())
|
||||
|
||||
if not torch.isfinite(loss):
|
||||
print('loss is infinite, ending training')
|
||||
sys.exit(1)
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
if device != torch.device('cpu'):
|
||||
torch.cuda.synchronize(device)
|
||||
sum_num = reduce_value(sum_num, average=False)
|
||||
|
||||
return sum_num.item(), mean_loss.item()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(model, data_loader, device):
|
||||
model.eval()
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
test_loss = torch.zeros(1).to(device)
|
||||
sum_num = torch.zeros(1).to(device)
|
||||
if is_main_process():
|
||||
data_loader = tqdm(data_loader, file=sys.stdout)
|
||||
|
||||
for idx, (inputs, labels) in enumerate(data_loader):
|
||||
inputs, labels = inputs.to(device), labels.to(device)
|
||||
outputs = model(inputs)
|
||||
loss = criterion(outputs, labels)
|
||||
loss = reduce_value(loss, average=True)
|
||||
|
||||
test_loss = (test_loss * idx + loss.detach()) / (idx + 1)
|
||||
pred = torch.max(outputs, dim=1)[1]
|
||||
sum_num += torch.eq(pred, labels).sum()
|
||||
|
||||
if device != torch.device('cpu'):
|
||||
torch.cuda.synchronize(device)
|
||||
|
||||
sum_num = reduce_value(sum_num, average=False)
|
||||
|
||||
return sum_num.item(), test_loss.item()
|
@ -1,173 +0,0 @@
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3,
|
||||
stride=stride, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(out_channel)
|
||||
self.relu = nn.ReLU()
|
||||
self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3,
|
||||
stride=1, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(out_channel)
|
||||
self.downsample = downsample
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, in_channel, out_channel, stride=1, downsample=None,
|
||||
groups=1, width_per_group=64):
|
||||
super(Bottleneck, self).__init__()
|
||||
width = int(out_channel * (width_per_group / 64.)) * groups
|
||||
|
||||
self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width, kernel_size=1,
|
||||
stride=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(width)
|
||||
self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, kernel_size=3,
|
||||
stride=stride, padding=1, bias=False, groups=groups)
|
||||
self.bn2 = nn.BatchNorm2d(width)
|
||||
self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel * self.expansion,
|
||||
kernel_size=1, stride=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(out_channel * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
|
||||
def __init__(self, block, blocks_num, reid=False, num_classes=1000, groups=1, width_per_group=64):
|
||||
super(ResNet, self).__init__()
|
||||
self.reid = reid
|
||||
self.in_channel = 64
|
||||
|
||||
self.groups = groups
|
||||
self.width_per_group = width_per_group
|
||||
|
||||
self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
|
||||
padding=3, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(self.in_channel)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layers(block, 64, blocks_num[0])
|
||||
self.layer2 = self._make_layers(block, 128, blocks_num[1], stride=2)
|
||||
self.layer3 = self._make_layers(block, 256, blocks_num[2], stride=2)
|
||||
# self.layer4 = self._make_layers(block, 512, blocks_num[3], stride=1)
|
||||
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(256 * block.expansion, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def _make_layers(self, block, channel, block_num, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.in_channel != channel * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(channel * block.expansion)
|
||||
)
|
||||
layers = []
|
||||
layers.append(block(self.in_channel, channel, downsample=downsample, stride=stride,
|
||||
groups=self.groups, width_per_group=self.width_per_group))
|
||||
self.in_channel = channel * block.expansion
|
||||
|
||||
for _ in range(1, block_num):
|
||||
layers.append(block(self.in_channel, channel, groups=self.groups, width_per_group=self.width_per_group))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
# x = self.layer4(x)
|
||||
x = self.avgpool(x)
|
||||
x = torch.flatten(x, 1)
|
||||
|
||||
# B x 512
|
||||
if self.reid:
|
||||
x = x.div(x.norm(p=2, dim=1, keepdim=True))
|
||||
return x
|
||||
# classifier
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
||||
def resnet18(num_classes=1000, reid=False):
|
||||
# https://download.pytorch.org/models/resnet18-5c106cde.pth
|
||||
return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, reid=reid)
|
||||
|
||||
|
||||
def resnet34(num_classes=1000, reid=False):
|
||||
# https://download.pytorch.org/models/resnet34-333f7ec4.pth
|
||||
return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, reid=reid)
|
||||
|
||||
|
||||
def resnet50(num_classes=1000, reid=False):
|
||||
# https://download.pytorch.org/models/resnet50-19c8e357.pth
|
||||
return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, reid=reid)
|
||||
|
||||
|
||||
def resnext50_32x4d(num_classes=1000, reid=False):
|
||||
# https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth
|
||||
groups = 32
|
||||
width_per_group = 4
|
||||
return ResNet(Bottleneck, [3, 4, 6, 3], reid=reid,
|
||||
num_classes=num_classes, groups=groups, width_per_group=width_per_group)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
net = resnet18(reid=True)
|
||||
x = torch.randn(4, 3, 128, 64)
|
||||
y = net(x)
|
@ -1,77 +0,0 @@
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torchvision
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from model import Net
|
||||
|
||||
parser = argparse.ArgumentParser(description="Train on market1501")
|
||||
parser.add_argument("--data-dir", default='data', type=str)
|
||||
parser.add_argument("--no-cuda", action="store_true")
|
||||
parser.add_argument("--gpu-id", default=0, type=int)
|
||||
args = parser.parse_args()
|
||||
|
||||
# device
|
||||
device = "cuda:{}".format(args.gpu_id) if torch.cuda.is_available() and not args.no_cuda else "cpu"
|
||||
if torch.cuda.is_available() and not args.no_cuda:
|
||||
cudnn.benchmark = True
|
||||
|
||||
# data loader
|
||||
root = args.data_dir
|
||||
query_dir = os.path.join(root, "query")
|
||||
gallery_dir = os.path.join(root, "gallery")
|
||||
transform = torchvision.transforms.Compose([
|
||||
torchvision.transforms.Resize((128, 64)),
|
||||
torchvision.transforms.ToTensor(),
|
||||
torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
queryloader = torch.utils.data.DataLoader(
|
||||
torchvision.datasets.ImageFolder(query_dir, transform=transform),
|
||||
batch_size=64, shuffle=False
|
||||
)
|
||||
galleryloader = torch.utils.data.DataLoader(
|
||||
torchvision.datas0ets.ImageFolder(gallery_dir, transform=transform),
|
||||
batch_size=64, shuffle=False
|
||||
)
|
||||
|
||||
# net definition
|
||||
net = Net(reid=True)
|
||||
assert os.path.isfile("./checkpoint/ckpt.t7"), "Error: no checkpoint file found!"
|
||||
print('Loading from checkpoint/ckpt.t7')
|
||||
checkpoint = torch.load("./checkpoint/ckpt.t7")
|
||||
net_dict = checkpoint['net_dict']
|
||||
net.load_state_dict(net_dict, strict=False)
|
||||
net.eval()
|
||||
net.to(device)
|
||||
|
||||
# compute features
|
||||
query_features = torch.tensor([]).float()
|
||||
query_labels = torch.tensor([]).long()
|
||||
gallery_features = torch.tensor([]).float()
|
||||
gallery_labels = torch.tensor([]).long()
|
||||
|
||||
with torch.no_grad():
|
||||
for idx, (inputs, labels) in enumerate(queryloader):
|
||||
inputs = inputs.to(device)
|
||||
features = net(inputs).cpu()
|
||||
query_features = torch.cat((query_features, features), dim=0)
|
||||
query_labels = torch.cat((query_labels, labels))
|
||||
|
||||
for idx, (inputs, labels) in enumerate(galleryloader):
|
||||
inputs = inputs.to(device)
|
||||
features = net(inputs).cpu()
|
||||
gallery_features = torch.cat((gallery_features, features), dim=0)
|
||||
gallery_labels = torch.cat((gallery_labels, labels))
|
||||
|
||||
gallery_labels -= 2
|
||||
|
||||
# save features
|
||||
features = {
|
||||
"qf": query_features,
|
||||
"ql": query_labels,
|
||||
"gf": gallery_features,
|
||||
"gl": gallery_labels
|
||||
}
|
||||
torch.save(features, "features.pth")
|
Binary file not shown.
Before Width: | Height: | Size: 59 KiB |
@ -1,151 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import math
|
||||
import warnings
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
import torchvision
|
||||
from torch.optim import lr_scheduler
|
||||
|
||||
from multi_train_utils.distributed_utils import init_distributed_mode, cleanup
|
||||
from multi_train_utils.train_eval_utils import train_one_epoch, evaluate, load_model
|
||||
import torch.distributed as dist
|
||||
from datasets import ClsDataset, read_split_data
|
||||
|
||||
from model import Net
|
||||
from resnet import resnet18
|
||||
|
||||
# plot figure
|
||||
x_epoch = []
|
||||
record = {'train_loss': [], 'train_err': [], 'test_loss': [], 'test_err': []}
|
||||
fig = plt.figure()
|
||||
ax0 = fig.add_subplot(121, title="loss")
|
||||
ax1 = fig.add_subplot(122, title="top1_err")
|
||||
|
||||
|
||||
def draw_curve(epoch, train_loss, train_err, test_loss, test_err):
|
||||
global record
|
||||
record['train_loss'].append(train_loss)
|
||||
record['train_err'].append(train_err)
|
||||
record['test_loss'].append(test_loss)
|
||||
record['test_err'].append(test_err)
|
||||
|
||||
x_epoch.append(epoch)
|
||||
ax0.plot(x_epoch, record['train_loss'], 'bo-', label='train')
|
||||
ax0.plot(x_epoch, record['test_loss'], 'ro-', label='val')
|
||||
ax1.plot(x_epoch, record['train_err'], 'bo-', label='train')
|
||||
ax1.plot(x_epoch, record['test_err'], 'ro-', label='val')
|
||||
if epoch == 0:
|
||||
ax0.legend()
|
||||
ax1.legend()
|
||||
fig.savefig("train.jpg")
|
||||
|
||||
|
||||
def main(args):
|
||||
batch_size = args.batch_size
|
||||
device = 'cuda:{}'.format(args.gpu_id) if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
train_info, val_info, num_classes = read_split_data(args.data_dir, valid_rate=0.2)
|
||||
train_images_path, train_labels = train_info
|
||||
val_images_path, val_labels = val_info
|
||||
|
||||
transform_train = torchvision.transforms.Compose([
|
||||
torchvision.transforms.RandomCrop((128, 64), padding=4),
|
||||
torchvision.transforms.RandomHorizontalFlip(),
|
||||
torchvision.transforms.ToTensor(),
|
||||
torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
transform_val = torchvision.transforms.Compose([
|
||||
torchvision.transforms.Resize((128, 64)),
|
||||
torchvision.transforms.ToTensor(),
|
||||
torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
train_dataset = ClsDataset(
|
||||
images_path=train_images_path,
|
||||
images_labels=train_labels,
|
||||
transform=transform_train
|
||||
)
|
||||
val_dataset = ClsDataset(
|
||||
images_path=val_images_path,
|
||||
images_labels=val_labels,
|
||||
transform=transform_val
|
||||
)
|
||||
|
||||
number_workers = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
|
||||
print('Using {} dataloader workers every process'.format(number_workers))
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=True,
|
||||
num_workers=number_workers
|
||||
)
|
||||
val_loader = torch.utils.data.DataLoader(
|
||||
val_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
pin_memory=True,
|
||||
num_workers=number_workers,
|
||||
)
|
||||
|
||||
# net definition
|
||||
start_epoch = 0
|
||||
net = Net(num_classes=num_classes)
|
||||
if args.weights:
|
||||
print('Loading from ', args.weights)
|
||||
checkpoint = torch.load(args.weights, map_location='cpu')
|
||||
net_dict = checkpoint if 'net_dict' not in checkpoint else checkpoint['net_dict']
|
||||
start_epoch = checkpoint['epoch'] if 'epoch' in checkpoint else start_epoch
|
||||
net = load_model(net_dict, net.state_dict(), net)
|
||||
|
||||
if args.freeze_layers:
|
||||
for name, param in net.named_parameters():
|
||||
if 'classifier' not in name:
|
||||
param.requires_grad = False
|
||||
|
||||
net.to(device)
|
||||
|
||||
# loss and optimizer
|
||||
pg = [p for p in net.parameters() if p.requires_grad]
|
||||
optimizer = torch.optim.SGD(pg, args.lr, momentum=0.9, weight_decay=5e-4)
|
||||
|
||||
lr = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf
|
||||
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr)
|
||||
for epoch in range(start_epoch, start_epoch + args.epochs):
|
||||
train_positive, train_loss = train_one_epoch(net, optimizer, train_loader, device, epoch)
|
||||
train_acc = train_positive / len(train_dataset)
|
||||
scheduler.step()
|
||||
|
||||
test_positive, test_loss = evaluate(net, val_loader, device)
|
||||
test_acc = test_positive / len(val_dataset)
|
||||
|
||||
print('[epoch {}] accuracy: {}'.format(epoch, test_acc))
|
||||
|
||||
state_dict = {
|
||||
'net_dict': net.state_dict(),
|
||||
'acc': test_acc,
|
||||
'epoch': epoch
|
||||
}
|
||||
torch.save(state_dict, './checkpoint/model_{}.pth'.format(epoch))
|
||||
draw_curve(epoch, train_loss, 1 - train_acc, test_loss, 1 - test_acc)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="Train on market1501")
|
||||
parser.add_argument("--data-dir", default='data', type=str)
|
||||
parser.add_argument('--epochs', type=int, default=40)
|
||||
parser.add_argument('--batch_size', type=int, default=32)
|
||||
parser.add_argument("--lr", default=0.001, type=float)
|
||||
parser.add_argument('--lrf', default=0.1, type=float)
|
||||
|
||||
parser.add_argument('--weights', type=str, default='./checkpoint/resnet18.pth')
|
||||
parser.add_argument('--freeze-layers', action='store_true')
|
||||
|
||||
parser.add_argument('--gpu_id', default='0', help='gpu id')
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
@ -1,189 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import math
|
||||
import warnings
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
import torchvision
|
||||
from torch.optim import lr_scheduler
|
||||
|
||||
from multi_train_utils.distributed_utils import init_distributed_mode, cleanup
|
||||
from multi_train_utils.train_eval_utils import train_one_epoch, evaluate, load_model
|
||||
import torch.distributed as dist
|
||||
from datasets import ClsDataset, read_split_data
|
||||
|
||||
from resnet import resnet18
|
||||
|
||||
|
||||
# plot figure
|
||||
x_epoch = []
|
||||
record = {'train_loss': [], 'train_err': [], 'test_loss': [], 'test_err': []}
|
||||
fig = plt.figure()
|
||||
ax0 = fig.add_subplot(121, title="loss")
|
||||
ax1 = fig.add_subplot(122, title="top1_err")
|
||||
|
||||
|
||||
def draw_curve(epoch, train_loss, train_err, test_loss, test_err):
|
||||
global record
|
||||
record['train_loss'].append(train_loss)
|
||||
record['train_err'].append(train_err)
|
||||
record['test_loss'].append(test_loss)
|
||||
record['test_err'].append(test_err)
|
||||
|
||||
x_epoch.append(epoch)
|
||||
ax0.plot(x_epoch, record['train_loss'], 'bo-', label='train')
|
||||
ax0.plot(x_epoch, record['test_loss'], 'ro-', label='val')
|
||||
ax1.plot(x_epoch, record['train_err'], 'bo-', label='train')
|
||||
ax1.plot(x_epoch, record['test_err'], 'ro-', label='val')
|
||||
if epoch == 0:
|
||||
ax0.legend()
|
||||
ax1.legend()
|
||||
fig.savefig("train.jpg")
|
||||
|
||||
|
||||
def main(args):
|
||||
init_distributed_mode(args)
|
||||
|
||||
rank = args.rank
|
||||
device = torch.device(args.device)
|
||||
batch_size = args.batch_size
|
||||
weights_path = args.weights
|
||||
args.lr *= args.world_size
|
||||
checkpoint_path = ''
|
||||
|
||||
if rank == 0:
|
||||
print(args)
|
||||
if os.path.exists('./checkpoint') is False:
|
||||
os.mkdir('./checkpoint')
|
||||
|
||||
train_info, val_info, num_classes = read_split_data(args.data_dir, valid_rate=0.2)
|
||||
train_images_path, train_labels = train_info
|
||||
val_images_path, val_labels = val_info
|
||||
|
||||
transform_train = torchvision.transforms.Compose([
|
||||
torchvision.transforms.RandomCrop((128, 64), padding=4),
|
||||
torchvision.transforms.RandomHorizontalFlip(),
|
||||
torchvision.transforms.ToTensor(),
|
||||
torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
transform_val = torchvision.transforms.Compose([
|
||||
torchvision.transforms.Resize((128, 64)),
|
||||
torchvision.transforms.ToTensor(),
|
||||
torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
train_dataset = ClsDataset(
|
||||
images_path=train_images_path,
|
||||
images_labels=train_labels,
|
||||
transform=transform_train
|
||||
)
|
||||
val_dataset = ClsDataset(
|
||||
images_path=val_images_path,
|
||||
images_labels=val_labels,
|
||||
transform=transform_val
|
||||
)
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
|
||||
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
|
||||
|
||||
train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size, drop_last=True)
|
||||
|
||||
number_workers = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
|
||||
|
||||
if rank == 0:
|
||||
print('Using {} dataloader workers every process'.format(number_workers))
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_sampler=train_batch_sampler,
|
||||
pin_memory=True,
|
||||
num_workers=number_workers
|
||||
)
|
||||
val_loader = torch.utils.data.DataLoader(
|
||||
val_dataset,
|
||||
sampler=val_sampler,
|
||||
batch_size=batch_size,
|
||||
pin_memory=True,
|
||||
num_workers=number_workers,
|
||||
)
|
||||
|
||||
# net definition
|
||||
start_epoch = 0
|
||||
net = resnet18(num_classes=num_classes)
|
||||
if args.weights:
|
||||
print('Loading from ', args.weights)
|
||||
checkpoint = torch.load(args.weights, map_location='cpu')
|
||||
net_dict = checkpoint if 'net_dict' not in checkpoint else checkpoint['net_dict']
|
||||
start_epoch = checkpoint['epoch'] if 'epoch' in checkpoint else start_epoch
|
||||
net = load_model(net_dict, net.state_dict(), net)
|
||||
else:
|
||||
warnings.warn("better providing pretraining weights")
|
||||
checkpoint_path = os.path.join(tempfile.gettempdir(), 'initial_weights.pth')
|
||||
if rank == 0:
|
||||
torch.save(net.state_dict(), checkpoint_path)
|
||||
|
||||
dist.barrier()
|
||||
net.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
|
||||
|
||||
if args.freeze_layers:
|
||||
for name, param in net.named_parameters():
|
||||
if 'fc' not in name:
|
||||
param.requires_grad = False
|
||||
else:
|
||||
if args.syncBN:
|
||||
net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)
|
||||
net.to(device)
|
||||
|
||||
net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[args.gpu])
|
||||
|
||||
# loss and optimizer
|
||||
pg = [p for p in net.parameters() if p.requires_grad]
|
||||
optimizer = torch.optim.SGD(pg, args.lr, momentum=0.9, weight_decay=5e-4)
|
||||
|
||||
lr = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf
|
||||
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr)
|
||||
for epoch in range(start_epoch, start_epoch + args.epochs):
|
||||
train_positive, train_loss = train_one_epoch(net, optimizer, train_loader, device, epoch)
|
||||
train_acc = train_positive / len(train_dataset)
|
||||
scheduler.step()
|
||||
|
||||
test_positive, test_loss = evaluate(net, val_loader, device)
|
||||
test_acc = test_positive / len(val_dataset)
|
||||
|
||||
if rank == 0:
|
||||
print('[epoch {}] accuracy: {}'.format(epoch, test_acc))
|
||||
|
||||
state_dict = {
|
||||
'net_dict': net.module.state_dict(),
|
||||
'acc': test_acc,
|
||||
'epoch': epoch
|
||||
}
|
||||
torch.save(state_dict, './checkpoint/model_{}.pth'.format(epoch))
|
||||
draw_curve(epoch, train_loss, 1 - train_acc, test_loss, 1 - test_acc)
|
||||
|
||||
if rank == 0:
|
||||
if os.path.exists(checkpoint_path) is True:
|
||||
os.remove(checkpoint_path)
|
||||
cleanup()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="Train on market1501")
|
||||
parser.add_argument("--data-dir", default='data', type=str)
|
||||
parser.add_argument('--epochs', type=int, default=40)
|
||||
parser.add_argument('--batch_size', type=int, default=32)
|
||||
parser.add_argument("--lr", default=0.001, type=float)
|
||||
parser.add_argument('--lrf', default=0.1, type=float)
|
||||
parser.add_argument('--syncBN', type=bool, default=True)
|
||||
|
||||
parser.add_argument('--weights', type=str, default='./checkpoint/resnet18.pth')
|
||||
parser.add_argument('--freeze-layers', action='store_true')
|
||||
|
||||
# not change the following parameters, the system will automatically assignment
|
||||
parser.add_argument('--device', default='cuda', help='device id (i.e. 0 or 0, 1 or cpu)')
|
||||
parser.add_argument('--world_size', default=4, type=int, help='number of distributed processes')
|
||||
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
@ -1,121 +0,0 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .deep.feature_extractor import Extractor, FastReIDExtractor
|
||||
from .sort.nn_matching import NearestNeighborDistanceMetric
|
||||
from .sort.preprocessing import non_max_suppression
|
||||
from .sort.detection import Detection
|
||||
from .sort.tracker import Tracker
|
||||
|
||||
__all__ = ['DeepSort']
|
||||
|
||||
|
||||
class DeepSort(object):
|
||||
def __init__(self, model_path, model_config=None, max_dist=0.2, min_confidence=0.3, nms_max_overlap=1.0,
|
||||
max_iou_distance=0.7, max_age=70, n_init=3, nn_budget=100, use_cuda=True):
|
||||
self.min_confidence = min_confidence
|
||||
self.nms_max_overlap = nms_max_overlap
|
||||
|
||||
if model_config is None:
|
||||
self.extractor = Extractor(model_path, use_cuda=use_cuda)
|
||||
else:
|
||||
self.extractor = FastReIDExtractor(model_config, model_path, use_cuda=use_cuda)
|
||||
|
||||
max_cosine_distance = max_dist
|
||||
metric = NearestNeighborDistanceMetric("cosine", max_cosine_distance, nn_budget)
|
||||
self.tracker = Tracker(metric, max_iou_distance=max_iou_distance, max_age=max_age, n_init=n_init)
|
||||
|
||||
def update(self, bbox_xywh, confidences, classes, ori_img, masks=None):
|
||||
self.height, self.width = ori_img.shape[:2]
|
||||
# generate detections
|
||||
features = self._get_features(bbox_xywh, ori_img)
|
||||
bbox_tlwh = self._xywh_to_tlwh(bbox_xywh)
|
||||
detections = [Detection(bbox_tlwh[i], conf, label, features[i], None if masks is None else masks[i])
|
||||
for i, (conf, label) in enumerate(zip(confidences, classes))
|
||||
if conf > self.min_confidence]
|
||||
|
||||
# run on non-maximum supression
|
||||
boxes = np.array([d.tlwh for d in detections])
|
||||
scores = np.array([d.confidence for d in detections])
|
||||
indices = non_max_suppression(boxes, self.nms_max_overlap, scores)
|
||||
detections = [detections[i] for i in indices]
|
||||
|
||||
# update tracker
|
||||
self.tracker.predict()
|
||||
self.tracker.update(detections)
|
||||
|
||||
# output bbox identities
|
||||
outputs = []
|
||||
mask_outputs = []
|
||||
for track in self.tracker.tracks:
|
||||
if not track.is_confirmed() or track.time_since_update > 1:
|
||||
continue
|
||||
box = track.to_tlwh()
|
||||
x1, y1, x2, y2 = self._tlwh_to_xyxy(box)
|
||||
track_id = track.track_id
|
||||
track_cls = track.cls
|
||||
outputs.append(np.array([x1, y1, x2, y2, track_cls, track_id], dtype=np.int32))
|
||||
if track.mask is not None:
|
||||
mask_outputs.append(track.mask)
|
||||
if len(outputs) > 0:
|
||||
outputs = np.stack(outputs, axis=0)
|
||||
return outputs, mask_outputs
|
||||
|
||||
"""
|
||||
TODO:
|
||||
Convert bbox from xc_yc_w_h to xtl_ytl_w_h
|
||||
Thanks JieChen91@github.com for reporting this bug!
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _xywh_to_tlwh(bbox_xywh):
|
||||
if isinstance(bbox_xywh, np.ndarray):
|
||||
bbox_tlwh = bbox_xywh.copy()
|
||||
elif isinstance(bbox_xywh, torch.Tensor):
|
||||
bbox_tlwh = bbox_xywh.clone()
|
||||
bbox_tlwh[:, 0] = bbox_xywh[:, 0] - bbox_xywh[:, 2] / 2.
|
||||
bbox_tlwh[:, 1] = bbox_xywh[:, 1] - bbox_xywh[:, 3] / 2.
|
||||
return bbox_tlwh
|
||||
|
||||
def _xywh_to_xyxy(self, bbox_xywh):
|
||||
x, y, w, h = bbox_xywh
|
||||
x1 = max(int(x - w / 2), 0)
|
||||
x2 = min(int(x + w / 2), self.width - 1)
|
||||
y1 = max(int(y - h / 2), 0)
|
||||
y2 = min(int(y + h / 2), self.height - 1)
|
||||
return x1, y1, x2, y2
|
||||
|
||||
def _tlwh_to_xyxy(self, bbox_tlwh):
|
||||
"""
|
||||
TODO:
|
||||
Convert bbox from xtl_ytl_w_h to xc_yc_w_h
|
||||
Thanks JieChen91@github.com for reporting this bug!
|
||||
"""
|
||||
x, y, w, h = bbox_tlwh
|
||||
x1 = max(int(x), 0)
|
||||
x2 = min(int(x + w), self.width - 1)
|
||||
y1 = max(int(y), 0)
|
||||
y2 = min(int(y + h), self.height - 1)
|
||||
return x1, y1, x2, y2
|
||||
|
||||
@staticmethod
|
||||
def _xyxy_to_tlwh(bbox_xyxy):
|
||||
x1, y1, x2, y2 = bbox_xyxy
|
||||
|
||||
t = x1
|
||||
l = y1
|
||||
w = int(x2 - x1)
|
||||
h = int(y2 - y1)
|
||||
return t, l, w, h
|
||||
|
||||
def _get_features(self, bbox_xywh, ori_img):
|
||||
im_crops = []
|
||||
for box in bbox_xywh:
|
||||
x1, y1, x2, y2 = self._xywh_to_xyxy(box)
|
||||
im = ori_img[y1:y2, x1:x2]
|
||||
im_crops.append(im)
|
||||
if im_crops:
|
||||
features = self.extractor(im_crops)
|
||||
else:
|
||||
features = np.array([])
|
||||
return features
|
@ -1,51 +0,0 @@
|
||||
# vim: expandtab:ts=4:sw=4
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Detection(object):
|
||||
"""
|
||||
This class represents a bounding box detection in a single image.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tlwh : array_like
|
||||
Bounding box in format `(x, y, w, h)`.
|
||||
confidence : float
|
||||
Detector confidence score.
|
||||
feature : array_like
|
||||
A feature vector that describes the object contained in this image.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
tlwh : ndarray
|
||||
Bounding box in format `(top left x, top left y, width, height)`.
|
||||
confidence : ndarray
|
||||
Detector confidence score.
|
||||
feature : ndarray | NoneType
|
||||
A feature vector that describes the object contained in this image.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, tlwh, confidence, label, feature, mask=None):
|
||||
self.tlwh = np.asarray(tlwh, dtype=np.float32)
|
||||
self.confidence = float(confidence)
|
||||
self.cls = int(label)
|
||||
self.feature = np.asarray(feature, dtype=np.float32)
|
||||
self.mask = mask
|
||||
|
||||
def to_tlbr(self):
|
||||
"""Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
|
||||
`(top left, bottom right)`.
|
||||
"""
|
||||
ret = self.tlwh.copy()
|
||||
ret[2:] += ret[:2]
|
||||
return ret
|
||||
|
||||
def to_xyah(self):
|
||||
"""Convert bounding box to format `(center x, center y, aspect ratio,
|
||||
height)`, where the aspect ratio is `width / height`.
|
||||
"""
|
||||
ret = self.tlwh.copy()
|
||||
ret[:2] += ret[2:] / 2
|
||||
ret[2] /= ret[3]
|
||||
return ret
|
@ -1,81 +0,0 @@
|
||||
# vim: expandtab:ts=4:sw=4
|
||||
from __future__ import absolute_import
|
||||
import numpy as np
|
||||
from . import linear_assignment
|
||||
|
||||
|
||||
def iou(bbox, candidates):
|
||||
"""Computer intersection over union.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bbox : ndarray
|
||||
A bounding box in format `(top left x, top left y, width, height)`.
|
||||
candidates : ndarray
|
||||
A matrix of candidate bounding boxes (one per row) in the same format
|
||||
as `bbox`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
The intersection over union in [0, 1] between the `bbox` and each
|
||||
candidate. A higher score means a larger fraction of the `bbox` is
|
||||
occluded by the candidate.
|
||||
|
||||
"""
|
||||
bbox_tl, bbox_br = bbox[:2], bbox[:2] + bbox[2:]
|
||||
candidates_tl = candidates[:, :2]
|
||||
candidates_br = candidates[:, :2] + candidates[:, 2:]
|
||||
|
||||
tl = np.c_[np.maximum(bbox_tl[0], candidates_tl[:, 0])[:, np.newaxis],
|
||||
np.maximum(bbox_tl[1], candidates_tl[:, 1])[:, np.newaxis]]
|
||||
br = np.c_[np.minimum(bbox_br[0], candidates_br[:, 0])[:, np.newaxis],
|
||||
np.minimum(bbox_br[1], candidates_br[:, 1])[:, np.newaxis]]
|
||||
wh = np.maximum(0., br - tl)
|
||||
|
||||
area_intersection = wh.prod(axis=1)
|
||||
area_bbox = bbox[2:].prod()
|
||||
area_candidates = candidates[:, 2:].prod(axis=1)
|
||||
return area_intersection / (area_bbox + area_candidates - area_intersection)
|
||||
|
||||
|
||||
def iou_cost(tracks, detections, track_indices=None,
|
||||
detection_indices=None):
|
||||
"""An intersection over union distance metric.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tracks : List[deep_sort.track.Track]
|
||||
A list of tracks.
|
||||
detections : List[deep_sort.detection.Detection]
|
||||
A list of detections.
|
||||
track_indices : Optional[List[int]]
|
||||
A list of indices to tracks that should be matched. Defaults to
|
||||
all `tracks`.
|
||||
detection_indices : Optional[List[int]]
|
||||
A list of indices to detections that should be matched. Defaults
|
||||
to all `detections`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
Returns a cost matrix of shape
|
||||
len(track_indices), len(detection_indices) where entry (i, j) is
|
||||
`1 - iou(tracks[track_indices[i]], detections[detection_indices[j]])`.
|
||||
|
||||
"""
|
||||
if track_indices is None:
|
||||
track_indices = np.arange(len(tracks))
|
||||
if detection_indices is None:
|
||||
detection_indices = np.arange(len(detections))
|
||||
|
||||
cost_matrix = np.zeros((len(track_indices), len(detection_indices)))
|
||||
for row, track_idx in enumerate(track_indices):
|
||||
if tracks[track_idx].time_since_update > 1:
|
||||
cost_matrix[row, :] = linear_assignment.INFTY_COST
|
||||
continue
|
||||
|
||||
bbox = tracks[track_idx].to_tlwh()
|
||||
candidates = np.asarray([detections[i].tlwh for i in detection_indices])
|
||||
cost_matrix[row, :] = 1. - iou(bbox, candidates)
|
||||
return cost_matrix
|
@ -1,231 +0,0 @@
|
||||
# vim: expandtab:ts=4:sw=4
|
||||
import numpy as np
|
||||
import scipy.linalg
|
||||
|
||||
|
||||
"""
|
||||
Table for the 0.95 quantile of the chi-square distribution with N degrees of
|
||||
freedom (contains values for N=1, ..., 9). Taken from MATLAB/Octave's chi2inv
|
||||
function and used as Mahalanobis gating threshold.
|
||||
"""
|
||||
chi2inv95 = {
|
||||
1: 3.8415,
|
||||
2: 5.9915,
|
||||
3: 7.8147,
|
||||
4: 9.4877,
|
||||
5: 11.070,
|
||||
6: 12.592,
|
||||
7: 14.067,
|
||||
8: 15.507,
|
||||
9: 16.919}
|
||||
|
||||
|
||||
class KalmanFilter(object):
|
||||
"""
|
||||
A simple Kalman filter for tracking bounding boxes in image space.
|
||||
|
||||
The 8-dimensional state space
|
||||
|
||||
x, y, a, h, vx, vy, va, vh
|
||||
|
||||
contains the bounding box center position (x, y), aspect ratio a, height h,
|
||||
and their respective velocities.
|
||||
|
||||
Object motion follows a constant velocity model. The bounding box location
|
||||
(x, y, a, h) is taken as direct observation of the state space (linear
|
||||
observation model).
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
ndim, dt = 4, 1.
|
||||
|
||||
# Create Kalman filter model matrices.
|
||||
self._motion_mat = np.eye(2 * ndim, 2 * ndim)
|
||||
for i in range(ndim):
|
||||
self._motion_mat[i, ndim + i] = dt
|
||||
self._update_mat = np.eye(ndim, 2 * ndim)
|
||||
|
||||
# Motion and observation uncertainty are chosen relative to the current
|
||||
# state estimate. These weights control the amount of uncertainty in
|
||||
# the model. This is a bit hacky.
|
||||
self._std_weight_position = 1. / 20
|
||||
self._std_weight_velocity = 1. / 160
|
||||
|
||||
def initiate(self, measurement):
|
||||
"""Create track from unassociated measurement.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
measurement : ndarray
|
||||
Bounding box coordinates (x, y, a, h) with center position (x, y),
|
||||
aspect ratio a, and height h.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector (8 dimensional) and covariance matrix (8x8
|
||||
dimensional) of the new track. Unobserved velocities are initialized
|
||||
to 0 mean.
|
||||
|
||||
"""
|
||||
mean_pos = measurement
|
||||
mean_vel = np.zeros_like(mean_pos)
|
||||
mean = np.r_[mean_pos, mean_vel]
|
||||
|
||||
std = [
|
||||
2 * self._std_weight_position * measurement[3],
|
||||
2 * self._std_weight_position * measurement[3],
|
||||
1e-2,
|
||||
2 * self._std_weight_position * measurement[3],
|
||||
10 * self._std_weight_velocity * measurement[3],
|
||||
10 * self._std_weight_velocity * measurement[3],
|
||||
1e-5,
|
||||
10 * self._std_weight_velocity * measurement[3]]
|
||||
covariance = np.diag(np.square(std))
|
||||
return mean, covariance
|
||||
|
||||
def predict(self, mean, covariance):
|
||||
"""Run Kalman filter prediction step.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The 8 dimensional mean vector of the object state at the previous
|
||||
time step.
|
||||
covariance : ndarray
|
||||
The 8x8 dimensional covariance matrix of the object state at the
|
||||
previous time step.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector and covariance matrix of the predicted
|
||||
state. Unobserved velocities are initialized to 0 mean.
|
||||
|
||||
"""
|
||||
std_pos = [
|
||||
self._std_weight_position * mean[3],
|
||||
self._std_weight_position * mean[3],
|
||||
1e-2,
|
||||
self._std_weight_position * mean[3]]
|
||||
std_vel = [
|
||||
self._std_weight_velocity * mean[3],
|
||||
self._std_weight_velocity * mean[3],
|
||||
1e-5,
|
||||
self._std_weight_velocity * mean[3]]
|
||||
motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
|
||||
|
||||
mean = np.dot(self._motion_mat, mean)
|
||||
covariance = np.linalg.multi_dot((
|
||||
self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
|
||||
|
||||
return mean, covariance
|
||||
|
||||
def project(self, mean, covariance):
|
||||
"""Project state distribution to measurement space.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The state's mean vector (8 dimensional array).
|
||||
covariance : ndarray
|
||||
The state's covariance matrix (8x8 dimensional).
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the projected mean and covariance matrix of the given state
|
||||
estimate.
|
||||
|
||||
"""
|
||||
std = [
|
||||
self._std_weight_position * mean[3],
|
||||
self._std_weight_position * mean[3],
|
||||
1e-1,
|
||||
self._std_weight_position * mean[3]]
|
||||
innovation_cov = np.diag(np.square(std))
|
||||
|
||||
mean = np.dot(self._update_mat, mean)
|
||||
covariance = np.linalg.multi_dot((
|
||||
self._update_mat, covariance, self._update_mat.T))
|
||||
return mean, covariance + innovation_cov
|
||||
|
||||
def update(self, mean, covariance, measurement):
|
||||
"""Run Kalman filter correction step.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The predicted state's mean vector (8 dimensional).
|
||||
covariance : ndarray
|
||||
The state's covariance matrix (8x8 dimensional).
|
||||
measurement : ndarray
|
||||
The 4 dimensional measurement vector (x, y, a, h), where (x, y)
|
||||
is the center position, a the aspect ratio, and h the height of the
|
||||
bounding box.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the measurement-corrected state distribution.
|
||||
|
||||
"""
|
||||
projected_mean, projected_cov = self.project(mean, covariance)
|
||||
|
||||
chol_factor, lower = scipy.linalg.cho_factor(
|
||||
projected_cov, lower=True, check_finite=False)
|
||||
kalman_gain = scipy.linalg.cho_solve(
|
||||
(chol_factor, lower), np.dot(covariance, self._update_mat.T).T,
|
||||
check_finite=False).T
|
||||
innovation = measurement - projected_mean
|
||||
|
||||
new_mean = mean + np.dot(innovation, kalman_gain.T)
|
||||
# new_covariance = covariance - np.linalg.multi_dot((
|
||||
# kalman_gain, projected_cov, kalman_gain.T))
|
||||
new_covariance = covariance - np.linalg.multi_dot((
|
||||
kalman_gain, self._update_mat, covariance))
|
||||
return new_mean, new_covariance
|
||||
|
||||
def gating_distance(self, mean, covariance, measurements,
|
||||
only_position=False):
|
||||
"""Compute gating distance between state distribution and measurements.
|
||||
|
||||
A suitable distance threshold can be obtained from `chi2inv95`. If
|
||||
`only_position` is False, the chi-square distribution has 4 degrees of
|
||||
freedom, otherwise 2.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
Mean vector over the state distribution (8 dimensional).
|
||||
covariance : ndarray
|
||||
Covariance of the state distribution (8x8 dimensional).
|
||||
measurements : ndarray
|
||||
An Nx4 dimensional matrix of N measurements, each in
|
||||
format (x, y, a, h) where (x, y) is the bounding box center
|
||||
position, a the aspect ratio, and h the height.
|
||||
only_position : Optional[bool]
|
||||
If True, distance computation is done with respect to the bounding
|
||||
box center position only.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
Returns an array of length N, where the i-th element contains the
|
||||
squared Mahalanobis distance between (mean, covariance) and
|
||||
`measurements[i]`.
|
||||
|
||||
"""
|
||||
mean, covariance = self.project(mean, covariance)
|
||||
if only_position:
|
||||
mean, covariance = mean[:2], covariance[:2, :2]
|
||||
measurements = measurements[:, :2]
|
||||
|
||||
cholesky_factor = np.linalg.cholesky(covariance)
|
||||
d = measurements - mean
|
||||
z = scipy.linalg.solve_triangular(
|
||||
cholesky_factor, d.T, lower=True, check_finite=False,
|
||||
overwrite_b=True)
|
||||
squared_maha = np.sum(z * z, axis=0)
|
||||
return squared_maha
|
@ -1,192 +0,0 @@
|
||||
# vim: expandtab:ts=4:sw=4
|
||||
from __future__ import absolute_import
|
||||
import numpy as np
|
||||
# from sklearn.utils.linear_assignment_ import linear_assignment
|
||||
from scipy.optimize import linear_sum_assignment as linear_assignment
|
||||
from . import kalman_filter
|
||||
|
||||
|
||||
INFTY_COST = 1e+5
|
||||
|
||||
|
||||
def min_cost_matching(
|
||||
distance_metric, max_distance, tracks, detections, track_indices=None,
|
||||
detection_indices=None):
|
||||
"""Solve linear assignment problem.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
distance_metric : Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray
|
||||
The distance metric is given a list of tracks and detections as well as
|
||||
a list of N track indices and M detection indices. The metric should
|
||||
return the NxM dimensional cost matrix, where element (i, j) is the
|
||||
association cost between the i-th track in the given track indices and
|
||||
the j-th detection in the given detection_indices.
|
||||
max_distance : float
|
||||
Gating threshold. Associations with cost larger than this value are
|
||||
disregarded.
|
||||
tracks : List[track.Track]
|
||||
A list of predicted tracks at the current time step.
|
||||
detections : List[detection.Detection]
|
||||
A list of detections at the current time step.
|
||||
track_indices : List[int]
|
||||
List of track indices that maps rows in `cost_matrix` to tracks in
|
||||
`tracks` (see description above).
|
||||
detection_indices : List[int]
|
||||
List of detection indices that maps columns in `cost_matrix` to
|
||||
detections in `detections` (see description above).
|
||||
|
||||
Returns
|
||||
-------
|
||||
(List[(int, int)], List[int], List[int])
|
||||
Returns a tuple with the following three entries:
|
||||
* A list of matched track and detection indices.
|
||||
* A list of unmatched track indices.
|
||||
* A list of unmatched detection indices.
|
||||
|
||||
"""
|
||||
if track_indices is None:
|
||||
track_indices = np.arange(len(tracks))
|
||||
if detection_indices is None:
|
||||
detection_indices = np.arange(len(detections))
|
||||
|
||||
if len(detection_indices) == 0 or len(track_indices) == 0:
|
||||
return [], track_indices, detection_indices # Nothing to match.
|
||||
|
||||
cost_matrix = distance_metric(
|
||||
tracks, detections, track_indices, detection_indices)
|
||||
cost_matrix[cost_matrix > max_distance] = max_distance + 1e-5
|
||||
|
||||
row_indices, col_indices = linear_assignment(cost_matrix)
|
||||
|
||||
matches, unmatched_tracks, unmatched_detections = [], [], []
|
||||
for col, detection_idx in enumerate(detection_indices):
|
||||
if col not in col_indices:
|
||||
unmatched_detections.append(detection_idx)
|
||||
for row, track_idx in enumerate(track_indices):
|
||||
if row not in row_indices:
|
||||
unmatched_tracks.append(track_idx)
|
||||
for row, col in zip(row_indices, col_indices):
|
||||
track_idx = track_indices[row]
|
||||
detection_idx = detection_indices[col]
|
||||
if cost_matrix[row, col] > max_distance:
|
||||
unmatched_tracks.append(track_idx)
|
||||
unmatched_detections.append(detection_idx)
|
||||
else:
|
||||
matches.append((track_idx, detection_idx))
|
||||
return matches, unmatched_tracks, unmatched_detections
|
||||
|
||||
|
||||
def matching_cascade(
|
||||
distance_metric, max_distance, cascade_depth, tracks, detections,
|
||||
track_indices=None, detection_indices=None):
|
||||
"""Run matching cascade.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
distance_metric : Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray
|
||||
The distance metric is given a list of tracks and detections as well as
|
||||
a list of N track indices and M detection indices. The metric should
|
||||
return the NxM dimensional cost matrix, where element (i, j) is the
|
||||
association cost between the i-th track in the given track indices and
|
||||
the j-th detection in the given detection indices.
|
||||
max_distance : float
|
||||
Gating threshold. Associations with cost larger than this value are
|
||||
disregarded.
|
||||
cascade_depth: int
|
||||
The cascade depth, should be se to the maximum track age.
|
||||
tracks : List[track.Track]
|
||||
A list of predicted tracks at the current time step.
|
||||
detections : List[detection.Detection]
|
||||
A list of detections at the current time step.
|
||||
track_indices : Optional[List[int]]
|
||||
List of track indices that maps rows in `cost_matrix` to tracks in
|
||||
`tracks` (see description above). Defaults to all tracks.
|
||||
detection_indices : Optional[List[int]]
|
||||
List of detection indices that maps columns in `cost_matrix` to
|
||||
detections in `detections` (see description above). Defaults to all
|
||||
detections.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(List[(int, int)], List[int], List[int])
|
||||
Returns a tuple with the following three entries:
|
||||
* A list of matched track and detection indices.
|
||||
* A list of unmatched track indices.
|
||||
* A list of unmatched detection indices.
|
||||
|
||||
"""
|
||||
if track_indices is None:
|
||||
track_indices = list(range(len(tracks)))
|
||||
if detection_indices is None:
|
||||
detection_indices = list(range(len(detections)))
|
||||
|
||||
unmatched_detections = detection_indices
|
||||
matches = []
|
||||
for level in range(cascade_depth):
|
||||
if len(unmatched_detections) == 0: # No detections left
|
||||
break
|
||||
|
||||
track_indices_l = [
|
||||
k for k in track_indices
|
||||
if tracks[k].time_since_update == 1 + level
|
||||
]
|
||||
if len(track_indices_l) == 0: # Nothing to match at this level
|
||||
continue
|
||||
|
||||
matches_l, _, unmatched_detections = \
|
||||
min_cost_matching(
|
||||
distance_metric, max_distance, tracks, detections,
|
||||
track_indices_l, unmatched_detections)
|
||||
matches += matches_l
|
||||
unmatched_tracks = list(set(track_indices) - set(k for k, _ in matches))
|
||||
return matches, unmatched_tracks, unmatched_detections
|
||||
|
||||
|
||||
def gate_cost_matrix(
|
||||
kf, cost_matrix, tracks, detections, track_indices, detection_indices,
|
||||
gated_cost=INFTY_COST, only_position=False):
|
||||
"""Invalidate infeasible entries in cost matrix based on the state
|
||||
distributions obtained by Kalman filtering.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
kf : The Kalman filter.
|
||||
cost_matrix : ndarray
|
||||
The NxM dimensional cost matrix, where N is the number of track indices
|
||||
and M is the number of detection indices, such that entry (i, j) is the
|
||||
association cost between `tracks[track_indices[i]]` and
|
||||
`detections[detection_indices[j]]`.
|
||||
tracks : List[track.Track]
|
||||
A list of predicted tracks at the current time step.
|
||||
detections : List[detection.Detection]
|
||||
A list of detections at the current time step.
|
||||
track_indices : List[int]
|
||||
List of track indices that maps rows in `cost_matrix` to tracks in
|
||||
`tracks` (see description above).
|
||||
detection_indices : List[int]
|
||||
List of detection indices that maps columns in `cost_matrix` to
|
||||
detections in `detections` (see description above).
|
||||
gated_cost : Optional[float]
|
||||
Entries in the cost matrix corresponding to infeasible associations are
|
||||
set this value. Defaults to a very large value.
|
||||
only_position : Optional[bool]
|
||||
If True, only the x, y position of the state distribution is considered
|
||||
during gating. Defaults to False.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
Returns the modified cost matrix.
|
||||
|
||||
"""
|
||||
gating_dim = 2 if only_position else 4
|
||||
gating_threshold = kalman_filter.chi2inv95[gating_dim]
|
||||
measurements = np.asarray(
|
||||
[detections[i].to_xyah() for i in detection_indices])
|
||||
for row, track_idx in enumerate(track_indices):
|
||||
track = tracks[track_idx]
|
||||
gating_distance = kf.gating_distance(
|
||||
track.mean, track.covariance, measurements, only_position)
|
||||
cost_matrix[row, gating_distance > gating_threshold] = gated_cost
|
||||
return cost_matrix
|
@ -1,176 +0,0 @@
|
||||
# vim: expandtab:ts=4:sw=4
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _pdist(a, b):
|
||||
"""Compute pair-wise squared distance between points in `a` and `b`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a : array_like
|
||||
An NxM matrix of N samples of dimensionality M.
|
||||
b : array_like
|
||||
An LxM matrix of L samples of dimensionality M.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
Returns a matrix of size len(a), len(b) such that eleement (i, j)
|
||||
contains the squared distance between `a[i]` and `b[j]`.
|
||||
|
||||
"""
|
||||
a, b = np.asarray(a), np.asarray(b)
|
||||
if len(a) == 0 or len(b) == 0:
|
||||
return np.zeros((len(a), len(b)))
|
||||
a2, b2 = np.square(a).sum(axis=1), np.square(b).sum(axis=1)
|
||||
r2 = -2. * np.dot(a, b.T) + a2[:, None] + b2[None, :]
|
||||
r2 = np.clip(r2, 0., float(np.inf))
|
||||
return r2
|
||||
|
||||
|
||||
def _cosine_distance(a, b, data_is_normalized=False):
|
||||
"""Compute pair-wise cosine distance between points in `a` and `b`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a : array_like
|
||||
An NxM matrix of N samples of dimensionality M.
|
||||
b : array_like
|
||||
An LxM matrix of L samples of dimensionality M.
|
||||
data_is_normalized : Optional[bool]
|
||||
If True, assumes rows in a and b are unit length vectors.
|
||||
Otherwise, a and b are explicitly normalized to lenght 1.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
Returns a matrix of size len(a), len(b) such that eleement (i, j)
|
||||
contains the squared distance between `a[i]` and `b[j]`.
|
||||
|
||||
"""
|
||||
if not data_is_normalized:
|
||||
a = np.asarray(a) / np.linalg.norm(a, axis=1, keepdims=True)
|
||||
b = np.asarray(b) / np.linalg.norm(b, axis=1, keepdims=True)
|
||||
return 1. - np.dot(a, b.T)
|
||||
|
||||
|
||||
def _nn_euclidean_distance(x, y):
|
||||
""" Helper function for nearest neighbor distance metric (Euclidean).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : ndarray
|
||||
A matrix of N row-vectors (sample points).
|
||||
y : ndarray
|
||||
A matrix of M row-vectors (query points).
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
A vector of length M that contains for each entry in `y` the
|
||||
smallest Euclidean distance to a sample in `x`.
|
||||
|
||||
"""
|
||||
distances = _pdist(x, y)
|
||||
return np.maximum(0.0, distances.min(axis=0))
|
||||
|
||||
|
||||
def _nn_cosine_distance(x, y):
|
||||
""" Helper function for nearest neighbor distance metric (cosine).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : ndarray
|
||||
A matrix of N row-vectors (sample points).
|
||||
y : ndarray
|
||||
A matrix of M row-vectors (query points).
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
A vector of length M that contains for each entry in `y` the
|
||||
smallest cosine distance to a sample in `x`.
|
||||
|
||||
"""
|
||||
distances = _cosine_distance(x, y)
|
||||
return distances.min(axis=0)
|
||||
|
||||
|
||||
class NearestNeighborDistanceMetric(object):
|
||||
"""
|
||||
A nearest neighbor distance metric that, for each target, returns
|
||||
the closest distance to any sample that has been observed so far.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metric : str
|
||||
Either "euclidean" or "cosine".
|
||||
matching_threshold: float
|
||||
The matching threshold. Samples with larger distance are considered an
|
||||
invalid match.
|
||||
budget : Optional[int]
|
||||
If not None, fix samples per class to at most this number. Removes
|
||||
the oldest samples when the budget is reached.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
samples : Dict[int -> List[ndarray]]
|
||||
A dictionary that maps from target identities to the list of samples
|
||||
that have been observed so far.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, metric, matching_threshold, budget=None):
|
||||
|
||||
if metric == "euclidean":
|
||||
self._metric = _nn_euclidean_distance
|
||||
elif metric == "cosine":
|
||||
self._metric = _nn_cosine_distance
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid metric; must be either 'euclidean' or 'cosine'")
|
||||
self.matching_threshold = matching_threshold
|
||||
self.budget = budget
|
||||
self.samples = {}
|
||||
|
||||
def partial_fit(self, features, targets, active_targets):
|
||||
"""Update the distance metric with new data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
features : ndarray
|
||||
An NxM matrix of N features of dimensionality M.
|
||||
targets : ndarray
|
||||
An integer array of associated target identities.
|
||||
active_targets : List[int]
|
||||
A list of targets that are currently present in the scene.
|
||||
|
||||
"""
|
||||
for feature, target in zip(features, targets):
|
||||
self.samples.setdefault(target, []).append(feature)
|
||||
if self.budget is not None:
|
||||
self.samples[target] = self.samples[target][-self.budget:]
|
||||
self.samples = {k: self.samples[k] for k in active_targets}
|
||||
|
||||
def distance(self, features, targets):
|
||||
"""Compute distance between features and targets.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
features : ndarray
|
||||
An NxM matrix of N features of dimensionality M.
|
||||
targets : List[int]
|
||||
A list of targets to match the given `features` against.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
Returns a cost matrix of shape len(targets), len(features), where
|
||||
element (i, j) contains the closest squared distance between
|
||||
`targets[i]` and `features[j]`.
|
||||
|
||||
"""
|
||||
cost_matrix = np.zeros((len(targets), len(features)))
|
||||
for i, target in enumerate(targets):
|
||||
cost_matrix[i, :] = self._metric(self.samples[target], features)
|
||||
return cost_matrix
|
@ -1,73 +0,0 @@
|
||||
# vim: expandtab:ts=4:sw=4
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
|
||||
def non_max_suppression(boxes, max_bbox_overlap, scores=None):
|
||||
"""Suppress overlapping detections.
|
||||
|
||||
Original code from [1]_ has been adapted to include confidence score.
|
||||
|
||||
.. [1] http://www.pyimagesearch.com/2015/02/16/
|
||||
faster-non-maximum-suppression-python/
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
>>> boxes = [d.roi for d in detections]
|
||||
>>> scores = [d.confidence for d in detections]
|
||||
>>> indices = non_max_suppression(boxes, max_bbox_overlap, scores)
|
||||
>>> detections = [detections[i] for i in indices]
|
||||
|
||||
Parameters
|
||||
----------
|
||||
boxes : ndarray
|
||||
Array of ROIs (x, y, width, height).
|
||||
max_bbox_overlap : float
|
||||
ROIs that overlap more than this values are suppressed.
|
||||
scores : Optional[array_like]
|
||||
Detector confidence score.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[int]
|
||||
Returns indices of detections that have survived non-maxima suppression.
|
||||
|
||||
"""
|
||||
if len(boxes) == 0:
|
||||
return []
|
||||
|
||||
boxes = boxes.astype(np.float32)
|
||||
pick = []
|
||||
|
||||
x1 = boxes[:, 0]
|
||||
y1 = boxes[:, 1]
|
||||
x2 = boxes[:, 2] + boxes[:, 0]
|
||||
y2 = boxes[:, 3] + boxes[:, 1]
|
||||
|
||||
area = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||
if scores is not None:
|
||||
idxs = np.argsort(scores)
|
||||
else:
|
||||
idxs = np.argsort(y2)
|
||||
|
||||
while len(idxs) > 0:
|
||||
last = len(idxs) - 1
|
||||
i = idxs[last]
|
||||
pick.append(i)
|
||||
|
||||
xx1 = np.maximum(x1[i], x1[idxs[:last]])
|
||||
yy1 = np.maximum(y1[i], y1[idxs[:last]])
|
||||
xx2 = np.minimum(x2[i], x2[idxs[:last]])
|
||||
yy2 = np.minimum(y2[i], y2[idxs[:last]])
|
||||
|
||||
w = np.maximum(0, xx2 - xx1 + 1)
|
||||
h = np.maximum(0, yy2 - yy1 + 1)
|
||||
|
||||
overlap = (w * h) / (area[idxs[:last]] + area[idxs[last]] - w * h)
|
||||
|
||||
idxs = np.delete(
|
||||
idxs, np.concatenate(
|
||||
([last], np.where(overlap > max_bbox_overlap)[0])))
|
||||
|
||||
return pick
|
@ -1,169 +0,0 @@
|
||||
# vim: expandtab:ts=4:sw=4
|
||||
|
||||
|
||||
class TrackState:
|
||||
"""
|
||||
Enumeration type for the single target track state. Newly created tracks are
|
||||
classified as `tentative` until enough evidence has been collected. Then,
|
||||
the track state is changed to `confirmed`. Tracks that are no longer alive
|
||||
are classified as `deleted` to mark them for removal from the set of active
|
||||
tracks.
|
||||
|
||||
"""
|
||||
|
||||
Tentative = 1
|
||||
Confirmed = 2
|
||||
Deleted = 3
|
||||
|
||||
|
||||
class Track:
|
||||
"""
|
||||
A single target track with state space `(x, y, a, h)` and associated
|
||||
velocities, where `(x, y)` is the center of the bounding box, `a` is the
|
||||
aspect ratio and `h` is the height.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
Mean vector of the initial state distribution.
|
||||
covariance : ndarray
|
||||
Covariance matrix of the initial state distribution.
|
||||
track_id : int
|
||||
A unique track identifier.
|
||||
n_init : int
|
||||
Number of consecutive detections before the track is confirmed. The
|
||||
track state is set to `Deleted` if a miss occurs within the first
|
||||
`n_init` frames.
|
||||
max_age : int
|
||||
The maximum number of consecutive misses before the track state is
|
||||
set to `Deleted`.
|
||||
feature : Optional[ndarray]
|
||||
Feature vector of the detection this track originates from. If not None,
|
||||
this feature is added to the `features` cache.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
mean : ndarray
|
||||
Mean vector of the initial state distribution.
|
||||
covariance : ndarray
|
||||
Covariance matrix of the initial state distribution.
|
||||
track_id : int
|
||||
A unique track identifier.
|
||||
hits : int
|
||||
Total number of measurement updates.
|
||||
age : int
|
||||
Total number of frames since first occurance.
|
||||
time_since_update : int
|
||||
Total number of frames since last measurement update.
|
||||
state : TrackState
|
||||
The current track state.
|
||||
features : List[ndarray]
|
||||
A cache of features. On each measurement update, the associated feature
|
||||
vector is added to this list.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, mean, covariance, track_id, n_init, max_age,
|
||||
feature=None, cls=None, mask=None):
|
||||
self.mean = mean
|
||||
self.covariance = covariance
|
||||
self.track_id = track_id
|
||||
self.hits = 1
|
||||
self.age = 1
|
||||
self.time_since_update = 0
|
||||
|
||||
self.state = TrackState.Tentative
|
||||
self.cls = cls
|
||||
self.mask = mask
|
||||
self.features = []
|
||||
if feature is not None:
|
||||
self.features.append(feature)
|
||||
|
||||
self._n_init = n_init
|
||||
self._max_age = max_age
|
||||
|
||||
def to_tlwh(self):
|
||||
"""Get current position in bounding box format `(top left x, top left y,
|
||||
width, height)`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
The bounding box.
|
||||
|
||||
"""
|
||||
ret = self.mean[:4].copy()
|
||||
ret[2] *= ret[3]
|
||||
ret[:2] -= ret[2:] / 2
|
||||
return ret
|
||||
|
||||
def to_tlbr(self):
|
||||
"""Get current position in bounding box format `(min x, miny, max x,
|
||||
max y)`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
The bounding box.
|
||||
|
||||
"""
|
||||
ret = self.to_tlwh()
|
||||
ret[2:] = ret[:2] + ret[2:]
|
||||
return ret
|
||||
|
||||
def predict(self, kf):
|
||||
"""Propagate the state distribution to the current time step using a
|
||||
Kalman filter prediction step.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
kf : kalman_filter.KalmanFilter
|
||||
The Kalman filter.
|
||||
|
||||
"""
|
||||
self.mean, self.covariance = kf.predict(self.mean, self.covariance)
|
||||
self.age += 1
|
||||
self.time_since_update += 1
|
||||
|
||||
def update(self, kf, detection):
|
||||
"""Perform Kalman filter measurement update step and update the feature
|
||||
cache.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
kf : kalman_filter.KalmanFilter
|
||||
The Kalman filter.
|
||||
detection : Detection
|
||||
The associated detection.
|
||||
|
||||
"""
|
||||
self.mask = detection.mask
|
||||
self.mean, self.covariance = kf.update(
|
||||
self.mean, self.covariance, detection.to_xyah())
|
||||
self.features.append(detection.feature)
|
||||
|
||||
self.hits += 1
|
||||
self.time_since_update = 0
|
||||
if self.state == TrackState.Tentative and self.hits >= self._n_init:
|
||||
self.state = TrackState.Confirmed
|
||||
|
||||
def mark_missed(self):
|
||||
"""Mark this track as missed (no association at the current time step).
|
||||
"""
|
||||
if self.state == TrackState.Tentative:
|
||||
self.state = TrackState.Deleted
|
||||
elif self.time_since_update > self._max_age:
|
||||
self.state = TrackState.Deleted
|
||||
|
||||
def is_tentative(self):
|
||||
"""Returns True if this track is tentative (unconfirmed).
|
||||
"""
|
||||
return self.state == TrackState.Tentative
|
||||
|
||||
def is_confirmed(self):
|
||||
"""Returns True if this track is confirmed."""
|
||||
return self.state == TrackState.Confirmed
|
||||
|
||||
def is_deleted(self):
|
||||
"""Returns True if this track is dead and should be deleted."""
|
||||
return self.state == TrackState.Deleted
|
@ -1,138 +0,0 @@
|
||||
# vim: expandtab:ts=4:sw=4
|
||||
from __future__ import absolute_import
|
||||
import numpy as np
|
||||
from . import kalman_filter
|
||||
from . import linear_assignment
|
||||
from . import iou_matching
|
||||
from .track import Track
|
||||
|
||||
|
||||
class Tracker:
|
||||
"""
|
||||
This is the multi-target tracker.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metric : nn_matching.NearestNeighborDistanceMetric
|
||||
A distance metric for measurement-to-track association.
|
||||
max_age : int
|
||||
Maximum number of missed misses before a track is deleted.
|
||||
n_init : int
|
||||
Number of consecutive detections before the track is confirmed. The
|
||||
track state is set to `Deleted` if a miss occurs within the first
|
||||
`n_init` frames.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
metric : nn_matching.NearestNeighborDistanceMetric
|
||||
The distance metric used for measurement to track association.
|
||||
max_age : int
|
||||
Maximum number of missed misses before a track is deleted.
|
||||
n_init : int
|
||||
Number of frames that a track remains in initialization phase.
|
||||
kf : kalman_filter.KalmanFilter
|
||||
A Kalman filter to filter target trajectories in image space.
|
||||
tracks : List[Track]
|
||||
The list of active tracks at the current time step.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, metric, max_iou_distance=0.7, max_age=70, n_init=3):
|
||||
self.metric = metric
|
||||
self.max_iou_distance = max_iou_distance
|
||||
self.max_age = max_age
|
||||
self.n_init = n_init
|
||||
|
||||
self.kf = kalman_filter.KalmanFilter()
|
||||
self.tracks = []
|
||||
self._next_id = 1
|
||||
|
||||
def predict(self):
|
||||
"""Propagate track state distributions one time step forward.
|
||||
|
||||
This function should be called once every time step, before `update`.
|
||||
"""
|
||||
for track in self.tracks:
|
||||
track.predict(self.kf)
|
||||
|
||||
def update(self, detections):
|
||||
"""Perform measurement update and track management.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
detections : List[deep_sort.detection.Detection]
|
||||
A list of detections at the current time step.
|
||||
|
||||
"""
|
||||
# Run matching cascade.
|
||||
matches, unmatched_tracks, unmatched_detections = \
|
||||
self._match(detections)
|
||||
|
||||
# Update track set.
|
||||
for track_idx, detection_idx in matches:
|
||||
self.tracks[track_idx].update(
|
||||
self.kf, detections[detection_idx])
|
||||
for track_idx in unmatched_tracks:
|
||||
self.tracks[track_idx].mark_missed()
|
||||
for detection_idx in unmatched_detections:
|
||||
self._initiate_track(detections[detection_idx])
|
||||
self.tracks = [t for t in self.tracks if not t.is_deleted()]
|
||||
|
||||
# Update distance metric.
|
||||
active_targets = [t.track_id for t in self.tracks if t.is_confirmed()]
|
||||
features, targets = [], []
|
||||
for track in self.tracks:
|
||||
if not track.is_confirmed():
|
||||
continue
|
||||
features += track.features
|
||||
targets += [track.track_id for _ in track.features]
|
||||
track.features = []
|
||||
self.metric.partial_fit(
|
||||
np.asarray(features), np.asarray(targets), active_targets)
|
||||
|
||||
def _match(self, detections):
|
||||
|
||||
def gated_metric(tracks, dets, track_indices, detection_indices):
|
||||
features = np.array([dets[i].feature for i in detection_indices])
|
||||
targets = np.array([tracks[i].track_id for i in track_indices])
|
||||
cost_matrix = self.metric.distance(features, targets)
|
||||
cost_matrix = linear_assignment.gate_cost_matrix(
|
||||
self.kf, cost_matrix, tracks, dets, track_indices,
|
||||
detection_indices)
|
||||
|
||||
return cost_matrix
|
||||
|
||||
# Split track set into confirmed and unconfirmed tracks.
|
||||
confirmed_tracks = [
|
||||
i for i, t in enumerate(self.tracks) if t.is_confirmed()]
|
||||
unconfirmed_tracks = [
|
||||
i for i, t in enumerate(self.tracks) if not t.is_confirmed()]
|
||||
|
||||
# Associate confirmed tracks using appearance features.
|
||||
matches_a, unmatched_tracks_a, unmatched_detections = \
|
||||
linear_assignment.matching_cascade(
|
||||
gated_metric, self.metric.matching_threshold, self.max_age,
|
||||
self.tracks, detections, confirmed_tracks)
|
||||
|
||||
# Associate remaining tracks together with unconfirmed tracks using IOU.
|
||||
iou_track_candidates = unconfirmed_tracks + [
|
||||
k for k in unmatched_tracks_a if
|
||||
self.tracks[k].time_since_update == 1]
|
||||
unmatched_tracks_a = [
|
||||
k for k in unmatched_tracks_a if
|
||||
self.tracks[k].time_since_update != 1]
|
||||
matches_b, unmatched_tracks_b, unmatched_detections = \
|
||||
linear_assignment.min_cost_matching(
|
||||
iou_matching.iou_cost, self.max_iou_distance, self.tracks,
|
||||
detections, iou_track_candidates, unmatched_detections)
|
||||
|
||||
matches = matches_a + matches_b
|
||||
unmatched_tracks = list(set(unmatched_tracks_a + unmatched_tracks_b))
|
||||
return matches, unmatched_tracks, unmatched_detections
|
||||
|
||||
def _initiate_track(self, detection):
|
||||
mean, covariance = self.kf.initiate(detection.to_xyah())
|
||||
self.tracks.append(Track(
|
||||
mean, covariance, self._next_id, self.n_init, self.max_age,
|
||||
detection.feature, detection.cls, detection.mask))
|
||||
self._next_id += 1
|
@ -1,2 +0,0 @@
|
||||
def datasets():
|
||||
return None
|
@ -1,13 +0,0 @@
|
||||
from os import environ
|
||||
|
||||
|
||||
def assert_in(file, files_to_check):
|
||||
if file not in files_to_check:
|
||||
raise AssertionError("{} does not exist in the list".format(str(file)))
|
||||
return True
|
||||
|
||||
|
||||
def assert_in_env(check_list: list):
|
||||
for item in check_list:
|
||||
assert_in(item, environ.keys())
|
||||
return True
|
@ -1,51 +0,0 @@
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
palette = (2 ** 11 - 1, 2 ** 15 - 1, 2 ** 20 - 1)
|
||||
|
||||
|
||||
def compute_color_for_labels(label):
|
||||
"""
|
||||
Simple function that adds fixed color depending on the class
|
||||
"""
|
||||
color = [int((p * (label ** 2 - label + 1)) % 255) for p in palette]
|
||||
return tuple(color)
|
||||
|
||||
|
||||
def draw_masks(image, mask, color, thresh: float = 0.7, alpha: float = 0.5):
|
||||
np_image = np.asarray(image)
|
||||
mask = mask > thresh
|
||||
|
||||
color = np.asarray(color)
|
||||
img_to_draw = np.copy(np_image)
|
||||
# TODO: There might be a way to vectorize this
|
||||
img_to_draw[mask] = color
|
||||
|
||||
out = np_image * (1 - alpha) + img_to_draw * alpha
|
||||
return out.astype(np.uint8)
|
||||
|
||||
|
||||
def draw_boxes(img, bbox, names=None, identities=None, masks=None, offset=(0, 0)):
|
||||
for i, box in enumerate(bbox):
|
||||
x1, y1, x2, y2 = [int(i) for i in box]
|
||||
x1 += offset[0]
|
||||
x2 += offset[0]
|
||||
y1 += offset[1]
|
||||
y2 += offset[1]
|
||||
# box text and bar
|
||||
id = int(identities[i]) if identities is not None else 0
|
||||
color = compute_color_for_labels(id)
|
||||
label = '{:}{:d}'.format(names[i], id)
|
||||
t_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_PLAIN, 2, 2)[0]
|
||||
if masks is not None:
|
||||
mask = masks[i]
|
||||
img = draw_masks(img, mask, color)
|
||||
cv2.rectangle(img, (x1, y1), (x2, y2), color, 3)
|
||||
cv2.rectangle(img, (x1, y1), (x1 + t_size[0] + 3, y1 + t_size[1] + 4), color, -1)
|
||||
cv2.putText(img, label, (x1, y1 + t_size[1] + 4), cv2.FONT_HERSHEY_PLAIN, 2, [255, 255, 255], 2)
|
||||
return img
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
for i in range(82):
|
||||
print(compute_color_for_labels(i))
|
@ -1,103 +0,0 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import copy
|
||||
import motmetrics as mm
|
||||
mm.lap.default_solver = 'lap'
|
||||
from utils.io import read_results, unzip_objs
|
||||
|
||||
|
||||
class Evaluator(object):
|
||||
|
||||
def __init__(self, data_root, seq_name, data_type):
|
||||
self.data_root = data_root
|
||||
self.seq_name = seq_name
|
||||
self.data_type = data_type
|
||||
|
||||
self.load_annotations()
|
||||
self.reset_accumulator()
|
||||
|
||||
def load_annotations(self):
|
||||
assert self.data_type == 'mot'
|
||||
|
||||
gt_filename = os.path.join(self.data_root, self.seq_name, 'gt', 'gt.txt')
|
||||
self.gt_frame_dict = read_results(gt_filename, self.data_type, is_gt=True)
|
||||
self.gt_ignore_frame_dict = read_results(gt_filename, self.data_type, is_ignore=True)
|
||||
|
||||
def reset_accumulator(self):
|
||||
self.acc = mm.MOTAccumulator(auto_id=True)
|
||||
|
||||
def eval_frame(self, frame_id, trk_tlwhs, trk_ids, rtn_events=False):
|
||||
# results
|
||||
trk_tlwhs = np.copy(trk_tlwhs)
|
||||
trk_ids = np.copy(trk_ids)
|
||||
|
||||
# gts
|
||||
gt_objs = self.gt_frame_dict.get(frame_id, [])
|
||||
gt_tlwhs, gt_ids = unzip_objs(gt_objs)[:2]
|
||||
|
||||
# ignore boxes
|
||||
ignore_objs = self.gt_ignore_frame_dict.get(frame_id, [])
|
||||
ignore_tlwhs = unzip_objs(ignore_objs)[0]
|
||||
|
||||
|
||||
# remove ignored results
|
||||
keep = np.ones(len(trk_tlwhs), dtype=bool)
|
||||
iou_distance = mm.distances.iou_matrix(ignore_tlwhs, trk_tlwhs, max_iou=0.5)
|
||||
if len(iou_distance) > 0:
|
||||
match_is, match_js = mm.lap.linear_sum_assignment(iou_distance)
|
||||
match_is, match_js = map(lambda a: np.asarray(a, dtype=int), [match_is, match_js])
|
||||
match_ious = iou_distance[match_is, match_js]
|
||||
|
||||
match_js = np.asarray(match_js, dtype=int)
|
||||
match_js = match_js[np.logical_not(np.isnan(match_ious))]
|
||||
keep[match_js] = False
|
||||
trk_tlwhs = trk_tlwhs[keep]
|
||||
trk_ids = trk_ids[keep]
|
||||
|
||||
# get distance matrix
|
||||
iou_distance = mm.distances.iou_matrix(gt_tlwhs, trk_tlwhs, max_iou=0.5)
|
||||
|
||||
# acc
|
||||
self.acc.update(gt_ids, trk_ids, iou_distance)
|
||||
|
||||
if rtn_events and iou_distance.size > 0 and hasattr(self.acc, 'last_mot_events'):
|
||||
events = self.acc.last_mot_events # only supported by https://github.com/longcw/py-motmetrics
|
||||
else:
|
||||
events = None
|
||||
return events
|
||||
|
||||
def eval_file(self, filename):
|
||||
self.reset_accumulator()
|
||||
|
||||
result_frame_dict = read_results(filename, self.data_type, is_gt=False)
|
||||
frames = sorted(list(set(self.gt_frame_dict.keys()) | set(result_frame_dict.keys())))
|
||||
for frame_id in frames:
|
||||
trk_objs = result_frame_dict.get(frame_id, [])
|
||||
trk_tlwhs, trk_ids = unzip_objs(trk_objs)[:2]
|
||||
self.eval_frame(frame_id, trk_tlwhs, trk_ids, rtn_events=False)
|
||||
|
||||
return self.acc
|
||||
|
||||
@staticmethod
|
||||
def get_summary(accs, names, metrics=('mota', 'num_switches', 'idp', 'idr', 'idf1', 'precision', 'recall')):
|
||||
names = copy.deepcopy(names)
|
||||
if metrics is None:
|
||||
metrics = mm.metrics.motchallenge_metrics
|
||||
metrics = copy.deepcopy(metrics)
|
||||
|
||||
mh = mm.metrics.create()
|
||||
summary = mh.compute_many(
|
||||
accs,
|
||||
metrics=metrics,
|
||||
names=names,
|
||||
generate_overall=True
|
||||
)
|
||||
|
||||
return summary
|
||||
|
||||
@staticmethod
|
||||
def save_summary(summary, filename):
|
||||
import pandas as pd
|
||||
writer = pd.ExcelWriter(filename)
|
||||
summary.to_excel(writer)
|
||||
writer.save()
|
@ -1,133 +0,0 @@
|
||||
import os
|
||||
from typing import Dict
|
||||
import numpy as np
|
||||
|
||||
# from utils.log import get_logger
|
||||
|
||||
|
||||
def write_results(filename, results, data_type):
|
||||
if data_type == 'mot':
|
||||
save_format = '{frame},{id},{cls},{x1},{y1},{w},{h},-1,-1,-1,-1\n'
|
||||
elif data_type == 'kitti':
|
||||
save_format = '{frame} {id} pedestrian 0 0 -10 {x1} {y1} {x2} {y2} -10 -10 -10 -1000 -1000 -1000 -10\n'
|
||||
else:
|
||||
raise ValueError(data_type)
|
||||
|
||||
with open(filename, 'w') as f:
|
||||
for frame_id, tlwhs, track_ids, classes in results:
|
||||
if data_type == 'kitti':
|
||||
frame_id -= 1
|
||||
for tlwh, track_id, cls_id in zip(tlwhs, track_ids, classes):
|
||||
if track_id < 0:
|
||||
continue
|
||||
x1, y1, w, h = tlwh
|
||||
x2, y2 = x1 + w, y1 + h
|
||||
line = save_format.format(frame=frame_id, id=track_id, cls=cls_id, x1=x1, y1=y1, x2=x2, y2=y2, w=w, h=h)
|
||||
f.write(line)
|
||||
|
||||
|
||||
# def write_results(filename, results_dict: Dict, data_type: str):
|
||||
# if not filename:
|
||||
# return
|
||||
# path = os.path.dirname(filename)
|
||||
# if not os.path.exists(path):
|
||||
# os.makedirs(path)
|
||||
|
||||
# if data_type in ('mot', 'mcmot', 'lab'):
|
||||
# save_format = '{frame},{id},{x1},{y1},{w},{h},1,-1,-1,-1\n'
|
||||
# elif data_type == 'kitti':
|
||||
# save_format = '{frame} {id} pedestrian -1 -1 -10 {x1} {y1} {x2} {y2} -1 -1 -1 -1000 -1000 -1000 -10 {score}\n'
|
||||
# else:
|
||||
# raise ValueError(data_type)
|
||||
|
||||
# with open(filename, 'w') as f:
|
||||
# for frame_id, frame_data in results_dict.items():
|
||||
# if data_type == 'kitti':
|
||||
# frame_id -= 1
|
||||
# for tlwh, track_id in frame_data:
|
||||
# if track_id < 0:
|
||||
# continue
|
||||
# x1, y1, w, h = tlwh
|
||||
# x2, y2 = x1 + w, y1 + h
|
||||
# line = save_format.format(frame=frame_id, id=track_id, x1=x1, y1=y1, x2=x2, y2=y2, w=w, h=h, score=1.0)
|
||||
# f.write(line)
|
||||
# logger.info('Save results to {}'.format(filename))
|
||||
|
||||
|
||||
def read_results(filename, data_type: str, is_gt=False, is_ignore=False):
|
||||
if data_type in ('mot', 'lab'):
|
||||
read_fun = read_mot_results
|
||||
else:
|
||||
raise ValueError('Unknown data type: {}'.format(data_type))
|
||||
|
||||
return read_fun(filename, is_gt, is_ignore)
|
||||
|
||||
|
||||
"""
|
||||
labels={'ped', ... % 1
|
||||
'person_on_vhcl', ... % 2
|
||||
'car', ... % 3
|
||||
'bicycle', ... % 4
|
||||
'mbike', ... % 5
|
||||
'non_mot_vhcl', ... % 6
|
||||
'static_person', ... % 7
|
||||
'distractor', ... % 8
|
||||
'occluder', ... % 9
|
||||
'occluder_on_grnd', ... %10
|
||||
'occluder_full', ... % 11
|
||||
'reflection', ... % 12
|
||||
'crowd' ... % 13
|
||||
};
|
||||
"""
|
||||
|
||||
|
||||
def read_mot_results(filename, is_gt, is_ignore):
|
||||
valid_labels = {1}
|
||||
ignore_labels = {2, 7, 8, 12}
|
||||
results_dict = dict()
|
||||
if os.path.isfile(filename):
|
||||
with open(filename, 'r') as f:
|
||||
for line in f.readlines():
|
||||
linelist = line.split(',')
|
||||
if len(linelist) < 7:
|
||||
continue
|
||||
fid = int(linelist[0])
|
||||
if fid < 1:
|
||||
continue
|
||||
results_dict.setdefault(fid, list())
|
||||
|
||||
if is_gt:
|
||||
if 'MOT16-' in filename or 'MOT17-' in filename:
|
||||
label = int(float(linelist[7]))
|
||||
mark = int(float(linelist[6]))
|
||||
if mark == 0 or label not in valid_labels:
|
||||
continue
|
||||
score = 1
|
||||
elif is_ignore:
|
||||
if 'MOT16-' in filename or 'MOT17-' in filename:
|
||||
label = int(float(linelist[7]))
|
||||
vis_ratio = float(linelist[8])
|
||||
if label not in ignore_labels and vis_ratio >= 0:
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
score = 1
|
||||
else:
|
||||
score = float(linelist[6])
|
||||
|
||||
tlwh = tuple(map(float, linelist[2:6]))
|
||||
target_id = int(linelist[1])
|
||||
|
||||
results_dict[fid].append((tlwh, target_id, score))
|
||||
|
||||
return results_dict
|
||||
|
||||
|
||||
def unzip_objs(objs):
|
||||
if len(objs) > 0:
|
||||
tlwhs, ids, scores = zip(*objs)
|
||||
else:
|
||||
tlwhs, ids, scores = [], [], []
|
||||
tlwhs = np.asarray(tlwhs, dtype=float).reshape(-1, 4)
|
||||
|
||||
return tlwhs, ids, scores
|
@ -1,383 +0,0 @@
|
||||
"""
|
||||
References:
|
||||
https://medium.com/analytics-vidhya/creating-a-custom-logging-mechanism-for-real-time-object-detection-using-tdd-4ca2cfcd0a2f
|
||||
"""
|
||||
import json
|
||||
from os import makedirs
|
||||
from os.path import exists, join
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class JsonMeta(object):
|
||||
HOURS = 3
|
||||
MINUTES = 59
|
||||
SECONDS = 59
|
||||
PATH_TO_SAVE = 'LOGS'
|
||||
DEFAULT_FILE_NAME = 'remaining'
|
||||
|
||||
|
||||
class BaseJsonLogger(object):
|
||||
"""
|
||||
This is the base class that returns __dict__ of its own
|
||||
it also returns the dicts of objects in the attributes that are list instances
|
||||
|
||||
"""
|
||||
|
||||
def dic(self):
|
||||
# returns dicts of objects
|
||||
out = {}
|
||||
for k, v in self.__dict__.items():
|
||||
if hasattr(v, 'dic'):
|
||||
out[k] = v.dic()
|
||||
elif isinstance(v, list):
|
||||
out[k] = self.list(v)
|
||||
else:
|
||||
out[k] = v
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def list(values):
|
||||
# applies the dic method on items in the list
|
||||
return [v.dic() if hasattr(v, 'dic') else v for v in values]
|
||||
|
||||
|
||||
class Label(BaseJsonLogger):
|
||||
"""
|
||||
For each bounding box there are various categories with confidences. Label class keeps track of that information.
|
||||
"""
|
||||
|
||||
def __init__(self, category: str, confidence: float):
|
||||
self.category = category
|
||||
self.confidence = confidence
|
||||
|
||||
|
||||
class Bbox(BaseJsonLogger):
|
||||
"""
|
||||
This module stores the information for each frame and use them in JsonParser
|
||||
Attributes:
|
||||
labels (list): List of label module.
|
||||
top (int):
|
||||
left (int):
|
||||
width (int):
|
||||
height (int):
|
||||
|
||||
Args:
|
||||
bbox_id (float):
|
||||
top (int):
|
||||
left (int):
|
||||
width (int):
|
||||
height (int):
|
||||
|
||||
References:
|
||||
Check Label module for better understanding.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, bbox_id, top, left, width, height):
|
||||
self.labels = []
|
||||
self.bbox_id = bbox_id
|
||||
self.top = top
|
||||
self.left = left
|
||||
self.width = width
|
||||
self.height = height
|
||||
|
||||
def add_label(self, category, confidence):
|
||||
# adds category and confidence only if top_k is not exceeded.
|
||||
self.labels.append(Label(category, confidence))
|
||||
|
||||
def labels_full(self, value):
|
||||
return len(self.labels) == value
|
||||
|
||||
|
||||
class Frame(BaseJsonLogger):
|
||||
"""
|
||||
This module stores the information for each frame and use them in JsonParser
|
||||
Attributes:
|
||||
timestamp (float): The elapsed time of captured frame
|
||||
frame_id (int): The frame number of the captured video
|
||||
bboxes (list of Bbox objects): Stores the list of bbox objects.
|
||||
|
||||
References:
|
||||
Check Bbox class for better information
|
||||
|
||||
Args:
|
||||
timestamp (float):
|
||||
frame_id (int):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, frame_id: int, timestamp: float = None):
|
||||
self.frame_id = frame_id
|
||||
self.timestamp = timestamp
|
||||
self.bboxes = []
|
||||
|
||||
def add_bbox(self, bbox_id: int, top: int, left: int, width: int, height: int):
|
||||
bboxes_ids = [bbox.bbox_id for bbox in self.bboxes]
|
||||
if bbox_id not in bboxes_ids:
|
||||
self.bboxes.append(Bbox(bbox_id, top, left, width, height))
|
||||
else:
|
||||
raise ValueError("Frame with id: {} already has a Bbox with id: {}".format(self.frame_id, bbox_id))
|
||||
|
||||
def add_label_to_bbox(self, bbox_id: int, category: str, confidence: float):
|
||||
bboxes = {bbox.id: bbox for bbox in self.bboxes}
|
||||
if bbox_id in bboxes.keys():
|
||||
res = bboxes.get(bbox_id)
|
||||
res.add_label(category, confidence)
|
||||
else:
|
||||
raise ValueError('the bbox with id: {} does not exists!'.format(bbox_id))
|
||||
|
||||
|
||||
class BboxToJsonLogger(BaseJsonLogger):
|
||||
"""
|
||||
ُ This module is designed to automate the task of logging jsons. An example json is used
|
||||
to show the contents of json file shortly
|
||||
Example:
|
||||
{
|
||||
"video_details": {
|
||||
"frame_width": 1920,
|
||||
"frame_height": 1080,
|
||||
"frame_rate": 20,
|
||||
"video_name": "/home/gpu/codes/MSD/pedestrian_2/project/public/camera1.avi"
|
||||
},
|
||||
"frames": [
|
||||
{
|
||||
"frame_id": 329,
|
||||
"timestamp": 3365.1254
|
||||
"bboxes": [
|
||||
{
|
||||
"labels": [
|
||||
{
|
||||
"category": "pedestrian",
|
||||
"confidence": 0.9
|
||||
}
|
||||
],
|
||||
"bbox_id": 0,
|
||||
"top": 1257,
|
||||
"left": 138,
|
||||
"width": 68,
|
||||
"height": 109
|
||||
}
|
||||
]
|
||||
}],
|
||||
|
||||
Attributes:
|
||||
frames (dict): It's a dictionary that maps each frame_id to json attributes.
|
||||
video_details (dict): information about video file.
|
||||
top_k_labels (int): shows the allowed number of labels
|
||||
start_time (datetime object): we use it to automate the json output by time.
|
||||
|
||||
Args:
|
||||
top_k_labels (int): shows the allowed number of labels
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, top_k_labels: int = 1):
|
||||
self.frames = {}
|
||||
self.video_details = self.video_details = dict(frame_width=None, frame_height=None, frame_rate=None,
|
||||
video_name=None)
|
||||
self.top_k_labels = top_k_labels
|
||||
self.start_time = datetime.now()
|
||||
|
||||
def set_top_k(self, value):
|
||||
self.top_k_labels = value
|
||||
|
||||
def frame_exists(self, frame_id: int) -> bool:
|
||||
"""
|
||||
Args:
|
||||
frame_id (int):
|
||||
|
||||
Returns:
|
||||
bool: true if frame_id is recognized
|
||||
"""
|
||||
return frame_id in self.frames.keys()
|
||||
|
||||
def add_frame(self, frame_id: int, timestamp: float = None) -> None:
|
||||
"""
|
||||
Args:
|
||||
frame_id (int):
|
||||
timestamp (float): opencv captured frame time property
|
||||
|
||||
Raises:
|
||||
ValueError: if frame_id would not exist in class frames attribute
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
"""
|
||||
if not self.frame_exists(frame_id):
|
||||
self.frames[frame_id] = Frame(frame_id, timestamp)
|
||||
else:
|
||||
raise ValueError("Frame id: {} already exists".format(frame_id))
|
||||
|
||||
def bbox_exists(self, frame_id: int, bbox_id: int) -> bool:
|
||||
"""
|
||||
Args:
|
||||
frame_id:
|
||||
bbox_id:
|
||||
|
||||
Returns:
|
||||
bool: if bbox exists in frame bboxes list
|
||||
"""
|
||||
bboxes = []
|
||||
if self.frame_exists(frame_id=frame_id):
|
||||
bboxes = [bbox.bbox_id for bbox in self.frames[frame_id].bboxes]
|
||||
return bbox_id in bboxes
|
||||
|
||||
def find_bbox(self, frame_id: int, bbox_id: int):
|
||||
"""
|
||||
|
||||
Args:
|
||||
frame_id:
|
||||
bbox_id:
|
||||
|
||||
Returns:
|
||||
bbox_id (int):
|
||||
|
||||
Raises:
|
||||
ValueError: if bbox_id does not exist in the bbox list of specific frame.
|
||||
"""
|
||||
if not self.bbox_exists(frame_id, bbox_id):
|
||||
raise ValueError("frame with id: {} does not contain bbox with id: {}".format(frame_id, bbox_id))
|
||||
bboxes = {bbox.bbox_id: bbox for bbox in self.frames[frame_id].bboxes}
|
||||
return bboxes.get(bbox_id)
|
||||
|
||||
def add_bbox_to_frame(self, frame_id: int, bbox_id: int, top: int, left: int, width: int, height: int) -> None:
|
||||
"""
|
||||
|
||||
Args:
|
||||
frame_id (int):
|
||||
bbox_id (int):
|
||||
top (int):
|
||||
left (int):
|
||||
width (int):
|
||||
height (int):
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Raises:
|
||||
ValueError: if bbox_id already exist in frame information with frame_id
|
||||
ValueError: if frame_id does not exist in frames attribute
|
||||
"""
|
||||
if self.frame_exists(frame_id):
|
||||
frame = self.frames[frame_id]
|
||||
if not self.bbox_exists(frame_id, bbox_id):
|
||||
frame.add_bbox(bbox_id, top, left, width, height)
|
||||
else:
|
||||
raise ValueError(
|
||||
"frame with frame_id: {} already contains the bbox with id: {} ".format(frame_id, bbox_id))
|
||||
else:
|
||||
raise ValueError("frame with frame_id: {} does not exist".format(frame_id))
|
||||
|
||||
def add_label_to_bbox(self, frame_id: int, bbox_id: int, category: str, confidence: float):
|
||||
"""
|
||||
Args:
|
||||
frame_id:
|
||||
bbox_id:
|
||||
category:
|
||||
confidence: the confidence value returned from yolo detection
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Raises:
|
||||
ValueError: if labels quota (top_k_labels) exceeds.
|
||||
"""
|
||||
bbox = self.find_bbox(frame_id, bbox_id)
|
||||
if not bbox.labels_full(self.top_k_labels):
|
||||
bbox.add_label(category, confidence)
|
||||
else:
|
||||
raise ValueError("labels in frame_id: {}, bbox_id: {} is fulled".format(frame_id, bbox_id))
|
||||
|
||||
def add_video_details(self, frame_width: int = None, frame_height: int = None, frame_rate: int = None,
|
||||
video_name: str = None):
|
||||
self.video_details['frame_width'] = frame_width
|
||||
self.video_details['frame_height'] = frame_height
|
||||
self.video_details['frame_rate'] = frame_rate
|
||||
self.video_details['video_name'] = video_name
|
||||
|
||||
def output(self):
|
||||
output = {'video_details': self.video_details}
|
||||
result = list(self.frames.values())
|
||||
output['frames'] = [item.dic() for item in result]
|
||||
return output
|
||||
|
||||
def json_output(self, output_name):
|
||||
"""
|
||||
Args:
|
||||
output_name:
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Notes:
|
||||
It creates the json output with `output_name` name.
|
||||
"""
|
||||
if not output_name.endswith('.json'):
|
||||
output_name += '.json'
|
||||
with open(output_name, 'w') as file:
|
||||
json.dump(self.output(), file)
|
||||
file.close()
|
||||
|
||||
def set_start(self):
|
||||
self.start_time = datetime.now()
|
||||
|
||||
def schedule_output_by_time(self, output_dir=JsonMeta.PATH_TO_SAVE, hours: int = 0, minutes: int = 0,
|
||||
seconds: int = 60) -> None:
|
||||
"""
|
||||
Notes:
|
||||
Creates folder and then periodically stores the jsons on that address.
|
||||
|
||||
Args:
|
||||
output_dir (str): the directory where output files will be stored
|
||||
hours (int):
|
||||
minutes (int):
|
||||
seconds (int):
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
"""
|
||||
end = datetime.now()
|
||||
interval = 0
|
||||
interval += abs(min([hours, JsonMeta.HOURS]) * 3600)
|
||||
interval += abs(min([minutes, JsonMeta.MINUTES]) * 60)
|
||||
interval += abs(min([seconds, JsonMeta.SECONDS]))
|
||||
diff = (end - self.start_time).seconds
|
||||
|
||||
if diff > interval:
|
||||
output_name = self.start_time.strftime('%Y-%m-%d %H-%M-%S') + '.json'
|
||||
if not exists(output_dir):
|
||||
makedirs(output_dir)
|
||||
output = join(output_dir, output_name)
|
||||
self.json_output(output_name=output)
|
||||
self.frames = {}
|
||||
self.start_time = datetime.now()
|
||||
|
||||
def schedule_output_by_frames(self, frames_quota, frame_counter, output_dir=JsonMeta.PATH_TO_SAVE):
|
||||
"""
|
||||
saves as the number of frames quota increases higher.
|
||||
:param frames_quota:
|
||||
:param frame_counter:
|
||||
:param output_dir:
|
||||
:return:
|
||||
"""
|
||||
pass
|
||||
|
||||
def flush(self, output_dir):
|
||||
"""
|
||||
Notes:
|
||||
We use this function to output jsons whenever possible.
|
||||
like the time that we exit the while loop of opencv.
|
||||
|
||||
Args:
|
||||
output_dir:
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
"""
|
||||
filename = self.start_time.strftime('%Y-%m-%d %H-%M-%S') + '-remaining.json'
|
||||
output = join(output_dir, filename)
|
||||
self.json_output(output_name=output)
|
@ -1,17 +0,0 @@
|
||||
import logging
|
||||
|
||||
|
||||
def get_logger(name='root'):
|
||||
formatter = logging.Formatter(
|
||||
# fmt='%(asctime)s [%(levelname)s]: %(filename)s(%(funcName)s:%(lineno)s) >> %(message)s')
|
||||
fmt='%(asctime)s [%(levelname)s]: %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(logging.INFO)
|
||||
logger.addHandler(handler)
|
||||
return logger
|
||||
|
||||
|
@ -1,38 +0,0 @@
|
||||
import os
|
||||
import yaml
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
class YamlParser(edict):
|
||||
"""
|
||||
This is yaml parser based on EasyDict.
|
||||
"""
|
||||
def __init__(self, cfg_dict=None, config_file=None):
|
||||
if cfg_dict is None:
|
||||
cfg_dict = {}
|
||||
|
||||
if config_file is not None:
|
||||
assert (os.path.isfile(config_file))
|
||||
with open(config_file, 'r') as fo:
|
||||
cfg_dict.update(yaml.safe_load(fo.read()))
|
||||
|
||||
super(YamlParser, self).__init__(cfg_dict)
|
||||
|
||||
|
||||
def merge_from_file(self, config_file):
|
||||
with open(config_file, 'r') as fo:
|
||||
self.update(yaml.safe_load(fo.read()))
|
||||
|
||||
|
||||
def merge_from_dict(self, config_dict):
|
||||
self.update(config_dict)
|
||||
|
||||
|
||||
def get_config(config_file=None):
|
||||
return YamlParser(config_file=config_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cfg = YamlParser(config_file="../configs/yolov3.yaml")
|
||||
cfg.merge_from_file("../configs/deep_sort.yaml")
|
||||
|
||||
import ipdb; ipdb.set_trace()
|
@ -1,39 +0,0 @@
|
||||
from functools import wraps
|
||||
from time import time
|
||||
|
||||
|
||||
def is_video(ext: str):
|
||||
"""
|
||||
Returns true if ext exists in
|
||||
allowed_exts for video files.
|
||||
|
||||
Args:
|
||||
ext:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
|
||||
allowed_exts = ('.mp4', '.webm', '.ogg', '.avi', '.wmv', '.mkv', '.3gp')
|
||||
return any((ext.endswith(x) for x in allowed_exts))
|
||||
|
||||
|
||||
def tik_tok(func):
|
||||
"""
|
||||
keep track of time for each process.
|
||||
Args:
|
||||
func:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
@wraps(func)
|
||||
def _time_it(*args, **kwargs):
|
||||
start = time()
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
finally:
|
||||
end_ = time()
|
||||
print("time: {:.03f}s, fps: {:.03f}".format(end_ - start, 1 / (end_ - start)))
|
||||
|
||||
return _time_it
|
BIN
requirements.txt
BIN
requirements.txt
Binary file not shown.
@ -1,181 +0,0 @@
|
||||
import os
|
||||
import torch
|
||||
|
||||
from yolov5.models.common import DetectMultiBackend
|
||||
from yolov5.utils.torch_utils import select_device
|
||||
from yolov5.utils.dataloaders import LoadImages
|
||||
from yolov5.utils.general import check_img_size, non_max_suppression, cv2, scale_coords, xyxy2xywh
|
||||
from deep_sort.deep_sort import DeepSort
|
||||
|
||||
|
||||
class VideoTracker(object):
|
||||
def __init__(self, weights_pt, data, video_path, save_path, idx_to_class):
|
||||
self.video_path = video_path
|
||||
self.save_path = save_path
|
||||
self.idx_to_class = idx_to_class
|
||||
|
||||
# 选择设备(CPU 或 GPU)
|
||||
device = select_device('cpu')
|
||||
|
||||
self.vdo = cv2.VideoCapture()
|
||||
self.detector = DetectMultiBackend(weights_pt, device=device, dnn=False, data=data, fp16=False)
|
||||
self.deepsort = DeepSort(
|
||||
model_path="deep_sort/deep/checkpoint/ckpt.t7", # ReID 模型路径
|
||||
max_dist=0.2, # 外观特征匹配阈值(越小越严格)
|
||||
max_iou_distance=0.7, # 最大IoU距离阈值
|
||||
max_age=70, # 目标最大存活帧数(未匹配时保留的帧数)
|
||||
n_init=3 # 初始确认帧数(连续匹配到n_init次后确认跟踪)
|
||||
)
|
||||
self.class_names = self.detector.class_names
|
||||
|
||||
def __enter__(self):
|
||||
self.vdo.open(self.video_path)
|
||||
self.im_width = int(self.vdo.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
self.im_height = int(self.vdo.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
assert self.vdo.isOpened()
|
||||
|
||||
if self.save_path:
|
||||
os.makedirs(self.args.save_path, exist_ok=True)
|
||||
|
||||
# path of saved video and results
|
||||
self.save_video_path = os.path.join(self.save_path, "results.avi")
|
||||
self.save_results_path = os.path.join(self.save_path, "results.txt")
|
||||
|
||||
# create video writer
|
||||
fourcc = cv2.VideoWriter_fourcc(*'MJPG')
|
||||
self.writer = cv2.VideoWriter(self.save_video_path, fourcc, 20, (self.im_width, self.im_height))
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||
if exc_type:
|
||||
print(exc_type, exc_value, exc_traceback)
|
||||
|
||||
def run(self):
|
||||
stride, names, pt = self.model.stride, self.model.names, self.model.pt
|
||||
imgsz = check_img_size((640, 640), s=stride) # check image size
|
||||
dataset = LoadImages(self.video_path, img_size=imgsz, stride=stride, auto=pt, vid_stride=1)
|
||||
bs = len(dataset)
|
||||
|
||||
self.model.warmup(imgsz=(1 if pt or self.model.triton else bs, 3, *imgsz))
|
||||
for path, im, im0s, vid_cap, s in dataset:
|
||||
im = torch.from_numpy(im).to(self.model.device)
|
||||
im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32
|
||||
im /= 255 # 0 - 255 to 0.0 - 1.0
|
||||
if len(im.shape) == 3:
|
||||
im = im[None] # expand for batch dim
|
||||
if self.model.xml and im.shape[0] > 1:
|
||||
ims = torch.chunk(im, im.shape[0], 0)
|
||||
|
||||
# Inference
|
||||
if self.model.xml and im.shape[0] > 1:
|
||||
pred = None
|
||||
for image in ims:
|
||||
if pred is None:
|
||||
pred = self.model(image, augment=False, visualize=False).unsqueeze(0)
|
||||
else:
|
||||
pred = torch.cat(
|
||||
(pred, self.model(image, augment=False, visualize=False).unsqueeze(0)),
|
||||
dim=0
|
||||
)
|
||||
pred = [pred, None]
|
||||
else:
|
||||
pred = self.model(im, augment=False, visualize=False)
|
||||
# NMS
|
||||
pred = non_max_suppression(pred, 0.40, 0.45, None, False, max_det=1000)[0]
|
||||
|
||||
image = im0s[0]
|
||||
|
||||
pred[:, :4] = scale_coords(im.shape[2:], pred[:, :4], image.shape).round()
|
||||
|
||||
# 使用YOLOv5进行检测后得到的pred
|
||||
bbox_xywh, cls_conf, cls_ids = yolov5_to_deepsort_format(pred)
|
||||
# select person class
|
||||
mask = cls_ids == 0
|
||||
|
||||
bbox_xywh = bbox_xywh[mask]
|
||||
# bbox dilation just in case bbox too small, delete this line if using a better pedestrian detector
|
||||
bbox_xywh[:, 2:] *= 1.2
|
||||
cls_conf = cls_conf[mask]
|
||||
cls_ids = cls_ids[mask]
|
||||
|
||||
# 调用Deep SORT更新方法
|
||||
outputs, _ = self.deepsort.update(bbox_xywh, cls_conf, cls_ids, image)
|
||||
|
||||
count_result = {}
|
||||
|
||||
for key in self.idx_to_class.keys():
|
||||
count_result[key] = set()
|
||||
|
||||
# draw boxes for visualization
|
||||
if len(outputs) > 0:
|
||||
bbox_xyxy = outputs[:, :4] # 这个是检测所在框的坐标的数组
|
||||
identities = outputs[:, -1] # 这个是每个元素的计数的数组
|
||||
cls = outputs[:, -2] # 这个是标签数组id的数组
|
||||
names = [self.idx_to_class[str(label)] for label in cls]
|
||||
image = draw_boxes(image, bbox_xyxy, names, identities)
|
||||
for i in range(len(cls)):
|
||||
count_result[str(cls[i])].add(identities[i])
|
||||
|
||||
|
||||
def draw_boxes(img, bbox, names=None, identities=None, offset=(0, 0)):
|
||||
for i, box in enumerate(bbox):
|
||||
x1, y1, x2, y2 = [int(i) for i in box]
|
||||
x1 += offset[0]
|
||||
x2 += offset[0]
|
||||
y1 += offset[1]
|
||||
y2 += offset[1]
|
||||
# box text and bar
|
||||
id = int(identities[i]) if identities is not None else 0
|
||||
color = compute_color_for_labels(id)
|
||||
label = '{:}{:d}'.format(names[i], id)
|
||||
t_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_PLAIN, 2, 2)[0]
|
||||
cv2.rectangle(img, (x1, y1), (x2, y2), color, 3)
|
||||
cv2.rectangle(img, (x1, y1), (x1 + t_size[0] + 3, y1 + t_size[1] + 4), color, -1)
|
||||
cv2.putText(img, label, (x1, y1 + t_size[1] + 4), cv2.FONT_HERSHEY_PLAIN, 2, [255, 255, 255], 2)
|
||||
return img
|
||||
|
||||
|
||||
def compute_color_for_labels(label):
|
||||
"""
|
||||
Simple function that adds fixed color depending on the class
|
||||
"""
|
||||
color = [int((p * (label ** 2 - label + 1)) % 255) for p in palette]
|
||||
return tuple(color)
|
||||
|
||||
|
||||
def yolov5_to_deepsort_format(pred):
|
||||
"""
|
||||
将YOLOv5的预测结果转换为Deep SORT所需的格式
|
||||
:param pred: YOLOv5的预测结果
|
||||
:return: 转换后的bbox_xywh, confs, class_ids
|
||||
"""
|
||||
pred[:, :4] = xyxy2xywh(pred[:, :4])
|
||||
xywh = pred[:, :4].cpu().numpy()
|
||||
conf = pred[:, 4].cpu().numpy()
|
||||
cls = pred[:, 5].cpu().numpy()
|
||||
return xywh, conf, cls
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
cfg = get_config()
|
||||
if args.segment:
|
||||
cfg.USE_SEGMENT = True
|
||||
else:
|
||||
cfg.USE_SEGMENT = False
|
||||
if args.mmdet:
|
||||
cfg.merge_from_file(args.config_mmdetection)
|
||||
cfg.USE_MMDET = True
|
||||
else:
|
||||
cfg.merge_from_file(args.config_detection)
|
||||
cfg.USE_MMDET = False
|
||||
cfg.merge_from_file(args.config_deepsort)
|
||||
if args.fastreid:
|
||||
cfg.merge_from_file(args.config_fastreid)
|
||||
cfg.USE_FASTREID = True
|
||||
else:
|
||||
cfg.USE_FASTREID = False
|
||||
|
||||
with VideoTracker(cfg, args, video_path=args.VIDEO_PATH) as vdo_trk:
|
||||
vdo_trk.run()
|
Reference in New Issue
Block a user