From 85d0a8fadcea472e68c4780ffea7c40bd76f9530 Mon Sep 17 00:00:00 2001 From: sunyugang Date: Fri, 21 Feb 2025 11:35:53 +0800 Subject: [PATCH] =?UTF-8?q?=E9=87=8D=E6=9E=84=E5=9F=BA=E6=9C=AC=E5=AE=8C?= =?UTF-8?q?=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/business/project_api.py | 10 ++ app/api/common/test_api.py | 42 +++++ app/application/app.py | 3 +- app/application/token_middleware.py | 2 +- app/config/application_config_dev.ini | 1 + app/config/application_config_prod.ini | 1 + app/config/config_reader.py | 3 +- app/model/bussiness_model.py | 9 + app/model/crud/project_image_crud.py | 5 + .../crud/project_img_leafer_label_crud.py | 11 ++ app/model/crud/project_info_crud.py | 23 ++- app/model/crud/project_label_crud.py | 10 ++ app/model/crud/project_train_crud.py | 28 ++++ app/model/schemas/project_train_schemas.py | 13 ++ app/service/project_service.py | 157 +++++++++++++++++- app/util/os_utils.py | 36 ++++ 16 files changed, 346 insertions(+), 8 deletions(-) create mode 100644 app/api/common/test_api.py create mode 100644 app/model/crud/project_train_crud.py create mode 100644 app/model/schemas/project_train_schemas.py diff --git a/app/api/business/project_api.py b/app/api/business/project_api.py index e2f1af2..1238876 100644 --- a/app/api/business/project_api.py +++ b/app/api/business/project_api.py @@ -13,6 +13,7 @@ from app.common import reponse_code as rc from typing import List from fastapi import APIRouter, Depends, Request, UploadFile, File, Form +from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session """项目管理API""" @@ -150,3 +151,12 @@ def get_img_leafer(image_id: int, session: Session = Depends(get_db)): img_leafer_out = ps.get_img_leafer(image_id, session) return rc.response_success(data=img_leafer_out) + +@project.get("/run_train/{project_id}") +def run_train(project_id: int, session: Session = Depends(get_db)): + project_info = pic.get_project_by_id(project_id, session) + if project_info is None: + return rc.response_error("项目查询错误") + if project_info.project_status == '1': + return rc.response_error("项目当前存在训练进程,请稍后再试") + return StreamingResponse(ps.run_train_yolo(project_info, session), media_type="text/plain") diff --git a/app/api/common/test_api.py b/app/api/common/test_api.py new file mode 100644 index 0000000..ccebfc3 --- /dev/null +++ b/app/api/common/test_api.py @@ -0,0 +1,42 @@ +import asyncio +import subprocess +from fastapi import APIRouter +from fastapi.responses import StreamingResponse + + +test = APIRouter() + + +async def generate_data(): + for i in range(1, 10): # 生成 5 行数据 + await asyncio.sleep(1) # 等待 1 秒 + yield f"data: This is line {i}\n\n" # 返回 SSE 格式的数据 + + +def run_command(command): + """执行命令并实时打印每一行输出""" + # 启动子进程 + with subprocess.Popen( + command, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True, # 确保输出以字符串形式返回而不是字节 + bufsize=1, # 行缓冲 + ) as process: + # 使用iter逐行读取stdout和stderr + for line in process.stdout: + yield f"stdout: {line.strip()} \n" + + for line in process.stderr: + yield f"stderr: {line.strip()} \n" + + # 等待进程结束并获取返回码 + return_code = process.wait() + if return_code != 0: + print(f"Process exited with non-zero code: {return_code}") + + +@test.get("/stream") +async def stream_response(): + return StreamingResponse(run_command(["ping", "-n", "10", "127.0.0.1"]), media_type="text/plain") diff --git a/app/application/app.py b/app/application/app.py index 2922e5c..86f3910 100644 --- a/app/application/app.py +++ b/app/application/app.py @@ -9,6 +9,7 @@ from app.api.sys.login_api import login from app.api.sys.sys_user_api import user from app.api.business.project_api import project from app.api.common.view_img import view +from app.api.common.test_api import test my_app = FastAPI() @@ -35,5 +36,5 @@ my_app.include_router(upload_files, prefix="/upload", tags=["文件上传API"]) my_app.include_router(view, prefix="/view_img", tags=["查看图片"]) my_app.include_router(user, prefix="/user", tags=["用户管理API"]) my_app.include_router(project, prefix="/proj", tags=["项目管理API"]) - +my_app.include_router(test, prefix="/test", tags=["测试用API"]) diff --git a/app/application/token_middleware.py b/app/application/token_middleware.py index 92f2cf9..5d92d71 100644 --- a/app/application/token_middleware.py +++ b/app/application/token_middleware.py @@ -33,7 +33,7 @@ class TokenMiddleware(BaseHTTPMiddleware): return rc.response_code_view(status.HTTP_401_UNAUTHORIZED, "Token错误或失效,请重新验证") -green = ['/login', '/view_img'] +green = ['/login', '/view_img', 'test'] def check_green(s: str): diff --git a/app/config/application_config_dev.ini b/app/config/application_config_dev.ini index 2104705..3857c15 100644 --- a/app/config/application_config_dev.ini +++ b/app/config/application_config_dev.ini @@ -13,6 +13,7 @@ dir = D:\syg\workspace\logs [yolo] datasets_url = D:\syg\yolov5\datasets runs_url = D:\syg\yolov5\runs +yolo_url = D:\syg\workspace\aicheckv2\yolov5 [images] image_url = D:\syg\images \ No newline at end of file diff --git a/app/config/application_config_prod.ini b/app/config/application_config_prod.ini index 3de556d..1e876f5 100644 --- a/app/config/application_config_prod.ini +++ b/app/config/application_config_prod.ini @@ -13,6 +13,7 @@ dir = /home/aicheckv2/logs [yolo] datasets_url = /home/aicheckv2/yolov5/datasets runs_url = /home/aicheckv2/yolov5/runs +yolo_url = /home/aicheckv2/backend/yolov5 [images] image_url = /home/aicheckv2/images \ No newline at end of file diff --git a/app/config/config_reader.py b/app/config/config_reader.py index ea6c15c..5f14c8a 100644 --- a/app/config/config_reader.py +++ b/app/config/config_reader.py @@ -24,5 +24,6 @@ log_dir = config.get('log', 'dir') datasets_url = config.get('yolo', 'datasets_url') runs_url = config.get('yolo', 'runs_url') +yolo_url = config.get('yolo', 'yolo_url') -images_url = config.get('images', 'image_url') \ No newline at end of file +images_url = config.get('images', 'image_url') diff --git a/app/model/bussiness_model.py b/app/model/bussiness_model.py index ded46a4..3ceb027 100644 --- a/app/model/bussiness_model.py +++ b/app/model/bussiness_model.py @@ -59,3 +59,12 @@ class ProjectImgLabel(DbCommon): mark_center_y: Mapped[str] = mapped_column(String(64), nullable=False) mark_width: Mapped[str] = mapped_column(String(64), nullable=False) mark_height: Mapped[str] = mapped_column(String(64), nullable=False) + + +class ProjectTrain(DbCommon): + """项目训练版本信息表""" + __tablename__ = "project_train" + project_id: Mapped[int] = mapped_column(Integer, nullable=False) + train_version: Mapped[str] = mapped_column(String(32), nullable=False) + best_pt: Mapped[str] = mapped_column(String(255), nullable=False) + last_pt: Mapped[str] = mapped_column(String(255), nullable=False) diff --git a/app/model/crud/project_image_crud.py b/app/model/crud/project_image_crud.py index 5b30f0e..8f51ce0 100644 --- a/app/model/crud/project_image_crud.py +++ b/app/model/crud/project_image_crud.py @@ -20,6 +20,11 @@ def get_image_list(project_id: int, session: Session): return image_list +def get_images(project_id: int, session: Session): + query = session.query(piModel).filter(piModel.project_id == project_id).order_by(asc(piModel.id)) + return query.all() + + def add_image(image: ProjectImage, session: Session): session.add(image) session.commit() diff --git a/app/model/crud/project_img_leafer_label_crud.py b/app/model/crud/project_img_leafer_label_crud.py index 66b9efb..bd1ce28 100644 --- a/app/model/crud/project_img_leafer_label_crud.py +++ b/app/model/crud/project_img_leafer_label_crud.py @@ -16,6 +16,17 @@ def get_img_leafer(image_id: int, session: Session): return img_leafer +def get_img_label_list(image_id, session: Session): + """ + 根据图片id获取图片标签信息 + :param image_id: + :param session: + :return: + """ + img_label_list = session.query(ProjectImgLabel).filter_by(image_id=image_id).all() + return img_label_list + + def save_img_leafer(leafer: ProjectImgLeafer, session: Session): leafer_saved = session.query(ProjectImgLeafer).filter_by(image_id=leafer.image_id).first() if leafer_saved is not None: diff --git a/app/model/crud/project_info_crud.py b/app/model/crud/project_info_crud.py index c18a9f3..0f5e72e 100644 --- a/app/model/crud/project_info_crud.py +++ b/app/model/crud/project_info_crud.py @@ -1,5 +1,5 @@ from sqlalchemy.orm import Session -from sqlalchemy import desc +from sqlalchemy import desc, update from app.model.bussiness_model import ProjectInfo from app.model.schemas.project_info_schemas import ProjectInfoOut @@ -41,3 +41,24 @@ def check_project_name(project_name: str, session: Session): else: return False + +def update_project_status(project_id: int, project_status: str, session: Session): + """ + 更新项目训练状态,如果是已完成的话,train_version自动+1 + :param project_id: + :param project_status: 0-未运行,1-运行中,2-已完成,-1-执行失败 + :param session: + :return: + """ + if project_status == '2': + stmt = update(ProjectInfo).where(ProjectInfo.id == project_id).values({ + 'train_status': project_status, + 'train_version': ProjectInfo.train_version + 1 + }) + session.execute(stmt) + else: + stmt = update(ProjectInfo).where(ProjectInfo.id == project_id).values({ + 'train_status': project_status + }) + session.execute(stmt) + session.commit() diff --git a/app/model/crud/project_label_crud.py b/app/model/crud/project_label_crud.py index 08bcad8..45e4970 100644 --- a/app/model/crud/project_label_crud.py +++ b/app/model/crud/project_label_crud.py @@ -17,6 +17,16 @@ def get_label_list(project_id: int, session: Session): return label_list +def get_label_for_train(project_id: int, session: Session): + id_list = [] + name_list = [] + label_list = session.query(plModel).filter(plModel.project_id == project_id).all() + for label in label_list: + id_list.append(label.id) + name_list.append(label.label_name) + return id_list, name_list + + def add_label(label: plModel, session: Session): """ 新增标签 diff --git a/app/model/crud/project_train_crud.py b/app/model/crud/project_train_crud.py new file mode 100644 index 0000000..41ee35b --- /dev/null +++ b/app/model/crud/project_train_crud.py @@ -0,0 +1,28 @@ +from sqlalchemy.orm import Session +from sqlalchemy import asc + +from app.model.bussiness_model import ProjectTrain +from app.model.schemas.project_train_schemas import ProjectTrainOut + + +def add_train(train: ProjectTrain, session: Session): + """ + 新增训练结果 + :param train: + :param session: + :return: + """ + session.add(train) + session.commit() + + +def get_train_list(project_id: int, session: Session): + """ + 根据项目id查询训练列表 + :param project_id: + :param session: + :return: + """ + query = session.query(ProjectTrain).filter_by(project_id=project_id).order_by(asc(ProjectTrain.id)) + train_list = [ProjectTrainOut.from_orm(train) for train in query.all()] + return train_list diff --git a/app/model/schemas/project_train_schemas.py b/app/model/schemas/project_train_schemas.py new file mode 100644 index 0000000..1ca73b6 --- /dev/null +++ b/app/model/schemas/project_train_schemas.py @@ -0,0 +1,13 @@ +from pydantic import BaseModel, Field +from typing import Optional + + +class ProjectTrainOut(BaseModel): + """项目训练版本信息表""" + id: Optional[int] = Field(None, description="训练id") + train_version: Optional[str] = Field(None, description="训练版本号") + best_pt: Optional[str] = Field(None, description="最好") + last_pt: Optional[str] = Field(None, description="最后") + + class Config: + orm_mode = True diff --git a/app/service/project_service.py b/app/service/project_service.py index 44f6478..41b1f4a 100644 --- a/app/service/project_service.py +++ b/app/service/project_service.py @@ -1,16 +1,20 @@ -from app.model.bussiness_model import ProjectImage, ProjectInfo, ProjectImgLeafer, ProjectImgLabel +from app.model.bussiness_model import ProjectImage, ProjectInfo, ProjectImgLeafer, ProjectImgLabel, ProjectTrain from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoOut from app.model.schemas.project_image_schemas import ProjectImgLeaferLabel, ProjectImgLeaferOut from app.model.crud import project_info_crud as pic from app.model.crud import project_image_crud as pimc +from app.model.crud import project_label_crud as plc +from app.model.crud import project_train_crud as ptc from app.model.crud import project_img_leafer_label_crud as pillc from app.util import os_utils as os from app.util import random_utils as ru -from app.config.config_reader import datasets_url, runs_url, images_url +from app.config.config_reader import datasets_url, runs_url, images_url, yolo_url from sqlalchemy.orm import Session from typing import List from fastapi import UploadFile +import yaml +import subprocess def add_project(info: ProjectInfoIn, session: Session, user_id: int): @@ -48,7 +52,7 @@ def upload_project_image(project_info: ProjectInfoOut, files: List[UploadFile], path = os.save_images(images_url, project_info.project_no, file=file) image.image_url = path # 生成缩略图 - thumb_image_url = images_url + "\\thumb\\" + project_info.project_no + "\\" + ru.random_str(10) + ".jpg" + thumb_image_url = os.file_path(images_url, 'thumb', project_info.project_no, ru.random_str(10) + ".jpg") os.create_thumbnail(path, thumb_image_url) image.thumb_image_url = thumb_image_url images.append(image) @@ -62,7 +66,9 @@ def save_img_label(img_leafer_label: ProjectImgLeaferLabel, session: Session): :param session: :return: """ - img_leafer = ProjectImgLeafer(**img_leafer_label) + img_leafer = ProjectImgLeafer() + img_leafer.image_id = img_leafer_label.image_id + img_leafer.leafer = img_leafer_label.leafer pillc.save_img_leafer(img_leafer, session) label_infos = img_leafer_label.label_infos img_labels = [] @@ -83,3 +89,146 @@ def get_img_leafer(image_id: int, session: Session): img_leafer = pillc.get_img_leafer(image_id, session) img_leafer_out = ProjectImgLeaferOut.from_orm(img_leafer).dict() return img_leafer_out + + +def run_train_yolo(project_info: ProjectInfoOut, session: Session): + """ + yolov5执行训练任务 + :param project_info: 项目信息 + :param session: 数据库session + :return: + """ + # 先获取项目的所有图片 + project_images = pimc.get_images(project_info.id, session) + + # 将图片根据,根据3:1的比例将图片分成train:val的两个数组 + project_images_train, project_images_val = split_array(project_images) + + # 得到训练版本 + version_path = 'v' + str(project_info.train_version + 1) + + # 创建训练的根目录 + train_path = os.create_folder(datasets_url, project_info.project_no, version_path) + + # 查询项目所属标签,返回两个 id,name一一对应的数组 + label_id_list, label_name_list = plc.get_label_for_train(project_info.id, session) + + # 在根目录创建classes.txt文件 + classes_txt = os.file_path(train_path, 'classes.txt') + with open(classes_txt, 'w', encoding='utf-8') as file: + for label_name in label_name_list: + file.write(label_name + '\n') + + # 创建图片的的两个文件夹 + img_path_train = os.create_folder(train_path, 'images', 'train') + img_path_val = os.create_folder(train_path, 'images', 'val') + + # 创建标签的两个文件夹 + label_path_train = os.create_folder(train_path, 'labels', 'train') + label_path_val = os.create_folder(train_path, 'labels', 'val') + + # 在根目录下创建yaml文件 + yaml_file = os.file_path(train_path, project_info.project_no + '.yaml') + yaml_data = { + 'path': train_path, + 'train': 'images/train', + 'val': 'images/val', + 'test': None, + 'names': {i: name for i, name in enumerate(label_name_list)} + } + with open(yaml_file, 'w', encoding='utf-8') as file: + yaml.dump(yaml_data, file, allow_unicode=True, default_flow_style=False) + + # 开始循环复制图片和生成label.txt + # 先操作train + operate_img_label(project_images_train, img_path_train, label_path_train, session, label_id_list) + # 再操作val + operate_img_label(project_images_val, img_path_val, label_path_val, session, label_id_list) + + # 打包完成开始训练,训练前,更改项目的训练状态 + pic.update_project_status(project_info.id, '1') + # 开始训练 + data = yaml_file + project = os.file_path(runs_url, project_info.project_no, 'train') + name = version_path + epochs = 10 + yolo_path = os.file_path(yolo_url, 'train.py') + + # 启动子进程 + with subprocess.Popen( + ["python", yolo_path, "--data=" + data, "--project=" + project, "--name=" + name, "--epochs=" + str(epochs)], + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True, # 确保输出以字符串形式返回而不是字节 + bufsize=1, # 行缓冲 + ) as process: + # 使用iter逐行读取stdout和stderr + for line in process.stdout: + yield f"stdout: {line.strip()} \n" + + for line in process.stderr: + yield f"stderr: {line.strip()} \n" + + # 等待进程结束并获取返回码 + return_code = process.wait() + if return_code != 0: + pic.update_project_status(project_info.id, '-1', session) + else: + pic.update_project_status(project_info.id, '2', session) + # 然后保存版本训练信息 + train = ProjectTrain() + train.project_id = project_info.id + train.train_version = version_path + bast_pt_path = os.file_path(project, name, 'weight', 'bast.pt') + last_pt_path = os.file_path(project, name, 'weight', 'last.pt') + train.best_pt = bast_pt_path + train.last_pt = last_pt_path + ptc.add_train(train, session) + + +def operate_img_label(img_list: List[ProjectImgLabel], + img_path: str, label_path: str, + session: Session, label_id_list: []): + """ + 生成图片和标签内容 + :param label_id_list: + :param session: + :param img_list: + :param img_path: + :param label_path: + :return: + """ + for i in range(len(img_list)): + image = img_list[i] + # 先复制图片,并把图片改名,不改后缀 + file_name = 'image' + str(i) + os.copy_and_rename_file(image.image_url, img_path, file_name) + # 查询这张图片的label信息然后生成这张照片的txt文件 + img_label_list = pillc.get_img_label_list(image.id, session) + label_txt_path = os.file_path(label_path, file_name + '.txt') + with open(label_txt_path, 'w', encoding='utf-8') as file: + for image_label in img_label_list: + index = label_id_list.index(image_label.label_id) + file.write(str(index) + ' ' + image_label.mark_center_x + ' ' + + image_label.mark_center_y + ' ' + + image_label.mark_width + ' ' + + image_label.mark_height + '\n') + + +def split_array(data): + """ + 将数组按照3:1的比例切分成两个数组。 + :param data: 原始数组 + :return: 按照3:1比例分割后的两个数组 + """ + total_length = len(data) + if total_length < 4: + raise ValueError("数组长度至少需要为4才能进行3:1的分割") + # 计算分割点 + split_index = (total_length * 3) // 4 + # 使用切片分割数组 + part1 = data[:split_index] + part2 = data[split_index:] + return part1, part2 + diff --git a/app/util/os_utils.py b/app/util/os_utils.py index 3d5956c..0788c33 100644 --- a/app/util/os_utils.py +++ b/app/util/os_utils.py @@ -1,6 +1,18 @@ import os +import shutil from fastapi import UploadFile from PIL import Image +"""文件处理相关的util""" + + +def file_path(*path): + """ + 拼接返回文件路径 + :param path: + :return: + """ + return_path = os.path.join(*path) + return return_path def create_folder(*path): @@ -10,6 +22,7 @@ def create_folder(*path): os.makedirs(folder_path, exist_ok=True) except Exception as e: print(f"创建文件夹时错误: {e}") + return folder_path def save_images(*path, file: UploadFile): @@ -43,3 +56,26 @@ def create_thumbnail(input_image_path, out_image_path, size=(116, 70)): os.makedirs(os.path.dirname(out_image_path), exist_ok=True) # 保存生成的缩略图 image.save(out_image_path, 'JPEG') + + +def copy_and_rename_file(src_file_path, dst_dir, new_name): + """ + 复制文件到指定目录并重命名,保持文件的后缀不变。 + :param src_file_path: 源文件路径 + :param dst_dir: 目标目录 + :param new_name: 新文件名(不含扩展名) + """ + # 获取文件的完整文件名(包括扩展名) + base_name = os.path.basename(src_file_path) + # 分离文件名和扩展名 + file_name, file_extension = os.path.splitext(base_name) + + # 构建新的文件名和目标路径 + new_file_name = f"{new_name}{file_extension}" + dst_file_path = os.path.join(dst_dir, new_file_name) + + # 确保目标目录存在 + os.makedirs(dst_dir, exist_ok=True) + + # 复制文件到目标位置并重命名 + shutil.copy(src_file_path, dst_file_path)