重构基本完成
This commit is contained in:
@ -1,16 +1,20 @@
|
||||
from app.model.bussiness_model import ProjectImage, ProjectInfo, ProjectImgLeafer, ProjectImgLabel
|
||||
from app.model.bussiness_model import ProjectImage, ProjectInfo, ProjectImgLeafer, ProjectImgLabel, ProjectTrain
|
||||
from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoOut
|
||||
from app.model.schemas.project_image_schemas import ProjectImgLeaferLabel, ProjectImgLeaferOut
|
||||
from app.model.crud import project_info_crud as pic
|
||||
from app.model.crud import project_image_crud as pimc
|
||||
from app.model.crud import project_label_crud as plc
|
||||
from app.model.crud import project_train_crud as ptc
|
||||
from app.model.crud import project_img_leafer_label_crud as pillc
|
||||
from app.util import os_utils as os
|
||||
from app.util import random_utils as ru
|
||||
from app.config.config_reader import datasets_url, runs_url, images_url
|
||||
from app.config.config_reader import datasets_url, runs_url, images_url, yolo_url
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
from fastapi import UploadFile
|
||||
import yaml
|
||||
import subprocess
|
||||
|
||||
|
||||
def add_project(info: ProjectInfoIn, session: Session, user_id: int):
|
||||
@ -48,7 +52,7 @@ def upload_project_image(project_info: ProjectInfoOut, files: List[UploadFile],
|
||||
path = os.save_images(images_url, project_info.project_no, file=file)
|
||||
image.image_url = path
|
||||
# 生成缩略图
|
||||
thumb_image_url = images_url + "\\thumb\\" + project_info.project_no + "\\" + ru.random_str(10) + ".jpg"
|
||||
thumb_image_url = os.file_path(images_url, 'thumb', project_info.project_no, ru.random_str(10) + ".jpg")
|
||||
os.create_thumbnail(path, thumb_image_url)
|
||||
image.thumb_image_url = thumb_image_url
|
||||
images.append(image)
|
||||
@ -62,7 +66,9 @@ def save_img_label(img_leafer_label: ProjectImgLeaferLabel, session: Session):
|
||||
:param session:
|
||||
:return:
|
||||
"""
|
||||
img_leafer = ProjectImgLeafer(**img_leafer_label)
|
||||
img_leafer = ProjectImgLeafer()
|
||||
img_leafer.image_id = img_leafer_label.image_id
|
||||
img_leafer.leafer = img_leafer_label.leafer
|
||||
pillc.save_img_leafer(img_leafer, session)
|
||||
label_infos = img_leafer_label.label_infos
|
||||
img_labels = []
|
||||
@ -83,3 +89,146 @@ def get_img_leafer(image_id: int, session: Session):
|
||||
img_leafer = pillc.get_img_leafer(image_id, session)
|
||||
img_leafer_out = ProjectImgLeaferOut.from_orm(img_leafer).dict()
|
||||
return img_leafer_out
|
||||
|
||||
|
||||
def run_train_yolo(project_info: ProjectInfoOut, session: Session):
|
||||
"""
|
||||
yolov5执行训练任务
|
||||
:param project_info: 项目信息
|
||||
:param session: 数据库session
|
||||
:return:
|
||||
"""
|
||||
# 先获取项目的所有图片
|
||||
project_images = pimc.get_images(project_info.id, session)
|
||||
|
||||
# 将图片根据,根据3:1的比例将图片分成train:val的两个数组
|
||||
project_images_train, project_images_val = split_array(project_images)
|
||||
|
||||
# 得到训练版本
|
||||
version_path = 'v' + str(project_info.train_version + 1)
|
||||
|
||||
# 创建训练的根目录
|
||||
train_path = os.create_folder(datasets_url, project_info.project_no, version_path)
|
||||
|
||||
# 查询项目所属标签,返回两个 id,name一一对应的数组
|
||||
label_id_list, label_name_list = plc.get_label_for_train(project_info.id, session)
|
||||
|
||||
# 在根目录创建classes.txt文件
|
||||
classes_txt = os.file_path(train_path, 'classes.txt')
|
||||
with open(classes_txt, 'w', encoding='utf-8') as file:
|
||||
for label_name in label_name_list:
|
||||
file.write(label_name + '\n')
|
||||
|
||||
# 创建图片的的两个文件夹
|
||||
img_path_train = os.create_folder(train_path, 'images', 'train')
|
||||
img_path_val = os.create_folder(train_path, 'images', 'val')
|
||||
|
||||
# 创建标签的两个文件夹
|
||||
label_path_train = os.create_folder(train_path, 'labels', 'train')
|
||||
label_path_val = os.create_folder(train_path, 'labels', 'val')
|
||||
|
||||
# 在根目录下创建yaml文件
|
||||
yaml_file = os.file_path(train_path, project_info.project_no + '.yaml')
|
||||
yaml_data = {
|
||||
'path': train_path,
|
||||
'train': 'images/train',
|
||||
'val': 'images/val',
|
||||
'test': None,
|
||||
'names': {i: name for i, name in enumerate(label_name_list)}
|
||||
}
|
||||
with open(yaml_file, 'w', encoding='utf-8') as file:
|
||||
yaml.dump(yaml_data, file, allow_unicode=True, default_flow_style=False)
|
||||
|
||||
# 开始循环复制图片和生成label.txt
|
||||
# 先操作train
|
||||
operate_img_label(project_images_train, img_path_train, label_path_train, session, label_id_list)
|
||||
# 再操作val
|
||||
operate_img_label(project_images_val, img_path_val, label_path_val, session, label_id_list)
|
||||
|
||||
# 打包完成开始训练,训练前,更改项目的训练状态
|
||||
pic.update_project_status(project_info.id, '1')
|
||||
# 开始训练
|
||||
data = yaml_file
|
||||
project = os.file_path(runs_url, project_info.project_no, 'train')
|
||||
name = version_path
|
||||
epochs = 10
|
||||
yolo_path = os.file_path(yolo_url, 'train.py')
|
||||
|
||||
# 启动子进程
|
||||
with subprocess.Popen(
|
||||
["python", yolo_path, "--data=" + data, "--project=" + project, "--name=" + name, "--epochs=" + str(epochs)],
|
||||
shell=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
universal_newlines=True, # 确保输出以字符串形式返回而不是字节
|
||||
bufsize=1, # 行缓冲
|
||||
) as process:
|
||||
# 使用iter逐行读取stdout和stderr
|
||||
for line in process.stdout:
|
||||
yield f"stdout: {line.strip()} \n"
|
||||
|
||||
for line in process.stderr:
|
||||
yield f"stderr: {line.strip()} \n"
|
||||
|
||||
# 等待进程结束并获取返回码
|
||||
return_code = process.wait()
|
||||
if return_code != 0:
|
||||
pic.update_project_status(project_info.id, '-1', session)
|
||||
else:
|
||||
pic.update_project_status(project_info.id, '2', session)
|
||||
# 然后保存版本训练信息
|
||||
train = ProjectTrain()
|
||||
train.project_id = project_info.id
|
||||
train.train_version = version_path
|
||||
bast_pt_path = os.file_path(project, name, 'weight', 'bast.pt')
|
||||
last_pt_path = os.file_path(project, name, 'weight', 'last.pt')
|
||||
train.best_pt = bast_pt_path
|
||||
train.last_pt = last_pt_path
|
||||
ptc.add_train(train, session)
|
||||
|
||||
|
||||
def operate_img_label(img_list: List[ProjectImgLabel],
|
||||
img_path: str, label_path: str,
|
||||
session: Session, label_id_list: []):
|
||||
"""
|
||||
生成图片和标签内容
|
||||
:param label_id_list:
|
||||
:param session:
|
||||
:param img_list:
|
||||
:param img_path:
|
||||
:param label_path:
|
||||
:return:
|
||||
"""
|
||||
for i in range(len(img_list)):
|
||||
image = img_list[i]
|
||||
# 先复制图片,并把图片改名,不改后缀
|
||||
file_name = 'image' + str(i)
|
||||
os.copy_and_rename_file(image.image_url, img_path, file_name)
|
||||
# 查询这张图片的label信息然后生成这张照片的txt文件
|
||||
img_label_list = pillc.get_img_label_list(image.id, session)
|
||||
label_txt_path = os.file_path(label_path, file_name + '.txt')
|
||||
with open(label_txt_path, 'w', encoding='utf-8') as file:
|
||||
for image_label in img_label_list:
|
||||
index = label_id_list.index(image_label.label_id)
|
||||
file.write(str(index) + ' ' + image_label.mark_center_x + ' '
|
||||
+ image_label.mark_center_y + ' '
|
||||
+ image_label.mark_width + ' '
|
||||
+ image_label.mark_height + '\n')
|
||||
|
||||
|
||||
def split_array(data):
|
||||
"""
|
||||
将数组按照3:1的比例切分成两个数组。
|
||||
:param data: 原始数组
|
||||
:return: 按照3:1比例分割后的两个数组
|
||||
"""
|
||||
total_length = len(data)
|
||||
if total_length < 4:
|
||||
raise ValueError("数组长度至少需要为4才能进行3:1的分割")
|
||||
# 计算分割点
|
||||
split_index = (total_length * 3) // 4
|
||||
# 使用切片分割数组
|
||||
part1 = data[:split_index]
|
||||
part2 = data[split_index:]
|
||||
return part1, part2
|
||||
|
||||
|
Reference in New Issue
Block a user