重构基本完成

This commit is contained in:
2025-02-21 11:35:53 +08:00
parent cdd05e95ba
commit 85d0a8fadc
16 changed files with 346 additions and 8 deletions

View File

@ -13,6 +13,7 @@ from app.common import reponse_code as rc
from typing import List from typing import List
from fastapi import APIRouter, Depends, Request, UploadFile, File, Form from fastapi import APIRouter, Depends, Request, UploadFile, File, Form
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
"""项目管理API""" """项目管理API"""
@ -150,3 +151,12 @@ def get_img_leafer(image_id: int, session: Session = Depends(get_db)):
img_leafer_out = ps.get_img_leafer(image_id, session) img_leafer_out = ps.get_img_leafer(image_id, session)
return rc.response_success(data=img_leafer_out) return rc.response_success(data=img_leafer_out)
@project.get("/run_train/{project_id}")
def run_train(project_id: int, session: Session = Depends(get_db)):
project_info = pic.get_project_by_id(project_id, session)
if project_info is None:
return rc.response_error("项目查询错误")
if project_info.project_status == '1':
return rc.response_error("项目当前存在训练进程,请稍后再试")
return StreamingResponse(ps.run_train_yolo(project_info, session), media_type="text/plain")

View File

@ -0,0 +1,42 @@
import asyncio
import subprocess
from fastapi import APIRouter
from fastapi.responses import StreamingResponse
test = APIRouter()
async def generate_data():
for i in range(1, 10): # 生成 5 行数据
await asyncio.sleep(1) # 等待 1 秒
yield f"data: This is line {i}\n\n" # 返回 SSE 格式的数据
def run_command(command):
"""执行命令并实时打印每一行输出"""
# 启动子进程
with subprocess.Popen(
command,
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True, # 确保输出以字符串形式返回而不是字节
bufsize=1, # 行缓冲
) as process:
# 使用iter逐行读取stdout和stderr
for line in process.stdout:
yield f"stdout: {line.strip()} \n"
for line in process.stderr:
yield f"stderr: {line.strip()} \n"
# 等待进程结束并获取返回码
return_code = process.wait()
if return_code != 0:
print(f"Process exited with non-zero code: {return_code}")
@test.get("/stream")
async def stream_response():
return StreamingResponse(run_command(["ping", "-n", "10", "127.0.0.1"]), media_type="text/plain")

View File

@ -9,6 +9,7 @@ from app.api.sys.login_api import login
from app.api.sys.sys_user_api import user from app.api.sys.sys_user_api import user
from app.api.business.project_api import project from app.api.business.project_api import project
from app.api.common.view_img import view from app.api.common.view_img import view
from app.api.common.test_api import test
my_app = FastAPI() my_app = FastAPI()
@ -35,5 +36,5 @@ my_app.include_router(upload_files, prefix="/upload", tags=["文件上传API"])
my_app.include_router(view, prefix="/view_img", tags=["查看图片"]) my_app.include_router(view, prefix="/view_img", tags=["查看图片"])
my_app.include_router(user, prefix="/user", tags=["用户管理API"]) my_app.include_router(user, prefix="/user", tags=["用户管理API"])
my_app.include_router(project, prefix="/proj", tags=["项目管理API"]) my_app.include_router(project, prefix="/proj", tags=["项目管理API"])
my_app.include_router(test, prefix="/test", tags=["测试用API"])

View File

@ -33,7 +33,7 @@ class TokenMiddleware(BaseHTTPMiddleware):
return rc.response_code_view(status.HTTP_401_UNAUTHORIZED, "Token错误或失效请重新验证") return rc.response_code_view(status.HTTP_401_UNAUTHORIZED, "Token错误或失效请重新验证")
green = ['/login', '/view_img'] green = ['/login', '/view_img', 'test']
def check_green(s: str): def check_green(s: str):

View File

