完成推理模块的主体功能

This commit is contained in:
2025-03-04 17:04:37 +08:00
parent 4262d3e908
commit fa6c344e84
9 changed files with 325 additions and 21 deletions

View File

@ -1,11 +1,14 @@
from typing import List from typing import List
from fastapi import APIRouter, Depends, Request, UploadFile, File, Form from fastapi import APIRouter, Depends, UploadFile, File, Form
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from fastapi.encoders import jsonable_encoder
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.common import reponse_code as rc from app.common import reponse_code as rc
from app.model.crud import project_detect_crud as pdc from app.model.crud import project_detect_crud as pdc
from app.model.schemas.project_detect_schemas import ProjectDetectPager from app.service import project_detect_service as pds
from app.model.crud.project_train_crud import get_train
from app.model.schemas.project_detect_schemas import ProjectDetectPager, ProjectDetectIn, ProjectDetectImgPager, ProjectDetectLogIn
from app.db.db_session import get_db from app.db.db_session import get_db
detect = APIRouter() detect = APIRouter()
@ -14,10 +17,121 @@ detect = APIRouter()
@detect.post("/detect_pager") @detect.post("/detect_pager")
def get_pager(detect_pager: ProjectDetectPager, session: Session = Depends(get_db)): def get_pager(detect_pager: ProjectDetectPager, session: Session = Depends(get_db)):
""" """
获取训练集的照片 获取训练集
:param detect_pager: :param detect_pager:
:param session: :param session:
:return: :return:
""" """
pager = pdc.get_detect_pager(detect_pager, session) pager = pdc.get_detect_pager(detect_pager, session)
return rc.response_success_pager(pager) return rc.response_success_pager(pager)
@detect.post("/add_detect")
def add_detect(detect_in: ProjectDetectIn, session: Session = Depends(get_db)):
"""
新增训练集合
:param detect_in:
:param session:
:return:
"""
pds.add_detect(detect_in, session)
return rc.response_success("新增成功")
@detect.post("/get_img_list")
def get_img_list(detect_img_pager: ProjectDetectImgPager, session: Session = Depends(get_db)):
"""
查询训练集合中的图片列表
:param detect_img_pager:
:param session:
:return:
"""
if detect_img_pager.pagerNum is None and detect_img_pager.pagerSize is None:
img_list = pdc.get_img_list(detect_img_pager.detect_id, session)
img_list = jsonable_encoder(img_list)
return rc.response_success(data=img_list)
else:
pager = pdc.get_img_pager(detect_img_pager, session)
return rc.response_success_pager(pager)
@detect.post("/upload_detect_img")
def upload_detect_img(detect_id: int = Form(...), files: List[UploadFile] = File(...), session: Session = Depends(get_db)):
"""
上传训练集合中的照片
:param detect_id:
:param files:
:param session:
:return:
"""
detect_out = pdc.get_detect_by_id(detect_id, session)
if detect_out is None:
return rc.response_error("训练集合查询失败,请刷新后再试")
is_check, file_name = pds.check_image_name(detect_id, files, session)
if not is_check:
return rc.response_error(msg="存在重名的图片文件:" + file_name)
pds.upload_detect_imgs(detect_out, files, session)
return rc.response_success("上传成功")
@detect.get("/del_detect_img/{detect_img_id}")
def del_detect_img(detect_img_id: int, session: Session = Depends(get_db)):
"""
删除训练集合照片
:param detect_img_id:
:param session:
:return:
"""
result = pds.del_detect_img(detect_img_id, session)
if result > 0:
return rc.response_success(msg="删除成功")
else:
return rc.response_error(msg="删除失败")
@detect.get("/run_detect_yolo")
def run_detect_yolo(detect_log_in: ProjectDetectLogIn, session: Session = Depends(get_db)):
"""
开始执行训练
:param detect_log_in:
:param session:
:return:
"""
detect = pdc.get_detect_by_id(detect_log_in.detect_id, session)
if detect is None:
return rc.response_error("训练集合不存在")
train = get_train(detect_log_in.train_id, session)
if train is None:
return rc.response_error("训练权重不存在")
detect_log = pds.run_detect_yolo(detect_log_in, detect, train, session)
return StreamingResponse(pds.run_commend(detect_log.pt_url,
detect_log.folder_url,
detect_log.detect_folder_url,
detect_log.detect_version,
detect_log.id, detect_log.detect_id, session), media_type="text/plain")
@detect.get("/get_log_list/{detect_id}")
def get_log_list(detect_id: int, session: Session = Depends(get_db)):
"""
根据推理集合id获取推理记录
:param detect_id:
:param session:
:return:
"""
result = pdc.get_log_list(detect_id, session)
result = jsonable_encoder(result)
return rc.response_success(data=result)
@detect.get("/get_log_imgs/{log_id}")
def get_log_imgs(log_id: int, session: Session = Depends(get_db)):
"""
根据推理集合中的结果图片
:param log_id:
:param session:
:return:
"""
result = pdc.get_log_imgs(log_id, session)
result = jsonable_encoder(result)
return rc.response_success(data=result)

