This commit is contained in:
wudong 2022-11-18 15:44:27 +08:00
commit f6135a55a4
2 changed files with 11 additions and 6 deletions

View File

@ -297,6 +297,7 @@ def error_return(id: str, data):
# 启动训练 # 启动训练
@start_train_algorithm() @start_train_algorithm()
def train_R0DY(params_str, id): def train_R0DY(params_str, id):
print(params_str)
from app.yolov5.train_server import train_start from app.yolov5.train_server import train_start
params = TrainParams() params = TrainParams()
params.read_from_str(params_str) params.read_from_str(params_str)
@ -308,12 +309,12 @@ def train_R0DY(params_str, id):
epoches = params.get('epochnum').value epoches = params.get('epochnum').value
batch_size = params.get('batch_size').value batch_size = params.get('batch_size').value
device = params.get('device').value device = params.get('device').value
try: #try:
train_start(weights, savemodel, epoches, img_size, batch_size, device, data_list, id) train_start(weights, savemodel, epoches, img_size, batch_size, device, data_list, id)
print("train down!") print("train down!")
except Exception as e: # except Exception as e:
print(repr(e)) # print(repr(e))
error_return(id=id,data=repr(e)) # error_return(id=id,data=repr(e))
# 启动验证程序 # 启动验证程序

View File

@ -93,6 +93,7 @@ def train(hyp, opt, device, data_list,id,callbacks): # hyp is path/to/hyp.yaml
#将数据路径写到yaml文件中 #将数据路径写到yaml文件中
#data_list = file_tool.get_file(proj_no=pro) #data_list = file_tool.get_file(proj_no=pro)
# print(data_list) # print(data_list)
print("get in train()")
yaml_rewrite(file=opt.data, data_list=data_list) yaml_rewrite(file=opt.data, data_list=data_list)
save_dir, epochs,batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze = \ save_dir, epochs,batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze = \
@ -209,6 +210,7 @@ def train(hyp, opt, device, data_list,id,callbacks): # hyp is path/to/hyp.yaml
ema = ModelEMA(model) if RANK in {-1, 0} else None ema = ModelEMA(model) if RANK in {-1, 0} else None
# Resume # Resume
print("Resume")
best_fitness, start_epoch = 0.0, 0 best_fitness, start_epoch = 0.0, 0
if pretrained: if pretrained:
if resume: if resume:
@ -226,6 +228,7 @@ def train(hyp, opt, device, data_list,id,callbacks): # hyp is path/to/hyp.yaml
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
LOGGER.info('Using SyncBatchNorm()') LOGGER.info('Using SyncBatchNorm()')
print("Trainloader")
# Trainloader # Trainloader
train_loader, dataset = create_dataloader(train_path, train_loader, dataset = create_dataloader(train_path,
imgsz, imgsz,
@ -283,6 +286,7 @@ def train(hyp, opt, device, data_list,id,callbacks): # hyp is path/to/hyp.yaml
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
model.names = names model.names = names
print("Start training")
# Start training # Start training
t0 = time.time() t0 = time.time()
nb = len(train_loader) # number of batches nb = len(train_loader) # number of batches