重构基本完成

This commit is contained in:
2025-02-21 11:35:53 +08:00
parent cdd05e95ba
commit 85d0a8fadc
16 changed files with 346 additions and 8 deletions

View File

@ -13,6 +13,7 @@ from app.common import reponse_code as rc
from typing import List
from fastapi import APIRouter, Depends, Request, UploadFile, File, Form
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
"""项目管理API"""
@ -150,3 +151,12 @@ def get_img_leafer(image_id: int, session: Session = Depends(get_db)):
img_leafer_out = ps.get_img_leafer(image_id, session)
return rc.response_success(data=img_leafer_out)
@project.get("/run_train/{project_id}")
def run_train(project_id: int, session: Session = Depends(get_db)):
project_info = pic.get_project_by_id(project_id, session)
if project_info is None:
return rc.response_error("项目查询错误")
if project_info.project_status == '1':
return rc.response_error("项目当前存在训练进程,请稍后再试")
return StreamingResponse(ps.run_train_yolo(project_info, session), media_type="text/plain")

View File

@ -0,0 +1,42 @@
import asyncio
import subprocess
from fastapi import APIRouter
from fastapi.responses import StreamingResponse
test = APIRouter()
async def generate_data():
for i in range(1, 10): # 生成 5 行数据
await asyncio.sleep(1) # 等待 1 秒
yield f"data: This is line {i}\n\n" # 返回 SSE 格式的数据
def run_command(command):
"""执行命令并实时打印每一行输出"""
# 启动子进程
with subprocess.Popen(
command,
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True, # 确保输出以字符串形式返回而不是字节
bufsize=1, # 行缓冲
) as process:
# 使用iter逐行读取stdout和stderr
for line in process.stdout:
yield f"stdout: {line.strip()} \n"
for line in process.stderr:
yield f"stderr: {line.strip()} \n"
# 等待进程结束并获取返回码
return_code = process.wait()
if return_code != 0:
print(f"Process exited with non-zero code: {return_code}")
@test.get("/stream")
async def stream_response():
return StreamingResponse(run_command(["ping", "-n", "10", "127.0.0.1"]), media_type="text/plain")

View File

@ -9,6 +9,7 @@ from app.api.sys.login_api import login
from app.api.sys.sys_user_api import user
from app.api.business.project_api import project
from app.api.common.view_img import view
from app.api.common.test_api import test
my_app = FastAPI()
@ -35,5 +36,5 @@ my_app.include_router(upload_files, prefix="/upload", tags=["文件上传API"])
my_app.include_router(view, prefix="/view_img", tags=["查看图片"])
my_app.include_router(user, prefix="/user", tags=["用户管理API"])
my_app.include_router(project, prefix="/proj", tags=["项目管理API"])
my_app.include_router(test, prefix="/test", tags=["测试用API"])

View File

@ -33,7 +33,7 @@ class TokenMiddleware(BaseHTTPMiddleware):
return rc.response_code_view(status.HTTP_401_UNAUTHORIZED, "Token错误或失效请重新验证")
green = ['/login', '/view_img']
green = ['/login', '/view_img', 'test']
def check_green(s: str):

View File

@ -13,6 +13,7 @@ dir = D:\syg\workspace\logs
[yolo]
datasets_url = D:\syg\yolov5\datasets
runs_url = D:\syg\yolov5\runs
yolo_url = D:\syg\workspace\aicheckv2\yolov5
[images]
image_url = D:\syg\images

View File

@ -13,6 +13,7 @@ dir = /home/aicheckv2/logs
[yolo]
datasets_url = /home/aicheckv2/yolov5/datasets
runs_url = /home/aicheckv2/yolov5/runs
yolo_url = /home/aicheckv2/backend/yolov5
[images]
image_url = /home/aicheckv2/images

View File

@ -24,5 +24,6 @@ log_dir = config.get('log', 'dir')
datasets_url = config.get('yolo', 'datasets_url')
runs_url = config.get('yolo', 'runs_url')
yolo_url = config.get('yolo', 'yolo_url')
images_url = config.get('images', 'image_url')
images_url = config.get('images', 'image_url')

View File

