358 lines
14 KiB
Python
358 lines
14 KiB
Python
from app.model.bussiness_model import ProjectImage, ProjectInfo, ProjectImgLeafer, ProjectImgLabel, ProjectTrain
|
||
from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoOut
|
||
from app.model.schemas.project_image_schemas import ProjectImgLeaferLabel, ProjectImgLeaferOut
|
||
from app.model.schemas.project_train_schemas import ProjectTrainIn
|
||
from app.model.crud import project_info_crud as pic
|
||
from app.model.crud import project_image_crud as pimc
|
||
from app.model.crud import project_label_crud as plc
|
||
from app.model.crud import project_train_crud as ptc
|
||
from app.model.crud import project_img_leafer_label_crud as pillc
|
||
from app.util import os_utils as os
|
||
from app.util import random_utils as ru
|
||
from app.config.config_reader import datasets_url, runs_url, images_url, yolo_url
|
||
from app.websocket.web_socket_server import room_manager
|
||
from app.util.csv_utils import read_csv
|
||
from app.common.redis_cli import redis_conn
|
||
|
||
from sqlalchemy.orm import Session
|
||
from typing import List
|
||
from fastapi import UploadFile
|
||
import yaml
|
||
import subprocess
|
||
import torch
|
||
|
||
|
||
def add_project(info: ProjectInfoIn, session: Session, user_id: int):
|
||
"""
|
||
新建项目,完善数据,并创建对应的文件夹
|
||
:param info: 项目信息
|
||
:param session: 数据库session
|
||
:param user_id: 用户id
|
||
:return:
|
||
"""
|
||
project_info = ProjectInfo(**info.dict())
|
||
project_info.user_id = user_id
|
||
project_info.project_no = ru.random_str(6)
|
||
project_info.project_status = "0"
|
||
project_info.train_version = 0
|
||
os.create_folder(datasets_url, project_info.project_no)
|
||
os.create_folder(runs_url, project_info.project_no)
|
||
project_info = pic.add_project(project_info, session)
|
||
return project_info.id
|
||
|
||
|
||
def check_image_name(project_id: int, img_type: str, files: List[UploadFile], session: Session):
|
||
for file in files:
|
||
if not pimc.check_img_name(project_id, img_type, file.filename, session):
|
||
return False, file.filename
|
||
return True, None
|
||
|
||
|
||
def upload_project_image(project_info: ProjectInfoOut, img_type: str, files: List[UploadFile], session: Session):
|
||
"""
|
||
上传项目的图片
|
||
:param files: 上传的图片
|
||
:param img_type: 上传的图片类别
|
||
:param project_info: 项目信息
|
||
:param session:
|
||
:return:
|
||
"""
|
||
images = []
|
||
for file in files:
|
||
image = ProjectImage()
|
||
image.project_id = project_info.id
|
||
image.file_name = file.filename
|
||
image.img_type = img_type
|
||
# 保存原图
|
||
path = os.save_images(images_url, project_info.project_no, file=file)
|
||
image.image_url = path
|
||
# 生成缩略图
|
||
thumb_image_url = os.file_path(images_url, 'thumb', project_info.project_no, ru.random_str(10) + ".jpg")
|
||
os.create_thumbnail(path, thumb_image_url)
|
||
image.thumb_image_url = thumb_image_url
|
||
images.append(image)
|
||
pimc.add_image_batch(images, session)
|
||
|
||
|
||
def del_img(image_id: int, session: Session):
|
||
"""
|
||
删除图片,并删除文件
|
||
:param image_id:
|
||
:param session:
|
||
:return:
|
||
"""
|
||
image = session.query(ProjectImage).filter_by(id=image_id).first()
|
||
if image is None:
|
||
return 0
|
||
os.delete_file_if_exists(image.image_url, image.thumb_image_url)
|
||
session.delete(image)
|
||
session.commit()
|
||
|
||
|
||
|
||
def save_img_label(img_leafer_label: ProjectImgLeaferLabel, session: Session):
|
||
"""
|
||
保存图片的标签框选信息,每次保存都会针对图片的信息全部删除,然后重新保存
|
||
:param img_leafer_label:
|
||
:param session:
|
||
:return:
|
||
"""
|
||
img_leafer = ProjectImgLeafer()
|
||
img_leafer.image_id = img_leafer_label.image_id
|
||
img_leafer.leafer = img_leafer_label.leafer
|
||
pillc.save_img_leafer(img_leafer, session)
|
||
label_infos = img_leafer_label.label_infos
|
||
img_labels = []
|
||
for label_info in label_infos:
|
||
img_label = ProjectImgLabel(**label_info.dict())
|
||
img_label.image_id = img_leafer_label.image_id
|
||
img_labels.append(img_label)
|
||
pillc.save_img_label_batch(img_leafer_label.image_id, img_labels, session)
|
||
|
||
|
||
def get_img_leafer(image_id: int, session: Session):
|
||
"""
|
||
根据图片id查询图片的leafer信息
|
||
:param image_id:
|
||
:param session:
|
||
:return:
|
||
"""
|
||
img_leafer = pillc.get_img_leafer(image_id, session)
|
||
if img_leafer is None:
|
||
return None
|
||
img_leafer_out = ProjectImgLeaferOut.from_orm(img_leafer).dict()
|
||
return img_leafer_out
|
||
|
||
|
||
def run_train_yolo(project_info: ProjectInfoOut, train_in: ProjectTrainIn, session: Session):
|
||
"""
|
||
yolov5执行训练任务
|
||
:param train_in: 训练参数
|
||
:param project_info: 项目信息
|
||
:param session: 数据库session
|
||
:return:
|
||
"""
|
||
# 先查询两个图片列表
|
||
project_images_train = pimc.get_images(project_info.id, 'train', session)
|
||
project_images_val = pimc.get_images(project_info.id, 'val', session)
|
||
|
||
# 得到训练版本
|
||
version_path = 'v' + str(project_info.train_version + 1)
|
||
|
||
# 创建训练的根目录
|
||
train_path = os.create_folder(datasets_url, project_info.project_no, version_path)
|
||
|
||
# 查询项目所属标签,返回两个 id,name一一对应的数组
|
||
label_id_list, label_name_list = plc.get_label_for_train(project_info.id, session)
|
||
|
||
# 创建图片的的两个文件夹
|
||
img_path_train = os.create_folder(train_path, 'images', 'train')
|
||
img_path_val = os.create_folder(train_path, 'images', 'val')
|
||
|
||
# 创建标签的两个文件夹
|
||
label_path_train = os.create_folder(train_path, 'labels', 'train')
|
||
label_path_val = os.create_folder(train_path, 'labels', 'val')
|
||
|
||
# 在根目录下创建yaml文件
|
||
yaml_file = os.file_path(train_path, project_info.project_no + '.yaml')
|
||
yaml_data = {
|
||
'path': train_path,
|
||
'train': 'images/train',
|
||
'val': 'images/val',
|
||
'test': None,
|
||
'names': {i: name for i, name in enumerate(label_name_list)}
|
||
}
|
||
with open(yaml_file, 'w', encoding='utf-8') as file:
|
||
yaml.dump(yaml_data, file, allow_unicode=True, default_flow_style=False)
|
||
|
||
# 开始循环复制图片和生成label.txt
|
||
# 先操作train
|
||
operate_img_label(project_images_train, img_path_train, label_path_train, session, label_id_list)
|
||
# 再操作val
|
||
operate_img_label(project_images_val, img_path_val, label_path_val, session, label_id_list)
|
||
|
||
# 打包完成开始训练,训练前,更改项目的训练状态
|
||
pic.update_project_status(project_info.id, '1', session)
|
||
|
||
# 开始执行异步训练
|
||
data = yaml_file
|
||
project = os.file_path(runs_url, project_info.project_no)
|
||
name = version_path
|
||
|
||
return data, project, name
|
||
|
||
|
||
async def run_commend(data: str, project: str, name: str, epochs: int, patience: int, weights: str, project_id: int, session: Session):
|
||
"""
|
||
执行训练
|
||
:param data: 训练数据集
|
||
:param project: 训练结果的项目目录
|
||
:param name: 实验名称
|
||
:param epochs: 训练轮数
|
||
:param patience: 早停耐心值
|
||
:param weights: 权重文件
|
||
:param project_id: 项目id
|
||
:param session:
|
||
:return:
|
||
"""
|
||
yolo_path = os.file_path(yolo_url, 'train.py')
|
||
room = 'train_' + str(project_id)
|
||
await room_manager.send_to_room(room, f"AiCheck: 模型训练开始,请稍等。。。\n")
|
||
commend = ["python", '-u', yolo_path, "--data=" + data, "--project=" + project, "--name=" + name,
|
||
"--epochs=" + str(epochs), "--batch-size=8", "--exist-ok", "--patience=" + str(patience)]
|
||
|
||
# 增加权重文件,在之前训练的基础上重新巡逻
|
||
if weights != '' and weights is not None:
|
||
train_info = ptc.get_train(int(weights), session)
|
||
if train_info is not None:
|
||
commend.append("--weights=" + train_info.best_pt)
|
||
|
||
is_gpu = redis_conn.get('is_gpu')
|
||
# 判断是否存在cuda版本
|
||
if is_gpu == 'True':
|
||
commend.append("--device=0")
|
||
# 启动子进程
|
||
with subprocess.Popen(
|
||
commend,
|
||
bufsize=1, # bufsize=0时,为不缓存;bufsize=1时,按行缓存;bufsize为其他正整数时,为按照近似该正整数的字节数缓存
|
||
shell=False,
|
||
stdout=subprocess.PIPE,
|
||
stderr=subprocess.STDOUT, # 这里可以显示yolov5训练过程中出现的进度条等信息
|
||
text=True, # 缓存内容为文本,避免后续编码显示问题
|
||
encoding='latin1',
|
||
) as process:
|
||
while process.poll() is None:
|
||
line = process.stdout.readline()
|
||
process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死
|
||
if line != '\n':
|
||
await room_manager.send_to_room(room, line + '\n')
|
||
|
||
# 等待进程结束并获取返回码
|
||
return_code = process.wait()
|
||
if return_code != 0:
|
||
pic.update_project_status(project_id, '-1', session)
|
||
else:
|
||
await room_manager.send_to_room(room, 'success')
|
||
pic.update_project_status(project_id, '2', session)
|
||
# 然后保存版本训练信息
|
||
train = ProjectTrain()
|
||
train.project_id = project_id
|
||
train.train_version = name
|
||
train_url = os.file_path(project, name)
|
||
train.train_url = train_url
|
||
train.train_data = data
|
||
bast_pt_path = os.file_path(train_url, 'weights', 'best.pt')
|
||
last_pt_path = os.file_path(train_url, 'weights', 'last.pt')
|
||
train.best_pt = bast_pt_path
|
||
train.last_pt = last_pt_path
|
||
if weights != None and weights != '':
|
||
train.weights_id = weights
|
||
train.weights_name = train_info.train_version
|
||
train.patience = patience
|
||
train.epochs = epochs
|
||
ptc.add_train(train, session)
|
||
|
||
|
||
def operate_img_label(img_list: List[ProjectImgLabel],
|
||
img_path: str, label_path: str,
|
||
session: Session, label_id_list: []):
|
||
"""
|
||
生成图片和标签内容
|
||
:param label_id_list:
|
||
:param session:
|
||
:param img_list:
|
||
:param img_path:
|
||
:param label_path:
|
||
:return:
|
||
"""
|
||
for i in range(len(img_list)):
|
||
image = img_list[i]
|
||
# 先复制图片,并把图片改名,不改后缀
|
||
file_name = 'image' + str(i)
|
||
os.copy_and_rename_file(image.image_url, img_path, file_name)
|
||
# 查询这张图片的label信息然后生成这张照片的txt文件
|
||
img_label_list = pillc.get_img_label_list(image.id, session)
|
||
label_txt_path = os.file_path(label_path, file_name + '.txt')
|
||
with open(label_txt_path, 'w', encoding='utf-8') as file:
|
||
for image_label in img_label_list:
|
||
index = label_id_list.index(image_label.label_id)
|
||
file.write(str(index) + ' ' + image_label.mark_center_x + ' '
|
||
+ image_label.mark_center_y + ' '
|
||
+ image_label.mark_width + ' '
|
||
+ image_label.mark_height + '\n')
|
||
|
||
|
||
def get_train_result(train_id: int, session: Session):
|
||
"""
|
||
根据result.csv文件查询训练报告
|
||
:param train_id:
|
||
:param session:
|
||
:return:
|
||
"""
|
||
train_info = ptc.get_train(train_id, session)
|
||
if train_info is None:
|
||
return None
|
||
result_csv_path = os.file_path(train_info.train_url, 'results.csv')
|
||
result_row = read_csv(result_csv_path)
|
||
report_data = {}
|
||
# 轮数
|
||
epoch_data = []
|
||
# 边界框回归损失(Bounding Box Loss),衡量预测框位置(中心坐标、宽高)与真实框的差异,值越低表示定位越准。
|
||
train_box_loss = []
|
||
# 目标置信度损失(Objectness Loss),衡量检测到目标的置信度误差(即是否包含物体),值越低表示模型越能正确判断有无物体。
|
||
train_obj_loss = []
|
||
# 分类损失(Classification Loss),衡量预测类别与真实类别的差异,值越低表示分类越准。
|
||
train_cls_loss = []
|
||
|
||
# 验证集的边界框回归损失,反映模型在未见数据上的定位能力。
|
||
val_box_loss = []
|
||
# 验证集的目标置信度损失,反映模型在未见数据上判断物体存在的能力。
|
||
val_obj_loss = []
|
||
# 验证集的分类损失,反映模型在未见数据上的分类准确性。
|
||
val_cls_loss = []
|
||
|
||
# 精确率(Precision):正确检测的正样本占所有预测为正样本的比例,反映“误检率”。值越高说明误检越少。
|
||
m_p = []
|
||
# 召回率(Recall):正确检测的正样本占所有真实正样本的比例,反映“漏检率”。值越高说明漏检越少。
|
||
m_r = []
|
||
|
||
# 主干网络(Backbone)的学习率。
|
||
x_lr0 = []
|
||
# 检测头(Head)的学习率。
|
||
x_lr1 = []
|
||
|
||
for row in result_row:
|
||
epoch_data.append(row[0].strip())
|
||
|
||
train_box_loss.append(row[1].strip())
|
||
train_obj_loss.append(row[2].strip())
|
||
train_cls_loss.append(row[3].strip())
|
||
|
||
val_box_loss.append(row[8].strip())
|
||
val_obj_loss.append(row[9].strip())
|
||
val_cls_loss.append(row[10].strip())
|
||
|
||
m_p.append(row[4].strip())
|
||
m_r.append(row[5].strip())
|
||
|
||
x_lr0.append(row[11].strip())
|
||
x_lr1.append(row[12].strip())
|
||
|
||
report_data['epoch_data'] = epoch_data
|
||
|
||
report_data['train_box_loss'] = train_box_loss
|
||
report_data['train_obj_loss'] = train_obj_loss
|
||
report_data['train_cls_loss'] = train_cls_loss
|
||
|
||
report_data['val_box_loss'] = val_box_loss
|
||
report_data['val_obj_loss'] = val_obj_loss
|
||
report_data['val_cls_loss'] = val_cls_loss
|
||
|
||
report_data['m_p'] = m_p
|
||
report_data['m_r'] = m_r
|
||
|
||
report_data['x_lr0'] = x_lr0
|
||
report_data['x_lr1'] = x_lr1
|
||
|
||
return report_data
|