完成训练模块的转移

This commit is contained in:
2025-04-17 11:03:05 +08:00
parent 4439687870
commit 74e8f0d415
188 changed files with 32931 additions and 70 deletions

View File

@ -0,0 +1,67 @@
import os
import torch
import torch.distributed as dist
def init_distributed_mode(args):
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
args.rank = int(os.environ['RANK'])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
elif 'SLURM_PROCID' in os.environ:
args.rank = int(os.environ['SLURM_PROCID'])
args.gpu = args.rank % torch.cuda.device_count()
else:
print("Not using distributed mode")
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}'.format(args.rank, args.dist_url), flush=True)
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
dist.barrier()
def cleanup():
dist.destroy_process_group()
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def reduce_value(value, average=True):
world_size = get_world_size()
if world_size < 2:
return value
with torch.no_grad():
dist.all_reduce(value)
if average:
value /= world_size
return value

View File

@ -0,0 +1,90 @@
import sys
from tqdm import tqdm
import torch
from .distributed_utils import reduce_value, is_main_process
def load_model(state_dict, model_state_dict, model):
for k in state_dict:
if k in model_state_dict:
if state_dict[k].shape != model_state_dict[k].shape:
print('Skip loading parameter {}, required shape {}, ' \
'loaded shape {}.'.format(
k, model_state_dict[k].shape, state_dict[k].shape))
state_dict[k] = model_state_dict[k]
else:
print('Drop parameter {}.'.format(k))
for k in model_state_dict:
if not (k in state_dict):
print('No param {}.'.format(k))
state_dict[k] = model_state_dict[k]
model.load_state_dict(state_dict, strict=False)
return model
def train_one_epoch(model, optimizer, data_loader, device, epoch):
model.train()
criterion = torch.nn.CrossEntropyLoss()
mean_loss = torch.zeros(1).to(device)
sum_num = torch.zeros(1).to(device)
optimizer.zero_grad()
if is_main_process():
data_loader = tqdm(data_loader, file=sys.stdout)
for idx, (images, labels) in enumerate(data_loader):
# forward
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
# backward
loss.backward()
loss = reduce_value(loss, average=True)
mean_loss = (mean_loss * idx + loss.detach()) / (idx + 1)
pred = torch.max(outputs, dim=1)[1]
sum_num += torch.eq(pred, labels).sum()
if is_main_process():
data_loader.desc = '[epoch {}] mean loss {}'.format(epoch, mean_loss.item())
if not torch.isfinite(loss):
print('loss is infinite, ending training')
sys.exit(1)
optimizer.step()
optimizer.zero_grad()
if device != torch.device('cpu'):
torch.cuda.synchronize(device)
sum_num = reduce_value(sum_num, average=False)
return sum_num.item(), mean_loss.item()
@torch.no_grad()
def evaluate(model, data_loader, device):
model.eval()
criterion = torch.nn.CrossEntropyLoss()
test_loss = torch.zeros(1).to(device)
sum_num = torch.zeros(1).to(device)
if is_main_process():
data_loader = tqdm(data_loader, file=sys.stdout)
for idx, (inputs, labels) in enumerate(data_loader):
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
loss = reduce_value(loss, average=True)
test_loss = (test_loss * idx + loss.detach()) / (idx + 1)
pred = torch.max(outputs, dim=1)[1]
sum_num += torch.eq(pred, labels).sum()
if device != torch.device('cpu'):
torch.cuda.synchronize(device)
sum_num = reduce_value(sum_num, average=False)
return sum_num.item(), test_loss.item()