@ -59,3 +59,12 @@ class ProjectImgLabel(DbCommon):
mark_center_y: Mapped[str] = mapped_column(String(64), nullable=False)
mark_width: Mapped[str] = mapped_column(String(64), nullable=False)
mark_height: Mapped[str] = mapped_column(String(64), nullable=False)
class ProjectTrain(DbCommon):
"""项目训练版本信息表"""
__tablename__ = "project_train"
project_id: Mapped[int] = mapped_column(Integer, nullable=False)
train_version: Mapped[str] = mapped_column(String(32), nullable=False)
best_pt: Mapped[str] = mapped_column(String(255), nullable=False)
last_pt: Mapped[str] = mapped_column(String(255), nullable=False)

View File

@ -20,6 +20,11 @@ def get_image_list(project_id: int, session: Session):
return image_list
def get_images(project_id: int, session: Session):
query = session.query(piModel).filter(piModel.project_id == project_id).order_by(asc(piModel.id))
return query.all()
def add_image(image: ProjectImage, session: Session):
session.add(image)
session.commit()

View File

@ -16,6 +16,17 @@ def get_img_leafer(image_id: int, session: Session):
return img_leafer
def get_img_label_list(image_id, session: Session):
"""
根据图片id获取图片标签信息
:param image_id:
:param session:
:return:
"""
img_label_list = session.query(ProjectImgLabel).filter_by(image_id=image_id).all()
return img_label_list
def save_img_leafer(leafer: ProjectImgLeafer, session: Session):
leafer_saved = session.query(ProjectImgLeafer).filter_by(image_id=leafer.image_id).first()
if leafer_saved is not None:

View File

@ -1,5 +1,5 @@
from sqlalchemy.orm import Session
from sqlalchemy import desc
from sqlalchemy import desc, update
from app.model.bussiness_model import ProjectInfo
from app.model.schemas.project_info_schemas import ProjectInfoOut
@ -41,3 +41,24 @@ def check_project_name(project_name: str, session: Session):
else:
return False
def update_project_status(project_id: int, project_status: str, session: Session):
"""
更新项目训练状态,如果是已完成的话train_version自动+1
:param project_id:
:param project_status: 0-未运行1-运行中2-已完成,-1-执行失败
:param session:
:return:
"""
if project_status == '2':
stmt = update(ProjectInfo).where(ProjectInfo.id == project_id).values({
'train_status': project_status,
'train_version': ProjectInfo.train_version + 1
})
session.execute(stmt)
else:
stmt = update(ProjectInfo).where(ProjectInfo.id == project_id).values({
'train_status': project_status
})
session.execute(stmt)
session.commit()

View File

