Files
aicheckv2/app/service/project_train_service.py
2025-03-20 16:41:51 +08:00

357 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from app.model.bussiness_model import ProjectImage, ProjectInfo, ProjectImgLeafer, ProjectImgLabel, ProjectTrain
from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoOut
from app.model.schemas.project_image_schemas import ProjectImgLeaferLabel, ProjectImgLeaferOut
from app.model.schemas.project_train_schemas import ProjectTrainIn
from app.model.crud import project_info_crud as pic
from app.model.crud import project_image_crud as pimc
from app.model.crud import project_label_crud as plc
from app.model.crud import project_train_crud as ptc
from app.model.crud import project_img_leafer_label_crud as pillc
from app.util import os_utils as os
from app.util import random_utils as ru
from app.config.config_reader import datasets_url, runs_url, images_url, yolo_url
from app.websocket.web_socket_server import room_manager
from app.util.csv_utils import read_csv
from app.common.redis_cli import redis_conn
import yaml
import subprocess
from typing import List
from fastapi import UploadFile
from sqlalchemy.orm import Session
def add_project(info: ProjectInfoIn, session: Session, user_id: int):
"""
新建项目,完善数据,并创建对应的文件夹
:param info: 项目信息
:param session: 数据库session
:param user_id: 用户id
:return:
"""
project_info = ProjectInfo(**info.dict())
project_info.user_id = user_id
project_info.project_no = ru.random_str(6)
project_info.project_status = "0"
project_info.train_version = 0
os.create_folder(datasets_url, project_info.project_no)
os.create_folder(runs_url, project_info.project_no)
project_info = pic.add_project(project_info, session)
return project_info.id
def check_image_name(project_id: int, img_type: str, files: List[UploadFile], session: Session):
for file in files:
if not pimc.check_img_name(project_id, img_type, file.filename, session):
return False, file.filename
return True, None
def upload_project_image(project_info: ProjectInfoOut, img_type: str, files: List[UploadFile], session: Session):
"""
上传项目的图片
:param files: 上传的图片
:param img_type: 上传的图片类别
:param project_info: 项目信息
:param session:
:return:
"""
images = []
for file in files:
image = ProjectImage()
image.project_id = project_info.id
image.file_name = file.filename
image.img_type = img_type
# 保存原图
path = os.save_images(images_url, project_info.project_no, file=file)
image.image_url = path
# 生成缩略图
thumb_image_url = os.file_path(images_url, 'thumb', project_info.project_no, ru.random_str(10) + ".jpg")
os.create_thumbnail(path, thumb_image_url)
image.thumb_image_url = thumb_image_url
images.append(image)
pimc.add_image_batch(images, session)
def del_img(image_id: int, session: Session):
"""
删除图片,并删除文件
:param image_id:
:param session:
:return:
"""
image = session.query(ProjectImage).filter_by(id=image_id).first()
if image is None:
return 0
os.delete_file_if_exists(image.image_url, image.thumb_image_url)
session.delete(image)
session.commit()
def save_img_label(img_leafer_label: ProjectImgLeaferLabel, session: Session):
"""
保存图片的标签框选信息,每次保存都会针对图片的信息全部删除,然后重新保存
:param img_leafer_label:
:param session:
:return:
"""
img_leafer = ProjectImgLeafer()
img_leafer.image_id = img_leafer_label.image_id
img_leafer.leafer = img_leafer_label.leafer
pillc.save_img_leafer(img_leafer, session)
label_infos = img_leafer_label.label_infos
img_labels = []
for label_info in label_infos:
img_label = ProjectImgLabel(**label_info.dict())
img_label.image_id = img_leafer_label.image_id
img_labels.append(img_label)
pillc.save_img_label_batch(img_leafer_label.image_id, img_labels, session)
def get_img_leafer(image_id: int, session: Session):
"""
根据图片id查询图片的leafer信息
:param image_id:
:param session:
:return:
"""
img_leafer = pillc.get_img_leafer(image_id, session)
if img_leafer is None:
return None
img_leafer_out = ProjectImgLeaferOut.from_orm(img_leafer).dict()
return img_leafer_out
def run_train_yolo(project_info: ProjectInfoOut, train_in: ProjectTrainIn, session: Session):
"""
yolov5执行训练任务
:param train_in: 训练参数
:param project_info: 项目信息
:param session: 数据库session
:return:
"""
# 先查询两个图片列表
project_images_train = pimc.get_images(project_info.id, 'train', session)
project_images_val = pimc.get_images(project_info.id, 'val', session)
# 得到训练版本
version_path = 'v' + str(project_info.train_version + 1)
# 创建训练的根目录
train_path = os.create_folder(datasets_url, project_info.project_no, version_path)
# 查询项目所属标签,返回两个 idname一一对应的数组
label_id_list, label_name_list = plc.get_label_for_train(project_info.id, session)
# 创建图片的的两个文件夹
img_path_train = os.create_folder(train_path, 'images', 'train')
img_path_val = os.create_folder(train_path, 'images', 'val')
# 创建标签的两个文件夹
label_path_train = os.create_folder(train_path, 'labels', 'train')
label_path_val = os.create_folder(train_path, 'labels', 'val')
# 在根目录下创建yaml文件
yaml_file = os.file_path(train_path, project_info.project_no + '.yaml')
yaml_data = {
'path': train_path,
'train': 'images/train',
'val': 'images/val',
'test': None,
'names': {i: name for i, name in enumerate(label_name_list)}
}
with open(yaml_file, 'w', encoding='utf-8') as file:
yaml.dump(yaml_data, file, allow_unicode=True, default_flow_style=False)
# 开始循环复制图片和生成label.txt
# 先操作train
operate_img_label(project_images_train, img_path_train, label_path_train, session, label_id_list)
# 再操作val
operate_img_label(project_images_val, img_path_val, label_path_val, session, label_id_list)
# 打包完成开始训练,训练前,更改项目的训练状态
pic.update_project_status(project_info.id, '1', session)
# 开始执行异步训练
data = yaml_file
project = os.file_path(runs_url, project_info.project_no)
name = version_path
return data, project, name
async def run_commend(data: str, project: str, name: str, epochs: int, patience: int, weights: str, project_id: int, session: Session):
"""
执行训练
:param data: 训练数据集
:param project: 训练结果的项目目录
:param name: 实验名称
:param epochs: 训练轮数
:param patience: 早停耐心值
:param weights: 权重文件
:param project_id: 项目id
:param session:
:return:
"""
yolo_path = os.file_path(yolo_url, 'train.py')
room = 'train_' + str(project_id)
await room_manager.send_to_room(room, f"AiCheck: 模型训练开始,请稍等。。。\n")
commend = ["python", '-u', yolo_path, "--data=" + data, "--project=" + project, "--name=" + name,
"--epochs=" + str(epochs), "--batch-size=8", "--exist-ok", "--patience=" + str(patience)]
# 增加权重文件,在之前训练的基础上重新巡逻
if weights != '' and weights is not None:
train_info = ptc.get_train(int(weights), session)
if train_info is not None:
commend.append("--weights=" + train_info.best_pt)
is_gpu = redis_conn.get('is_gpu')
# 判断是否存在cuda版本
if is_gpu == 'True':
commend.append("--device=0")
# 启动子进程
with subprocess.Popen(
commend,
bufsize=1, # bufsize=0时为不缓存bufsize=1时按行缓存bufsize为其他正整数时为按照近似该正整数的字节数缓存
shell=False,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, # 这里可以显示yolov5训练过程中出现的进度条等信息
text=True, # 缓存内容为文本,避免后续编码显示问题
encoding='latin1',
) as process:
while process.poll() is None:
line = process.stdout.readline()
process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死
if line != '\n':
await room_manager.send_to_room(room, line + '\n')
# 等待进程结束并获取返回码
return_code = process.wait()
if return_code != 0:
pic.update_project_status(project_id, '-1', session)
else:
await room_manager.send_to_room(room, 'success')
pic.update_project_status(project_id, '2', session)
# 然后保存版本训练信息
train = ProjectTrain()
train.project_id = project_id
train.train_version = name
train_url = os.file_path(project, name)
train.train_url = train_url
train.train_data = data
bast_pt_path = os.file_path(train_url, 'weights', 'best.pt')
last_pt_path = os.file_path(train_url, 'weights', 'last.pt')
train.best_pt = bast_pt_path
train.last_pt = last_pt_path
if weights != None and weights != '':
train.weights_id = weights
train.weights_name = train_info.train_version
train.patience = patience
train.epochs = epochs
ptc.add_train(train, session)
def operate_img_label(img_list: List[ProjectImgLabel],
img_path: str, label_path: str,
session: Session, label_id_list: []):
"""
生成图片和标签内容
:param label_id_list:
:param session:
:param img_list:
:param img_path:
:param label_path:
:return:
"""
for i in range(len(img_list)):
image = img_list[i]
# 先复制图片,并把图片改名,不改后缀
file_name = 'image' + str(i)
os.copy_and_rename_file(image.image_url, img_path, file_name)
# 查询这张图片的label信息然后生成这张照片的txt文件
img_label_list = pillc.get_img_label_list(image.id, session)
label_txt_path = os.file_path(label_path, file_name + '.txt')
with open(label_txt_path, 'w', encoding='utf-8') as file:
for image_label in img_label_list:
index = label_id_list.index(image_label.label_id)
file.write(str(index) + ' ' + image_label.mark_center_x + ' '
+ image_label.mark_center_y + ' '
+ image_label.mark_width + ' '
+ image_label.mark_height + '\n')
def get_train_result(train_id: int, session: Session):
"""
根据result.csv文件查询训练报告
:param train_id:
:param session:
:return:
"""
train_info = ptc.get_train(train_id, session)
if train_info is None:
return None
result_csv_path = os.file_path(train_info.train_url, 'results.csv')
result_row = read_csv(result_csv_path)
report_data = {}
# 轮数
epoch_data = []
# 边界框回归损失Bounding Box Loss衡量预测框位置中心坐标、宽高与真实框的差异值越低表示定位越准。
train_box_loss = []
# 目标置信度损失Objectness Loss衡量检测到目标的置信度误差即是否包含物体值越低表示模型越能正确判断有无物体。
train_obj_loss = []
# 分类损失Classification Loss衡量预测类别与真实类别的差异值越低表示分类越准。
train_cls_loss = []
# 验证集的边界框回归损失,反映模型在未见数据上的定位能力。
val_box_loss = []
# 验证集的目标置信度损失,反映模型在未见数据上判断物体存在的能力。
val_obj_loss = []
# 验证集的分类损失,反映模型在未见数据上的分类准确性。
val_cls_loss = []
# 精确率Precision正确检测的正样本占所有预测为正样本的比例反映“误检率”。值越高说明误检越少。
m_p = []
# 召回率Recall正确检测的正样本占所有真实正样本的比例反映“漏检率”。值越高说明漏检越少。
m_r = []
# 主干网络Backbone的学习率。
x_lr0 = []
# 检测头Head的学习率。
x_lr1 = []
for row in result_row:
epoch_data.append(row[0].strip())
train_box_loss.append(row[1].strip())
train_obj_loss.append(row[2].strip())
train_cls_loss.append(row[3].strip())
val_box_loss.append(row[8].strip())
val_obj_loss.append(row[9].strip())
val_cls_loss.append(row[10].strip())
m_p.append(row[4].strip())
m_r.append(row[5].strip())
x_lr0.append(row[11].strip())
x_lr1.append(row[12].strip())
report_data['epoch_data'] = epoch_data
report_data['train_box_loss'] = train_box_loss
report_data['train_obj_loss'] = train_obj_loss
report_data['train_cls_loss'] = train_cls_loss
report_data['val_box_loss'] = val_box_loss
report_data['val_obj_loss'] = val_obj_loss
report_data['val_cls_loss'] = val_cls_loss
report_data['m_p'] = m_p
report_data['m_r'] = m_r
report_data['x_lr0'] = x_lr0
report_data['x_lr1'] = x_lr1
return report_data