中断训练
This commit is contained in:
@ -1,486 +1,486 @@
|
||||
"""
|
||||
@Time : 2022/9/20 16:17
|
||||
@Auth : 东
|
||||
@File :AlgorithmController.py
|
||||
@IDE :PyCharm
|
||||
@Motto:ABC(Always Be Coding)
|
||||
@Desc:算法接口
|
||||
|
||||
"""
|
||||
import json
|
||||
from functools import wraps
|
||||
from threading import Thread
|
||||
from multiprocessing import Process
|
||||
from time import sleep
|
||||
|
||||
from flask import Blueprint, request
|
||||
|
||||
from app.schemas.TrainResult import Report, ProcessValueList
|
||||
from app.utils.RedisMQTool import Task
|
||||
from app.utils.StandardizedOutput import output_wrapped
|
||||
from app.utils.redis_config import redis_client
|
||||
from app.utils.websocket_tool import manager
|
||||
from app.configs.global_var import set_value
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from pynvml import *
|
||||
# FILE = Path(__file__).resolve()
|
||||
# ROOT = FILE.parents[0] # YOLOv5 root directory
|
||||
# if str(ROOT) not in sys.path:
|
||||
# sys.path.append(str(ROOT)) # add ROOT to PATH
|
||||
# sys.path.append("/mnt/sdc/algorithm/AICheck-MaskRCNN/app/maskrcnn_ppx")
|
||||
# import ppx as pdx
|
||||
|
||||
bp = Blueprint('AlgorithmController', __name__)
|
||||
|
||||
ifKillDict = {}
|
||||
|
||||
def start_train_algorithm():
|
||||
"""
|
||||
调用训练算法
|
||||
"""
|
||||
|
||||
def wrapTheFunction(func):
|
||||
@wraps(func)
|
||||
@bp.route('/start_train_algorithm', methods=['get'])
|
||||
def wrapped_function():
|
||||
param = request.args.get('param')
|
||||
id = request.args.get('id')
|
||||
dict = manager.active_connections_dist
|
||||
# t = Thread(target=func, args=(param, id))
|
||||
t = Process(target=func, args=(param, id, dict[id]), name=id)
|
||||
set_value(key=id, value=False)
|
||||
t.start()
|
||||
return output_wrapped(0, 'success', '成功')
|
||||
|
||||
return wrapped_function
|
||||
|
||||
return wrapTheFunction
|
||||
|
||||
|
||||
def start_test_algorithm():
|
||||
"""
|
||||
调用验证算法
|
||||
"""
|
||||
|
||||
def wrapTheFunction(func):
|
||||
@wraps(func)
|
||||
@bp.route('/start_test_algorithm', methods=['get'])
|
||||
def wrapped_function_test():
|
||||
param = request.args.get('param')
|
||||
id = request.args.get('id')
|
||||
t = Thread(target=func, args=(param, id))
|
||||
t.start()
|
||||
return output_wrapped(0, 'success', '成功')
|
||||
|
||||
return wrapped_function_test
|
||||
|
||||
return wrapTheFunction
|
||||
|
||||
|
||||
def start_detect_algorithm():
|
||||
"""
|
||||
调用检测算法
|
||||
"""
|
||||
|
||||
def wrapTheFunction(func):
|
||||
@wraps(func)
|
||||
@bp.route('/start_detect_algorithm', methods=['get'])
|
||||
def wrapped_function_detect():
|
||||
param = request.args.get('param')
|
||||
id = request.args.get('id')
|
||||
t = Thread(target=func, args=(param, id))
|
||||
t.start()
|
||||
return output_wrapped(0, 'success', '成功')
|
||||
|
||||
return wrapped_function_detect
|
||||
|
||||
return wrapTheFunction
|
||||
|
||||
|
||||
def start_download_pt():
|
||||
"""
|
||||
下载模型
|
||||
"""
|
||||
|
||||
def wrapTheFunction(func):
|
||||
@wraps(func)
|
||||
@bp.route('/start_download_pt', methods=['get'])
|
||||
def wrapped_function_start_download_pt():
|
||||
param = request.args.get('param')
|
||||
data = func(param)
|
||||
return output_wrapped(0, 'success', data)
|
||||
|
||||
return wrapped_function_start_download_pt
|
||||
|
||||
return wrapTheFunction
|
||||
|
||||
|
||||
def algorithm_process_value():
|
||||
"""
|
||||
获取中间值, redis订阅发布
|
||||
"""
|
||||
|
||||
def wrapTheFunction(func):
|
||||
@wraps(func)
|
||||
def wrapped_function(*args, **kwargs):
|
||||
data = func(*args, **kwargs)
|
||||
print(data)
|
||||
Task(redis_conn=redis_client.get_redis(), channel="ceshi").publish_task(
|
||||
data={'code': 0, 'msg': 'success', 'data': data})
|
||||
return output_wrapped(0, 'success', data)
|
||||
|
||||
return wrapped_function
|
||||
|
||||
return wrapTheFunction
|
||||
|
||||
|
||||
def algorithm_process_value_websocket():
|
||||
"""
|
||||
获取中间值, websocket发布
|
||||
"""
|
||||
|
||||
def wrapTheFunction(func):
|
||||
@wraps(func)
|
||||
def wrapped_function(*args, **kwargs):
|
||||
data = func(*args, **kwargs)
|
||||
id = data["id"]
|
||||
data_res = {'code': 0, "type": 'connected', 'msg': 'success', 'data': data}
|
||||
manager.send_message_proj_json(message=data_res, id=id)
|
||||
return data
|
||||
|
||||
return wrapped_function
|
||||
|
||||
return wrapTheFunction
|
||||
|
||||
def algorithm_kill_value_websocket():
|
||||
"""
|
||||
获取kill值, websocket发布
|
||||
"""
|
||||
|
||||
def wrapTheFunction(func):
|
||||
@wraps(func)
|
||||
def wrapped_function(*args, **kwargs):
|
||||
data = func(*args, **kwargs)
|
||||
id = data["id"]
|
||||
data_res = {'code': 1, "type": 'kill', 'msg': 'success', 'data': data}
|
||||
manager.send_message_proj_json(message=data_res, id=id)
|
||||
return data
|
||||
|
||||
return wrapped_function
|
||||
|
||||
return wrapTheFunction
|
||||
|
||||
|
||||
def algorithm_error_value_websocket():
|
||||
"""
|
||||
获取error值, websocket发布
|
||||
"""
|
||||
|
||||
def wrapTheFunction(func):
|
||||
@wraps(func)
|
||||
def wrapped_function(*args, **kwargs):
|
||||
data = func(*args, **kwargs)
|
||||
id = data["id"]
|
||||
data_res = {'code': 2, "type": 'error', 'msg': 'fail', 'data': data}
|
||||
manager.send_message_proj_json(message=data_res, id=id)
|
||||
return data
|
||||
|
||||
return wrapped_function
|
||||
|
||||
return wrapTheFunction
|
||||
|
||||
def obtain_train_param():
|
||||
"""
|
||||
获取训练参数
|
||||
"""
|
||||
|
||||
def wrapTheFunction(func):
|
||||
@wraps(func)
|
||||
@bp.route('/obtain_train_param', methods=['get'])
|
||||
def wrapped_function_train_param(*args, **kwargs):
|
||||
data = func(*args, **kwargs)
|
||||
return output_wrapped(0, 'success', data)
|
||||
|
||||
return wrapped_function_train_param
|
||||
|
||||
return wrapTheFunction
|
||||
|
||||
def obtain_test_param():
|
||||
"""
|
||||
获取验证参数
|
||||
"""
|
||||
|
||||
def wrapTheFunction(func):
|
||||
@wraps(func)
|
||||
@bp.route('/obtain_test_param', methods=['get'])
|
||||
def wrapped_function_test_param(*args, **kwargs):
|
||||
data = func(*args, **kwargs)
|
||||
return output_wrapped(0, 'success', data)
|
||||
|
||||
return wrapped_function_test_param
|
||||
|
||||
return wrapTheFunction
|
||||
|
||||
|
||||
def obtain_detect_param():
|
||||
"""
|
||||
获取测试参数
|
||||
"""
|
||||
|
||||
def wrapTheFunction(func):
|
||||
@wraps(func)
|
||||
@bp.route('/obtain_detect_param', methods=['get'])
|
||||
def wrapped_function_inf_param(*args, **kwargs):
|
||||
data = func(*args, **kwargs)
|
||||
return output_wrapped(0, 'success', data)
|
||||
|
||||
return wrapped_function_inf_param
|
||||
|
||||
return wrapTheFunction
|
||||
|
||||
|
||||
def obtain_download_pt_param():
|
||||
"""
|
||||
获取下载模型参数
|
||||
"""
|
||||
|
||||
def wrapTheFunction(func):
|
||||
@wraps(func)
|
||||
@bp.route('/obtain_download_pt_param', methods=['get'])
|
||||
def wrapped_function_obtain_download_pt_param(*args, **kwargs):
|
||||
data = func(*args, **kwargs)
|
||||
return output_wrapped(0, 'success', data)
|
||||
|
||||
return wrapped_function_obtain_download_pt_param
|
||||
|
||||
return wrapTheFunction
|
||||
|
||||
@bp.route('/change_ifKillDIct', methods=['get'])
|
||||
def change_ifKillDIct():
|
||||
"""
|
||||
修改全局变量
|
||||
"""
|
||||
id = request.args.get('id')
|
||||
type = request.args.get('type')
|
||||
set_value(id, type)
|
||||
return output_wrapped(0, 'success')
|
||||
|
||||
|
||||
# @start_train_algorithm()
|
||||
# def start(param: str):
|
||||
# """
|
||||
# 例子
|
||||
# """
|
||||
# print(param)
|
||||
# process_value_list = ProcessValueList(name='1', value=[])
|
||||
# report = Report(rate_of_progess=0, process_value=[process_value_list], id='1')
|
||||
#
|
||||
# @algorithm_process_value_websocket()
|
||||
# def process(v: int):
|
||||
# print(v)
|
||||
# report.rate_of_progess = ((v + 1) / 10) * 100
|
||||
# report.precision[0].value.append(v)
|
||||
# return report.dict()
|
||||
#
|
||||
# for i in range(10):
|
||||
# process(i)
|
||||
# return report.dict()
|
||||
from setparams import TrainParams
|
||||
import os
|
||||
from app.schemas.TrainResult import DetectProcessValueDice, DetectReport
|
||||
from app import file_tool
|
||||
|
||||
|
||||
def error_return(id: str, data):
|
||||
"""
|
||||
算法出错,返回
|
||||
"""
|
||||
data_res = {'code': 2, "type": 'error', 'msg': 'fail', 'data': data}
|
||||
manager.send_message_proj_json(message=data_res, id=id)
|
||||
|
||||
# 启动训练
|
||||
@start_train_algorithm()
|
||||
def train_R0DY(params_str, id, getsomething):
|
||||
print('**********************************')
|
||||
print(getsomething)
|
||||
print('**********************************')
|
||||
manager.active_connections_dist[id] = getsomething
|
||||
print('**********************************')
|
||||
print(manager.active_connections_dist)
|
||||
print('**********************************')
|
||||
print(params_str)
|
||||
print('**********************************')
|
||||
from app.yolov5.train_server import train_start
|
||||
params = TrainParams()
|
||||
params.read_from_str(params_str)
|
||||
print(params.get('device').value)
|
||||
data_list = file_tool.get_file(ori_path=params.get('DatasetDir').value, type_list=params.get('CLASS_NAMES').value)
|
||||
weights = params.get('resumeModPath').value # 初始化模型绝对路径
|
||||
img_size = params.get('img_size').value
|
||||
savemodel = os.path.splitext(params.get('saveModDir').value)[0] + '_' + str(img_size) + '.pt' # 模型命名加上图像参数
|
||||
epoches = params.get('epochnum').value
|
||||
batch_size = params.get('batch_size').value
|
||||
device = params.get('device').value
|
||||
#try:
|
||||
train_start(weights, savemodel, epoches, img_size, batch_size, device, data_list, id, getsomething)
|
||||
print("train down!")
|
||||
# except Exception as e:
|
||||
# print(repr(e))
|
||||
# error_return(id=id,data=repr(e))
|
||||
|
||||
|
||||
# 启动验证程序
|
||||
|
||||
@start_test_algorithm()
|
||||
def validate_RODY(params_str, id):
|
||||
from app.yolov5.validate_server import validate_start
|
||||
params = TrainParams()
|
||||
params.read_from_str(params_str)
|
||||
weights = params.get('modPath').value # 验证模型绝对路径
|
||||
(filename, extension) = os.path.splitext(weights) # 文件名与后缀名分开
|
||||
img_size = int(filename.split('ROD')[1].split('_')[2]) # 获取图像参数
|
||||
# v_num = int(filename.split('ROD')[1].split('_')[1]) #获取版本号
|
||||
output = params.get('outputPath').value
|
||||
batch_size = params.get('batch_size').default
|
||||
device = params.get('device').value
|
||||
|
||||
validate_start(weights, img_size, batch_size, device, output, id)
|
||||
|
||||
|
||||
@start_detect_algorithm()
|
||||
def detect_RODY(params_str, id):
|
||||
from app.yolov5.detect_server import detect_start
|
||||
params = TrainParams()
|
||||
params.read_from_str(params_str)
|
||||
weights = params.get('modPath').value # 检测模型绝对路径
|
||||
input = params.get('inputPath').value
|
||||
outpath = params.get('outputPath').value
|
||||
# (filename, extension) = os.path.splitext(weights) # 文件名与后缀名分开
|
||||
# img_size = int(filename.split('ROD')[1].split('_')[2]) #获取图像参数
|
||||
# v_num = int(filename.split('ROD')[1].split('_')[1]) #获取版本号
|
||||
# batch_size = params.get('batch_size').default
|
||||
device = params.get('device').value
|
||||
|
||||
detect_start(input, weights, outpath, device, id)
|
||||
|
||||
|
||||
@start_download_pt()
|
||||
def Export_model_RODY(params_str):
|
||||
from app.yolov5.export import Start_Model_Export
|
||||
import zipfile
|
||||
params = TrainParams()
|
||||
params.read_from_str(params_str)
|
||||
exp_inputPath = params.get('exp_inputPath').value # 模型路径
|
||||
print('输入模型:', exp_inputPath)
|
||||
exp_device = params.get('device').value
|
||||
imgsz = params.get('imgsz').value
|
||||
modellist = Start_Model_Export(exp_inputPath, exp_device, imgsz)
|
||||
exp_outputPath = exp_inputPath.replace('pt', 'zip') # 压缩文件
|
||||
print('模型路径:',exp_outputPath)
|
||||
zipf = zipfile.ZipFile(exp_outputPath, 'w')
|
||||
for file in modellist:
|
||||
zipf.write(file, arcname=Path(file).name) # 将torchscript和onnx模型压缩
|
||||
|
||||
return exp_outputPath
|
||||
|
||||
@obtain_train_param()
|
||||
def returnTrainParams():
|
||||
nvmlInit()
|
||||
gpuDeviceCount = nvmlDeviceGetCount() # 获取Nvidia GPU块数
|
||||
_kernel = [f"cuda:{a}" for a in range(gpuDeviceCount)]
|
||||
params_list = [
|
||||
{"index": 0, "name": "epochnum", "value": 10, "description": '训练轮次', "default": 100, "type": "I", 'show': True},
|
||||
{"index": 1, "name": "batch_size", "value": 4, "description": '批次图像数量', "default": 1, "type": "I",
|
||||
'show': True},
|
||||
{"index": 2, "name": "img_size", "value": 640, "description": '训练图像大小', "default": 640, "type": "I",
|
||||
'show': True},
|
||||
{"index": 3, "name": "device", "value": f'{_kernel[0]}', "description": '训练核心', "default": f'{_kernel[0]}', "type": "E",
|
||||
"items": _kernel, 'show': False}, # _kernel
|
||||
{"index": 4, "name": "saveModDir", "value": "E:/alg_demo-master/alg_demo/app/yolov5/best.pt",
|
||||
"description": '保存模型路径',
|
||||
"default": "./app/maskrcnn/saved_model/test.pt", "type": "S", 'show': False},
|
||||
{"index": 5, "name": "resumeModPath", "value": '/yolov5s.pt',
|
||||
"description": '继续训练路径', "default": '', "type": "S",
|
||||
'show': False},
|
||||
{"index": 6, "name": "resumeMod", "value": '', "description": '继续训练模型', "default": '', "type": "E", "items": '',
|
||||
'show': True},
|
||||
{"index": 7, "name": "CLASS_NAMES", "value": ['hole', '456'], "description": '类别名称', "default": '', "type": "L",
|
||||
"items": '',
|
||||
'show': False},
|
||||
{"index": 8, "name": "DatasetDir", "value": "E:/aicheck/data_set/11442136178662604800/ori",
|
||||
"description": '数据集路径',
|
||||
"default": "./app/maskrcnn/datasets/test", "type": "S", 'show': False} # ORI_PATH
|
||||
]
|
||||
# {"index": 9, "name": "saveEpoch", "value": 2, "description": '保存模型轮次', "default": 2, "type": "I", 'show': True}]
|
||||
params_str = json.dumps(params_list)
|
||||
return params_str
|
||||
|
||||
|
||||
@obtain_test_param()
|
||||
def returnValidateParams():
|
||||
# nvmlInit()
|
||||
# gpuDeviceCount = nvmlDeviceGetCount() # 获取Nvidia GPU块数
|
||||
# _kernel = [f"cuda:{a}" for a in range(gpuDeviceCount)]
|
||||
params_list = [
|
||||
{"index": 0, "name": "modPath", "value": "E:/alg_demo-master/alg_demo/app/yolov5/圆孔_123_RODY_1_640.pt",
|
||||
"description": '验证模型路径', "default": "./app/maskrcnn/saved_model/test.pt", "type": "S", 'show': False},
|
||||
{"index": 1, "name": "batch_size", "value": 1, "description": '批次图像数量', "default": 1, "type": "I",
|
||||
'show': False},
|
||||
{"index": 2, "name": "img_size", "value": 640, "description": '训练图像大小', "default": 640, "type": "I",
|
||||
'show': False},
|
||||
{"index": 3, "name": "outputPath", "value": 'E:/aicheck/data_set/11442136178662604800/val_results/',
|
||||
"description": '输出结果路径',
|
||||
"default": './app/maskrcnn/datasets/M006B_waibi/res', "type": "S", 'show': False},
|
||||
{"index": 4, "name": "device", "value": "0", "description": '训练核心', "default": "cuda", "type": "S",
|
||||
"items": '', 'show': False} # _kernel
|
||||
]
|
||||
# {"index": 9, "name": "saveEpoch", "value": 2, "description": '保存模型轮次', "default": 2, "type": "I", 'show': True}]
|
||||
params_str = json.dumps(params_list)
|
||||
return params_str
|
||||
|
||||
|
||||
@obtain_detect_param()
|
||||
def returnDetectParams():
|
||||
# nvmlInit()
|
||||
# gpuDeviceCount = nvmlDeviceGetCount() # 获取Nvidia GPU块数
|
||||
# _kernel = [f"cuda:{a}" for a in range(gpuDeviceCount)]
|
||||
params_list = [
|
||||
{"index": 0, "name": "inputPath", "value": 'E:/aicheck/data_set/11442136178662604800/input/',
|
||||
"description": '输入图像路径', "default": './app/maskrcnn/datasets/M006B_waibi/JPEGImages', "type": "S",
|
||||
'show': False},
|
||||
{"index": 1, "name": "outputPath", "value": 'E:/aicheck/data_set/11442136178662604800/val_results/',
|
||||
"description": '输出结果路径',
|
||||
"default": './app/maskrcnn/datasets/M006B_waibi/res', "type": "S", 'show': False},
|
||||
{"index": 2, "name": "modPath", "value": "E:/alg_demo-master/alg_demo/app/yolov5/圆孔_123_RODY_1_640.pt",
|
||||
"description": '模型路径', "default": "./app/maskrcnn/saved_model/test.pt", "type": "S", 'show': False},
|
||||
{"index": 3, "name": "device", "value": "0", "description": '推理核', "default": "cpu", "type": "S",
|
||||
'show': False},
|
||||
]
|
||||
# {"index": 9, "name": "saveEpoch", "value": 2, "description": '保存模型轮次', "default": 2, "type": "I", 'show': True}]
|
||||
params_str = json.dumps(params_list)
|
||||
return params_str
|
||||
|
||||
|
||||
@obtain_download_pt_param()
|
||||
def returnDownloadParams():
|
||||
params_list = [
|
||||
{"index": 0, "name": "exp_inputPath", "value": 'E:/alg_demo-master/alg_demo/app/yolov5/圆孔_123_RODY_1_640.pt',
|
||||
"description": '转化模型输入路径',
|
||||
"default": 'E:/alg_demo-master/alg_demo/app/yolov5/圆孔_123_RODY_1_640.pt/',
|
||||
"type": "S", 'show': False},
|
||||
{"index": 1, "name": "device", "value": 'gpu', "description": 'CPU或GPU', "default": 'gpu', "type": "S",
|
||||
'show': False},
|
||||
{"index": 2, "name": "imgsz", "value": 640, "description": '图像大小', "default": 640, "type": "I",
|
||||
'show': True}
|
||||
]
|
||||
params_str = json.dumps(params_list)
|
||||
return params_str
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
par = returnTrainParams()
|
||||
print(par)
|
||||
id='1'
|
||||
train_R0DY(par,id)
|
||||
"""
|
||||
@Time : 2022/9/20 16:17
|
||||
@Auth : 东
|
||||
@File :AlgorithmController.py
|
||||
@IDE :PyCharm
|
||||
@Motto:ABC(Always Be Coding)
|
||||
@Desc:算法接口
|
||||
|
||||
"""
|
||||
import json
|
||||
from functools import wraps
|
||||
from threading import Thread
|
||||
from multiprocessing import Process
|
||||
from time import sleep
|
||||
|
||||
from flask import Blueprint, request
|
||||
|
||||
from app.schemas.TrainResult import Report, ProcessValueList
|
||||
from app.utils.RedisMQTool import Task
|
||||
from app.utils.StandardizedOutput import output_wrapped
|
||||
from app.utils.redis_config import redis_client
|
||||
from app.utils.websocket_tool import manager
|
||||
from app.configs.global_var import set_value
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from pynvml import *
|
||||
# FILE = Path(__file__).resolve()
|
||||
# ROOT = FILE.parents[0] # YOLOv5 root directory
|
||||
# if str(ROOT) not in sys.path:
|
||||
# sys.path.append(str(ROOT)) # add ROOT to PATH
|
||||
# sys.path.append("/mnt/sdc/algorithm/AICheck-MaskRCNN/app/maskrcnn_ppx")
|
||||
# import ppx as pdx
|
||||
|
||||
bp = Blueprint('AlgorithmController', __name__)
|
||||
|
||||
ifKillDict = {}
|
||||
|
||||
def start_train_algorithm():
|
||||
"""
|
||||
调用训练算法
|
||||
"""
|
||||
|
||||
def wrapTheFunction(func):
|
||||
@wraps(func)
|
||||
@bp.route('/start_train_algorithm', methods=['get'])
|
||||
def wrapped_function():
|
||||
param = request.args.get('param')
|
||||
id = request.args.get('id')
|
||||
dict = manager.active_connections_dist
|
||||
# t = Thread(target=func, args=(param, id))
|
||||
t = Process(target=func, args=(param, id, dict[id]), name=id)
|
||||
set_value(key=id, value=False)
|
||||
t.start()
|
||||
return output_wrapped(0, 'success', '成功')
|
||||
|
||||
return wrapped_function
|
||||
|
||||
return wrapTheFunction
|
||||
|
||||
|
||||
def start_test_algorithm():
|
||||
"""
|
||||
调用验证算法
|
||||
"""
|
||||
|
||||
def wrapTheFunction(func):
|
||||
@wraps(func)
|
||||
@bp.route('/start_test_algorithm', methods=['get'])
|
||||
def wrapped_function_test():
|
||||
param = request.args.get('param')
|
||||
id = request.args.get('id')
|
||||
t = Thread(target=func, args=(param, id))
|
||||
t.start()
|
||||
return output_wrapped(0, 'success', '成功')
|
||||
|
||||
return wrapped_function_test
|
||||
|
||||
return wrapTheFunction
|
||||
|
||||
|
||||
def start_detect_algorithm():
|
||||
"""
|
||||
调用检测算法
|
||||
"""
|
||||
|
||||
def wrapTheFunction(func):
|
||||
@wraps(func)
|
||||
@bp.route('/start_detect_algorithm', methods=['get'])
|
||||
def wrapped_function_detect():
|
||||
param = request.args.get('param')
|
||||
id = request.args.get('id')
|
||||
t = Thread(target=func, args=(param, id))
|
||||
t.start()
|
||||
return output_wrapped(0, 'success', '成功')
|
||||
|
||||
return wrapped_function_detect
|
||||
|
||||
return wrapTheFunction
|
||||
|
||||
|
||||
def start_download_pt():
|
||||
"""
|
||||
下载模型
|
||||
"""
|
||||
|
||||
def wrapTheFunction(func):
|
||||
@wraps(func)
|
||||
@bp.route('/start_download_pt', methods=['get'])
|
||||
def wrapped_function_start_download_pt():
|
||||
param = request.args.get('param')
|
||||
data = func(param)
|
||||
return output_wrapped(0, 'success', data)
|
||||
|
||||
return wrapped_function_start_download_pt
|
||||
|
||||
return wrapTheFunction
|
||||
|
||||
|
||||
def algorithm_process_value():
|
||||
"""
|
||||
获取中间值, redis订阅发布
|
||||
"""
|
||||
|
||||
def wrapTheFunction(func):
|
||||
@wraps(func)
|
||||
def wrapped_function(*args, **kwargs):
|
||||
data = func(*args, **kwargs)
|
||||
print(data)
|
||||
Task(redis_conn=redis_client.get_redis(), channel="ceshi").publish_task(
|
||||
data={'code': 0, 'msg': 'success', 'data': data})
|
||||
return output_wrapped(0, 'success', data)
|
||||
|
||||
return wrapped_function
|
||||
|
||||
return wrapTheFunction
|
||||
|
||||
|
||||
def algorithm_process_value_websocket():
|
||||
"""
|
||||
获取中间值, websocket发布
|
||||
"""
|
||||
|
||||
def wrapTheFunction(func):
|
||||
@wraps(func)
|
||||
def wrapped_function(*args, **kwargs):
|
||||
data = func(*args, **kwargs)
|
||||
id = data["id"]
|
||||
data_res = {'code': 0, "type": 'connected', 'msg': 'success', 'data': data}
|
||||
manager.send_message_proj_json(message=data_res, id=id)
|
||||
return data
|
||||
|
||||
return wrapped_function
|
||||
|
||||
return wrapTheFunction
|
||||
|
||||
def algorithm_kill_value_websocket():
|
||||
"""
|
||||
获取kill值, websocket发布
|
||||
"""
|
||||
|
||||
def wrapTheFunction(func):
|
||||
@wraps(func)
|
||||
def wrapped_function(*args, **kwargs):
|
||||
data = func(*args, **kwargs)
|
||||
id = data["id"]
|
||||
data_res = {'code': 1, "type": 'kill', 'msg': 'success', 'data': data}
|
||||
manager.send_message_proj_json(message=data_res, id=id)
|
||||
return data
|
||||
|
||||
return wrapped_function
|
||||
|
||||
return wrapTheFunction
|
||||
|
||||
|
||||
def algorithm_error_value_websocket():
|
||||
"""
|
||||
获取error值, websocket发布
|
||||
"""
|
||||
|
||||
def wrapTheFunction(func):
|
||||
@wraps(func)
|
||||
def wrapped_function(*args, **kwargs):
|
||||
data = func(*args, **kwargs)
|
||||
id = data["id"]
|
||||
data_res = {'code': 2, "type": 'error', 'msg': 'fail', 'data': data}
|
||||
manager.send_message_proj_json(message=data_res, id=id)
|
||||
return data
|
||||
|
||||
return wrapped_function
|
||||
|
||||
return wrapTheFunction
|
||||
|
||||
def obtain_train_param():
|
||||
"""
|
||||
获取训练参数
|
||||
"""
|
||||
|
||||
def wrapTheFunction(func):
|
||||
@wraps(func)
|
||||
@bp.route('/obtain_train_param', methods=['get'])
|
||||
def wrapped_function_train_param(*args, **kwargs):
|
||||
data = func(*args, **kwargs)
|
||||
return output_wrapped(0, 'success', data)
|
||||
|
||||
return wrapped_function_train_param
|
||||
|
||||
return wrapTheFunction
|
||||
|
||||
def obtain_test_param():
|
||||
"""
|
||||
获取验证参数
|
||||
"""
|
||||
|
||||
def wrapTheFunction(func):
|
||||
@wraps(func)
|
||||
@bp.route('/obtain_test_param', methods=['get'])
|
||||
def wrapped_function_test_param(*args, **kwargs):
|
||||
data = func(*args, **kwargs)
|
||||
return output_wrapped(0, 'success', data)
|
||||
|
||||
return wrapped_function_test_param
|
||||
|
||||
return wrapTheFunction
|
||||
|
||||
|
||||
def obtain_detect_param():
|
||||
"""
|
||||
获取测试参数
|
||||
"""
|
||||
|
||||
def wrapTheFunction(func):
|
||||
@wraps(func)
|
||||
@bp.route('/obtain_detect_param', methods=['get'])
|
||||
def wrapped_function_inf_param(*args, **kwargs):
|
||||
data = func(*args, **kwargs)
|
||||
return output_wrapped(0, 'success', data)
|
||||
|
||||
return wrapped_function_inf_param
|
||||
|
||||
return wrapTheFunction
|
||||
|
||||
|
||||
def obtain_download_pt_param():
|
||||
"""
|
||||
获取下载模型参数
|
||||
"""
|
||||
|
||||
def wrapTheFunction(func):
|
||||
@wraps(func)
|
||||
@bp.route('/obtain_download_pt_param', methods=['get'])
|
||||
def wrapped_function_obtain_download_pt_param(*args, **kwargs):
|
||||
data = func(*args, **kwargs)
|
||||
return output_wrapped(0, 'success', data)
|
||||
|
||||
return wrapped_function_obtain_download_pt_param
|
||||
|
||||
return wrapTheFunction
|
||||
|
||||
@bp.route('/change_ifKillDIct', methods=['get'])
|
||||
def change_ifKillDIct():
|
||||
"""
|
||||
修改全局变量
|
||||
"""
|
||||
id = request.args.get('id')
|
||||
type = request.args.get('type')
|
||||
set_value(id, type)
|
||||
return output_wrapped(0, 'success')
|
||||
|
||||
|
||||
# @start_train_algorithm()
|
||||
# def start(param: str):
|
||||
# """
|
||||
# 例子
|
||||
# """
|
||||
# print(param)
|
||||
# process_value_list = ProcessValueList(name='1', value=[])
|
||||
# report = Report(rate_of_progess=0, process_value=[process_value_list], id='1')
|
||||
#
|
||||
# @algorithm_process_value_websocket()
|
||||
# def process(v: int):
|
||||
# print(v)
|
||||
# report.rate_of_progess = ((v + 1) / 10) * 100
|
||||
# report.precision[0].value.append(v)
|
||||
# return report.dict()
|
||||
#
|
||||
# for i in range(10):
|
||||
# process(i)
|
||||
# return report.dict()
|
||||
from setparams import TrainParams
|
||||
import os
|
||||
from app.schemas.TrainResult import DetectProcessValueDice, DetectReport
|
||||
from app import file_tool
|
||||
|
||||
|
||||
def error_return(id: str, data):
|
||||
"""
|
||||
算法出错,返回
|
||||
"""
|
||||
data_res = {'code': 2, "type": 'error', 'msg': 'fail', 'data': data}
|
||||
manager.send_message_proj_json(message=data_res, id=id)
|
||||
|
||||
# 启动训练
|
||||
@start_train_algorithm()
|
||||
def train_R0DY(params_str, id, getsomething):
|
||||
print('**********************************')
|
||||
print(getsomething)
|
||||
print('**********************************')
|
||||
manager.active_connections_dist[id] = getsomething
|
||||
print('**********************************')
|
||||
print(manager.active_connections_dist)
|
||||
print('**********************************')
|
||||
print(params_str)
|
||||
print('**********************************')
|
||||
from app.yolov5.train_server import train_start
|
||||
params = TrainParams()
|
||||
params.read_from_str(params_str)
|
||||
print(params.get('device').value)
|
||||
data_list = file_tool.get_file(ori_path=params.get('DatasetDir').value, type_list=params.get('CLASS_NAMES').value)
|
||||
weights = params.get('resumeModPath').value # 初始化模型绝对路径
|
||||
img_size = params.get('img_size').value
|
||||
savemodel = os.path.splitext(params.get('saveModDir').value)[0] + '_' + str(img_size) + '.pt' # 模型命名加上图像参数
|
||||
epoches = params.get('epochnum').value
|
||||
batch_size = params.get('batch_size').value
|
||||
device = params.get('device').value
|
||||
#try:
|
||||
train_start(weights, savemodel, epoches, img_size, batch_size, device, data_list, id, getsomething)
|
||||
print("train down!")
|
||||
# except Exception as e:
|
||||
# print(repr(e))
|
||||
# error_return(id=id,data=repr(e))
|
||||
|
||||
|
||||
# 启动验证程序
|
||||
|
||||
@start_test_algorithm()
|
||||
def validate_RODY(params_str, id):
|
||||
from app.yolov5.validate_server import validate_start
|
||||
params = TrainParams()
|
||||
params.read_from_str(params_str)
|
||||
weights = params.get('modPath').value # 验证模型绝对路径
|
||||
(filename, extension) = os.path.splitext(weights) # 文件名与后缀名分开
|
||||
img_size = int(filename.split('ROD')[1].split('_')[2]) # 获取图像参数
|
||||
# v_num = int(filename.split('ROD')[1].split('_')[1]) #获取版本号
|
||||
output = params.get('outputPath').value
|
||||
batch_size = params.get('batch_size').default
|
||||
device = params.get('device').value
|
||||
|
||||
validate_start(weights, img_size, batch_size, device, output, id)
|
||||
|
||||
|
||||
@start_detect_algorithm()
|
||||
def detect_RODY(params_str, id):
|
||||
from app.yolov5.detect_server import detect_start
|
||||
params = TrainParams()
|
||||
params.read_from_str(params_str)
|
||||
weights = params.get('modPath').value # 检测模型绝对路径
|
||||
input = params.get('inputPath').value
|
||||
outpath = params.get('outputPath').value
|
||||
# (filename, extension) = os.path.splitext(weights) # 文件名与后缀名分开
|
||||
# img_size = int(filename.split('ROD')[1].split('_')[2]) #获取图像参数
|
||||
# v_num = int(filename.split('ROD')[1].split('_')[1]) #获取版本号
|
||||
# batch_size = params.get('batch_size').default
|
||||
device = params.get('device').value
|
||||
|
||||
detect_start(input, weights, outpath, device, id)
|
||||
|
||||
|
||||
@start_download_pt()
|
||||
def Export_model_RODY(params_str):
|
||||
from app.yolov5.export import Start_Model_Export
|
||||
import zipfile
|
||||
params = TrainParams()
|
||||
params.read_from_str(params_str)
|
||||
exp_inputPath = params.get('exp_inputPath').value # 模型路径
|
||||
print('输入模型:', exp_inputPath)
|
||||
exp_device = params.get('device').value
|
||||
imgsz = params.get('imgsz').value
|
||||
modellist = Start_Model_Export(exp_inputPath, exp_device, imgsz)
|
||||
exp_outputPath = exp_inputPath.replace('pt', 'zip') # 压缩文件
|
||||
print('模型路径:',exp_outputPath)
|
||||
zipf = zipfile.ZipFile(exp_outputPath, 'w')
|
||||
for file in modellist:
|
||||
zipf.write(file, arcname=Path(file).name) # 将torchscript和onnx模型压缩
|
||||
|
||||
return exp_outputPath
|
||||
|
||||
@obtain_train_param()
|
||||
def returnTrainParams():
|
||||
nvmlInit()
|
||||
gpuDeviceCount = nvmlDeviceGetCount() # 获取Nvidia GPU块数
|
||||
_kernel = [f"cuda:{a}" for a in range(gpuDeviceCount)]
|
||||
params_list = [
|
||||
{"index": 0, "name": "epochnum", "value": 10, "description": '训练轮次', "default": 100, "type": "I", 'show': True},
|
||||
{"index": 1, "name": "batch_size", "value": 4, "description": '批次图像数量', "default": 1, "type": "I",
|
||||
'show': True},
|
||||
{"index": 2, "name": "img_size", "value": 640, "description": '训练图像大小', "default": 640, "type": "I",
|
||||
'show': True},
|
||||
{"index": 3, "name": "device", "value": f'{_kernel[0]}', "description": '训练核心', "default": f'{_kernel[0]}', "type": "E",
|
||||
"items": _kernel, 'show': False}, # _kernel
|
||||
{"index": 4, "name": "saveModDir", "value": "E:/alg_demo-master/alg_demo/app/yolov5/best.pt",
|
||||
"description": '保存模型路径',
|
||||
"default": "./app/maskrcnn/saved_model/test.pt", "type": "S", 'show': False},
|
||||
{"index": 5, "name": "resumeModPath", "value": '/yolov5s.pt',
|
||||
"description": '继续训练路径', "default": '', "type": "S",
|
||||
'show': False},
|
||||
{"index": 6, "name": "resumeMod", "value": '', "description": '继续训练模型', "default": '', "type": "E", "items": '',
|
||||
'show': True},
|
||||
{"index": 7, "name": "CLASS_NAMES", "value": ['hole', '456'], "description": '类别名称', "default": '', "type": "L",
|
||||
"items": '',
|
||||
'show': False},
|
||||
{"index": 8, "name": "DatasetDir", "value": "E:/aicheck/data_set/11442136178662604800/ori",
|
||||
"description": '数据集路径',
|
||||
"default": "./app/maskrcnn/datasets/test", "type": "S", 'show': False} # ORI_PATH
|
||||
]
|
||||
# {"index": 9, "name": "saveEpoch", "value": 2, "description": '保存模型轮次', "default": 2, "type": "I", 'show': True}]
|
||||
params_str = json.dumps(params_list)
|
||||
return params_str
|
||||
|
||||
|
||||
@obtain_test_param()
|
||||
def returnValidateParams():
|
||||
# nvmlInit()
|
||||
# gpuDeviceCount = nvmlDeviceGetCount() # 获取Nvidia GPU块数
|
||||
# _kernel = [f"cuda:{a}" for a in range(gpuDeviceCount)]
|
||||
params_list = [
|
||||
{"index": 0, "name": "modPath", "value": "E:/alg_demo-master/alg_demo/app/yolov5/圆孔_123_RODY_1_640.pt",
|
||||
"description": '验证模型路径', "default": "./app/maskrcnn/saved_model/test.pt", "type": "S", 'show': False},
|
||||
{"index": 1, "name": "batch_size", "value": 1, "description": '批次图像数量', "default": 1, "type": "I",
|
||||
'show': False},
|
||||
{"index": 2, "name": "img_size", "value": 640, "description": '训练图像大小', "default": 640, "type": "I",
|
||||
'show': False},
|
||||
{"index": 3, "name": "outputPath", "value": 'E:/aicheck/data_set/11442136178662604800/val_results/',
|
||||
"description": '输出结果路径',
|
||||
"default": './app/maskrcnn/datasets/M006B_waibi/res', "type": "S", 'show': False},
|
||||
{"index": 4, "name": "device", "value": "0", "description": '训练核心', "default": "cuda", "type": "S",
|
||||
"items": '', 'show': False} # _kernel
|
||||
]
|
||||
# {"index": 9, "name": "saveEpoch", "value": 2, "description": '保存模型轮次', "default": 2, "type": "I", 'show': True}]
|
||||
params_str = json.dumps(params_list)
|
||||
return params_str
|
||||
|
||||
|
||||
@obtain_detect_param()
|
||||
def returnDetectParams():
|
||||
# nvmlInit()
|
||||
# gpuDeviceCount = nvmlDeviceGetCount() # 获取Nvidia GPU块数
|
||||
# _kernel = [f"cuda:{a}" for a in range(gpuDeviceCount)]
|
||||
params_list = [
|
||||
{"index": 0, "name": "inputPath", "value": 'E:/aicheck/data_set/11442136178662604800/input/',
|
||||
"description": '输入图像路径', "default": './app/maskrcnn/datasets/M006B_waibi/JPEGImages', "type": "S",
|
||||
'show': False},
|
||||
{"index": 1, "name": "outputPath", "value": 'E:/aicheck/data_set/11442136178662604800/val_results/',
|
||||
"description": '输出结果路径',
|
||||
"default": './app/maskrcnn/datasets/M006B_waibi/res', "type": "S", 'show': False},
|
||||
{"index": 2, "name": "modPath", "value": "E:/alg_demo-master/alg_demo/app/yolov5/圆孔_123_RODY_1_640.pt",
|
||||
"description": '模型路径', "default": "./app/maskrcnn/saved_model/test.pt", "type": "S", 'show': False},
|
||||
{"index": 3, "name": "device", "value": "0", "description": '推理核', "default": "cpu", "type": "S",
|
||||
'show': False},
|
||||
]
|
||||
# {"index": 9, "name": "saveEpoch", "value": 2, "description": '保存模型轮次', "default": 2, "type": "I", 'show': True}]
|
||||
params_str = json.dumps(params_list)
|
||||
return params_str
|
||||
|
||||
|
||||
@obtain_download_pt_param()
|
||||
def returnDownloadParams():
|
||||
params_list = [
|
||||
{"index": 0, "name": "exp_inputPath", "value": 'E:/alg_demo-master/alg_demo/app/yolov5/圆孔_123_RODY_1_640.pt',
|
||||
"description": '转化模型输入路径',
|
||||
"default": 'E:/alg_demo-master/alg_demo/app/yolov5/圆孔_123_RODY_1_640.pt/',
|
||||
"type": "S", 'show': False},
|
||||
{"index": 1, "name": "device", "value": 'gpu', "description": 'CPU或GPU', "default": 'gpu', "type": "S",
|
||||
'show': False},
|
||||
{"index": 2, "name": "imgsz", "value": 640, "description": '图像大小', "default": 640, "type": "I",
|
||||
'show': True}
|
||||
]
|
||||
params_str = json.dumps(params_list)
|
||||
return params_str
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
par = returnTrainParams()
|
||||
print(par)
|
||||
id='1'
|
||||
train_R0DY(par,id)
|
||||
|
@ -1,33 +1,33 @@
|
||||
import logging
|
||||
|
||||
from flask import Blueprint, app
|
||||
|
||||
from app.exts import redisClient
|
||||
from app.utils.StandardizedOutput import output_wrapped
|
||||
|
||||
bp = Blueprint('WebStatus', __name__)
|
||||
|
||||
|
||||
@bp.route('/ping', methods=['GET'])
|
||||
def ping():
|
||||
""" For health check.
|
||||
"""
|
||||
res = output_wrapped(0, 'pong', '')
|
||||
return res
|
||||
|
||||
|
||||
@bp.route('/redis/set', methods=['post'])
|
||||
def redis_set():
|
||||
redisClient.set('foo', 'bar', ex=60*60*6)
|
||||
res = output_wrapped(0, 'set foo', '')
|
||||
return res
|
||||
|
||||
|
||||
@bp.route('/redis/get', methods=['get'])
|
||||
def redis_get():
|
||||
""" For health check.
|
||||
"""
|
||||
the_food = redisClient.get('foo')
|
||||
if not the_food:
|
||||
return output_wrapped(5006, 'foo', "")
|
||||
return output_wrapped(0, 'foo', the_food.decode("utf-8"))
|
||||
import logging
|
||||
|
||||
from flask import Blueprint, app
|
||||
|
||||
from app.exts import redisClient
|
||||
from app.utils.StandardizedOutput import output_wrapped
|
||||
|
||||
bp = Blueprint('WebStatus', __name__)
|
||||
|
||||
|
||||
@bp.route('/ping', methods=['GET'])
|
||||
def ping():
|
||||
""" For health check.
|
||||
"""
|
||||
res = output_wrapped(0, 'pong', '')
|
||||
return res
|
||||
|
||||
|
||||
@bp.route('/redis/set', methods=['post'])
|
||||
def redis_set():
|
||||
redisClient.set('foo', 'bar', ex=60*60*6)
|
||||
res = output_wrapped(0, 'set foo', '')
|
||||
return res
|
||||
|
||||
|
||||
@bp.route('/redis/get', methods=['get'])
|
||||
def redis_get():
|
||||
""" For health check.
|
||||
"""
|
||||
the_food = redisClient.get('foo')
|
||||
if not the_food:
|
||||
return output_wrapped(5006, 'foo', "")
|
||||
return output_wrapped(0, 'foo', the_food.decode("utf-8"))
|
||||
|
@ -1,4 +1,4 @@
|
||||
from app.core.common_utils import import_subs
|
||||
|
||||
|
||||
__all__ = import_subs(locals(), modules_only=True)
|
||||
from app.core.common_utils import import_subs
|
||||
|
||||
|
||||
__all__ = import_subs(locals(), modules_only=True)
|
||||
|
Binary file not shown.
Reference in New Issue
Block a user