@ -17,6 +17,16 @@ def get_label_list(project_id: int, session: Session):
return label_list
def get_label_for_train(project_id: int, session: Session):
id_list = []
name_list = []
label_list = session.query(plModel).filter(plModel.project_id == project_id).all()
for label in label_list:
id_list.append(label.id)
name_list.append(label.label_name)
return id_list, name_list
def add_label(label: plModel, session: Session):
"""
新增标签

View File

@ -0,0 +1,28 @@
from sqlalchemy.orm import Session
from sqlalchemy import asc
from app.model.bussiness_model import ProjectTrain
from app.model.schemas.project_train_schemas import ProjectTrainOut
def add_train(train: ProjectTrain, session: Session):
"""
新增训练结果
:param train:
:param session:
:return:
"""
session.add(train)
session.commit()
def get_train_list(project_id: int, session: Session):
"""
根据项目id查询训练列表
:param project_id:
:param session:
:return:
"""
query = session.query(ProjectTrain).filter_by(project_id=project_id).order_by(asc(ProjectTrain.id))
train_list = [ProjectTrainOut.from_orm(train) for train in query.all()]
return train_list

View File

@ -0,0 +1,13 @@
from pydantic import BaseModel, Field
from typing import Optional
class ProjectTrainOut(BaseModel):
"""项目训练版本信息表"""
id: Optional[int] = Field(None, description="训练id")
train_version: Optional[str] = Field(None, description="训练版本号")
best_pt: Optional[str] = Field(None, description="最好")
last_pt: Optional[str] = Field(None, description="最后")
class Config:
orm_mode = True

View File

@ -1,16 +1,20 @@
from app.model.bussiness_model import ProjectImage, ProjectInfo, ProjectImgLeafer, ProjectImgLabel
from app.model.bussiness_model import ProjectImage, ProjectInfo, ProjectImgLeafer, ProjectImgLabel, ProjectTrain
from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoOut
from app.model.schemas.project_image_schemas import ProjectImgLeaferLabel, ProjectImgLeaferOut
from app.model.crud import project_info_crud as pic
from app.model.crud import project_image_crud as pimc
from app.model.crud import project_label_crud as plc
from app.model.crud import project_train_crud as ptc
from app.model.crud import project_img_leafer_label_crud as pillc
from app.util import os_utils as os
from app.util import random_utils as ru
from app.config.config_reader import datasets_url, runs_url, images_url
from app.config.config_reader import datasets_url, runs_url, images_url, yolo_url
from sqlalchemy.orm import Session
from typing import List
from fastapi import UploadFile
import yaml
import subprocess
def add_project(info: ProjectInfoIn, session: Session, user_id: int):
@ -48,7 +52,7 @@ def upload_project_image(project_info: ProjectInfoOut, files: List[UploadFile],
path = os.save_images(images_url, project_info.project_no, file=file)
image.image_url = path
# 生成缩略图
thumb_image_url = images_url + "\\thumb\\" + project_info.project_no + "\\" + ru.random_str(10) + ".jpg"
thumb_image_url = os.file_path(images_url, 'thumb', project_info.project_no, ru.random_str(10) + ".jpg")
os.create_thumbnail(path, thumb_image_url)
image.thumb_image_url = thumb_image_url
images.append(image)
@ -62,7 +66,9 @@ def save_img_label(img_leafer_label: ProjectImgLeaferLabel, session: Session):
:param session:
:return:
"""
img_leafer = ProjectImgLeafer(**img_leafer_label)
img_leafer = ProjectImgLeafer()
img_leafer.image_id = img_leafer_label.image_id
img_leafer.leafer = img_leafer_label.leafer
pillc.save_img_leafer(img_leafer, session)
label_infos = img_leafer_label.label_infos
img_labels = []
@ -83,3 +89,146 @@ def get_img_leafer(image_id: int, session: Session):
img_leafer = pillc.get_img_leafer(image_id, session)
img_leafer_out = ProjectImgLeaferOut.from_orm(img_leafer).dict()
return img_leafer_out
def run_train_yolo(project_info: ProjectInfoOut, session: Session):
"""
yolov5执行训练任务
:param project_info: 项目信息
:param session: 数据库session
:return:
"""
# 先获取项目的所有图片
project_images = pimc.get_images(project_info.id, session)
# 将图片根据,根据3:1的比例将图片分成trainval的两个数组
project_images_train, project_images_val = split_array(project_images)
# 得到训练版本
version_path = 'v' + str(project_info.train_version + 1)
# 创建训练的根目录
train_path = os.create_folder(datasets_url, project_info.project_no, version_path)
# 查询项目所属标签,返回两个 idname一一对应的数组
label_id_list, label_name_list = plc.get_label_for_train(project_info.id, session)
# 在根目录创建classes.txt文件
classes_txt = os.file_path(train_path, 'classes.txt')
with open(classes_txt, 'w', encoding='utf-8') as file:
for label_name in label_name_list:
file.write(label_name + '\n')
# 创建图片的的两个文件夹
img_path_train = os.create_folder(train_path, 'images', 'train')
img_path_val = os.create_folder(train_path, 'images', 'val')
# 创建标签的两个文件夹
label_path_train = os.create_folder(train_path, 'labels', 'train')
label_path_val = os.create_folder(train_path, 'labels', 'val')
# 在根目录下创建yaml文件
yaml_file = os.file_path(train_path, project_info.project_no + '.yaml')
yaml_data = {
'path': train_path,
'train': 'images/train',
'val': 'images/val',
'test': None,
'names': {i: name for i, name in enumerate(label_name_list)}
}
with open(yaml_file, 'w', encoding='utf-8') as file:
yaml.dump(yaml_data, file, allow_unicode=True, default_flow_style=False)
# 开始循环复制图片和生成label.txt
# 先操作train
operate_img_label(project_images_train, img_path_train, label_path_train, session, label_id_list)
# 再操作val
operate_img_label(project_images_val, img_path_val, label_path_val, session, label_id_list)
# 打包完成开始训练,训练前,更改项目的训练状态
pic.update_project_status(project_info.id, '1')
# 开始训练
data = yaml_file
project = os.file_path(runs_url, project_info.project_no, 'train')
name = version_path
epochs = 10
yolo_path = os.file_path(yolo_url, 'train.py')
# 启动子进程
with subprocess.Popen(
["python", yolo_path, "--data=" + data, "--project=" + project, "--name=" + name, "--epochs=" + str(epochs)],
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True, # 确保输出以字符串形式返回而不是字节
bufsize=1, # 行缓冲
) as process:
# 使用iter逐行读取stdout和stderr
for line in process.stdout:
yield f"stdout: {line.strip()} \n"
for line in process.stderr:
yield f"stderr: {line.strip()} \n"
# 等待进程结束并获取返回码
return_code = process.wait()
if return_code != 0:
pic.update_project_status(project_info.id, '-1', session)
else:
pic.update_project_status(project_info.id, '2', session)
# 然后保存版本训练信息
train = ProjectTrain()
train.project_id = project_info.id
train.train_version = version_path
bast_pt_path = os.file_path(project, name, 'weight', 'bast.pt')
last_pt_path = os.file_path(project, name, 'weight', 'last.pt')
train.best_pt = bast_pt_path
train.last_pt = last_pt_path
ptc.add_train(train, session)
def operate_img_label(img_list: List[ProjectImgLabel],
img_path: str, label_path: str,
session: Session, label_id_list: []):
"""
生成图片和标签内容
:param label_id_list:
:param session:
:param img_list:
:param img_path:
:param label_path:
:return:
"""
for i in range(len(img_list)):
image = img_list[i]
# 先复制图片,并把图片改名,不改后缀
file_name = 'image' + str(i)
os.copy_and_rename_file(image.image_url, img_path, file_name)
# 查询这张图片的label信息然后生成这张照片的txt文件
img_label_list = pillc.get_img_label_list(image.id, session)
label_txt_path = os.file_path(label_path, file_name + '.txt')
with open(label_txt_path, 'w', encoding='utf-8') as file:
for image_label in img_label_list:
index = label_id_list.index(image_label.label_id)
file.write(str(index) + ' ' + image_label.mark_center_x + ' '
+ image_label.mark_center_y + ' '
+ image_label.mark_width + ' '
+ image_label.mark_height + '\n')
def split_array(data):
"""
将数组按照3:1的比例切分成两个数组。
:param data: 原始数组
:return: 按照3:1比例分割后的两个数组
"""
total_length = len(data)
if total_length < 4:
raise ValueError("数组长度至少需要为4才能进行3:1的分割")
# 计算分割点
split_index = (total_length * 3) // 4
# 使用切片分割数组
part1 = data[:split_index]
part2 = data[split_index:]
return part1, part2

