完成训练模块的转移
This commit is contained in:
67
deep_sort/deep/multi_train_utils/distributed_utils.py
Normal file
67
deep_sort/deep/multi_train_utils/distributed_utils.py
Normal file
@ -0,0 +1,67 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def init_distributed_mode(args):
|
||||
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
||||
args.rank = int(os.environ['RANK'])
|
||||
args.world_size = int(os.environ['WORLD_SIZE'])
|
||||
args.gpu = int(os.environ['LOCAL_RANK'])
|
||||
elif 'SLURM_PROCID' in os.environ:
|
||||
args.rank = int(os.environ['SLURM_PROCID'])
|
||||
args.gpu = args.rank % torch.cuda.device_count()
|
||||
else:
|
||||
print("Not using distributed mode")
|
||||
args.distributed = False
|
||||
return
|
||||
|
||||
args.distributed = True
|
||||
|
||||
torch.cuda.set_device(args.gpu)
|
||||
args.dist_backend = 'nccl'
|
||||
print('| distributed init (rank {}): {}'.format(args.rank, args.dist_url), flush=True)
|
||||
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
||||
world_size=args.world_size, rank=args.rank)
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def cleanup():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def is_dist_avail_and_initialized():
|
||||
if not dist.is_available():
|
||||
return False
|
||||
if not dist.is_initialized():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_world_size():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 1
|
||||
return dist.get_world_size()
|
||||
|
||||
|
||||
def get_rank():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def is_main_process():
|
||||
return get_rank() == 0
|
||||
|
||||
|
||||
def reduce_value(value, average=True):
|
||||
world_size = get_world_size()
|
||||
if world_size < 2:
|
||||
return value
|
||||
with torch.no_grad():
|
||||
dist.all_reduce(value)
|
||||
if average:
|
||||
value /= world_size
|
||||
|
||||
return value
|
90
deep_sort/deep/multi_train_utils/train_eval_utils.py
Normal file
90
deep_sort/deep/multi_train_utils/train_eval_utils.py
Normal file
@ -0,0 +1,90 @@
|
||||
import sys
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
from .distributed_utils import reduce_value, is_main_process
|
||||
|
||||
|
||||
def load_model(state_dict, model_state_dict, model):
|
||||
for k in state_dict:
|
||||
if k in model_state_dict:
|
||||
if state_dict[k].shape != model_state_dict[k].shape:
|
||||
print('Skip loading parameter {}, required shape {}, ' \
|
||||
'loaded shape {}.'.format(
|
||||
k, model_state_dict[k].shape, state_dict[k].shape))
|
||||
state_dict[k] = model_state_dict[k]
|
||||
else:
|
||||
print('Drop parameter {}.'.format(k))
|
||||
for k in model_state_dict:
|
||||
if not (k in state_dict):
|
||||
print('No param {}.'.format(k))
|
||||
state_dict[k] = model_state_dict[k]
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
return model
|
||||
|
||||
|
||||
def train_one_epoch(model, optimizer, data_loader, device, epoch):
|
||||
model.train()
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
mean_loss = torch.zeros(1).to(device)
|
||||
sum_num = torch.zeros(1).to(device)
|
||||
optimizer.zero_grad()
|
||||
|
||||
if is_main_process():
|
||||
data_loader = tqdm(data_loader, file=sys.stdout)
|
||||
|
||||
for idx, (images, labels) in enumerate(data_loader):
|
||||
# forward
|
||||
images, labels = images.to(device), labels.to(device)
|
||||
outputs = model(images)
|
||||
loss = criterion(outputs, labels)
|
||||
|
||||
# backward
|
||||
loss.backward()
|
||||
loss = reduce_value(loss, average=True)
|
||||
mean_loss = (mean_loss * idx + loss.detach()) / (idx + 1)
|
||||
pred = torch.max(outputs, dim=1)[1]
|
||||
sum_num += torch.eq(pred, labels).sum()
|
||||
|
||||
if is_main_process():
|
||||
data_loader.desc = '[epoch {}] mean loss {}'.format(epoch, mean_loss.item())
|
||||
|
||||
if not torch.isfinite(loss):
|
||||
print('loss is infinite, ending training')
|
||||
sys.exit(1)
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
if device != torch.device('cpu'):
|
||||
torch.cuda.synchronize(device)
|
||||
sum_num = reduce_value(sum_num, average=False)
|
||||
|
||||
return sum_num.item(), mean_loss.item()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(model, data_loader, device):
|
||||
model.eval()
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
test_loss = torch.zeros(1).to(device)
|
||||
sum_num = torch.zeros(1).to(device)
|
||||
if is_main_process():
|
||||
data_loader = tqdm(data_loader, file=sys.stdout)
|
||||
|
||||
for idx, (inputs, labels) in enumerate(data_loader):
|
||||
inputs, labels = inputs.to(device), labels.to(device)
|
||||
outputs = model(inputs)
|
||||
loss = criterion(outputs, labels)
|
||||
loss = reduce_value(loss, average=True)
|
||||
|
||||
test_loss = (test_loss * idx + loss.detach()) / (idx + 1)
|
||||
pred = torch.max(outputs, dim=1)[1]
|
||||
sum_num += torch.eq(pred, labels).sum()
|
||||
|
||||
if device != torch.device('cpu'):
|
||||
torch.cuda.synchronize(device)
|
||||
|
||||
sum_num = reduce_value(sum_num, average=False)
|
||||
|
||||
return sum_num.item(), test_loss.item()
|
Reference in New Issue
Block a user