View File

@ -32,7 +32,7 @@ def get_type_list(session: Session = Depends(get_db)):
@project.post("/list") @project.post("/list")
def project_pager(info: ProjectInfoPager, session: Session = Depends(get_db)): def project_pager(info: ProjectInfoPager, session: Session = Depends(get_db)):
""" """
项目列表
:param info: :param info:
:param session: :param session:
:return: :return:

View File

@ -4,6 +4,7 @@ from starlette.responses import FileResponse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.model.crud.project_image_crud import get_img_url from app.model.crud.project_image_crud import get_img_url
from app.model.crud.project_detect_crud import get_detect_img_url
from app.config.config_reader import images_url from app.config.config_reader import images_url
from app.db.db_session import get_db from app.db.db_session import get_db
@ -40,3 +41,38 @@ def view_thumb(image_id: int, session: Session = Depends(get_db)):
if not os.path.isfile(image_path): if not os.path.isfile(image_path):
raise HTTPException(status_code=404, detail="Image not found") raise HTTPException(status_code=404, detail="Image not found")
return FileResponse(image_path, media_type='image/jpeg') return FileResponse(image_path, media_type='image/jpeg')
@view.get("/view_detect_img/{image_id}")
def view_detect_img(image_id: int, session: Session = Depends(get_db)):
"""
查看图片
:param session:
:param image_id: 图片id
:return:
"""
sour_url, thumb_url = get_detect_img_url(image_id, session)
image_path = os.path.join(images_url, sour_url)
# 检查文件是否存在以及是否是文件
if not os.path.isfile(image_path):
raise HTTPException(status_code=404, detail="Image not found")
return FileResponse(image_path, media_type='image/jpeg')
@view.get("/view_detect_thumb/{image_id}")
def view_detect_thumb(image_id: int, session: Session = Depends(get_db)):
"""
查看图片
:param session:
:param image_id: 图片id
:return:
"""
sour_url, thumb_url = get_detect_img_url(image_id, session)
image_path = os.path.join(images_url, thumb_url)
# 检查文件是否存在以及是否是文件
if not os.path.isfile(image_path):
raise HTTPException(status_code=404, detail="Image not found")
return FileResponse(image_path, media_type='image/jpeg')

View File

@ -8,6 +8,7 @@ from app.api.sys.login_api import login
from app.api.sys.sys_user_api import user from app.api.sys.sys_user_api import user
from app.api.business.project_train_api import project from app.api.business.project_train_api import project
from app.api.common.view_img import view from app.api.common.view_img import view
from app.api.business.project_detect_api import detect
my_app = FastAPI() my_app = FastAPI()
@ -33,5 +34,6 @@ my_app.include_router(login, prefix="/login", tags=["用户登录接口"])
my_app.include_router(upload_files, prefix="/upload", tags=["文件上传API"]) my_app.include_router(upload_files, prefix="/upload", tags=["文件上传API"])
my_app.include_router(view, tags=["查看图片"]) my_app.include_router(view, tags=["查看图片"])
my_app.include_router(user, prefix="/user", tags=["用户管理API"]) my_app.include_router(user, prefix="/user", tags=["用户管理API"])
my_app.include_router(project, prefix="/proj", tags=["项目管理API"]) my_app.include_router(project, prefix="/proj", tags=["项目训练API"])
my_app.include_router(detect, prefix="/detect", tags=["项目推理API"])

View File

