重构基本完成

This commit is contained in:
2025-02-21 11:35:53 +08:00
parent cdd05e95ba
commit 85d0a8fadc
16 changed files with 346 additions and 8 deletions

View File

@ -1,16 +1,20 @@
from app.model.bussiness_model import ProjectImage, ProjectInfo, ProjectImgLeafer, ProjectImgLabel
from app.model.bussiness_model import ProjectImage, ProjectInfo, ProjectImgLeafer, ProjectImgLabel, ProjectTrain
from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoOut
from app.model.schemas.project_image_schemas import ProjectImgLeaferLabel, ProjectImgLeaferOut
from app.model.crud import project_info_crud as pic
from app.model.crud import project_image_crud as pimc
from app.model.crud import project_label_crud as plc
from app.model.crud import project_train_crud as ptc
from app.model.crud import project_img_leafer_label_crud as pillc
from app.util import os_utils as os
from app.util import random_utils as ru
from app.config.config_reader import datasets_url, runs_url, images_url
from app.config.config_reader import datasets_url, runs_url, images_url, yolo_url
from sqlalchemy.orm import Session
from typing import List
from fastapi import UploadFile
import yaml
import subprocess
def add_project(info: ProjectInfoIn, session: Session, user_id: int):
@ -48,7 +52,7 @@ def upload_project_image(project_info: ProjectInfoOut, files: List[UploadFile],
path = os.save_images(images_url, project_info.project_no, file=file)
image.image_url = path
# 生成缩略图
thumb_image_url = images_url + "\\thumb\\" + project_info.project_no + "\\" + ru.random_str(10) + ".jpg"
thumb_image_url = os.file_path(images_url, 'thumb', project_info.project_no, ru.random_str(10) + ".jpg")
os.create_thumbnail(path, thumb_image_url)
image.thumb_image_url = thumb_image_url
images.append(image)
@ -62,7 +66,9 @@ def save_img_label(img_leafer_label: ProjectImgLeaferLabel, session: Session):
:param session:
:return:
"""
img_leafer = ProjectImgLeafer(**img_leafer_label)
img_leafer = ProjectImgLeafer()
img_leafer.image_id = img_leafer_label.image_id
img_leafer.leafer = img_leafer_label.leafer
pillc.save_img_leafer(img_leafer, session)
label_infos = img_leafer_label.label_infos
img_labels = []
@ -83,3 +89,146 @@ def get_img_leafer(image_id: int, session: Session):
img_leafer = pillc.get_img_leafer(image_id, session)
img_leafer_out = ProjectImgLeaferOut.from_orm(img_leafer).dict()
return img_leafer_out
def run_train_yolo(project_info: ProjectInfoOut, session: Session):
"""
yolov5执行训练任务
:param project_info: 项目信息
:param session: 数据库session
:return:
"""
# 先获取项目的所有图片
project_images = pimc.get_images(project_info.id, session)
# 将图片根据,根据3:1的比例将图片分成trainval的两个数组
project_images_train, project_images_val = split_array(project_images)
# 得到训练版本
version_path = 'v' + str(project_info.train_version + 1)
# 创建训练的根目录
train_path = os.create_folder(datasets_url, project_info.project_no, version_path)
# 查询项目所属标签,返回两个 idname一一对应的数组
label_id_list, label_name_list = plc.get_label_for_train(project_info.id, session)
# 在根目录创建classes.txt文件
classes_txt = os.file_path(train_path, 'classes.txt')
with open(classes_txt, 'w', encoding='utf-8') as file:
for label_name in label_name_list:
file.write(label_name + '\n')
# 创建图片的的两个文件夹
img_path_train = os.create_folder(train_path, 'images', 'train')
img_path_val = os.create_folder(train_path, 'images', 'val')
# 创建标签的两个文件夹
label_path_train = os.create_folder(train_path, 'labels', 'train')
label_path_val = os.create_folder(train_path, 'labels', 'val')
# 在根目录下创建yaml文件
yaml_file = os.file_path(train_path, project_info.project_no + '.yaml')
yaml_data = {
'path': train_path,
'train': 'images/train',
'val': 'images/val',
'test': None,
'names': {i: name for i, name in enumerate(label_name_list)}
}
with open(yaml_file, 'w', encoding='utf-8') as file:
yaml.dump(yaml_data, file, allow_unicode=True, default_flow_style=False)
# 开始循环复制图片和生成label.txt
# 先操作train
operate_img_label(project_images_train, img_path_train, label_path_train, session, label_id_list)
# 再操作val
operate_img_label(project_images_val, img_path_val, label_path_val, session, label_id_list)
# 打包完成开始训练,训练前,更改项目的训练状态
pic.update_project_status(project_info.id, '1')
# 开始训练
data = yaml_file
project = os.file_path(runs_url, project_info.project_no, 'train')
name = version_path
epochs = 10
yolo_path = os.file_path(yolo_url, 'train.py')
# 启动子进程
with subprocess.Popen(
["python", yolo_path, "--data=" + data, "--project=" + project, "--name=" + name, "--epochs=" + str(epochs)],
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True, # 确保输出以字符串形式返回而不是字节
bufsize=1, # 行缓冲
) as process:
# 使用iter逐行读取stdout和stderr
for line in process.stdout:
yield f"stdout: {line.strip()} \n"
for line in process.stderr:
yield f"stderr: {line.strip()} \n"
# 等待进程结束并获取返回码
return_code = process.wait()
if return_code != 0:
pic.update_project_status(project_info.id, '-1', session)
else:
pic.update_project_status(project_info.id, '2', session)
# 然后保存版本训练信息
train = ProjectTrain()
train.project_id = project_info.id
train.train_version = version_path
bast_pt_path = os.file_path(project, name, 'weight', 'bast.pt')
last_pt_path = os.file_path(project, name, 'weight', 'last.pt')
train.best_pt = bast_pt_path
train.last_pt = last_pt_path
ptc.add_train(train, session)
def operate_img_label(img_list: List[ProjectImgLabel],
img_path: str, label_path: str,
session: Session, label_id_list: []):
"""
生成图片和标签内容
:param label_id_list:
:param session:
:param img_list:
:param img_path:
:param label_path:
:return:
"""
for i in range(len(img_list)):
image = img_list[i]
# 先复制图片,并把图片改名,不改后缀
file_name = 'image' + str(i)
os.copy_and_rename_file(image.image_url, img_path, file_name)
# 查询这张图片的label信息然后生成这张照片的txt文件
img_label_list = pillc.get_img_label_list(image.id, session)
label_txt_path = os.file_path(label_path, file_name + '.txt')
with open(label_txt_path, 'w', encoding='utf-8') as file:
for image_label in img_label_list:
index = label_id_list.index(image_label.label_id)
file.write(str(index) + ' ' + image_label.mark_center_x + ' '
+ image_label.mark_center_y + ' '
+ image_label.mark_width + ' '
+ image_label.mark_height + '\n')
def split_array(data):
"""
将数组按照3:1的比例切分成两个数组。
:param data: 原始数组
:return: 按照3:1比例分割后的两个数组
"""
total_length = len(data)
if total_length < 4:
raise ValueError("数组长度至少需要为4才能进行3:1的分割")
# 计算分割点
split_index = (total_length * 3) // 4
# 使用切片分割数组
part1 = data[:split_index]
part2 = data[split_index:]
return part1, part2