完成训练模块的转移
This commit is contained in:
133
deep_sort/utils/io.py
Normal file
133
deep_sort/utils/io.py
Normal file
@ -0,0 +1,133 @@
|
||||
import os
|
||||
from typing import Dict
|
||||
import numpy as np
|
||||
|
||||
# from utils.log import get_logger
|
||||
|
||||
|
||||
def write_results(filename, results, data_type):
|
||||
if data_type == 'mot':
|
||||
save_format = '{frame},{id},{cls},{x1},{y1},{w},{h},-1,-1,-1,-1\n'
|
||||
elif data_type == 'kitti':
|
||||
save_format = '{frame} {id} pedestrian 0 0 -10 {x1} {y1} {x2} {y2} -10 -10 -10 -1000 -1000 -1000 -10\n'
|
||||
else:
|
||||
raise ValueError(data_type)
|
||||
|
||||
with open(filename, 'w') as f:
|
||||
for frame_id, tlwhs, track_ids, classes in results:
|
||||
if data_type == 'kitti':
|
||||
frame_id -= 1
|
||||
for tlwh, track_id, cls_id in zip(tlwhs, track_ids, classes):
|
||||
if track_id < 0:
|
||||
continue
|
||||
x1, y1, w, h = tlwh
|
||||
x2, y2 = x1 + w, y1 + h
|
||||
line = save_format.format(frame=frame_id, id=track_id, cls=cls_id, x1=x1, y1=y1, x2=x2, y2=y2, w=w, h=h)
|
||||
f.write(line)
|
||||
|
||||
|
||||
# def write_results(filename, results_dict: Dict, data_type: str):
|
||||
# if not filename:
|
||||
# return
|
||||
# path = os.path.dirname(filename)
|
||||
# if not os.path.exists(path):
|
||||
# os.makedirs(path)
|
||||
|
||||
# if data_type in ('mot', 'mcmot', 'lab'):
|
||||
# save_format = '{frame},{id},{x1},{y1},{w},{h},1,-1,-1,-1\n'
|
||||
# elif data_type == 'kitti':
|
||||
# save_format = '{frame} {id} pedestrian -1 -1 -10 {x1} {y1} {x2} {y2} -1 -1 -1 -1000 -1000 -1000 -10 {score}\n'
|
||||
# else:
|
||||
# raise ValueError(data_type)
|
||||
|
||||
# with open(filename, 'w') as f:
|
||||
# for frame_id, frame_data in results_dict.items():
|
||||
# if data_type == 'kitti':
|
||||
# frame_id -= 1
|
||||
# for tlwh, track_id in frame_data:
|
||||
# if track_id < 0:
|
||||
# continue
|
||||
# x1, y1, w, h = tlwh
|
||||
# x2, y2 = x1 + w, y1 + h
|
||||
# line = save_format.format(frame=frame_id, id=track_id, x1=x1, y1=y1, x2=x2, y2=y2, w=w, h=h, score=1.0)
|
||||
# f.write(line)
|
||||
# logger.info('Save results to {}'.format(filename))
|
||||
|
||||
|
||||
def read_results(filename, data_type: str, is_gt=False, is_ignore=False):
|
||||
if data_type in ('mot', 'lab'):
|
||||
read_fun = read_mot_results
|
||||
else:
|
||||
raise ValueError('Unknown data type: {}'.format(data_type))
|
||||
|
||||
return read_fun(filename, is_gt, is_ignore)
|
||||
|
||||
|
||||
"""
|
||||
labels={'ped', ... % 1
|
||||
'person_on_vhcl', ... % 2
|
||||
'car', ... % 3
|
||||
'bicycle', ... % 4
|
||||
'mbike', ... % 5
|
||||
'non_mot_vhcl', ... % 6
|
||||
'static_person', ... % 7
|
||||
'distractor', ... % 8
|
||||
'occluder', ... % 9
|
||||
'occluder_on_grnd', ... %10
|
||||
'occluder_full', ... % 11
|
||||
'reflection', ... % 12
|
||||
'crowd' ... % 13
|
||||
};
|
||||
"""
|
||||
|
||||
|
||||
def read_mot_results(filename, is_gt, is_ignore):
|
||||
valid_labels = {1}
|
||||
ignore_labels = {2, 7, 8, 12}
|
||||
results_dict = dict()
|
||||
if os.path.isfile(filename):
|
||||
with open(filename, 'r') as f:
|
||||
for line in f.readlines():
|
||||
linelist = line.split(',')
|
||||
if len(linelist) < 7:
|
||||
continue
|
||||
fid = int(linelist[0])
|
||||
if fid < 1:
|
||||
continue
|
||||
results_dict.setdefault(fid, list())
|
||||
|
||||
if is_gt:
|
||||
if 'MOT16-' in filename or 'MOT17-' in filename:
|
||||
label = int(float(linelist[7]))
|
||||
mark = int(float(linelist[6]))
|
||||
if mark == 0 or label not in valid_labels:
|
||||
continue
|
||||
score = 1
|
||||
elif is_ignore:
|
||||
if 'MOT16-' in filename or 'MOT17-' in filename:
|
||||
label = int(float(linelist[7]))
|
||||
vis_ratio = float(linelist[8])
|
||||
if label not in ignore_labels and vis_ratio >= 0:
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
score = 1
|
||||
else:
|
||||
score = float(linelist[6])
|
||||
|
||||
tlwh = tuple(map(float, linelist[2:6]))
|
||||
target_id = int(linelist[1])
|
||||
|
||||
results_dict[fid].append((tlwh, target_id, score))
|
||||
|
||||
return results_dict
|
||||
|
||||
|
||||
def unzip_objs(objs):
|
||||
if len(objs) > 0:
|
||||
tlwhs, ids, scores = zip(*objs)
|
||||
else:
|
||||
tlwhs, ids, scores = [], [], []
|
||||
tlwhs = np.asarray(tlwhs, dtype=float).reshape(-1, 4)
|
||||
|
||||
return tlwhs, ids, scores
|
Reference in New Issue
Block a user