完成训练模块的转移
This commit is contained in:
92
deep_sort/deep/datasets.py
Normal file
92
deep_sort/deep/datasets.py
Normal file
@ -0,0 +1,92 @@
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
|
||||
import cv2
|
||||
from PIL import Image
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
class ClsDataset(Dataset):
|
||||
def __init__(self, images_path, images_labels, transform=None):
|
||||
self.images_path = images_path
|
||||
self.images_labels = images_labels
|
||||
self.transform = transform
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images_path)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img = cv2.imread(self.images_path[idx])
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
img = Image.fromarray(img)
|
||||
label = self.images_labels[idx]
|
||||
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
return img, label
|
||||
|
||||
@staticmethod
|
||||
def collate_fn(batch):
|
||||
images, labels = tuple(zip(*batch))
|
||||
images = torch.stack(images, dim=0)
|
||||
labels = torch.as_tensor(labels)
|
||||
return images, labels
|
||||
|
||||
|
||||
def read_split_data(root, valid_rate=0.2):
|
||||
assert os.path.exists(root), 'dataset root: {} does not exist.'.format(root)
|
||||
|
||||
class_names = [cls for cls in os.listdir(root) if os.path.isdir(os.path.join(root, cls))]
|
||||
class_names.sort()
|
||||
|
||||
class_indices = {name: i for i, name in enumerate(class_names)}
|
||||
json_str = json.dumps({v: k for k, v in class_indices.items()}, indent=4)
|
||||
with open('class_indices.json', 'w') as f:
|
||||
f.write(json_str)
|
||||
|
||||
train_images_path = []
|
||||
train_labels = []
|
||||
val_images_path = []
|
||||
val_labels = []
|
||||
per_class_num = []
|
||||
|
||||
supported = ['.jpg', '.JPG', '.png', '.PNG']
|
||||
for cls in class_names:
|
||||
cls_path = os.path.join(root, cls)
|
||||
images_path = [os.path.join(cls_path, i) for i in os.listdir(cls_path)
|
||||
if os.path.splitext(i)[-1] in supported]
|
||||
images_label = class_indices[cls]
|
||||
per_class_num.append(len(images_path))
|
||||
|
||||
val_path = random.sample(images_path, int(len(images_path) * valid_rate))
|
||||
for img_path in images_path:
|
||||
if img_path in val_path:
|
||||
val_images_path.append(img_path)
|
||||
val_labels.append(images_label)
|
||||
else:
|
||||
train_images_path.append(img_path)
|
||||
train_labels.append(images_label)
|
||||
|
||||
print("{} images were found in the dataset.".format(sum(per_class_num)))
|
||||
print("{} images for training.".format(len(train_images_path)))
|
||||
print("{} images for validation.".format(len(val_images_path)))
|
||||
|
||||
assert len(train_images_path) > 0, "number of training images must greater than zero"
|
||||
assert len(val_images_path) > 0, "number of validation images must greater than zero"
|
||||
|
||||
plot_distribution = False
|
||||
if plot_distribution:
|
||||
plt.bar(range(len(class_names)), per_class_num, align='center')
|
||||
plt.xticks(range(len(class_names)), class_names)
|
||||
|
||||
for i, v in enumerate(per_class_num):
|
||||
plt.text(x=i, y=v + 5, s=str(v), ha='center')
|
||||
|
||||
plt.xlabel('classes')
|
||||
plt.ylabel('numbers')
|
||||
plt.title('the distribution of dataset')
|
||||
plt.show()
|
||||
return [train_images_path, train_labels], [val_images_path, val_labels], len(class_names)
|
Reference in New Issue
Block a user