View File

@ -1,6 +1,18 @@
import os
import shutil
from fastapi import UploadFile
from PIL import Image
"""文件处理相关的util"""
def file_path(*path):
"""
拼接返回文件路径
:param path:
:return:
"""
return_path = os.path.join(*path)
return return_path
def create_folder(*path):
@ -10,6 +22,7 @@ def create_folder(*path):
os.makedirs(folder_path, exist_ok=True)
except Exception as e:
print(f"创建文件夹时错误: {e}")
return folder_path
def save_images(*path, file: UploadFile):
@ -43,3 +56,26 @@ def create_thumbnail(input_image_path, out_image_path, size=(116, 70)):
os.makedirs(os.path.dirname(out_image_path), exist_ok=True)
# 保存生成的缩略图
image.save(out_image_path, 'JPEG')
def copy_and_rename_file(src_file_path, dst_dir, new_name):
"""
复制文件到指定目录并重命名,保持文件的后缀不变。
:param src_file_path: 源文件路径
:param dst_dir: 目标目录
:param new_name: 新文件名(不含扩展名)
"""
# 获取文件的完整文件名(包括扩展名)
base_name = os.path.basename(src_file_path)
# 分离文件名和扩展名
file_name, file_extension = os.path.splitext(base_name)
# 构建新的文件名和目标路径
new_file_name = f"{new_name}{file_extension}"
dst_file_path = os.path.join(dst_dir, new_file_name)
# 确保目标目录存在
os.makedirs(dst_dir, exist_ok=True)
# 复制文件到目标位置并重命名
shutil.copy(src_file_path, dst_file_path)