完善相关问题,并增加推理的部分代码

This commit is contained in:
2025-02-28 16:30:11 +08:00
parent 0301e41e96
commit 4262d3e908
20 changed files with 564 additions and 32 deletions

View File

@ -0,0 +1,266 @@
from app.model.bussiness_model import ProjectImage, ProjectInfo, ProjectImgLeafer, ProjectImgLabel, ProjectTrain
from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoOut
from app.model.schemas.project_image_schemas import ProjectImgLeaferLabel, ProjectImgLeaferOut
from app.model.crud import project_info_crud as pic
from app.model.crud import project_image_crud as pimc
from app.model.crud import project_label_crud as plc
from app.model.crud import project_train_crud as ptc
from app.model.crud import project_img_leafer_label_crud as pillc
from app.util import os_utils as os
from app.util import random_utils as ru
from app.config.config_reader import datasets_url, runs_url, images_url, yolo_url
from sqlalchemy.orm import Session
from typing import List
from fastapi import UploadFile
import yaml
import subprocess
def add_project(info: ProjectInfoIn, session: Session, user_id: int):
"""
新建项目,完善数据,并创建对应的文件夹
:param info: 项目信息
:param session: 数据库session
:param user_id: 用户id
:return:
"""
project_info = ProjectInfo(**info.dict())
project_info.user_id = user_id
project_info.project_no = ru.random_str(6)
project_info.project_status = "0"
project_info.train_version = 0
os.create_folder(datasets_url, project_info.project_no)
os.create_folder(runs_url, project_info.project_no)
project_info = pic.add_project(project_info, session)
return project_info.id
def check_image_name(project_id: int, files: List[UploadFile], session: Session):
for file in files:
if not pimc.check_img_name(project_id, file.filename, session):
return False, file.filename
return True, None
def upload_project_image(project_info: ProjectInfoOut, files: List[UploadFile], session: Session):
"""
上传项目的图片
:param files: 上传的图片
:param project_info: 项目信息
:param session:
:return:
"""
images = []
for file in files:
image = ProjectImage()
image.project_id = project_info.id
image.file_name = file.filename
# 保存原图
path = os.save_images(images_url, project_info.project_no, file=file)
image.image_url = path
# 生成缩略图
thumb_image_url = os.file_path(images_url, 'thumb', project_info.project_no, ru.random_str(10) + ".jpg")
os.create_thumbnail(path, thumb_image_url)
image.thumb_image_url = thumb_image_url
images.append(image)
pimc.add_image_batch(images, session)
def del_img(image_id: int, session: Session):
"""
删除图片,并删除文件
:param image_id:
:param session:
:return:
"""
image = session.query(ProjectImage).filter_by(id=image_id).first()
if image is None:
return 0
os.delete_file_if_exists(image.image_url, image.thumb_image_url)
session.delete(image)
session.commit()
def save_img_label(img_leafer_label: ProjectImgLeaferLabel, session: Session):
"""
保存图片的标签框选信息,每次保存都会针对图片的信息全部删除,然后重新保存
:param img_leafer_label:
:param session:
:return:
"""
img_leafer = ProjectImgLeafer()
img_leafer.image_id = img_leafer_label.image_id
img_leafer.leafer = img_leafer_label.leafer
pillc.save_img_leafer(img_leafer, session)
label_infos = img_leafer_label.label_infos
img_labels = []
for label_info in label_infos:
img_label = ProjectImgLabel(**label_info.dict())
img_label.image_id = img_leafer_label.image_id
img_labels.append(img_label)
pillc.save_img_label_batch(img_leafer_label.image_id, img_labels, session)
def get_img_leafer(image_id: int, session: Session):
"""
根据图片id查询图片的leafer信息
:param image_id:
:param session:
:return:
"""
img_leafer = pillc.get_img_leafer(image_id, session)
if img_leafer is None:
return None
img_leafer_out = ProjectImgLeaferOut.from_orm(img_leafer).dict()
return img_leafer_out
def run_train_yolo(project_info: ProjectInfoOut, session: Session):
"""
yolov5执行训练任务
:param project_info: 项目信息
:param session: 数据库session
:return:
"""
# 先获取项目的所有图片
project_images = pimc.get_images(project_info.id, session)
# 将图片根据,根据3:1的比例将图片分成trainval的两个数组
project_images_train, project_images_val = split_array(project_images)
# 得到训练版本
version_path = 'v' + str(project_info.train_version + 1)
# 创建训练的根目录
train_path = os.create_folder(datasets_url, project_info.project_no, version_path)
# 查询项目所属标签,返回两个 idname一一对应的数组
label_id_list, label_name_list = plc.get_label_for_train(project_info.id, session)
# 创建图片的的两个文件夹
img_path_train = os.create_folder(train_path, 'images', 'train')
img_path_val = os.create_folder(train_path, 'images', 'val')
# 创建标签的两个文件夹
label_path_train = os.create_folder(train_path, 'labels', 'train')
label_path_val = os.create_folder(train_path, 'labels', 'val')
# 在根目录下创建yaml文件
yaml_file = os.file_path(train_path, project_info.project_no + '.yaml')
yaml_data = {
'path': train_path,
'train': 'images/train',
'val': 'images/val',
'test': None,
'names': {i: name for i, name in enumerate(label_name_list)}
}
with open(yaml_file, 'w', encoding='utf-8') as file:
yaml.dump(yaml_data, file, allow_unicode=True, default_flow_style=False)
# 开始循环复制图片和生成label.txt
# 先操作train
operate_img_label(project_images_train, img_path_train, label_path_train, session, label_id_list)
# 再操作val
operate_img_label(project_images_val, img_path_val, label_path_val, session, label_id_list)
# 打包完成开始训练,训练前,更改项目的训练状态
pic.update_project_status(project_info.id, '1', session)
# 开始训练
data = yaml_file
project = os.file_path(runs_url, project_info.project_no, 'train')
name = version_path
return data, project, name
def run_commend(data: str, project: str,
name: str, epochs: int,
project_id: int, session: Session):
yolo_path = os.file_path(yolo_url, 'train.py')
yield f"stdout: 模型训练开始,请稍等。。。\n"
# 启动子进程
with subprocess.Popen(
["python", '-u', yolo_path,
"--data=" + data,
"--project=" + project,
"--name=" + name,
"--epochs=" + str(epochs)],
bufsize=1, # bufsize=0时为不缓存bufsize=1时按行缓存bufsize为其他正整数时为按照近似该正整数的字节数缓存
shell=False,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, # 这里可以显示yolov5训练过程中出现的进度条等信息
text=True, # 缓存内容为文本,避免后续编码显示问题
encoding='utf-8',
) as process:
while process.poll() is None:
line = process.stdout.readline()
process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死
if line != '\n':
yield line
# 等待进程结束并获取返回码
return_code = process.wait()
if return_code != 0:
pic.update_project_status(project_id, '-1', session)
else:
pic.update_project_status(project_id, '2', session)
# 然后保存版本训练信息
train = ProjectTrain()
train.project_id = project_id
train.train_version = name
bast_pt_path = os.file_path(project, name, 'weights', 'bast.pt')
last_pt_path = os.file_path(project, name, 'weights', 'last.pt')
train.best_pt = bast_pt_path
train.last_pt = last_pt_path
ptc.add_train(train, session)
def operate_img_label(img_list: List[ProjectImgLabel],
img_path: str, label_path: str,
session: Session, label_id_list: []):
"""
生成图片和标签内容
:param label_id_list:
:param session:
:param img_list:
:param img_path:
:param label_path:
:return:
"""
for i in range(len(img_list)):
image = img_list[i]
# 先复制图片,并把图片改名,不改后缀
file_name = 'image' + str(i)
os.copy_and_rename_file(image.image_url, img_path, file_name)
# 查询这张图片的label信息然后生成这张照片的txt文件
img_label_list = pillc.get_img_label_list(image.id, session)
label_txt_path = os.file_path(label_path, file_name + '.txt')
with open(label_txt_path, 'w', encoding='utf-8') as file:
for image_label in img_label_list:
index = label_id_list.index(image_label.label_id)
file.write(str(index) + ' ' + image_label.mark_center_x + ' '
+ image_label.mark_center_y + ' '
+ image_label.mark_width + ' '
+ image_label.mark_height + '\n')
def split_array(data):
"""
将数组按照3:1的比例切分成两个数组。
:param data: 原始数组
:return: 按照3:1比例分割后的两个数组
"""
total_length = len(data)
if total_length < 4:
raise ValueError("数组长度至少需要为4才能进行3:1的分割")
# 计算分割点
split_index = (total_length * 3) // 4
# 使用切片分割数组
part1 = data[:split_index]
part2 = data[split_index:]
return part1, part2