@ -116,9 +116,11 @@ class ProjectDetectLog(DbCommon):
__tablename__ = "project_detect_log" __tablename__ = "project_detect_log"
detect_id: Mapped[int] = mapped_column(Integer, nullable=False) detect_id: Mapped[int] = mapped_column(Integer, nullable=False)
detect_version: Mapped[str] = mapped_column(String(10)) detect_version: Mapped[str] = mapped_column(String(10))
detect_name: Mapped[str] = mapped_column(String(64), nullable=False)
train_id: Mapped[int] = mapped_column(Integer, nullable=False) train_id: Mapped[int] = mapped_column(Integer, nullable=False)
train_version: Mapped[str] = mapped_column(String(10)) train_version: Mapped[str] = mapped_column(String(10))
pt_type: Mapped[str] = mapped_column(String(32)) pt_type: Mapped[str] = mapped_column(String(10))
pt_url: Mapped[str] = mapped_column(String(255))
folder_url: Mapped[str] = mapped_column(String(255)) folder_url: Mapped[str] = mapped_column(String(255))
detect_folder_url: Mapped[str] = mapped_column(String(255)) detect_folder_url: Mapped[str] = mapped_column(String(255))

View File

@ -2,10 +2,10 @@ from typing import List
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from sqlalchemy import asc, and_ from sqlalchemy import asc, and_
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
from fastapi import UploadFile
from app.model.bussiness_model import ProjectDetect, ProjectDetectImg, ProjectDetectLog, ProjectDetectLogImg from app.model.bussiness_model import ProjectDetect, ProjectDetectImg, ProjectDetectLog, ProjectDetectLogImg
from app.model.schemas.project_detect_schemas import ProjectDetectOut, ProjectDetectPager from app.model.schemas.project_detect_schemas import ProjectDetectOut, ProjectDetectPager, \
ProjectDetectImageOut, ProjectDetectImgPager, ProjectDetectLogOut, ProjectDetectLogImgOut
from app.db.page_util import get_pager from app.db.page_util import get_pager
@ -21,6 +21,18 @@ def add_detect(detect: ProjectDetect, session: Session):
return detect return detect
def get_detect_by_id(detect_id: int, session: Session):
"""
根据id查询训练集合
:param detect_id:
:param session:
:return:
"""
detect = session.query(ProjectDetect).filter_by(id=detect_id).first()
detect_out = ProjectDetectOut.from_orm(detect)
return detect_out
def get_detect_pager(detect_pager: ProjectDetectPager, session: Session): def get_detect_pager(detect_pager: ProjectDetectPager, session: Session):
""" """
查询推理集合分页数据 查询推理集合分页数据
@ -39,6 +51,26 @@ def get_detect_pager(detect_pager: ProjectDetectPager, session: Session):
return pager return pager
def update_detect_status(detect_id: int, detect_status: int, session: Session):
"""
更新项目训练状态,如果是已完成的话train_version自动+1
:param detect_id:
:param detect_status: 0-未运行1-运行中2-已完成,-1-执行失败
:param session:
:return:
"""
if detect_status == 2:
session.query(ProjectDetect).filter_by(id=detect_id).update({
'detect_status': detect_status,
'detect_version': ProjectDetect.detect_version + 1
})
else:
session.query(ProjectDetect).filter_by(id=detect_id).update({
'detect_status': detect_status
})
session.commit()
def check_img_name(detect_id: int, file_name: str, session: Session): def check_img_name(detect_id: int, file_name: str, session: Session):
""" """
校验上传的图片名称是否重名 校验上传的图片名称是否重名
@ -65,6 +97,40 @@ def add_detect_imgs(detect_imgs: List[ProjectDetectImg], session: Session):
session.commit() session.commit()
def get_img_list(detect_id: int, session: Session):
"""
获取训练集合中的图片列表
:param detect_id:
:param session:
:return:
"""
query = session.query(ProjectDetectImg).filter_by(detect_id=detect_id).order_by(asc(ProjectDetectImg.id))
image_list = [ProjectDetectImageOut.from_orm(image) for image in query.all()]
return image_list
def get_img_pager(detect_img_pager: ProjectDetectImgPager, session: Session):
"""
获取训练集合中的图片列表,返回pager对象
:param detect_img_pager:
:param session:
:return:
"""
query = session.query(ProjectDetectImg).filter_by(detect_id=detect_img_pager.detect_id)\
.order_by(asc(ProjectDetectImg.id))
pager = get_pager(query, detect_img_pager.pagerNum, detect_img_pager.pagerSize)
pager.data = [ProjectDetectImageOut.from_orm(image) for image in pager.data]
pager.data = jsonable_encoder(pager.data)
return pager
def get_detect_img_url(image_id: int, session: Session):
result = session.query(ProjectDetectImg).filter_by(id=image_id).first()
sour_url = result.image_url
thumb_url = result.thumb_image_url
return sour_url, thumb_url
def add_detect_log(detect_log: ProjectDetectLog, session: Session): def add_detect_log(detect_log: ProjectDetectLog, session: Session):
""" """
新增推理记录 新增推理记录
@ -86,3 +152,22 @@ def add_detect_log_imgs(detect_log_imgs: List[ProjectDetectLogImg], session: Ses
""" """
session.add_all(detect_log_imgs) session.add_all(detect_log_imgs)
session.commit() session.commit()
def get_log_list(detect_id: int, session: Session):
"""
获取推理记录
:param detect_id:
:param session:
:return:
"""
query = session.query(ProjectDetectLog).filter_by(detect_id=detect_id).order_by(asc(ProjectDetectLog.id))
result = [ProjectDetectLogOut.from_orm(log) for log in query.all()]
return result
def get_log_imgs(log_id: int, session: Session):
query = session.query(ProjectDetectLogImg).filter_by(log_id=log_id).order_by(asc(ProjectDetectLogImg.id))
result = [ProjectDetectLogImgOut.from_orm(img) for img in query.all()]
return result

View File

@ -19,7 +19,7 @@ class ProjectDetectPager(BaseModel):
class ProjectDetectOut(BaseModel): class ProjectDetectOut(BaseModel):
id: Optional[int] id: Optional[int]
project_id: Optional[int] project_id: Optional[int]
detect_name: Optional[int] detect_name: Optional[str]
detect_no: Optional[str] detect_no: Optional[str]
detect_version: Optional[int] detect_version: Optional[int]
file_type: Optional[str] file_type: Optional[str]
@ -33,6 +33,25 @@ class ProjectDetectOut(BaseModel):
} }
class ProjectDetectImgPager(BaseModel):
detect_id: Optional[int] = Field(..., description="训练集合id")
pagerNum: Optional[int] = Field(None, description="当前页码")
pagerSize: Optional[int] = Field(None, description="每页数量")
class ProjectDetectImageOut(BaseModel):
id: Optional[int] = Field(None, description="id")
detect_id: Optional[int] = Field(..., description="训练集合id")
file_name: Optional[str] = Field(None, description="文件名称")
create_time: Optional[datetime] = Field(None, description="上传时间")
class Config:
orm_mode = True
json_encoders = {
datetime: lambda v: v.strftime("%Y-%m-%d %H:%M:%S")
}
class ProjectDetectLogIn(BaseModel): class ProjectDetectLogIn(BaseModel):
detect_id: Optional[int] = Field(..., description="推理集合id") detect_id: Optional[int] = Field(..., description="推理集合id")
train_id: Optional[int] = Field(..., description="训练结果id") train_id: Optional[int] = Field(..., description="训练结果id")
@ -43,6 +62,7 @@ class ProjectDetectLogOut(BaseModel):
id: Optional[int] id: Optional[int]
detect_id: Optional[int] detect_id: Optional[int]
detect_version: Optional[str] detect_version: Optional[str]
detect_name: Optional[str]
train_id: Optional[int] train_id: Optional[int]
train_version: Optional[int] train_version: Optional[int]
pt_type: Optional[str] pt_type: Optional[str]

