重构基本完成
This commit is contained in:
@ -13,6 +13,7 @@ from app.common import reponse_code as rc
|
||||
|
||||
from typing import List
|
||||
from fastapi import APIRouter, Depends, Request, UploadFile, File, Form
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
"""项目管理API"""
|
||||
@ -150,3 +151,12 @@ def get_img_leafer(image_id: int, session: Session = Depends(get_db)):
|
||||
img_leafer_out = ps.get_img_leafer(image_id, session)
|
||||
return rc.response_success(data=img_leafer_out)
|
||||
|
||||
|
||||
@project.get("/run_train/{project_id}")
|
||||
def run_train(project_id: int, session: Session = Depends(get_db)):
|
||||
project_info = pic.get_project_by_id(project_id, session)
|
||||
if project_info is None:
|
||||
return rc.response_error("项目查询错误")
|
||||
if project_info.project_status == '1':
|
||||
return rc.response_error("项目当前存在训练进程,请稍后再试")
|
||||
return StreamingResponse(ps.run_train_yolo(project_info, session), media_type="text/plain")
|
||||
|
42
app/api/common/test_api.py
Normal file
42
app/api/common/test_api.py
Normal file
@ -0,0 +1,42 @@
|
||||
import asyncio
|
||||
import subprocess
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
|
||||
test = APIRouter()
|
||||
|
||||
|
||||
async def generate_data():
|
||||
for i in range(1, 10): # 生成 5 行数据
|
||||
await asyncio.sleep(1) # 等待 1 秒
|
||||
yield f"data: This is line {i}\n\n" # 返回 SSE 格式的数据
|
||||
|
||||
|
||||
def run_command(command):
|
||||
"""执行命令并实时打印每一行输出"""
|
||||
# 启动子进程
|
||||
with subprocess.Popen(
|
||||
command,
|
||||
shell=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
universal_newlines=True, # 确保输出以字符串形式返回而不是字节
|
||||
bufsize=1, # 行缓冲
|
||||
) as process:
|
||||
# 使用iter逐行读取stdout和stderr
|
||||
for line in process.stdout:
|
||||
yield f"stdout: {line.strip()} \n"
|
||||
|
||||
for line in process.stderr:
|
||||
yield f"stderr: {line.strip()} \n"
|
||||
|
||||
# 等待进程结束并获取返回码
|
||||
return_code = process.wait()
|
||||
if return_code != 0:
|
||||
print(f"Process exited with non-zero code: {return_code}")
|
||||
|
||||
|
||||
@test.get("/stream")
|
||||
async def stream_response():
|
||||
return StreamingResponse(run_command(["ping", "-n", "10", "127.0.0.1"]), media_type="text/plain")
|
@ -9,6 +9,7 @@ from app.api.sys.login_api import login
|
||||
from app.api.sys.sys_user_api import user
|
||||
from app.api.business.project_api import project
|
||||
from app.api.common.view_img import view
|
||||
from app.api.common.test_api import test
|
||||
|
||||
my_app = FastAPI()
|
||||
|
||||
@ -35,5 +36,5 @@ my_app.include_router(upload_files, prefix="/upload", tags=["文件上传API"])
|
||||
my_app.include_router(view, prefix="/view_img", tags=["查看图片"])
|
||||
my_app.include_router(user, prefix="/user", tags=["用户管理API"])
|
||||
my_app.include_router(project, prefix="/proj", tags=["项目管理API"])
|
||||
|
||||
my_app.include_router(test, prefix="/test", tags=["测试用API"])
|
||||
|
||||
|
@ -33,7 +33,7 @@ class TokenMiddleware(BaseHTTPMiddleware):
|
||||
return rc.response_code_view(status.HTTP_401_UNAUTHORIZED, "Token错误或失效,请重新验证")
|
||||
|
||||
|
||||
green = ['/login', '/view_img']
|
||||
green = ['/login', '/view_img', 'test']
|
||||
|
||||
|
||||
def check_green(s: str):
|
||||
|
@ -13,6 +13,7 @@ dir = D:\syg\workspace\logs
|
||||
[yolo]
|
||||
datasets_url = D:\syg\yolov5\datasets
|
||||
runs_url = D:\syg\yolov5\runs
|
||||
yolo_url = D:\syg\workspace\aicheckv2\yolov5
|
||||
|
||||
[images]
|
||||
image_url = D:\syg\images
|
@ -13,6 +13,7 @@ dir = /home/aicheckv2/logs
|
||||
[yolo]
|
||||
datasets_url = /home/aicheckv2/yolov5/datasets
|
||||
runs_url = /home/aicheckv2/yolov5/runs
|
||||
yolo_url = /home/aicheckv2/backend/yolov5
|
||||
|
||||
[images]
|
||||
image_url = /home/aicheckv2/images
|
@ -24,5 +24,6 @@ log_dir = config.get('log', 'dir')
|
||||
|
||||
datasets_url = config.get('yolo', 'datasets_url')
|
||||
runs_url = config.get('yolo', 'runs_url')
|
||||
yolo_url = config.get('yolo', 'yolo_url')
|
||||
|
||||
images_url = config.get('images', 'image_url')
|
||||
images_url = config.get('images', 'image_url')
|
||||
|
@ -59,3 +59,12 @@ class ProjectImgLabel(DbCommon):
|
||||
mark_center_y: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
mark_width: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
mark_height: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
|
||||
|
||||
class ProjectTrain(DbCommon):
|
||||
"""项目训练版本信息表"""
|
||||
__tablename__ = "project_train"
|
||||
project_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
train_version: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||
best_pt: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
last_pt: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
|
@ -20,6 +20,11 @@ def get_image_list(project_id: int, session: Session):
|
||||
return image_list
|
||||
|
||||
|
||||
def get_images(project_id: int, session: Session):
|
||||
query = session.query(piModel).filter(piModel.project_id == project_id).order_by(asc(piModel.id))
|
||||
return query.all()
|
||||
|
||||
|
||||
def add_image(image: ProjectImage, session: Session):
|
||||
session.add(image)
|
||||
session.commit()
|
||||
|
@ -16,6 +16,17 @@ def get_img_leafer(image_id: int, session: Session):
|
||||
return img_leafer
|
||||
|
||||
|
||||
def get_img_label_list(image_id, session: Session):
|
||||
"""
|
||||
根据图片id获取图片标签信息
|
||||
:param image_id:
|
||||
:param session:
|
||||
:return:
|
||||
"""
|
||||
img_label_list = session.query(ProjectImgLabel).filter_by(image_id=image_id).all()
|
||||
return img_label_list
|
||||
|
||||
|
||||
def save_img_leafer(leafer: ProjectImgLeafer, session: Session):
|
||||
leafer_saved = session.query(ProjectImgLeafer).filter_by(image_id=leafer.image_id).first()
|
||||
if leafer_saved is not None:
|
||||
|
@ -1,5 +1,5 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import desc, update
|
||||
|
||||
from app.model.bussiness_model import ProjectInfo
|
||||
from app.model.schemas.project_info_schemas import ProjectInfoOut
|
||||
@ -41,3 +41,24 @@ def check_project_name(project_name: str, session: Session):
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def update_project_status(project_id: int, project_status: str, session: Session):
|
||||
"""
|
||||
更新项目训练状态,如果是已完成的话,train_version自动+1
|
||||
:param project_id:
|
||||
:param project_status: 0-未运行,1-运行中,2-已完成,-1-执行失败
|
||||
:param session:
|
||||
:return:
|
||||
"""
|
||||
if project_status == '2':
|
||||
stmt = update(ProjectInfo).where(ProjectInfo.id == project_id).values({
|
||||
'train_status': project_status,
|
||||
'train_version': ProjectInfo.train_version + 1
|
||||
})
|
||||
session.execute(stmt)
|
||||
else:
|
||||
stmt = update(ProjectInfo).where(ProjectInfo.id == project_id).values({
|
||||
'train_status': project_status
|
||||
})
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
@ -17,6 +17,16 @@ def get_label_list(project_id: int, session: Session):
|
||||
return label_list
|
||||
|
||||
|
||||
def get_label_for_train(project_id: int, session: Session):
|
||||
id_list = []
|
||||
name_list = []
|
||||
label_list = session.query(plModel).filter(plModel.project_id == project_id).all()
|
||||
for label in label_list:
|
||||
id_list.append(label.id)
|
||||
name_list.append(label.label_name)
|
||||
return id_list, name_list
|
||||
|
||||
|
||||
def add_label(label: plModel, session: Session):
|
||||
"""
|
||||
新增标签
|
||||
|
28
app/model/crud/project_train_crud.py
Normal file
28
app/model/crud/project_train_crud.py
Normal file
@ -0,0 +1,28 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import asc
|
||||
|
||||
from app.model.bussiness_model import ProjectTrain
|
||||
from app.model.schemas.project_train_schemas import ProjectTrainOut
|
||||
|
||||
|
||||
def add_train(train: ProjectTrain, session: Session):
|
||||
"""
|
||||
新增训练结果
|
||||
:param train:
|
||||
:param session:
|
||||
:return:
|
||||
"""
|
||||
session.add(train)
|
||||
session.commit()
|
||||
|
||||
|
||||
def get_train_list(project_id: int, session: Session):
|
||||
"""
|
||||
根据项目id查询训练列表
|
||||
:param project_id:
|
||||
:param session:
|
||||
:return:
|
||||
"""
|
||||
query = session.query(ProjectTrain).filter_by(project_id=project_id).order_by(asc(ProjectTrain.id))
|
||||
train_list = [ProjectTrainOut.from_orm(train) for train in query.all()]
|
||||
return train_list
|
13
app/model/schemas/project_train_schemas.py
Normal file
13
app/model/schemas/project_train_schemas.py
Normal file
@ -0,0 +1,13 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class ProjectTrainOut(BaseModel):
|
||||
"""项目训练版本信息表"""
|
||||
id: Optional[int] = Field(None, description="训练id")
|
||||
train_version: Optional[str] = Field(None, description="训练版本号")
|
||||
best_pt: Optional[str] = Field(None, description="最好")
|
||||
last_pt: Optional[str] = Field(None, description="最后")
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
@ -1,16 +1,20 @@
|
||||
from app.model.bussiness_model import ProjectImage, ProjectInfo, ProjectImgLeafer, ProjectImgLabel
|
||||
from app.model.bussiness_model import ProjectImage, ProjectInfo, ProjectImgLeafer, ProjectImgLabel, ProjectTrain
|
||||
from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoOut
|
||||
from app.model.schemas.project_image_schemas import ProjectImgLeaferLabel, ProjectImgLeaferOut
|
||||
from app.model.crud import project_info_crud as pic
|
||||
from app.model.crud import project_image_crud as pimc
|
||||
from app.model.crud import project_label_crud as plc
|
||||
from app.model.crud import project_train_crud as ptc
|
||||
from app.model.crud import project_img_leafer_label_crud as pillc
|
||||
from app.util import os_utils as os
|
||||
from app.util import random_utils as ru
|
||||
from app.config.config_reader import datasets_url, runs_url, images_url
|
||||
from app.config.config_reader import datasets_url, runs_url, images_url, yolo_url
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
from fastapi import UploadFile
|
||||
import yaml
|
||||
import subprocess
|
||||
|
||||
|
||||
def add_project(info: ProjectInfoIn, session: Session, user_id: int):
|
||||
@ -48,7 +52,7 @@ def upload_project_image(project_info: ProjectInfoOut, files: List[UploadFile],
|
||||
path = os.save_images(images_url, project_info.project_no, file=file)
|
||||
image.image_url = path
|
||||
# 生成缩略图
|
||||
thumb_image_url = images_url + "\\thumb\\" + project_info.project_no + "\\" + ru.random_str(10) + ".jpg"
|
||||
thumb_image_url = os.file_path(images_url, 'thumb', project_info.project_no, ru.random_str(10) + ".jpg")
|
||||
os.create_thumbnail(path, thumb_image_url)
|
||||
image.thumb_image_url = thumb_image_url
|
||||
images.append(image)
|
||||
@ -62,7 +66,9 @@ def save_img_label(img_leafer_label: ProjectImgLeaferLabel, session: Session):
|
||||
:param session:
|
||||
:return:
|
||||
"""
|
||||
img_leafer = ProjectImgLeafer(**img_leafer_label)
|
||||
img_leafer = ProjectImgLeafer()
|
||||
img_leafer.image_id = img_leafer_label.image_id
|
||||
img_leafer.leafer = img_leafer_label.leafer
|
||||
pillc.save_img_leafer(img_leafer, session)
|
||||
label_infos = img_leafer_label.label_infos
|
||||
img_labels = []
|
||||
@ -83,3 +89,146 @@ def get_img_leafer(image_id: int, session: Session):
|
||||
img_leafer = pillc.get_img_leafer(image_id, session)
|
||||
img_leafer_out = ProjectImgLeaferOut.from_orm(img_leafer).dict()
|
||||
return img_leafer_out
|
||||
|
||||
|
||||
def run_train_yolo(project_info: ProjectInfoOut, session: Session):
|
||||
"""
|
||||
yolov5执行训练任务
|
||||
:param project_info: 项目信息
|
||||
:param session: 数据库session
|
||||
:return:
|
||||
"""
|
||||
# 先获取项目的所有图片
|
||||
project_images = pimc.get_images(project_info.id, session)
|
||||
|
||||
# 将图片根据,根据3:1的比例将图片分成train:val的两个数组
|
||||
project_images_train, project_images_val = split_array(project_images)
|
||||
|
||||
# 得到训练版本
|
||||
version_path = 'v' + str(project_info.train_version + 1)
|
||||
|
||||
# 创建训练的根目录
|
||||
train_path = os.create_folder(datasets_url, project_info.project_no, version_path)
|
||||
|
||||
# 查询项目所属标签,返回两个 id,name一一对应的数组
|
||||
label_id_list, label_name_list = plc.get_label_for_train(project_info.id, session)
|
||||
|
||||
# 在根目录创建classes.txt文件
|
||||
classes_txt = os.file_path(train_path, 'classes.txt')
|
||||
with open(classes_txt, 'w', encoding='utf-8') as file:
|
||||
for label_name in label_name_list:
|
||||
file.write(label_name + '\n')
|
||||
|
||||
# 创建图片的的两个文件夹
|
||||
img_path_train = os.create_folder(train_path, 'images', 'train')
|
||||
img_path_val = os.create_folder(train_path, 'images', 'val')
|
||||
|
||||
# 创建标签的两个文件夹
|
||||
label_path_train = os.create_folder(train_path, 'labels', 'train')
|
||||
label_path_val = os.create_folder(train_path, 'labels', 'val')
|
||||
|
||||
# 在根目录下创建yaml文件
|
||||
yaml_file = os.file_path(train_path, project_info.project_no + '.yaml')
|
||||
yaml_data = {
|
||||
'path': train_path,
|
||||
'train': 'images/train',
|
||||
'val': 'images/val',
|
||||
'test': None,
|
||||
'names': {i: name for i, name in enumerate(label_name_list)}
|
||||
}
|
||||
with open(yaml_file, 'w', encoding='utf-8') as file:
|
||||
yaml.dump(yaml_data, file, allow_unicode=True, default_flow_style=False)
|
||||
|
||||
# 开始循环复制图片和生成label.txt
|
||||
# 先操作train
|
||||
operate_img_label(project_images_train, img_path_train, label_path_train, session, label_id_list)
|
||||
# 再操作val
|
||||
operate_img_label(project_images_val, img_path_val, label_path_val, session, label_id_list)
|
||||
|
||||
# 打包完成开始训练,训练前,更改项目的训练状态
|
||||
pic.update_project_status(project_info.id, '1')
|
||||
# 开始训练
|
||||
data = yaml_file
|
||||
project = os.file_path(runs_url, project_info.project_no, 'train')
|
||||
name = version_path
|
||||
epochs = 10
|
||||
yolo_path = os.file_path(yolo_url, 'train.py')
|
||||
|
||||
# 启动子进程
|
||||
with subprocess.Popen(
|
||||
["python", yolo_path, "--data=" + data, "--project=" + project, "--name=" + name, "--epochs=" + str(epochs)],
|
||||
shell=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
universal_newlines=True, # 确保输出以字符串形式返回而不是字节
|
||||
bufsize=1, # 行缓冲
|
||||
) as process:
|
||||
# 使用iter逐行读取stdout和stderr
|
||||
for line in process.stdout:
|
||||
yield f"stdout: {line.strip()} \n"
|
||||
|
||||
for line in process.stderr:
|
||||
yield f"stderr: {line.strip()} \n"
|
||||
|
||||
# 等待进程结束并获取返回码
|
||||
return_code = process.wait()
|
||||
if return_code != 0:
|
||||
pic.update_project_status(project_info.id, '-1', session)
|
||||
else:
|
||||
pic.update_project_status(project_info.id, '2', session)
|
||||
# 然后保存版本训练信息
|
||||
train = ProjectTrain()
|
||||
train.project_id = project_info.id
|
||||
train.train_version = version_path
|
||||
bast_pt_path = os.file_path(project, name, 'weight', 'bast.pt')
|
||||
last_pt_path = os.file_path(project, name, 'weight', 'last.pt')
|
||||
train.best_pt = bast_pt_path
|
||||
train.last_pt = last_pt_path
|
||||
ptc.add_train(train, session)
|
||||
|
||||
|
||||
def operate_img_label(img_list: List[ProjectImgLabel],
|
||||
img_path: str, label_path: str,
|
||||
session: Session, label_id_list: []):
|
||||
"""
|
||||
生成图片和标签内容
|
||||
:param label_id_list:
|
||||
:param session:
|
||||
:param img_list:
|
||||
:param img_path:
|
||||
:param label_path:
|
||||
:return:
|
||||
"""
|
||||
for i in range(len(img_list)):
|
||||
image = img_list[i]
|
||||
# 先复制图片,并把图片改名,不改后缀
|
||||
file_name = 'image' + str(i)
|
||||
os.copy_and_rename_file(image.image_url, img_path, file_name)
|
||||
# 查询这张图片的label信息然后生成这张照片的txt文件
|
||||
img_label_list = pillc.get_img_label_list(image.id, session)
|
||||
label_txt_path = os.file_path(label_path, file_name + '.txt')
|
||||
with open(label_txt_path, 'w', encoding='utf-8') as file:
|
||||
for image_label in img_label_list:
|
||||
index = label_id_list.index(image_label.label_id)
|
||||
file.write(str(index) + ' ' + image_label.mark_center_x + ' '
|
||||
+ image_label.mark_center_y + ' '
|
||||
+ image_label.mark_width + ' '
|
||||
+ image_label.mark_height + '\n')
|
||||
|
||||
|
||||
def split_array(data):
|
||||
"""
|
||||
将数组按照3:1的比例切分成两个数组。
|
||||
:param data: 原始数组
|
||||
:return: 按照3:1比例分割后的两个数组
|
||||
"""
|
||||
total_length = len(data)
|
||||
if total_length < 4:
|
||||
raise ValueError("数组长度至少需要为4才能进行3:1的分割")
|
||||
# 计算分割点
|
||||
split_index = (total_length * 3) // 4
|
||||
# 使用切片分割数组
|
||||
part1 = data[:split_index]
|
||||
part2 = data[split_index:]
|
||||
return part1, part2
|
||||
|
||||
|
@ -1,6 +1,18 @@
|
||||
import os
|
||||
import shutil
|
||||
from fastapi import UploadFile
|
||||
from PIL import Image
|
||||
"""文件处理相关的util"""
|
||||
|
||||
|
||||
def file_path(*path):
|
||||
"""
|
||||
拼接返回文件路径
|
||||
:param path:
|
||||
:return:
|
||||
"""
|
||||
return_path = os.path.join(*path)
|
||||
return return_path
|
||||
|
||||
|
||||
def create_folder(*path):
|
||||
@ -10,6 +22,7 @@ def create_folder(*path):
|
||||
os.makedirs(folder_path, exist_ok=True)
|
||||
except Exception as e:
|
||||
print(f"创建文件夹时错误: {e}")
|
||||
return folder_path
|
||||
|
||||
|
||||
def save_images(*path, file: UploadFile):
|
||||
@ -43,3 +56,26 @@ def create_thumbnail(input_image_path, out_image_path, size=(116, 70)):
|
||||
os.makedirs(os.path.dirname(out_image_path), exist_ok=True)
|
||||
# 保存生成的缩略图
|
||||
image.save(out_image_path, 'JPEG')
|
||||
|
||||
|
||||
def copy_and_rename_file(src_file_path, dst_dir, new_name):
|
||||
"""
|
||||
复制文件到指定目录并重命名,保持文件的后缀不变。
|
||||
:param src_file_path: 源文件路径
|
||||
:param dst_dir: 目标目录
|
||||
:param new_name: 新文件名(不含扩展名)
|
||||
"""
|
||||
# 获取文件的完整文件名(包括扩展名)
|
||||
base_name = os.path.basename(src_file_path)
|
||||
# 分离文件名和扩展名
|
||||
file_name, file_extension = os.path.splitext(base_name)
|
||||
|
||||
# 构建新的文件名和目标路径
|
||||
new_file_name = f"{new_name}{file_extension}"
|
||||
dst_file_path = os.path.join(dst_dir, new_file_name)
|
||||
|
||||
# 确保目标目录存在
|
||||
os.makedirs(dst_dir, exist_ok=True)
|
||||
|
||||
# 复制文件到目标位置并重命名
|
||||
shutil.copy(src_file_path, dst_file_path)
|
||||
|
Reference in New Issue
Block a user