Files
aicheckv2/app/service/project_train_service.py

267 lines
9.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from app.model.bussiness_model import ProjectImage, ProjectInfo, ProjectImgLeafer, ProjectImgLabel, ProjectTrain
from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoOut
from app.model.schemas.project_image_schemas import ProjectImgLeaferLabel, ProjectImgLeaferOut
from app.model.crud import project_info_crud as pic
from app.model.crud import project_image_crud as pimc
from app.model.crud import project_label_crud as plc
from app.model.crud import project_train_crud as ptc
from app.model.crud import project_img_leafer_label_crud as pillc
from app.util import os_utils as os
from app.util import random_utils as ru
from app.config.config_reader import datasets_url, runs_url, images_url, yolo_url
from sqlalchemy.orm import Session
from typing import List
from fastapi import UploadFile
import yaml
import subprocess
def add_project(info: ProjectInfoIn, session: Session, user_id: int):
"""
新建项目,完善数据,并创建对应的文件夹
:param info: 项目信息
:param session: 数据库session
:param user_id: 用户id
:return:
"""
project_info = ProjectInfo(**info.dict())
project_info.user_id = user_id
project_info.project_no = ru.random_str(6)
project_info.project_status = "0"
project_info.train_version = 0
os.create_folder(datasets_url, project_info.project_no)
os.create_folder(runs_url, project_info.project_no)
project_info = pic.add_project(project_info, session)
return project_info.id
def check_image_name(project_id: int, files: List[UploadFile], session: Session):
for file in files:
if not pimc.check_img_name(project_id, file.filename, session):
return False, file.filename
return True, None
def upload_project_image(project_info: ProjectInfoOut, files: List[UploadFile], session: Session):
"""
上传项目的图片
:param files: 上传的图片
:param project_info: 项目信息
:param session:
:return:
"""
images = []
for file in files:
image = ProjectImage()
image.project_id = project_info.id
image.file_name = file.filename
# 保存原图
path = os.save_images(images_url, project_info.project_no, file=file)
image.image_url = path
# 生成缩略图
thumb_image_url = os.file_path(images_url, 'thumb', project_info.project_no, ru.random_str(10) + ".jpg")
os.create_thumbnail(path, thumb_image_url)
image.thumb_image_url = thumb_image_url
images.append(image)
pimc.add_image_batch(images, session)
def del_img(image_id: int, session: Session):
"""
删除图片,并删除文件
:param image_id:
:param session:
:return:
"""
image = session.query(ProjectImage).filter_by(id=image_id).first()
if image is None:
return 0
os.delete_file_if_exists(image.image_url, image.thumb_image_url)
session.delete(image)
session.commit()
def save_img_label(img_leafer_label: ProjectImgLeaferLabel, session: Session):
"""
保存图片的标签框选信息,每次保存都会针对图片的信息全部删除,然后重新保存
:param img_leafer_label:
:param session:
:return:
"""
img_leafer = ProjectImgLeafer()
img_leafer.image_id = img_leafer_label.image_id
img_leafer.leafer = img_leafer_label.leafer
pillc.save_img_leafer(img_leafer, session)
label_infos = img_leafer_label.label_infos
img_labels = []
for label_info in label_infos:
img_label = ProjectImgLabel(**label_info.dict())
img_label.image_id = img_leafer_label.image_id
img_labels.append(img_label)
pillc.save_img_label_batch(img_leafer_label.image_id, img_labels, session)
def get_img_leafer(image_id: int, session: Session):
"""
根据图片id查询图片的leafer信息
:param image_id:
:param session:
:return:
"""
img_leafer = pillc.get_img_leafer(image_id, session)
if img_leafer is None:
return None
img_leafer_out = ProjectImgLeaferOut.from_orm(img_leafer).dict()
return img_leafer_out
def run_train_yolo(project_info: ProjectInfoOut, session: Session):
"""
yolov5执行训练任务
:param project_info: 项目信息
:param session: 数据库session
:return:
"""
# 先获取项目的所有图片
project_images = pimc.get_images(project_info.id, session)
# 将图片根据,根据3:1的比例将图片分成trainval的两个数组
project_images_train, project_images_val = split_array(project_images)
# 得到训练版本
version_path = 'v' + str(project_info.train_version + 1)
# 创建训练的根目录
train_path = os.create_folder(datasets_url, project_info.project_no, version_path)
# 查询项目所属标签,返回两个 idname一一对应的数组
label_id_list, label_name_list = plc.get_label_for_train(project_info.id, session)
# 创建图片的的两个文件夹
img_path_train = os.create_folder(train_path, 'images', 'train')
img_path_val = os.create_folder(train_path, 'images', 'val')
# 创建标签的两个文件夹
label_path_train = os.create_folder(train_path, 'labels', 'train')
label_path_val = os.create_folder(train_path, 'labels', 'val')
# 在根目录下创建yaml文件
yaml_file = os.file_path(train_path, project_info.project_no + '.yaml')
yaml_data = {
'path': train_path,
'train': 'images/train',
'val': 'images/val',
'test': None,
'names': {i: name for i, name in enumerate(label_name_list)}
}
with open(yaml_file, 'w', encoding='utf-8') as file:
yaml.dump(yaml_data, file, allow_unicode=True, default_flow_style=False)
# 开始循环复制图片和生成label.txt
# 先操作train
operate_img_label(project_images_train, img_path_train, label_path_train, session, label_id_list)
# 再操作val
operate_img_label(project_images_val, img_path_val, label_path_val, session, label_id_list)
# 打包完成开始训练,训练前,更改项目的训练状态
pic.update_project_status(project_info.id, '1', session)
# 开始训练
data = yaml_file
project = os.file_path(runs_url, project_info.project_no, 'train')
name = version_path
return data, project, name
def run_commend(data: str, project: str,
name: str, epochs: int,
project_id: int, session: Session):
yolo_path = os.file_path(yolo_url, 'train.py')
yield f"stdout: 模型训练开始,请稍等。。。\n"
# 启动子进程
with subprocess.Popen(
["python", '-u', yolo_path,
"--data=" + data,
"--project=" + project,
"--name=" + name,
"--epochs=" + str(epochs)],
bufsize=1, # bufsize=0时为不缓存bufsize=1时按行缓存bufsize为其他正整数时为按照近似该正整数的字节数缓存
shell=False,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, # 这里可以显示yolov5训练过程中出现的进度条等信息
text=True, # 缓存内容为文本,避免后续编码显示问题
encoding='utf-8',
) as process:
while process.poll() is None:
line = process.stdout.readline()
process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死
if line != '\n':
yield line
# 等待进程结束并获取返回码
return_code = process.wait()
if return_code != 0:
pic.update_project_status(project_id, '-1', session)
else:
pic.update_project_status(project_id, '2', session)
# 然后保存版本训练信息
train = ProjectTrain()
train.project_id = project_id
train.train_version = name
bast_pt_path = os.file_path(project, name, 'weights', 'bast.pt')
last_pt_path = os.file_path(project, name, 'weights', 'last.pt')
train.best_pt = bast_pt_path
train.last_pt = last_pt_path
ptc.add_train(train, session)
def operate_img_label(img_list: List[ProjectImgLabel],
img_path: str, label_path: str,
session: Session, label_id_list: []):
"""
生成图片和标签内容
:param label_id_list:
:param session:
:param img_list:
:param img_path:
:param label_path:
:return:
"""
for i in range(len(img_list)):
image = img_list[i]
# 先复制图片,并把图片改名,不改后缀
file_name = 'image' + str(i)
os.copy_and_rename_file(image.image_url, img_path, file_name)
# 查询这张图片的label信息然后生成这张照片的txt文件
img_label_list = pillc.get_img_label_list(image.id, session)
label_txt_path = os.file_path(label_path, file_name + '.txt')
with open(label_txt_path, 'w', encoding='utf-8') as file:
for image_label in img_label_list:
index = label_id_list.index(image_label.label_id)
file.write(str(index) + ' ' + image_label.mark_center_x + ' '
+ image_label.mark_center_y + ' '
+ image_label.mark_width + ' '
+ image_label.mark_height + '\n')
def split_array(data):
"""
将数组按照3:1的比例切分成两个数组。
:param data: 原始数组
:return: 按照3:1比例分割后的两个数组
"""
total_length = len(data)
if total_length < 4:
raise ValueError("数组长度至少需要为4才能进行3:1的分割")
# 计算分割点
split_index = (total_length * 3) // 4
# 使用切片分割数组
part1 = data[:split_index]
part2 = data[split_index:]
return part1, part2