@ -13,6 +13,7 @@ dir = D:\syg\workspace\logs
[yolo] [yolo]
datasets_url = D:\syg\yolov5\datasets datasets_url = D:\syg\yolov5\datasets
runs_url = D:\syg\yolov5\runs runs_url = D:\syg\yolov5\runs
yolo_url = D:\syg\workspace\aicheckv2\yolov5
[images] [images]
image_url = D:\syg\images image_url = D:\syg\images

View File

@ -13,6 +13,7 @@ dir = /home/aicheckv2/logs
[yolo] [yolo]
datasets_url = /home/aicheckv2/yolov5/datasets datasets_url = /home/aicheckv2/yolov5/datasets
runs_url = /home/aicheckv2/yolov5/runs runs_url = /home/aicheckv2/yolov5/runs
yolo_url = /home/aicheckv2/backend/yolov5
[images] [images]
image_url = /home/aicheckv2/images image_url = /home/aicheckv2/images

View File

@ -24,5 +24,6 @@ log_dir = config.get('log', 'dir')
datasets_url = config.get('yolo', 'datasets_url') datasets_url = config.get('yolo', 'datasets_url')
runs_url = config.get('yolo', 'runs_url') runs_url = config.get('yolo', 'runs_url')
yolo_url = config.get('yolo', 'yolo_url')
images_url = config.get('images', 'image_url') images_url = config.get('images', 'image_url')

View File

@ -59,3 +59,12 @@ class ProjectImgLabel(DbCommon):
mark_center_y: Mapped[str] = mapped_column(String(64), nullable=False) mark_center_y: Mapped[str] = mapped_column(String(64), nullable=False)
mark_width: Mapped[str] = mapped_column(String(64), nullable=False) mark_width: Mapped[str] = mapped_column(String(64), nullable=False)
mark_height: Mapped[str] = mapped_column(String(64), nullable=False) mark_height: Mapped[str] = mapped_column(String(64), nullable=False)
class ProjectTrain(DbCommon):
"""项目训练版本信息表"""
__tablename__ = "project_train"
project_id: Mapped[int] = mapped_column(Integer, nullable=False)
train_version: Mapped[str] = mapped_column(String(32), nullable=False)
best_pt: Mapped[str] = mapped_column(String(255), nullable=False)
last_pt: Mapped[str] = mapped_column(String(255), nullable=False)

View File

@ -20,6 +20,11 @@ def get_image_list(project_id: int, session: Session):
return image_list return image_list
def get_images(project_id: int, session: Session):
query = session.query(piModel).filter(piModel.project_id == project_id).order_by(asc(piModel.id))
return query.all()
def add_image(image: ProjectImage, session: Session): def add_image(image: ProjectImage, session: Session):
session.add(image) session.add(image)
session.commit() session.commit()

View File

@ -16,6 +16,17 @@ def get_img_leafer(image_id: int, session: Session):
return img_leafer return img_leafer
def get_img_label_list(image_id, session: Session):
"""
根据图片id获取图片标签信息
:param image_id:
:param session:
:return:
"""
img_label_list = session.query(ProjectImgLabel).filter_by(image_id=image_id).all()
return img_label_list
def save_img_leafer(leafer: ProjectImgLeafer, session: Session): def save_img_leafer(leafer: ProjectImgLeafer, session: Session):
leafer_saved = session.query(ProjectImgLeafer).filter_by(image_id=leafer.image_id).first() leafer_saved = session.query(ProjectImgLeafer).filter_by(image_id=leafer.image_id).first()
if leafer_saved is not None: if leafer_saved is not None:

View File

@ -1,5 +1,5 @@
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from sqlalchemy import desc from sqlalchemy import desc, update
from app.model.bussiness_model import ProjectInfo from app.model.bussiness_model import ProjectInfo
from app.model.schemas.project_info_schemas import ProjectInfoOut from app.model.schemas.project_info_schemas import ProjectInfoOut
@ -41,3 +41,24 @@ def check_project_name(project_name: str, session: Session):
else: else:
return False return False
def update_project_status(project_id: int, project_status: str, session: Session):
"""
更新项目训练状态,如果是已完成的话train_version自动+1
:param project_id:
:param project_status: 0-未运行1-运行中2-已完成,-1-执行失败
:param session:
:return:
"""
if project_status == '2':
stmt = update(ProjectInfo).where(ProjectInfo.id == project_id).values({
'train_status': project_status,
'train_version': ProjectInfo.train_version + 1
})
session.execute(stmt)
else:
stmt = update(ProjectInfo).where(ProjectInfo.id == project_id).values({
'train_status': project_status
})
session.execute(stmt)
session.commit()

