This commit is contained in:
552068321@qq.com 2022-11-08 09:59:54 +08:00
parent 317a71ed5f
commit ee5d09adfe
4 changed files with 9 additions and 9 deletions

View File

@ -242,7 +242,7 @@ from app import file_tool
# 启动训练 # 启动训练
@start_train_algorithm() #@start_train_algorithm()
def train_R0DY(params_str, id): def train_R0DY(params_str, id):
from app.yolov5.train_server import train_start from app.yolov5.train_server import train_start
params = TrainParams() params = TrainParams()
@ -314,7 +314,7 @@ def train_R0DY(params_str, id):
# zip_outputPath = os.path.join(exp_outputPath, "inference_model.zip") # zip_outputPath = os.path.join(exp_outputPath, "inference_model.zip")
@obtain_train_param() #@obtain_train_param()
def returnTrainParams(): def returnTrainParams():
# nvmlInit() # nvmlInit()
# gpuDeviceCount = nvmlDeviceGetCount() # 获取Nvidia GPU块数 # gpuDeviceCount = nvmlDeviceGetCount() # 获取Nvidia GPU块数
@ -338,7 +338,7 @@ def returnTrainParams():
{"index": 7, "name": "CLASS_NAMES", "value": ['hole', '456'], "description": '类别名称', "default": '', "type": "L", {"index": 7, "name": "CLASS_NAMES", "value": ['hole', '456'], "description": '类别名称', "default": '', "type": "L",
"items": '', "items": '',
'show': False}, 'show': False},
{"index": 8, "name": "DatasetDir", "value": "E:/aicheck/data_set/11442136178662604800/ori/", {"index": 8, "name": "DatasetDir", "value": "E:/aicheck/data_set/11442136178662604800/ori",
"description": '数据集路径', "description": '数据集路径',
"default": "./app/maskrcnn/datasets/test", "type": "S", 'show': False} # ORI_PATH "default": "./app/maskrcnn/datasets/test", "type": "S", 'show': False} # ORI_PATH
] ]

View File

@ -179,7 +179,7 @@ def get_file(ori_path: str, type_list: Union[object,str]):
test_files = [] test_files = []
# 训练、测试比例强制91 # 训练、测试比例强制91
for img in imgs[0:1]: for img in imgs[0:1]:
path = ori_path + '/images/' +img path = ori_path + '/images/' +img #'/images/'
# print(os.path.exists(path)) # print(os.path.exists(path))
print('图像路径',path) print('图像路径',path)
if os.path.exists(path): if os.path.exists(path):
@ -187,7 +187,7 @@ def get_file(ori_path: str, type_list: Union[object,str]):
print('1111') print('1111')
#label = ori_path + 'labels/' + os.path.split(path)[1] #label = ori_path + 'labels/' + os.path.split(path)[1]
(filename1, extension) = os.path.splitext(img) # 文件名与后缀名分开 (filename1, extension) = os.path.splitext(img) # 文件名与后缀名分开
label = ori_path + '/labels/' + filename1 + '.json' label = ori_path + '/labels/' + filename1 + '.json' #'/labels/'
print('标签',label) print('标签',label)
if label is not None: if label is not None:
#train_files.append(label) #train_files.append(label)

View File

@ -3,5 +3,5 @@ train: E:/aicheck/data_set/11442136178662604800/trained/images/train/
val: E:/aicheck/data_set/11442136178662604800/trained/images/val/ val: E:/aicheck/data_set/11442136178662604800/trained/images/val/
test: null test: null
names: names:
0: logo 0: hole
1: 3C 1: '456'

View File

@ -304,7 +304,7 @@ def train(hyp, opt, device, data_list,id,callbacks): # hyp is path/to/hyp.yaml
num_train_img=train_num, num_train_img=train_num,
train_mod_savepath=best) train_mod_savepath=best)
# @algorithm_process_value_websocket() @algorithm_process_value_websocket()
def report_cellback(i, num_epochs, reportAccu): def report_cellback(i, num_epochs, reportAccu):
report.rate_of_progess = ((i + 1) / num_epochs) * 100 report.rate_of_progess = ((i + 1) / num_epochs) * 100
report.progress = (i + 1) report.progress = (i + 1)
@ -470,7 +470,7 @@ def train(hyp, opt, device, data_list,id,callbacks): # hyp is path/to/hyp.yaml
print('##############',best) print('##############',best)
for f in best: for f in best:
print('##################',f) print('##################',f)
if os.path.exists(f): if os.path.exists(best):
strip_optimizer(f) # strip optimizers strip_optimizer(f) # strip optimizers
if f is best: if f is best:
LOGGER.info(f'\nValidating {f}...') LOGGER.info(f'\nValidating {f}...')