View File

@ -5,7 +5,7 @@ import subprocess
from app.model.crud import project_detect_crud as pdc from app.model.crud import project_detect_crud as pdc
from app.model.schemas.project_detect_schemas import ProjectDetectIn, ProjectDetectOut, ProjectDetectLogIn from app.model.schemas.project_detect_schemas import ProjectDetectIn, ProjectDetectOut, ProjectDetectLogIn
from app.model.bussiness_model import ProjectDetect, ProjectDetectImg, ProjectTrain, ProjectDetectLog from app.model.bussiness_model import ProjectDetect, ProjectDetectImg, ProjectTrain, ProjectDetectLog, ProjectDetectLogImg
from app.util.random_utils import random_str from app.util.random_utils import random_str
from app.config.config_reader import detect_url from app.config.config_reader import detect_url
from app.util import os_utils as os from app.util import os_utils as os
@ -23,6 +23,7 @@ def add_detect(detect_in: ProjectDetectIn, session: Session):
detect = ProjectDetect(**detect_in.dict()) detect = ProjectDetect(**detect_in.dict())
detect.detect_no = random_str(6) detect.detect_no = random_str(6)
detect.detect_version = 0 detect.detect_version = 0
detect.detect_status = '0'
url = os.create_folder(detect_url, detect.detect_no, 'images') url = os.create_folder(detect_url, detect.detect_no, 'images')
detect.folder_url = url detect.folder_url = url
detect = pdc.add_detect(detect, session) detect = pdc.add_detect(detect, session)
@ -67,6 +68,22 @@ def upload_detect_imgs(detect: ProjectDetectOut, files: List[UploadFile], sessio
pdc.add_detect_imgs(images, session) pdc.add_detect_imgs(images, session)
def del_detect_img(detect_img_id: int, session: Session):
"""
删除训练集合图片
:param detect_img_id:
:param session:
:return:
"""
detect_img = session.query(ProjectDetectImg).filter_by(id=detect_img_id).first()
if detect_img is None:
return 0
os.delete_file_if_exists(detect_img.image_url, detect_img.thumb_image_url)
session.delete(detect_img)
session.commit()
return 1
def run_detect_yolo(detect_in: ProjectDetectLogIn, detect: ProjectDetect, train: ProjectTrain, session: Session): def run_detect_yolo(detect_in: ProjectDetectLogIn, detect: ProjectDetect, train: ProjectTrain, session: Session):
""" """
开始推理 开始推理
@ -94,25 +111,35 @@ def run_detect_yolo(detect_in: ProjectDetectLogIn, detect: ProjectDetect, train:
detect_log.train_id = train.id detect_log.train_id = train.id
detect_log.train_version = train.train_version detect_log.train_version = train.train_version
detect_log.pt_type = detect_in.pt_type detect_log.pt_type = detect_in.pt_type
detect_log.folder_url = detect.folder_url detect_log.pt_url = pt_url
detect_log.folder_url = img_url
detect_log.detect_folder_url = out_url detect_log.detect_folder_url = out_url
detect_log = pdc.add_detect_log(detect_log, session) detect_log = pdc.add_detect_log(detect_log, session)
return detect_log return detect_log
def run_commend(weights: str, source: str, project: str, name: str, def run_commend(weights: str, source: str, project: str, name: str,
detect_log_id: int, session: Session): log_id: int, detect_id: int, session: Session):
"""
执行yolov5的推理
:param weights: 权重文件
:param source: 图片所在文件
:param project: 推理完成的文件位置
:param name: 版本名称
:param log_id: 日志id
:param detect_id: 推理集合id
:param session:
:return:
"""
yolo_path = os.file_path(yolo_url, 'detect.py') yolo_path = os.file_path(yolo_url, 'detect.py')
yield f"stdout: 模型推理开始,请稍等。。。 \n" yield f"stdout: 模型推理开始,请稍等。。。 \n"
# 启动子进程 # 启动子进程
with subprocess.Popen( with subprocess.Popen(
["python", '-u', yolo_path, ["python", '-u', yolo_path,
"--weights =" + weights, "--weights", weights,
"--source =" + source, "--source", source,
"--name=" + name, "--name", name,
"--project=" + project, "--project", project],
"--view-img"],
bufsize=1, # bufsize=0时为不缓存bufsize=1时按行缓存bufsize为其他正整数时为按照近似该正整数的字节数缓存 bufsize=1, # bufsize=0时为不缓存bufsize=1时按行缓存bufsize为其他正整数时为按照近似该正整数的字节数缓存
shell=False, shell=False,
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
@ -126,6 +153,24 @@ def run_commend(weights: str, source: str, project: str, name: str,
if line != '\n': if line != '\n':
yield line yield line
# 等待进程结束并获取返回码
return_code = process.wait()
if return_code != 0:
pdc.update_detect_status(detect_id, -1, session)
else:
pdc.update_detect_status(detect_id, 2, session)
detect_imgs = pdc.get_img_list(detect_id, session)
detect_log_imgs = []
for detect_img in detect_imgs:
detect_log_img = ProjectDetectLogImg()
detect_log_img.log_id = log_id
image_url = os.file_path(project, name, detect_img.file_name)
detect_log_img.image_url = image_url
detect_log_img.file_name = detect_img.file_name
detect_log_imgs.append(detect_log_img)
pdc.add_detect_imgs(detect_log_imgs, session)

View File

@ -212,7 +212,7 @@ def run_commend(data: str, project: str,
train = ProjectTrain() train = ProjectTrain()
train.project_id = project_id train.project_id = project_id
train.train_version = name train.train_version = name
bast_pt_path = os.file_path(project, name, 'weights', 'bast.pt') bast_pt_path = os.file_path(project, name, 'weights', 'best.pt')
last_pt_path = os.file_path(project, name, 'weights', 'last.pt') last_pt_path = os.file_path(project, name, 'weights', 'last.pt')
train.best_pt = bast_pt_path train.best_pt = bast_pt_path
train.last_pt = last_pt_path train.last_pt = last_pt_path