RODY/app/controller/AlgorithmController.py
2022-11-29 16:14:09 +08:00

487 lines
17 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
@Time 2022/9/20 16:17
@Auth
@File AlgorithmController.py
@IDE PyCharm
@MottoABC(Always Be Coding)
@Desc算法接口
"""
import json
from functools import wraps
from threading import Thread
from multiprocessing import Process
from time import sleep
from flask import Blueprint, request
from app.schemas.TrainResult import Report, ProcessValueList
from app.utils.RedisMQTool import Task
from app.utils.StandardizedOutput import output_wrapped
from app.utils.redis_config import redis_client
from app.utils.websocket_tool import manager
from app.configs.global_var import set_value
import sys
from pathlib import Path
from pynvml import *
# FILE = Path(__file__).resolve()
# ROOT = FILE.parents[0] # YOLOv5 root directory
# if str(ROOT) not in sys.path:
# sys.path.append(str(ROOT)) # add ROOT to PATH
# sys.path.append("/mnt/sdc/algorithm/AICheck-MaskRCNN/app/maskrcnn_ppx")
# import ppx as pdx
bp = Blueprint('AlgorithmController', __name__)
ifKillDict = {}
def start_train_algorithm():
"""
调用训练算法
"""
def wrapTheFunction(func):
@wraps(func)
@bp.route('/start_train_algorithm', methods=['get'])
def wrapped_function():
param = request.args.get('param')
id = request.args.get('id')
dict = manager.active_connections_dist
# t = Thread(target=func, args=(param, id))
t = Process(target=func, args=(param, id, dict[id]), name=id)
set_value(key=id, value=False)
t.start()
return output_wrapped(0, 'success', '成功')
return wrapped_function
return wrapTheFunction
def start_test_algorithm():
"""
调用验证算法
"""
def wrapTheFunction(func):
@wraps(func)
@bp.route('/start_test_algorithm', methods=['get'])
def wrapped_function_test():
param = request.args.get('param')
id = request.args.get('id')
t = Thread(target=func, args=(param, id))
t.start()
return output_wrapped(0, 'success', '成功')
return wrapped_function_test
return wrapTheFunction
def start_detect_algorithm():
"""
调用检测算法
"""
def wrapTheFunction(func):
@wraps(func)
@bp.route('/start_detect_algorithm', methods=['get'])
def wrapped_function_detect():
param = request.args.get('param')
id = request.args.get('id')
t = Thread(target=func, args=(param, id))
t.start()
return output_wrapped(0, 'success', '成功')
return wrapped_function_detect
return wrapTheFunction
def start_download_pt():
"""
下载模型
"""
def wrapTheFunction(func):
@wraps(func)
@bp.route('/start_download_pt', methods=['get'])
def wrapped_function_start_download_pt():
param = request.args.get('param')
data = func(param)
return output_wrapped(0, 'success', data)
return wrapped_function_start_download_pt
return wrapTheFunction
def algorithm_process_value():
"""
获取中间值, redis订阅发布
"""
def wrapTheFunction(func):
@wraps(func)
def wrapped_function(*args, **kwargs):
data = func(*args, **kwargs)
print(data)
Task(redis_conn=redis_client.get_redis(), channel="ceshi").publish_task(
data={'code': 0, 'msg': 'success', 'data': data})
return output_wrapped(0, 'success', data)
return wrapped_function
return wrapTheFunction
def algorithm_process_value_websocket():
"""
获取中间值, websocket发布
"""
def wrapTheFunction(func):
@wraps(func)
def wrapped_function(*args, **kwargs):
data = func(*args, **kwargs)
id = data["id"]
data_res = {'code': 0, "type": 'connected', 'msg': 'success', 'data': data}
manager.send_message_proj_json(message=data_res, id=id)
return data
return wrapped_function
return wrapTheFunction
def algorithm_kill_value_websocket():
"""
获取kill值, websocket发布
"""
def wrapTheFunction(func):
@wraps(func)
def wrapped_function(*args, **kwargs):
data = func(*args, **kwargs)
id = data["id"]
data_res = {'code': 1, "type": 'kill', 'msg': 'success', 'data': data}
manager.send_message_proj_json(message=data_res, id=id)
return data
return wrapped_function
return wrapTheFunction
def algorithm_error_value_websocket():
"""
获取error值, websocket发布
"""
def wrapTheFunction(func):
@wraps(func)
def wrapped_function(*args, **kwargs):
data = func(*args, **kwargs)
id = data["id"]
data_res = {'code': 2, "type": 'error', 'msg': 'fail', 'data': data}
manager.send_message_proj_json(message=data_res, id=id)
return data
return wrapped_function
return wrapTheFunction
def obtain_train_param():
"""
获取训练参数
"""
def wrapTheFunction(func):
@wraps(func)
@bp.route('/obtain_train_param', methods=['get'])
def wrapped_function_train_param(*args, **kwargs):
data = func(*args, **kwargs)
return output_wrapped(0, 'success', data)
return wrapped_function_train_param
return wrapTheFunction
def obtain_test_param():
"""
获取验证参数
"""
def wrapTheFunction(func):
@wraps(func)
@bp.route('/obtain_test_param', methods=['get'])
def wrapped_function_test_param(*args, **kwargs):
data = func(*args, **kwargs)
return output_wrapped(0, 'success', data)
return wrapped_function_test_param
return wrapTheFunction
def obtain_detect_param():
"""
获取测试参数
"""
def wrapTheFunction(func):
@wraps(func)
@bp.route('/obtain_detect_param', methods=['get'])
def wrapped_function_inf_param(*args, **kwargs):
data = func(*args, **kwargs)
return output_wrapped(0, 'success', data)
return wrapped_function_inf_param
return wrapTheFunction
def obtain_download_pt_param():
"""
获取下载模型参数
"""
def wrapTheFunction(func):
@wraps(func)
@bp.route('/obtain_download_pt_param', methods=['get'])
def wrapped_function_obtain_download_pt_param(*args, **kwargs):
data = func(*args, **kwargs)
return output_wrapped(0, 'success', data)
return wrapped_function_obtain_download_pt_param
return wrapTheFunction
@bp.route('/change_ifKillDIct', methods=['get'])
def change_ifKillDIct():
"""
修改全局变量
"""
id = request.args.get('id')
type = request.args.get('type')
set_value(id, type)
return output_wrapped(0, 'success')
# @start_train_algorithm()
# def start(param: str):
# """
# 例子
# """
# print(param)
# process_value_list = ProcessValueList(name='1', value=[])
# report = Report(rate_of_progess=0, process_value=[process_value_list], id='1')
#
# @algorithm_process_value_websocket()
# def process(v: int):
# print(v)
# report.rate_of_progess = ((v + 1) / 10) * 100
# report.precision[0].value.append(v)
# return report.dict()
#
# for i in range(10):
# process(i)
# return report.dict()
from setparams import TrainParams
import os
from app.schemas.TrainResult import DetectProcessValueDice, DetectReport
from app import file_tool
def error_return(id: str, data):
"""
算法出错,返回
"""
data_res = {'code': 2, "type": 'error', 'msg': 'fail', 'data': data}
manager.send_message_proj_json(message=data_res, id=id)
# 启动训练
@start_train_algorithm()
def train_R0DY(params_str, id, getsomething):
print('**********************************')
print(getsomething)
print('**********************************')
manager.active_connections_dist[id] = getsomething
print('**********************************')
print(manager.active_connections_dist)
print('**********************************')
print(params_str)
print('**********************************')
from app.yolov5.train_server import train_start
params = TrainParams()
params.read_from_str(params_str)
print(params.get('device').value)
data_list = file_tool.get_file(ori_path=params.get('DatasetDir').value, type_list=params.get('CLASS_NAMES').value)
weights = params.get('resumeModPath').value # 初始化模型绝对路径
img_size = params.get('img_size').value
savemodel = os.path.splitext(params.get('saveModDir').value)[0] + '_' + str(img_size) + '.pt' # 模型命名加上图像参数
epoches = params.get('epochnum').value
batch_size = params.get('batch_size').value
device = params.get('device').value
#try:
train_start(weights, savemodel, epoches, img_size, batch_size, device, data_list, id, getsomething)
print("train down!")
# except Exception as e:
# print(repr(e))
# error_return(id=id,data=repr(e))
# 启动验证程序
@start_test_algorithm()
def validate_RODY(params_str, id):
from app.yolov5.validate_server import validate_start
params = TrainParams()
params.read_from_str(params_str)
weights = params.get('modPath').value # 验证模型绝对路径
(filename, extension) = os.path.splitext(weights) # 文件名与后缀名分开
img_size = int(filename.split('ROD')[1].split('_')[2]) # 获取图像参数
# v_num = int(filename.split('ROD')[1].split('_')[1]) #获取版本号
output = params.get('outputPath').value
batch_size = params.get('batch_size').default
device = params.get('device').value
validate_start(weights, img_size, batch_size, device, output, id)
@start_detect_algorithm()
def detect_RODY(params_str, id):
from app.yolov5.detect_server import detect_start
params = TrainParams()
params.read_from_str(params_str)
weights = params.get('modPath').value # 检测模型绝对路径
input = params.get('inputPath').value
outpath = params.get('outputPath').value
# (filename, extension) = os.path.splitext(weights) # 文件名与后缀名分开
# img_size = int(filename.split('ROD')[1].split('_')[2]) #获取图像参数
# v_num = int(filename.split('ROD')[1].split('_')[1]) #获取版本号
# batch_size = params.get('batch_size').default
device = params.get('device').value
detect_start(input, weights, outpath, device, id)
@start_download_pt()
def Export_model_RODY(params_str):
from app.yolov5.export import Start_Model_Export
import zipfile
params = TrainParams()
params.read_from_str(params_str)
exp_inputPath = params.get('exp_inputPath').value # 模型路径
print('输入模型:', exp_inputPath)
exp_device = params.get('device').value
imgsz = params.get('imgsz').value
modellist = Start_Model_Export(exp_inputPath, exp_device, imgsz)
exp_outputPath = exp_inputPath.replace('pt', 'zip') # 压缩文件
print('模型路径:',exp_outputPath)
zipf = zipfile.ZipFile(exp_outputPath, 'w')
for file in modellist:
zipf.write(file, arcname=Path(file).name) # 将torchscript和onnx模型压缩
return exp_outputPath
@obtain_train_param()
def returnTrainParams():
nvmlInit()
gpuDeviceCount = nvmlDeviceGetCount() # 获取Nvidia GPU块数
_kernel = [f"cuda:{a}" for a in range(gpuDeviceCount)]
params_list = [
{"index": 0, "name": "epochnum", "value": 10, "description": '训练轮次', "default": 100, "type": "I", 'show': True},
{"index": 1, "name": "batch_size", "value": 4, "description": '批次图像数量', "default": 1, "type": "I",
'show': True},
{"index": 2, "name": "img_size", "value": 640, "description": '训练图像大小', "default": 640, "type": "I",
'show': True},
{"index": 3, "name": "device", "value": f'{_kernel[0]}', "description": '训练核心', "default": f'{_kernel[0]}', "type": "E",
"items": _kernel, 'show': False}, # _kernel
{"index": 4, "name": "saveModDir", "value": "E:/alg_demo-master/alg_demo/app/yolov5/best.pt",
"description": '保存模型路径',
"default": "./app/maskrcnn/saved_model/test.pt", "type": "S", 'show': False},
{"index": 5, "name": "resumeModPath", "value": '/yolov5s.pt',
"description": '继续训练路径', "default": '', "type": "S",
'show': False},
{"index": 6, "name": "resumeMod", "value": '', "description": '继续训练模型', "default": '', "type": "E", "items": '',
'show': True},
{"index": 7, "name": "CLASS_NAMES", "value": ['hole', '456'], "description": '类别名称', "default": '', "type": "L",
"items": '',
'show': False},
{"index": 8, "name": "DatasetDir", "value": "E:/aicheck/data_set/11442136178662604800/ori",
"description": '数据集路径',
"default": "./app/maskrcnn/datasets/test", "type": "S", 'show': False} # ORI_PATH
]
# {"index": 9, "name": "saveEpoch", "value": 2, "description": '保存模型轮次', "default": 2, "type": "I", 'show': True}]
params_str = json.dumps(params_list)
return params_str
@obtain_test_param()
def returnValidateParams():
# nvmlInit()
# gpuDeviceCount = nvmlDeviceGetCount() # 获取Nvidia GPU块数
# _kernel = [f"cuda:{a}" for a in range(gpuDeviceCount)]
params_list = [
{"index": 0, "name": "modPath", "value": "E:/alg_demo-master/alg_demo/app/yolov5/圆孔_123_RODY_1_640.pt",
"description": '验证模型路径', "default": "./app/maskrcnn/saved_model/test.pt", "type": "S", 'show': False},
{"index": 1, "name": "batch_size", "value": 1, "description": '批次图像数量', "default": 1, "type": "I",
'show': False},
{"index": 2, "name": "img_size", "value": 640, "description": '训练图像大小', "default": 640, "type": "I",
'show': False},
{"index": 3, "name": "outputPath", "value": 'E:/aicheck/data_set/11442136178662604800/val_results/',
"description": '输出结果路径',
"default": './app/maskrcnn/datasets/M006B_waibi/res', "type": "S", 'show': False},
{"index": 4, "name": "device", "value": "0", "description": '训练核心', "default": "cuda", "type": "S",
"items": '', 'show': False} # _kernel
]
# {"index": 9, "name": "saveEpoch", "value": 2, "description": '保存模型轮次', "default": 2, "type": "I", 'show': True}]
params_str = json.dumps(params_list)
return params_str
@obtain_detect_param()
def returnDetectParams():
# nvmlInit()
# gpuDeviceCount = nvmlDeviceGetCount() # 获取Nvidia GPU块数
# _kernel = [f"cuda:{a}" for a in range(gpuDeviceCount)]
params_list = [
{"index": 0, "name": "inputPath", "value": 'E:/aicheck/data_set/11442136178662604800/input/',
"description": '输入图像路径', "default": './app/maskrcnn/datasets/M006B_waibi/JPEGImages', "type": "S",
'show': False},
{"index": 1, "name": "outputPath", "value": 'E:/aicheck/data_set/11442136178662604800/val_results/',
"description": '输出结果路径',
"default": './app/maskrcnn/datasets/M006B_waibi/res', "type": "S", 'show': False},
{"index": 2, "name": "modPath", "value": "E:/alg_demo-master/alg_demo/app/yolov5/圆孔_123_RODY_1_640.pt",
"description": '模型路径', "default": "./app/maskrcnn/saved_model/test.pt", "type": "S", 'show': False},
{"index": 3, "name": "device", "value": "0", "description": '推理核', "default": "cpu", "type": "S",
'show': False},
]
# {"index": 9, "name": "saveEpoch", "value": 2, "description": '保存模型轮次', "default": 2, "type": "I", 'show': True}]
params_str = json.dumps(params_list)
return params_str
@obtain_download_pt_param()
def returnDownloadParams():
params_list = [
{"index": 0, "name": "exp_inputPath", "value": 'E:/alg_demo-master/alg_demo/app/yolov5/圆孔_123_RODY_1_640.pt',
"description": '转化模型输入路径',
"default": 'E:/alg_demo-master/alg_demo/app/yolov5/圆孔_123_RODY_1_640.pt/',
"type": "S", 'show': False},
{"index": 1, "name": "device", "value": 'gpu', "description": 'CPU或GPU', "default": 'gpu', "type": "S",
'show': False},
{"index": 2, "name": "imgsz", "value": 640, "description": '图像大小', "default": 640, "type": "I",
'show': True}
]
params_str = json.dumps(params_list)
return params_str
if __name__ == '__main__':
par = returnTrainParams()
print(par)
id='1'
train_R0DY(par,id)