from app.model.bussiness_model import ProjectImage, ProjectInfo, ProjectImgLeafer, ProjectImgLabel, ProjectTrain from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoOut from app.model.schemas.project_image_schemas import ProjectImgLeaferLabel, ProjectImgLeaferOut from app.model.crud import project_info_crud as pic from app.model.crud import project_image_crud as pimc from app.model.crud import project_label_crud as plc from app.model.crud import project_train_crud as ptc from app.model.crud import project_img_leafer_label_crud as pillc from app.util import os_utils as os from app.util import random_utils as ru from app.config.config_reader import datasets_url, runs_url, images_url, yolo_url from sqlalchemy.orm import Session from typing import List from fastapi import UploadFile import yaml import select import subprocess def add_project(info: ProjectInfoIn, session: Session, user_id: int): """ 新建项目,完善数据,并创建对应的文件夹 :param info: 项目信息 :param session: 数据库session :param user_id: 用户id :return: """ project_info = ProjectInfo(**info.dict()) project_info.user_id = user_id project_info.project_no = ru.random_str(6) project_info.project_status = "0" project_info.train_version = 0 os.create_folder(datasets_url, project_info.project_no) os.create_folder(runs_url, project_info.project_no) pic.add_project(project_info, session) return project_info def upload_project_image(project_info: ProjectInfoOut, files: List[UploadFile], session: Session): """ 上传项目的图片 :param files: 上传的图片 :param project_info: 项目信息 :param session: :return: """ images = [] for file in files: image = ProjectImage() image.project_id = project_info.id # 保存原图 path = os.save_images(images_url, project_info.project_no, file=file) image.image_url = path # 生成缩略图 thumb_image_url = os.file_path(images_url, 'thumb', project_info.project_no, ru.random_str(10) + ".jpg") os.create_thumbnail(path, thumb_image_url) image.thumb_image_url = thumb_image_url images.append(image) pimc.add_image_batch(images, session) def save_img_label(img_leafer_label: ProjectImgLeaferLabel, session: Session): """ 保存图片的标签框选信息,每次保存都会针对图片的信息全部删除,然后重新保存 :param img_leafer_label: :param session: :return: """ img_leafer = ProjectImgLeafer() img_leafer.image_id = img_leafer_label.image_id img_leafer.leafer = img_leafer_label.leafer pillc.save_img_leafer(img_leafer, session) label_infos = img_leafer_label.label_infos img_labels = [] for label_info in label_infos: img_label = ProjectImgLabel(**label_info.dict()) img_label.image_id = img_leafer_label.image_id img_labels.append(img_label) pillc.save_img_label_batch(img_leafer_label.image_id, img_labels, session) def get_img_leafer(image_id: int, session: Session): """ 根据图片id查询图片的leafer信息 :param image_id: :param session: :return: """ img_leafer = pillc.get_img_leafer(image_id, session) img_leafer_out = ProjectImgLeaferOut.from_orm(img_leafer).dict() return img_leafer_out def run_train_yolo(project_info: ProjectInfoOut, session: Session): """ yolov5执行训练任务 :param project_info: 项目信息 :param session: 数据库session :return: """ # 先获取项目的所有图片 project_images = pimc.get_images(project_info.id, session) # 将图片根据,根据3:1的比例将图片分成train:val的两个数组 project_images_train, project_images_val = split_array(project_images) # 得到训练版本 version_path = 'v' + str(project_info.train_version + 1) # 创建训练的根目录 train_path = os.create_folder(datasets_url, project_info.project_no, version_path) # 查询项目所属标签,返回两个 id,name一一对应的数组 label_id_list, label_name_list = plc.get_label_for_train(project_info.id, session) # 创建图片的的两个文件夹 img_path_train = os.create_folder(train_path, 'images', 'train') img_path_val = os.create_folder(train_path, 'images', 'val') # 创建标签的两个文件夹 label_path_train = os.create_folder(train_path, 'labels', 'train') label_path_val = os.create_folder(train_path, 'labels', 'val') # 在根目录下创建yaml文件 yaml_file = os.file_path(train_path, project_info.project_no + '.yaml') yaml_data = { 'path': train_path, 'train': 'images/train', 'val': 'images/val', 'test': None, 'names': {i: name for i, name in enumerate(label_name_list)} } with open(yaml_file, 'w', encoding='utf-8') as file: yaml.dump(yaml_data, file, allow_unicode=True, default_flow_style=False) # 开始循环复制图片和生成label.txt # 先操作train operate_img_label(project_images_train, img_path_train, label_path_train, session, label_id_list) # 再操作val operate_img_label(project_images_val, img_path_val, label_path_val, session, label_id_list) # 打包完成开始训练,训练前,更改项目的训练状态 pic.update_project_status(project_info.id, '1', session) # 开始训练 data = yaml_file project = os.file_path(runs_url, project_info.project_no, 'train') name = version_path return data, project, name def run_commend(data: str, project: str, name: str, epochs: int, project_id: int, session: Session): yolo_path = os.file_path(yolo_url, 'train.py') yield f"stdout: 模型训练开始,请稍等。。。" # 启动子进程 with subprocess.Popen( ["python", '-u', yolo_path, "--data=" + data, "--project=" + project, "--name=" + name, "--epochs=" + str(epochs)], bufsize=1, # bufsize=0时,为不缓存;bufsize=1时,按行缓存;bufsize为其他正整数时,为按照近似该正整数的字节数缓存 shell=False, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, # 这里可以显示yolov5训练过程中出现的进度条等信息 text=True, # 缓存内容为文本,避免后续编码显示问题 encoding='utf-8', ) as process: while process.poll() is None: line = process.stdout.readline() process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死 if line != '\n': yield line # 等待进程结束并获取返回码 return_code = process.wait() if return_code != 0: pic.update_project_status(project_id, '-1', session) else: pic.update_project_status(project_id, '2', session) # 然后保存版本训练信息 train = ProjectTrain() train.project_id = project_id train.train_version = name bast_pt_path = os.file_path(project, name, 'weight', 'bast.pt') last_pt_path = os.file_path(project, name, 'weight', 'last.pt') train.best_pt = bast_pt_path train.last_pt = last_pt_path ptc.add_train(train, session) def operate_img_label(img_list: List[ProjectImgLabel], img_path: str, label_path: str, session: Session, label_id_list: []): """ 生成图片和标签内容 :param label_id_list: :param session: :param img_list: :param img_path: :param label_path: :return: """ for i in range(len(img_list)): image = img_list[i] # 先复制图片,并把图片改名,不改后缀 file_name = 'image' + str(i) os.copy_and_rename_file(image.image_url, img_path, file_name) # 查询这张图片的label信息然后生成这张照片的txt文件 img_label_list = pillc.get_img_label_list(image.id, session) label_txt_path = os.file_path(label_path, file_name + '.txt') with open(label_txt_path, 'w', encoding='utf-8') as file: for image_label in img_label_list: index = label_id_list.index(image_label.label_id) file.write(str(index) + ' ' + image_label.mark_center_x + ' ' + image_label.mark_center_y + ' ' + image_label.mark_width + ' ' + image_label.mark_height + '\n') def split_array(data): """ 将数组按照3:1的比例切分成两个数组。 :param data: 原始数组 :return: 按照3:1比例分割后的两个数组 """ total_length = len(data) if total_length < 4: raise ValueError("数组长度至少需要为4才能进行3:1的分割") # 计算分割点 split_index = (total_length * 3) // 4 # 使用切片分割数组 part1 = data[:split_index] part2 = data[split_index:] return part1, part2