Files
aicheckv2-api/apps/business/train/service.py
2025-06-09 15:27:45 +08:00

174 lines
5.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from algo import YoloModel
from utils import os_utils as osu
from application.settings import *
from . import schemas, models, crud
from apps.business.project import models as proj_models, crud as proj_crud
import yaml
import asyncio
from sqlalchemy.ext.asyncio import AsyncSession
async def before_train(proj_info: proj_models.ProjectInfo, db: AsyncSession):
"""
yolov5执行训练任务
:param proj_info: 项目信息
:param db: 数据库session
:return:
"""
img_dal = proj_crud.ProjectImageDal(db)
label_dal = proj_crud.ProjectLabelDal(db)
# 先查询两个图片列表
project_images_train = await img_dal.get_datas(
v_where=[proj_models.ProjectImage.project_id == proj_info.id, proj_models.ProjectImage.img_type == 'train'],
limit=0,
v_return_count=False,
v_return_objs=True)
project_images_val = await img_dal.get_datas(
v_where=[proj_models.ProjectImage.project_id == proj_info.id, proj_models.ProjectImage.img_type == 'val'],
limit=0,
v_return_count=False,
v_return_objs=True)
# 得到训练版本
version_path = 'v' + str(proj_info.train_version + 1)
# 创建训练的根目录
train_path = osu.create_folder(datasets_url, proj_info.project_no, version_path)
# 查询项目所属标签,返回两个 idname一一对应的数组
label_id_list, label_name_list = await label_dal.get_label_for_train(proj_info.id)
# 创建图片的的两个文件夹
img_path_train = osu.create_folder(train_path, 'images', 'train')
img_path_val = osu.create_folder(train_path, 'images', 'val')
# 创建标签的两个文件夹
label_path_train = osu.create_folder(train_path, 'labels', 'train')
label_path_val = osu.create_folder(train_path, 'labels', 'val')
# 在根目录下创建yaml文件
yaml_file = osu.file_path(train_path, proj_info.project_no + '.yaml')
yaml_data = {
'path': train_path,
'train': 'images/train',
'val': 'images/val',
'test': None,
'names': {i: name for i, name in enumerate(label_name_list)}
}
with open(yaml_file, 'w', encoding='utf-8') as file:
yaml.dump(yaml_data, file, allow_unicode=True, default_flow_style=False)
# 开始循环复制图片和生成label.txt
# 先操作train
await operate_img_label(project_images_train, img_path_train, label_path_train, db, label_id_list)
# 再操作val
await operate_img_label(project_images_val, img_path_val, label_path_val, db, label_id_list)
# 开始执行异步训练
data = yaml_file
project = osu.file_path(runs_url, proj_info.project_no)
name = version_path
return data, project, name
async def operate_img_label(
img_list: list[proj_models.ProjectImage],
img_path: str,
label_path: str,
db: AsyncSession,
label_id_list: []):
"""
生成图片和标签内容
:param label_id_list:
:param db: 数据库session
:param img_list:
:param img_path:
:param label_path:
:return:
"""
for i in range(len(img_list)):
image = img_list[i]
# 先复制图片,并把图片改名,不改后缀
file_name = 'image' + str(i)
osu.copy_and_rename_file(image.image_url, img_path, file_name)
# 查询这张图片的label信息然后生成这张照片的txt文件
img_label_list = await proj_crud.ProjectImgLabelDal(db).get_img_label_list(image.id)
label_txt_path = osu.file_path(label_path, file_name + '.txt')
with open(label_txt_path, 'w', encoding='utf-8') as file:
for image_label in img_label_list:
index = label_id_list.index(image_label.label_id)
file.write(str(index) + ' ' + image_label.mark_center_x + ' '
+ image_label.mark_center_y + ' '
+ image_label.mark_width + ' '
+ image_label.mark_height + '\n')
def run_event_loop(
data: str,
project: str,
name: str,
train_in: schemas.ProjectTrainIn,
project_id: int,
train_info: models.ProjectTrain,
is_gup: str):
# 运行异步函数,开始训练
loop_run = asyncio.new_event_loop()
asyncio.set_event_loop(loop_run)
loop_run.run_until_complete(run_commend(data, project, name, train_in.epochs, train_in.patience, train_info))
def run_commend(
data: str,
project: str,
name: str,
epochs: int,
patience: int,
train_info: models.ProjectTrain):
"""
执行训练
:param data: 训练数据集
:param project: 训练结果的项目目录
:param name: 实验名称
:param epochs: 训练轮数
:param patience: 早停耐心值
:param train_info: 训练信息
:return:
"""
if train_info is None:
model = YoloModel()
else:
model = YoloModel(train_info.best_pt)
model.train(data=data, epochs=epochs, project=project, name=name, patience=patience)
async def add_train(
db,
project_id,
name,
project,
data,
train_in,
user_id):
# 更新版本信息
await proj_crud.ProjectInfoDal(db).update_version(data_id=project_id)
# 增加训练版本信息
train = models.ProjectTrain()
train.project_id = project_id
train.train_version = name
train_url = osu.file_path(project, name)
train.train_url = train_url
train.train_data = data
train.user_id = user_id
bast_pt_path = osu.file_path(train_url, 'weights', 'best.pt')
last_pt_path = osu.file_path(train_url, 'weights', 'last.pt')
train.best_pt = bast_pt_path
train.last_pt = last_pt_path
if train_in is not None:
train.weights_id = train_in.weights_id
train.weights_name = train_in.weights_name
train.patience = train_in.patience
train.epochs = train_in.epochs
await crud.ProjectTrainDal(db).create_model(data=train)