完成训练模块的转移
This commit is contained in:
82
deep_sort/deep/GETTING_STARTED.md
Normal file
82
deep_sort/deep/GETTING_STARTED.md
Normal file
@ -0,0 +1,82 @@
|
||||
In deepsort algorithm, appearance feature extraction network used to extract features from **image_crops** for matching purpose.The original model used in paper is in `model.py`, and its parameter here [ckpt.t7](https://drive.google.com/drive/folders/1xhG0kRH1EX5B9_Iz8gQJb7UNnn_riXi6). This repository also provides a `resnet.py` script and its pre-training weights on Imagenet here.
|
||||
|
||||
```
|
||||
# resnet18
|
||||
https://download.pytorch.org/models/resnet18-5c106cde.pth
|
||||
# resnet34
|
||||
https://download.pytorch.org/models/resnet34-333f7ec4.pth
|
||||
# resnet50
|
||||
https://download.pytorch.org/models/resnet50-19c8e357.pth
|
||||
# resnext50_32x4d
|
||||
https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth
|
||||
```
|
||||
|
||||
## Dataset PrePare
|
||||
|
||||
To train the model, first you need download [Market1501](http://www.liangzheng.com.cn/Project/project_reid.html) dataset or [Mars](http://www.liangzheng.com.cn/Project/project_mars.html) dataset.
|
||||
|
||||
If you want to train on your **own dataset**, assuming you have already downloaded the dataset.The dataset should be arranged in the following way.
|
||||
|
||||
```
|
||||
├── dataset_root: The root dir of the dataset.
|
||||
├── class1: Category 1 is located in the folder dir.
|
||||
├── xxx1.jpg: Image belonging to category 1.
|
||||
├── xxx2.jpg: Image belonging to category 1.
|
||||
├── class2: Category 2 is located in the folder dir.
|
||||
├── xxx3.jpg: Image belonging to category 2.
|
||||
├── xxx4.jpg: Image belonging to category 2.
|
||||
├── class3: Category 3 is located in the folder dir.
|
||||
...
|
||||
...
|
||||
```
|
||||
|
||||
## Training the RE-ID model
|
||||
|
||||
Assuming you have already prepare the dataset. Then you can use the following command to start your training progress.
|
||||
|
||||
#### training on a single GPU
|
||||
|
||||
```python
|
||||
usage: train.py [--data-dir]
|
||||
[--epochs]
|
||||
[--batch_size]
|
||||
[--lr]
|
||||
[--lrf]
|
||||
[--weights]
|
||||
[--freeze-layers]
|
||||
[--gpu_id]
|
||||
|
||||
# default use cuda:0, use Net in `model.py`
|
||||
python train.py --data-dir [dataset/root/path] --weights [(optional)pre-train/weight/path]
|
||||
# you can use `--freeze-layers` option to freeze full convolutional layer parameters except fc layers parameters
|
||||
python train.py --data-dir [dataset/root/path] --weights [(optional)pre-train/weight/path] --freeze-layers
|
||||
```
|
||||
|
||||
#### training on multiple GPU
|
||||
|
||||
```python
|
||||
usage: train_multiGPU.py [--data-dir]
|
||||
[--epochs]
|
||||
[--batch_size]
|
||||
[--lr]
|
||||
[--lrf]
|
||||
[--syncBN]
|
||||
[--weights]
|
||||
[--freeze-layers]
|
||||
# not change the following parameters, the system will automatically assignment
|
||||
[--device]
|
||||
[--world_size]
|
||||
[--dist_url]
|
||||
|
||||
# default use cuda:0, cuda:1, cuda:2, cuda:3, use resnet18 in `resnet.py`
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 train_multiGPU.py --data-dir [dataset/root/path] --weights [(optional)pre-train/weight/path]
|
||||
# you can use `--freeze-layers` option to freeze full convolutional layer parameters except fc layers parameters
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 train_multiGPU.py --data-dir [dataset/root/path] --weights [(optional)pre-train/weight/path] --freeze-layers
|
||||
```
|
||||
|
||||
An example of training progress is as follows:
|
||||
|
||||

|
||||
|
||||
The last, you can evaluate it using [test.py](deep_sort/deep/test.py) and [evaluate.py](deep_sort/deep/evalute.py).
|
||||
|
0
deep_sort/deep/__init__.py
Normal file
0
deep_sort/deep/__init__.py
Normal file
0
deep_sort/deep/checkpoint/.gitkeep
Normal file
0
deep_sort/deep/checkpoint/.gitkeep
Normal file
BIN
deep_sort/deep/checkpoint/ckpt.t7
Normal file
BIN
deep_sort/deep/checkpoint/ckpt.t7
Normal file
Binary file not shown.
92
deep_sort/deep/datasets.py
Normal file
92
deep_sort/deep/datasets.py
Normal file
@ -0,0 +1,92 @@
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
|
||||
import cv2
|
||||
from PIL import Image
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
class ClsDataset(Dataset):
|
||||
def __init__(self, images_path, images_labels, transform=None):
|
||||
self.images_path = images_path
|
||||
self.images_labels = images_labels
|
||||
self.transform = transform
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images_path)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img = cv2.imread(self.images_path[idx])
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
img = Image.fromarray(img)
|
||||
label = self.images_labels[idx]
|
||||
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
return img, label
|
||||
|
||||
@staticmethod
|
||||
def collate_fn(batch):
|
||||
images, labels = tuple(zip(*batch))
|
||||
images = torch.stack(images, dim=0)
|
||||
labels = torch.as_tensor(labels)
|
||||
return images, labels
|
||||
|
||||
|
||||
def read_split_data(root, valid_rate=0.2):
|
||||
assert os.path.exists(root), 'dataset root: {} does not exist.'.format(root)
|
||||
|
||||
class_names = [cls for cls in os.listdir(root) if os.path.isdir(os.path.join(root, cls))]
|
||||
class_names.sort()
|
||||
|
||||
class_indices = {name: i for i, name in enumerate(class_names)}
|
||||
json_str = json.dumps({v: k for k, v in class_indices.items()}, indent=4)
|
||||
with open('class_indices.json', 'w') as f:
|
||||
f.write(json_str)
|
||||
|
||||
train_images_path = []
|
||||
train_labels = []
|
||||
val_images_path = []
|
||||
val_labels = []
|
||||
per_class_num = []
|
||||
|
||||
supported = ['.jpg', '.JPG', '.png', '.PNG']
|
||||
for cls in class_names:
|
||||
cls_path = os.path.join(root, cls)
|
||||
images_path = [os.path.join(cls_path, i) for i in os.listdir(cls_path)
|
||||
if os.path.splitext(i)[-1] in supported]
|
||||
images_label = class_indices[cls]
|
||||
per_class_num.append(len(images_path))
|
||||
|
||||
val_path = random.sample(images_path, int(len(images_path) * valid_rate))
|
||||
for img_path in images_path:
|
||||
if img_path in val_path:
|
||||
val_images_path.append(img_path)
|
||||
val_labels.append(images_label)
|
||||
else:
|
||||
train_images_path.append(img_path)
|
||||
train_labels.append(images_label)
|
||||
|
||||
print("{} images were found in the dataset.".format(sum(per_class_num)))
|
||||
print("{} images for training.".format(len(train_images_path)))
|
||||
print("{} images for validation.".format(len(val_images_path)))
|
||||
|
||||
assert len(train_images_path) > 0, "number of training images must greater than zero"
|
||||
assert len(val_images_path) > 0, "number of validation images must greater than zero"
|
||||
|
||||
plot_distribution = False
|
||||
if plot_distribution:
|
||||
plt.bar(range(len(class_names)), per_class_num, align='center')
|
||||
plt.xticks(range(len(class_names)), class_names)
|
||||
|
||||
for i, v in enumerate(per_class_num):
|
||||
plt.text(x=i, y=v + 5, s=str(v), ha='center')
|
||||
|
||||
plt.xlabel('classes')
|
||||
plt.ylabel('numbers')
|
||||
plt.title('the distribution of dataset')
|
||||
plt.show()
|
||||
return [train_images_path, train_labels], [val_images_path, val_labels], len(class_names)
|
15
deep_sort/deep/evaluate.py
Normal file
15
deep_sort/deep/evaluate.py
Normal file
@ -0,0 +1,15 @@
|
||||
import torch
|
||||
|
||||
features = torch.load("features.pth")
|
||||
qf = features["qf"]
|
||||
ql = features["ql"]
|
||||
gf = features["gf"]
|
||||
gl = features["gl"]
|
||||
|
||||
scores = qf.mm(gf.t())
|
||||
res = scores.topk(5, dim=1)[1][:,0]
|
||||
top1correct = gl[res].eq(ql).sum().item()
|
||||
|
||||
print("Acc top1:{:.3f}".format(top1correct/ql.size(0)))
|
||||
|
||||
|
93
deep_sort/deep/feature_extractor.py
Normal file
93
deep_sort/deep/feature_extractor.py
Normal file
@ -0,0 +1,93 @@
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
import numpy as np
|
||||
import cv2
|
||||
import logging
|
||||
|
||||
from .model import Net
|
||||
from .resnet import resnet18
|
||||
# from fastreid.config import get_cfg
|
||||
# from fastreid.engine import DefaultTrainer
|
||||
# from fastreid.utils.checkpoint import Checkpointer
|
||||
|
||||
|
||||
class Extractor(object):
|
||||
def __init__(self, model_path, use_cuda=True):
|
||||
self.net = Net(reid=True)
|
||||
# self.net = resnet18(reid=True)
|
||||
self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
|
||||
state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
|
||||
self.net.load_state_dict(state_dict if 'net_dict' not in state_dict else state_dict['net_dict'], strict=False)
|
||||
logger = logging.getLogger("root.tracker")
|
||||
logger.info("Loading weights from {}... Done!".format(model_path))
|
||||
self.net.to(self.device)
|
||||
self.size = (64, 128)
|
||||
self.norm = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||
])
|
||||
|
||||
def _preprocess(self, im_crops):
|
||||
"""
|
||||
TODO:
|
||||
1. to float with scale from 0 to 1
|
||||
2. resize to (64, 128) as Market1501 dataset did
|
||||
3. concatenate to a numpy array
|
||||
3. to torch Tensor
|
||||
4. normalize
|
||||
"""
|
||||
|
||||
def _resize(im, size):
|
||||
return cv2.resize(im.astype(np.float32) / 255., size)
|
||||
|
||||
im_batch = torch.cat([self.norm(_resize(im, self.size)).unsqueeze(0) for im in im_crops], dim=0).float()
|
||||
return im_batch
|
||||
|
||||
def __call__(self, im_crops):
|
||||
im_batch = self._preprocess(im_crops)
|
||||
with torch.no_grad():
|
||||
im_batch = im_batch.to(self.device)
|
||||
features = self.net(im_batch)
|
||||
return features.cpu().numpy()
|
||||
|
||||
|
||||
class FastReIDExtractor(object):
|
||||
def __init__(self, model_config, model_path, use_cuda=True):
|
||||
cfg = get_cfg()
|
||||
cfg.merge_from_file(model_config)
|
||||
cfg.MODEL.BACKBONE.PRETRAIN = False
|
||||
self.net = DefaultTrainer.build_model(cfg)
|
||||
self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
|
||||
|
||||
Checkpointer(self.net).load(model_path)
|
||||
logger = logging.getLogger("root.tracker")
|
||||
logger.info("Loading weights from {}... Done!".format(model_path))
|
||||
self.net.to(self.device)
|
||||
self.net.eval()
|
||||
height, width = cfg.INPUT.SIZE_TEST
|
||||
self.size = (width, height)
|
||||
self.norm = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||
])
|
||||
|
||||
def _preprocess(self, im_crops):
|
||||
def _resize(im, size):
|
||||
return cv2.resize(im.astype(np.float32) / 255., size)
|
||||
|
||||
im_batch = torch.cat([self.norm(_resize(im, self.size)).unsqueeze(0) for im in im_crops], dim=0).float()
|
||||
return im_batch
|
||||
|
||||
def __call__(self, im_crops):
|
||||
im_batch = self._preprocess(im_crops)
|
||||
with torch.no_grad():
|
||||
im_batch = im_batch.to(self.device)
|
||||
features = self.net(im_batch)
|
||||
return features.cpu().numpy()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
img = cv2.imread("demo.jpg")[:, :, (2, 1, 0)]
|
||||
extr = Extractor("checkpoint/ckpt.t7")
|
||||
feature = extr(img)
|
||||
print(feature.shape)
|
105
deep_sort/deep/model.py
Normal file
105
deep_sort/deep/model.py
Normal file
@ -0,0 +1,105 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
def __init__(self, c_in, c_out, is_downsample=False):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.is_downsample = is_downsample
|
||||
if is_downsample:
|
||||
self.conv1 = nn.Conv2d(c_in, c_out, 3, stride=2, padding=1, bias=False)
|
||||
else:
|
||||
self.conv1 = nn.Conv2d(c_in, c_out, 3, stride=1, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(c_out)
|
||||
self.relu = nn.ReLU(True)
|
||||
self.conv2 = nn.Conv2d(c_out, c_out, 3, stride=1, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(c_out)
|
||||
if is_downsample:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(c_in, c_out, 1, stride=2, bias=False),
|
||||
nn.BatchNorm2d(c_out)
|
||||
)
|
||||
elif c_in != c_out:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(c_in, c_out, 1, stride=1, bias=False),
|
||||
nn.BatchNorm2d(c_out)
|
||||
)
|
||||
self.is_downsample = True
|
||||
|
||||
def forward(self, x):
|
||||
y = self.conv1(x)
|
||||
y = self.bn1(y)
|
||||
y = self.relu(y)
|
||||
y = self.conv2(y)
|
||||
y = self.bn2(y)
|
||||
if self.is_downsample:
|
||||
x = self.downsample(x)
|
||||
return F.relu(x.add(y), True)
|
||||
|
||||
|
||||
def make_layers(c_in, c_out, repeat_times, is_downsample=False):
|
||||
blocks = []
|
||||
for i in range(repeat_times):
|
||||
if i == 0:
|
||||
blocks += [BasicBlock(c_in, c_out, is_downsample=is_downsample), ]
|
||||
else:
|
||||
blocks += [BasicBlock(c_out, c_out), ]
|
||||
return nn.Sequential(*blocks)
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self, num_classes=751, reid=False):
|
||||
super(Net, self).__init__()
|
||||
# 3 128 64
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(3, 64, 3, stride=1, padding=1),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(inplace=True),
|
||||
# nn.Conv2d(32,32,3,stride=1,padding=1),
|
||||
# nn.BatchNorm2d(32),
|
||||
# nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(3, 2, padding=1),
|
||||
)
|
||||
# 32 64 32
|
||||
self.layer1 = make_layers(64, 64, 2, False)
|
||||
# 32 64 32
|
||||
self.layer2 = make_layers(64, 128, 2, True)
|
||||
# 64 32 16
|
||||
self.layer3 = make_layers(128, 256, 2, True)
|
||||
# 128 16 8
|
||||
self.layer4 = make_layers(256, 512, 2, True)
|
||||
# 256 8 4
|
||||
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
||||
# 256 1 1
|
||||
self.reid = reid
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(512, 256),
|
||||
nn.BatchNorm1d(256),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(),
|
||||
nn.Linear(256, num_classes),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
x = self.avgpool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
# B x 128
|
||||
if self.reid:
|
||||
x = x.div(x.norm(p=2, dim=1, keepdim=True))
|
||||
return x
|
||||
# classifier
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
net = Net()
|
||||
x = torch.randn(4, 3, 128, 64)
|
||||
y = net(x)
|
||||
|
67
deep_sort/deep/multi_train_utils/distributed_utils.py
Normal file
67
deep_sort/deep/multi_train_utils/distributed_utils.py
Normal file
@ -0,0 +1,67 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def init_distributed_mode(args):
|
||||
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
||||
args.rank = int(os.environ['RANK'])
|
||||
args.world_size = int(os.environ['WORLD_SIZE'])
|
||||
args.gpu = int(os.environ['LOCAL_RANK'])
|
||||
elif 'SLURM_PROCID' in os.environ:
|
||||
args.rank = int(os.environ['SLURM_PROCID'])
|
||||
args.gpu = args.rank % torch.cuda.device_count()
|
||||
else:
|
||||
print("Not using distributed mode")
|
||||
args.distributed = False
|
||||
return
|
||||
|
||||
args.distributed = True
|
||||
|
||||
torch.cuda.set_device(args.gpu)
|
||||
args.dist_backend = 'nccl'
|
||||
print('| distributed init (rank {}): {}'.format(args.rank, args.dist_url), flush=True)
|
||||
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
||||
world_size=args.world_size, rank=args.rank)
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def cleanup():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def is_dist_avail_and_initialized():
|
||||
if not dist.is_available():
|
||||
return False
|
||||
if not dist.is_initialized():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_world_size():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 1
|
||||
return dist.get_world_size()
|
||||
|
||||
|
||||
def get_rank():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def is_main_process():
|
||||
return get_rank() == 0
|
||||
|
||||
|
||||
def reduce_value(value, average=True):
|
||||
world_size = get_world_size()
|
||||
if world_size < 2:
|
||||
return value
|
||||
with torch.no_grad():
|
||||
dist.all_reduce(value)
|
||||
if average:
|
||||
value /= world_size
|
||||
|
||||
return value
|
90
deep_sort/deep/multi_train_utils/train_eval_utils.py
Normal file
90
deep_sort/deep/multi_train_utils/train_eval_utils.py
Normal file
@ -0,0 +1,90 @@
|
||||
import sys
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
from .distributed_utils import reduce_value, is_main_process
|
||||
|
||||
|
||||
def load_model(state_dict, model_state_dict, model):
|
||||
for k in state_dict:
|
||||
if k in model_state_dict:
|
||||
if state_dict[k].shape != model_state_dict[k].shape:
|
||||
print('Skip loading parameter {}, required shape {}, ' \
|
||||
'loaded shape {}.'.format(
|
||||
k, model_state_dict[k].shape, state_dict[k].shape))
|
||||
state_dict[k] = model_state_dict[k]
|
||||
else:
|
||||
print('Drop parameter {}.'.format(k))
|
||||
for k in model_state_dict:
|
||||
if not (k in state_dict):
|
||||
print('No param {}.'.format(k))
|
||||
state_dict[k] = model_state_dict[k]
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
return model
|
||||
|
||||
|
||||
def train_one_epoch(model, optimizer, data_loader, device, epoch):
|
||||
model.train()
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
mean_loss = torch.zeros(1).to(device)
|
||||
sum_num = torch.zeros(1).to(device)
|
||||
optimizer.zero_grad()
|
||||
|
||||
if is_main_process():
|
||||
data_loader = tqdm(data_loader, file=sys.stdout)
|
||||
|
||||
for idx, (images, labels) in enumerate(data_loader):
|
||||
# forward
|
||||
images, labels = images.to(device), labels.to(device)
|
||||
outputs = model(images)
|
||||
loss = criterion(outputs, labels)
|
||||
|
||||
# backward
|
||||
loss.backward()
|
||||
loss = reduce_value(loss, average=True)
|
||||
mean_loss = (mean_loss * idx + loss.detach()) / (idx + 1)
|
||||
pred = torch.max(outputs, dim=1)[1]
|
||||
sum_num += torch.eq(pred, labels).sum()
|
||||
|
||||
if is_main_process():
|
||||
data_loader.desc = '[epoch {}] mean loss {}'.format(epoch, mean_loss.item())
|
||||
|
||||
if not torch.isfinite(loss):
|
||||
print('loss is infinite, ending training')
|
||||
sys.exit(1)
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
if device != torch.device('cpu'):
|
||||
torch.cuda.synchronize(device)
|
||||
sum_num = reduce_value(sum_num, average=False)
|
||||
|
||||
return sum_num.item(), mean_loss.item()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(model, data_loader, device):
|
||||
model.eval()
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
test_loss = torch.zeros(1).to(device)
|
||||
sum_num = torch.zeros(1).to(device)
|
||||
if is_main_process():
|
||||
data_loader = tqdm(data_loader, file=sys.stdout)
|
||||
|
||||
for idx, (inputs, labels) in enumerate(data_loader):
|
||||
inputs, labels = inputs.to(device), labels.to(device)
|
||||
outputs = model(inputs)
|
||||
loss = criterion(outputs, labels)
|
||||
loss = reduce_value(loss, average=True)
|
||||
|
||||
test_loss = (test_loss * idx + loss.detach()) / (idx + 1)
|
||||
pred = torch.max(outputs, dim=1)[1]
|
||||
sum_num += torch.eq(pred, labels).sum()
|
||||
|
||||
if device != torch.device('cpu'):
|
||||
torch.cuda.synchronize(device)
|
||||
|
||||
sum_num = reduce_value(sum_num, average=False)
|
||||
|
||||
return sum_num.item(), test_loss.item()
|
173
deep_sort/deep/resnet.py
Normal file
173
deep_sort/deep/resnet.py
Normal file
@ -0,0 +1,173 @@
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3,
|
||||
stride=stride, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(out_channel)
|
||||
self.relu = nn.ReLU()
|
||||
self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3,
|
||||
stride=1, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(out_channel)
|
||||
self.downsample = downsample
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, in_channel, out_channel, stride=1, downsample=None,
|
||||
groups=1, width_per_group=64):
|
||||
super(Bottleneck, self).__init__()
|
||||
width = int(out_channel * (width_per_group / 64.)) * groups
|
||||
|
||||
self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width, kernel_size=1,
|
||||
stride=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(width)
|
||||
self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, kernel_size=3,
|
||||
stride=stride, padding=1, bias=False, groups=groups)
|
||||
self.bn2 = nn.BatchNorm2d(width)
|
||||
self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel * self.expansion,
|
||||
kernel_size=1, stride=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(out_channel * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
|
||||
def __init__(self, block, blocks_num, reid=False, num_classes=1000, groups=1, width_per_group=64):
|
||||
super(ResNet, self).__init__()
|
||||
self.reid = reid
|
||||
self.in_channel = 64
|
||||
|
||||
self.groups = groups
|
||||
self.width_per_group = width_per_group
|
||||
|
||||
self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
|
||||
padding=3, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(self.in_channel)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layers(block, 64, blocks_num[0])
|
||||
self.layer2 = self._make_layers(block, 128, blocks_num[1], stride=2)
|
||||
self.layer3 = self._make_layers(block, 256, blocks_num[2], stride=2)
|
||||
# self.layer4 = self._make_layers(block, 512, blocks_num[3], stride=1)
|
||||
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(256 * block.expansion, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def _make_layers(self, block, channel, block_num, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.in_channel != channel * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(channel * block.expansion)
|
||||
)
|
||||
layers = []
|
||||
layers.append(block(self.in_channel, channel, downsample=downsample, stride=stride,
|
||||
groups=self.groups, width_per_group=self.width_per_group))
|
||||
self.in_channel = channel * block.expansion
|
||||
|
||||
for _ in range(1, block_num):
|
||||
layers.append(block(self.in_channel, channel, groups=self.groups, width_per_group=self.width_per_group))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
# x = self.layer4(x)
|
||||
x = self.avgpool(x)
|
||||
x = torch.flatten(x, 1)
|
||||
|
||||
# B x 512
|
||||
if self.reid:
|
||||
x = x.div(x.norm(p=2, dim=1, keepdim=True))
|
||||
return x
|
||||
# classifier
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
||||
def resnet18(num_classes=1000, reid=False):
|
||||
# https://download.pytorch.org/models/resnet18-5c106cde.pth
|
||||
return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, reid=reid)
|
||||
|
||||
|
||||
def resnet34(num_classes=1000, reid=False):
|
||||
# https://download.pytorch.org/models/resnet34-333f7ec4.pth
|
||||
return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, reid=reid)
|
||||
|
||||
|
||||
def resnet50(num_classes=1000, reid=False):
|
||||
# https://download.pytorch.org/models/resnet50-19c8e357.pth
|
||||
return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, reid=reid)
|
||||
|
||||
|
||||
def resnext50_32x4d(num_classes=1000, reid=False):
|
||||
# https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth
|
||||
groups = 32
|
||||
width_per_group = 4
|
||||
return ResNet(Bottleneck, [3, 4, 6, 3], reid=reid,
|
||||
num_classes=num_classes, groups=groups, width_per_group=width_per_group)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
net = resnet18(reid=True)
|
||||
x = torch.randn(4, 3, 128, 64)
|
||||
y = net(x)
|
77
deep_sort/deep/test.py
Normal file
77
deep_sort/deep/test.py
Normal file
@ -0,0 +1,77 @@
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torchvision
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from model import Net
|
||||
|
||||
parser = argparse.ArgumentParser(description="Train on market1501")
|
||||
parser.add_argument("--data-dir", default='data', type=str)
|
||||
parser.add_argument("--no-cuda", action="store_true")
|
||||
parser.add_argument("--gpu-id", default=0, type=int)
|
||||
args = parser.parse_args()
|
||||
|
||||
# device
|
||||
device = "cuda:{}".format(args.gpu_id) if torch.cuda.is_available() and not args.no_cuda else "cpu"
|
||||
if torch.cuda.is_available() and not args.no_cuda:
|
||||
cudnn.benchmark = True
|
||||
|
||||
# data loader
|
||||
root = args.data_dir
|
||||
query_dir = os.path.join(root, "query")
|
||||
gallery_dir = os.path.join(root, "gallery")
|
||||
transform = torchvision.transforms.Compose([
|
||||
torchvision.transforms.Resize((128, 64)),
|
||||
torchvision.transforms.ToTensor(),
|
||||
torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
queryloader = torch.utils.data.DataLoader(
|
||||
torchvision.datasets.ImageFolder(query_dir, transform=transform),
|
||||
batch_size=64, shuffle=False
|
||||
)
|
||||
galleryloader = torch.utils.data.DataLoader(
|
||||
torchvision.datas0ets.ImageFolder(gallery_dir, transform=transform),
|
||||
batch_size=64, shuffle=False
|
||||
)
|
||||
|
||||
# net definition
|
||||
net = Net(reid=True)
|
||||
assert os.path.isfile("./checkpoint/ckpt.t7"), "Error: no checkpoint file found!"
|
||||
print('Loading from checkpoint/ckpt.t7')
|
||||
checkpoint = torch.load("./checkpoint/ckpt.t7")
|
||||
net_dict = checkpoint['net_dict']
|
||||
net.load_state_dict(net_dict, strict=False)
|
||||
net.eval()
|
||||
net.to(device)
|
||||
|
||||
# compute features
|
||||
query_features = torch.tensor([]).float()
|
||||
query_labels = torch.tensor([]).long()
|
||||
gallery_features = torch.tensor([]).float()
|
||||
gallery_labels = torch.tensor([]).long()
|
||||
|
||||
with torch.no_grad():
|
||||
for idx, (inputs, labels) in enumerate(queryloader):
|
||||
inputs = inputs.to(device)
|
||||
features = net(inputs).cpu()
|
||||
query_features = torch.cat((query_features, features), dim=0)
|
||||
query_labels = torch.cat((query_labels, labels))
|
||||
|
||||
for idx, (inputs, labels) in enumerate(galleryloader):
|
||||
inputs = inputs.to(device)
|
||||
features = net(inputs).cpu()
|
||||
gallery_features = torch.cat((gallery_features, features), dim=0)
|
||||
gallery_labels = torch.cat((gallery_labels, labels))
|
||||
|
||||
gallery_labels -= 2
|
||||
|
||||
# save features
|
||||
features = {
|
||||
"qf": query_features,
|
||||
"ql": query_labels,
|
||||
"gf": gallery_features,
|
||||
"gl": gallery_labels
|
||||
}
|
||||
torch.save(features, "features.pth")
|
BIN
deep_sort/deep/train.jpg
Normal file
BIN
deep_sort/deep/train.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 59 KiB |
151
deep_sort/deep/train.py
Normal file
151
deep_sort/deep/train.py
Normal file
@ -0,0 +1,151 @@
|
||||
import argparse
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import math
|
||||
import warnings
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
import torchvision
|
||||
from torch.optim import lr_scheduler
|
||||
|
||||
from multi_train_utils.distributed_utils import init_distributed_mode, cleanup
|
||||
from multi_train_utils.train_eval_utils import train_one_epoch, evaluate, load_model
|
||||
import torch.distributed as dist
|
||||
from datasets import ClsDataset, read_split_data
|
||||
|
||||
from model import Net
|
||||
from resnet import resnet18
|
||||
|
||||
# plot figure
|
||||
x_epoch = []
|
||||
record = {'train_loss': [], 'train_err': [], 'test_loss': [], 'test_err': []}
|
||||
fig = plt.figure()
|
||||
ax0 = fig.add_subplot(121, title="loss")
|
||||
ax1 = fig.add_subplot(122, title="top1_err")
|
||||
|
||||
|
||||
def draw_curve(epoch, train_loss, train_err, test_loss, test_err):
|
||||
global record
|
||||
record['train_loss'].append(train_loss)
|
||||
record['train_err'].append(train_err)
|
||||
record['test_loss'].append(test_loss)
|
||||
record['test_err'].append(test_err)
|
||||
|
||||
x_epoch.append(epoch)
|
||||
ax0.plot(x_epoch, record['train_loss'], 'bo-', label='train')
|
||||
ax0.plot(x_epoch, record['test_loss'], 'ro-', label='val')
|
||||
ax1.plot(x_epoch, record['train_err'], 'bo-', label='train')
|
||||
ax1.plot(x_epoch, record['test_err'], 'ro-', label='val')
|
||||
if epoch == 0:
|
||||
ax0.legend()
|
||||
ax1.legend()
|
||||
fig.savefig("train.jpg")
|
||||
|
||||
|
||||
def main(args):
|
||||
batch_size = args.batch_size
|
||||
device = 'cuda:{}'.format(args.gpu_id) if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
train_info, val_info, num_classes = read_split_data(args.data_dir, valid_rate=0.2)
|
||||
train_images_path, train_labels = train_info
|
||||
val_images_path, val_labels = val_info
|
||||
|
||||
transform_train = torchvision.transforms.Compose([
|
||||
torchvision.transforms.RandomCrop((128, 64), padding=4),
|
||||
torchvision.transforms.RandomHorizontalFlip(),
|
||||
torchvision.transforms.ToTensor(),
|
||||
torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
transform_val = torchvision.transforms.Compose([
|
||||
torchvision.transforms.Resize((128, 64)),
|
||||
torchvision.transforms.ToTensor(),
|
||||
torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
train_dataset = ClsDataset(
|
||||
images_path=train_images_path,
|
||||
images_labels=train_labels,
|
||||
transform=transform_train
|
||||
)
|
||||
val_dataset = ClsDataset(
|
||||
images_path=val_images_path,
|
||||
images_labels=val_labels,
|
||||
transform=transform_val
|
||||
)
|
||||
|
||||
number_workers = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
|
||||
print('Using {} dataloader workers every process'.format(number_workers))
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=True,
|
||||
num_workers=number_workers
|
||||
)
|
||||
val_loader = torch.utils.data.DataLoader(
|
||||
val_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
pin_memory=True,
|
||||
num_workers=number_workers,
|
||||
)
|
||||
|
||||
# net definition
|
||||
start_epoch = 0
|
||||
net = Net(num_classes=num_classes)
|
||||
if args.weights:
|
||||
print('Loading from ', args.weights)
|
||||
checkpoint = torch.load(args.weights, map_location='cpu')
|
||||
net_dict = checkpoint if 'net_dict' not in checkpoint else checkpoint['net_dict']
|
||||
start_epoch = checkpoint['epoch'] if 'epoch' in checkpoint else start_epoch
|
||||
net = load_model(net_dict, net.state_dict(), net)
|
||||
|
||||
if args.freeze_layers:
|
||||
for name, param in net.named_parameters():
|
||||
if 'classifier' not in name:
|
||||
param.requires_grad = False
|
||||
|
||||
net.to(device)
|
||||
|
||||
# loss and optimizer
|
||||
pg = [p for p in net.parameters() if p.requires_grad]
|
||||
optimizer = torch.optim.SGD(pg, args.lr, momentum=0.9, weight_decay=5e-4)
|
||||
|
||||
lr = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf
|
||||
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr)
|
||||
for epoch in range(start_epoch, start_epoch + args.epochs):
|
||||
train_positive, train_loss = train_one_epoch(net, optimizer, train_loader, device, epoch)
|
||||
train_acc = train_positive / len(train_dataset)
|
||||
scheduler.step()
|
||||
|
||||
test_positive, test_loss = evaluate(net, val_loader, device)
|
||||
test_acc = test_positive / len(val_dataset)
|
||||
|
||||
print('[epoch {}] accuracy: {}'.format(epoch, test_acc))
|
||||
|
||||
state_dict = {
|
||||
'net_dict': net.state_dict(),
|
||||
'acc': test_acc,
|
||||
'epoch': epoch
|
||||
}
|
||||
torch.save(state_dict, './checkpoint/model_{}.pth'.format(epoch))
|
||||
draw_curve(epoch, train_loss, 1 - train_acc, test_loss, 1 - test_acc)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="Train on market1501")
|
||||
parser.add_argument("--data-dir", default='data', type=str)
|
||||
parser.add_argument('--epochs', type=int, default=40)
|
||||
parser.add_argument('--batch_size', type=int, default=32)
|
||||
parser.add_argument("--lr", default=0.001, type=float)
|
||||
parser.add_argument('--lrf', default=0.1, type=float)
|
||||
|
||||
parser.add_argument('--weights', type=str, default='./checkpoint/resnet18.pth')
|
||||
parser.add_argument('--freeze-layers', action='store_true')
|
||||
|
||||
parser.add_argument('--gpu_id', default='0', help='gpu id')
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
189
deep_sort/deep/train_multiGPU.py
Normal file
189
deep_sort/deep/train_multiGPU.py
Normal file
@ -0,0 +1,189 @@
|
||||
import argparse
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import math
|
||||
import warnings
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
import torchvision
|
||||
from torch.optim import lr_scheduler
|
||||
|
||||
from multi_train_utils.distributed_utils import init_distributed_mode, cleanup
|
||||
from multi_train_utils.train_eval_utils import train_one_epoch, evaluate, load_model
|
||||
import torch.distributed as dist
|
||||
from datasets import ClsDataset, read_split_data
|
||||
|
||||
from resnet import resnet18
|
||||
|
||||
|
||||
# plot figure
|
||||
x_epoch = []
|
||||
record = {'train_loss': [], 'train_err': [], 'test_loss': [], 'test_err': []}
|
||||
fig = plt.figure()
|
||||
ax0 = fig.add_subplot(121, title="loss")
|
||||
ax1 = fig.add_subplot(122, title="top1_err")
|
||||
|
||||
|
||||
def draw_curve(epoch, train_loss, train_err, test_loss, test_err):
|
||||
global record
|
||||
record['train_loss'].append(train_loss)
|
||||
record['train_err'].append(train_err)
|
||||
record['test_loss'].append(test_loss)
|
||||
record['test_err'].append(test_err)
|
||||
|
||||
x_epoch.append(epoch)
|
||||
ax0.plot(x_epoch, record['train_loss'], 'bo-', label='train')
|
||||
ax0.plot(x_epoch, record['test_loss'], 'ro-', label='val')
|
||||
ax1.plot(x_epoch, record['train_err'], 'bo-', label='train')
|
||||
ax1.plot(x_epoch, record['test_err'], 'ro-', label='val')
|
||||
if epoch == 0:
|
||||
ax0.legend()
|
||||
ax1.legend()
|
||||
fig.savefig("train.jpg")
|
||||
|
||||
|
||||
def main(args):
|
||||
init_distributed_mode(args)
|
||||
|
||||
rank = args.rank
|
||||
device = torch.device(args.device)
|
||||
batch_size = args.batch_size
|
||||
weights_path = args.weights
|
||||
args.lr *= args.world_size
|
||||
checkpoint_path = ''
|
||||
|
||||
if rank == 0:
|
||||
print(args)
|
||||
if os.path.exists('./checkpoint') is False:
|
||||
os.mkdir('./checkpoint')
|
||||
|
||||
train_info, val_info, num_classes = read_split_data(args.data_dir, valid_rate=0.2)
|
||||
train_images_path, train_labels = train_info
|
||||
val_images_path, val_labels = val_info
|
||||
|
||||
transform_train = torchvision.transforms.Compose([
|
||||
torchvision.transforms.RandomCrop((128, 64), padding=4),
|
||||
torchvision.transforms.RandomHorizontalFlip(),
|
||||
torchvision.transforms.ToTensor(),
|
||||
torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
transform_val = torchvision.transforms.Compose([
|
||||
torchvision.transforms.Resize((128, 64)),
|
||||
torchvision.transforms.ToTensor(),
|
||||
torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
train_dataset = ClsDataset(
|
||||
images_path=train_images_path,
|
||||
images_labels=train_labels,
|
||||
transform=transform_train
|
||||
)
|
||||
val_dataset = ClsDataset(
|
||||
images_path=val_images_path,
|
||||
images_labels=val_labels,
|
||||
transform=transform_val
|
||||
)
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
|
||||
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
|
||||
|
||||
train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size, drop_last=True)
|
||||
|
||||
number_workers = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
|
||||
|
||||
if rank == 0:
|
||||
print('Using {} dataloader workers every process'.format(number_workers))
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_sampler=train_batch_sampler,
|
||||
pin_memory=True,
|
||||
num_workers=number_workers
|
||||
)
|
||||
val_loader = torch.utils.data.DataLoader(
|
||||
val_dataset,
|
||||
sampler=val_sampler,
|
||||
batch_size=batch_size,
|
||||
pin_memory=True,
|
||||
num_workers=number_workers,
|
||||
)
|
||||
|
||||
# net definition
|
||||
start_epoch = 0
|
||||
net = resnet18(num_classes=num_classes)
|
||||
if args.weights:
|
||||
print('Loading from ', args.weights)
|
||||
checkpoint = torch.load(args.weights, map_location='cpu')
|
||||
net_dict = checkpoint if 'net_dict' not in checkpoint else checkpoint['net_dict']
|
||||
start_epoch = checkpoint['epoch'] if 'epoch' in checkpoint else start_epoch
|
||||
net = load_model(net_dict, net.state_dict(), net)
|
||||
else:
|
||||
warnings.warn("better providing pretraining weights")
|
||||
checkpoint_path = os.path.join(tempfile.gettempdir(), 'initial_weights.pth')
|
||||
if rank == 0:
|
||||
torch.save(net.state_dict(), checkpoint_path)
|
||||
|
||||
dist.barrier()
|
||||
net.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
|
||||
|
||||
if args.freeze_layers:
|
||||
for name, param in net.named_parameters():
|
||||
if 'fc' not in name:
|
||||
param.requires_grad = False
|
||||
else:
|
||||
if args.syncBN:
|
||||
net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)
|
||||
net.to(device)
|
||||
|
||||
net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[args.gpu])
|
||||
|
||||
# loss and optimizer
|
||||
pg = [p for p in net.parameters() if p.requires_grad]
|
||||
optimizer = torch.optim.SGD(pg, args.lr, momentum=0.9, weight_decay=5e-4)
|
||||
|
||||
lr = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf
|
||||
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr)
|
||||
for epoch in range(start_epoch, start_epoch + args.epochs):
|
||||
train_positive, train_loss = train_one_epoch(net, optimizer, train_loader, device, epoch)
|
||||
train_acc = train_positive / len(train_dataset)
|
||||
scheduler.step()
|
||||
|
||||
test_positive, test_loss = evaluate(net, val_loader, device)
|
||||
test_acc = test_positive / len(val_dataset)
|
||||
|
||||
if rank == 0:
|
||||
print('[epoch {}] accuracy: {}'.format(epoch, test_acc))
|
||||
|
||||
state_dict = {
|
||||
'net_dict': net.module.state_dict(),
|
||||
'acc': test_acc,
|
||||
'epoch': epoch
|
||||
}
|
||||
torch.save(state_dict, './checkpoint/model_{}.pth'.format(epoch))
|
||||
draw_curve(epoch, train_loss, 1 - train_acc, test_loss, 1 - test_acc)
|
||||
|
||||
if rank == 0:
|
||||
if os.path.exists(checkpoint_path) is True:
|
||||
os.remove(checkpoint_path)
|
||||
cleanup()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="Train on market1501")
|
||||
parser.add_argument("--data-dir", default='data', type=str)
|
||||
parser.add_argument('--epochs', type=int, default=40)
|
||||
parser.add_argument('--batch_size', type=int, default=32)
|
||||
parser.add_argument("--lr", default=0.001, type=float)
|
||||
parser.add_argument('--lrf', default=0.1, type=float)
|
||||
parser.add_argument('--syncBN', type=bool, default=True)
|
||||
|
||||
parser.add_argument('--weights', type=str, default='./checkpoint/resnet18.pth')
|
||||
parser.add_argument('--freeze-layers', action='store_true')
|
||||
|
||||
# not change the following parameters, the system will automatically assignment
|
||||
parser.add_argument('--device', default='cuda', help='device id (i.e. 0 or 0, 1 or cpu)')
|
||||
parser.add_argument('--world_size', default=4, type=int, help='number of distributed processes')
|
||||
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
Reference in New Issue
Block a user