View File

@ -17,6 +17,16 @@ def get_label_list(project_id: int, session: Session):
return label_list return label_list
def get_label_for_train(project_id: int, session: Session):
id_list = []
name_list = []
label_list = session.query(plModel).filter(plModel.project_id == project_id).all()
for label in label_list:
id_list.append(label.id)
name_list.append(label.label_name)
return id_list, name_list
def add_label(label: plModel, session: Session): def add_label(label: plModel, session: Session):
""" """
新增标签 新增标签

View File

@ -0,0 +1,28 @@
from sqlalchemy.orm import Session
from sqlalchemy import asc
from app.model.bussiness_model import ProjectTrain
from app.model.schemas.project_train_schemas import ProjectTrainOut
def add_train(train: ProjectTrain, session: Session):
"""
新增训练结果
:param train:
:param session:
:return:
"""
session.add(train)
session.commit()
def get_train_list(project_id: int, session: Session):
"""
根据项目id查询训练列表
:param project_id:
:param session:
:return:
"""
query = session.query(ProjectTrain).filter_by(project_id=project_id).order_by(asc(ProjectTrain.id))
train_list = [ProjectTrainOut.from_orm(train) for train in query.all()]
return train_list

View File

@ -0,0 +1,13 @@
from pydantic import BaseModel, Field
from typing import Optional
class ProjectTrainOut(BaseModel):
"""项目训练版本信息表"""
id: Optional[int] = Field(None, description="训练id")
train_version: Optional[str] = Field(None, description="训练版本号")
best_pt: Optional[str] = Field(None, description="最好")
last_pt: Optional[str] = Field(None, description="最后")
class Config:
orm_mode = True

View File

@ -1,16 +1,20 @@
from app.model.bussiness_model import ProjectImage, ProjectInfo, ProjectImgLeafer, ProjectImgLabel from app.model.bussiness_model import ProjectImage, ProjectInfo, ProjectImgLeafer, ProjectImgLabel, ProjectTrain
from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoOut from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoOut
from app.model.schemas.project_image_schemas import ProjectImgLeaferLabel, ProjectImgLeaferOut from app.model.schemas.project_image_schemas import ProjectImgLeaferLabel, ProjectImgLeaferOut
from app.model.crud import project_info_crud as pic from app.model.crud import project_info_crud as pic
from app.model.crud import project_image_crud as pimc from app.model.crud import project_image_crud as pimc
from app.model.crud import project_label_crud as plc
from app.model.crud import project_train_crud as ptc
from app.model.crud import project_img_leafer_label_crud as pillc from app.model.crud import project_img_leafer_label_crud as pillc
from app.util import os_utils as os from app.util import os_utils as os
from app.util import random_utils as ru from app.util import random_utils as ru
from app.config.config_reader import datasets_url, runs_url, images_url from app.config.config_reader import datasets_url, runs_url, images_url, yolo_url
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from typing import List from typing import List
from fastapi import UploadFile from fastapi import UploadFile
import yaml
import subprocess
def add_project(info: ProjectInfoIn, session: Session, user_id: int): def add_project(info: ProjectInfoIn, session: Session, user_id: int):
@ -48,7 +52,7 @@ def upload_project_image(project_info: ProjectInfoOut, files: List[UploadFile],
path = os.save_images(images_url, project_info.project_no, file=file) path = os.save_images(images_url, project_info.project_no, file=file)
image.image_url = path image.image_url = path
# 生成缩略图 # 生成缩略图
thumb_image_url = images_url + "\\thumb\\" + project_info.project_no + "\\" + ru.random_str(10) + ".jpg" thumb_image_url = os.file_path(images_url, 'thumb', project_info.project_no, ru.random_str(10) + ".jpg")
os.create_thumbnail(path, thumb_image_url) os.create_thumbnail(path, thumb_image_url)
image.thumb_image_url = thumb_image_url image.thumb_image_url = thumb_image_url
images.append(image) images.append(image)
@ -62,7 +66,9 @@ def save_img_label(img_leafer_label: ProjectImgLeaferLabel, session: Session):
:param session: :param session:
:return: :return:
""" """
img_leafer = ProjectImgLeafer(**img_leafer_label) img_leafer = ProjectImgLeafer()
img_leafer.image_id = img_leafer_label.image_id
img_leafer.leafer = img_leafer_label.leafer
pillc.save_img_leafer(img_leafer, session) pillc.save_img_leafer(img_leafer, session)
label_infos = img_leafer_label.label_infos label_infos = img_leafer_label.label_infos
img_labels = [] img_labels = []
@ -83,3 +89,146 @@ def get_img_leafer(image_id: int, session: Session):
img_leafer = pillc.get_img_leafer(image_id, session) img_leafer = pillc.get_img_leafer(image_id, session)
img_leafer_out = ProjectImgLeaferOut.from_orm(img_leafer).dict() img_leafer_out = ProjectImgLeaferOut.from_orm(img_leafer).dict()
return img_leafer_out return img_leafer_out
def run_train_yolo(project_info: ProjectInfoOut, session: Session):
"""
yolov5执行训练任务
:param project_info: 项目信息
:param session: 数据库session
:return:
"""
# 先获取项目的所有图片
project_images = pimc.get_images(project_info.id, session)
# 将图片根据,根据3:1的比例将图片分成trainval的两个数组
project_images_train, project_images_val = split_array(project_images)
# 得到训练版本
version_path = 'v' + str(project_info.train_version + 1)
# 创建训练的根目录
train_path = os.create_folder(datasets_url, project_info.project_no, version_path)
# 查询项目所属标签,返回两个 idname一一对应的数组
label_id_list, label_name_list = plc.get_label_for_train(project_info.id, session)
# 在根目录创建classes.txt文件
classes_txt = os.file_path(train_path, 'classes.txt')
with open(classes_txt, 'w', encoding='utf-8') as file:
for label_name in label_name_list:
file.write(label_name + '\n')
# 创建图片的的两个文件夹
img_path_train = os.create_folder(train_path, 'images', 'train')
img_path_val = os.create_folder(train_path, 'images', 'val')
# 创建标签的两个文件夹
label_path_train = os.create_folder(train_path, 'labels', 'train')
label_path_val = os.create_folder(train_path, 'labels', 'val')
# 在根目录下创建yaml文件
yaml_file = os.file_path(train_path, project_info.project_no + '.yaml')
yaml_data = {
'path': train_path,
'train': 'images/train',
'val': 'images/val',
'test': None,
'names': {i: name for i, name in enumerate(label_name_list)}
}
with open(yaml_file, 'w', encoding='utf-8') as file:
yaml.dump(yaml_data, file, allow_unicode=True, default_flow_style=False)
# 开始循环复制图片和生成label.txt
# 先操作train
operate_img_label(project_images_train, img_path_train, label_path_train, session, label_id_list)
# 再操作val
operate_img_label(project_images_val, img_path_val, label_path_val, session, label_id_list)
# 打包完成开始训练,训练前,更改项目的训练状态
pic.update_project_status(project_info.id, '1')
# 开始训练
data = yaml_file
project = os.file_path(runs_url, project_info.project_no, 'train')
name = version_path
epochs = 10
yolo_path = os.file_path(yolo_url, 'train.py')
# 启动子进程
with subprocess.Popen(
["python", yolo_path, "--data=" + data, "--project=" + project, "--name=" + name, "--epochs=" + str(epochs)],
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True, # 确保输出以字符串形式返回而不是字节
bufsize=1, # 行缓冲
) as process:
# 使用iter逐行读取stdout和stderr
for line in process.stdout:
yield f"stdout: {line.strip()} \n"
for line in process.stderr:
yield f"stderr: {line.strip()} \n"
# 等待进程结束并获取返回码
return_code = process.wait()
if return_code != 0:
pic.update_project_status(project_info.id, '-1', session)
else:
pic.update_project_status(project_info.id, '2', session)
# 然后保存版本训练信息
train = ProjectTrain()
train.project_id = project_info.id
train.train_version = version_path
bast_pt_path = os.file_path(project, name, 'weight', 'bast.pt')
last_pt_path = os.file_path(project, name, 'weight', 'last.pt')
train.best_pt = bast_pt_path
train.last_pt = last_pt_path
ptc.add_train(train, session)
def operate_img_label(img_list: List[ProjectImgLabel],
img_path: str, label_path: str,
session: Session, label_id_list: []):
"""
生成图片和标签内容
:param label_id_list:
:param session:
:param img_list:
:param img_path:
:param label_path:
:return:
"""
for i in range(len(img_list)):
image = img_list[i]
# 先复制图片,并把图片改名,不改后缀
file_name = 'image' + str(i)
os.copy_and_rename_file(image.image_url, img_path, file_name)
# 查询这张图片的label信息然后生成这张照片的txt文件
img_label_list = pillc.get_img_label_list(image.id, session)
label_txt_path = os.file_path(label_path, file_name + '.txt')
with open(label_txt_path, 'w', encoding='utf-8') as file:
for image_label in img_label_list:
index = label_id_list.index(image_label.label_id)
file.write(str(index) + ' ' + image_label.mark_center_x + ' '
+ image_label.mark_center_y + ' '
+ image_label.mark_width + ' '
+ image_label.mark_height + '\n')
def split_array(data):
"""
将数组按照3:1的比例切分成两个数组。
:param data: 原始数组
:return: 按照3:1比例分割后的两个数组
"""
total_length = len(data)
if total_length < 4:
raise ValueError("数组长度至少需要为4才能进行3:1的分割")
# 计算分割点
split_index = (total_length * 3) // 4
# 使用切片分割数组
part1 = data[:split_index]
part2 = data[split_index:]
return part1, part2

View File

@ -1,6 +1,18 @@
import os import os
import shutil
from fastapi import UploadFile from fastapi import UploadFile
from PIL import Image from PIL import Image
"""文件处理相关的util"""
def file_path(*path):
"""
拼接返回文件路径
:param path:
:return:
"""
return_path = os.path.join(*path)
return return_path
def create_folder(*path): def create_folder(*path):
@ -10,6 +22,7 @@ def create_folder(*path):
os.makedirs(folder_path, exist_ok=True) os.makedirs(folder_path, exist_ok=True)
except Exception as e: except Exception as e:
print(f"创建文件夹时错误: {e}") print(f"创建文件夹时错误: {e}")
return folder_path
def save_images(*path, file: UploadFile): def save_images(*path, file: UploadFile):
@ -43,3 +56,26 @@ def create_thumbnail(input_image_path, out_image_path, size=(116, 70)):
os.makedirs(os.path.dirname(out_image_path), exist_ok=True) os.makedirs(os.path.dirname(out_image_path), exist_ok=True)
# 保存生成的缩略图 # 保存生成的缩略图
image.save(out_image_path, 'JPEG') image.save(out_image_path, 'JPEG')
def copy_and_rename_file(src_file_path, dst_dir, new_name):
"""
复制文件到指定目录并重命名,保持文件的后缀不变。
:param src_file_path: 源文件路径
:param dst_dir: 目标目录
:param new_name: 新文件名(不含扩展名)
"""
# 获取文件的完整文件名(包括扩展名)
base_name = os.path.basename(src_file_path)
# 分离文件名和扩展名
file_name, file_extension = os.path.splitext(base_name)
# 构建新的文件名和目标路径
new_file_name = f"{new_name}{file_extension}"
dst_file_path = os.path.join(dst_dir, new_file_name)
# 确保目标目录存在
os.makedirs(dst_dir, exist_ok=True)
# 复制文件到目标位置并重命名
shutil.copy(src_file_path, dst_file_path)