From fa6c344e849c589c720678d26d8023c0342285a6 Mon Sep 17 00:00:00 2001 From: sunyugang Date: Tue, 4 Mar 2025 17:04:37 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E6=88=90=E6=8E=A8=E7=90=86=E6=A8=A1?= =?UTF-8?q?=E5=9D=97=E7=9A=84=E4=B8=BB=E4=BD=93=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/business/project_detect_api.py | 122 +++++++++++++++++++- app/api/business/project_train_api.py | 2 +- app/api/common/view_img.py | 36 ++++++ app/application/app.py | 4 +- app/model/bussiness_model.py | 4 +- app/model/crud/project_detect_crud.py | 91 ++++++++++++++- app/model/schemas/project_detect_schemas.py | 22 +++- app/service/project_detect_service.py | 63 ++++++++-- app/service/project_train_service.py | 2 +- 9 files changed, 325 insertions(+), 21 deletions(-) diff --git a/app/api/business/project_detect_api.py b/app/api/business/project_detect_api.py index c15a556..bc71394 100644 --- a/app/api/business/project_detect_api.py +++ b/app/api/business/project_detect_api.py @@ -1,11 +1,14 @@ from typing import List -from fastapi import APIRouter, Depends, Request, UploadFile, File, Form +from fastapi import APIRouter, Depends, UploadFile, File, Form from fastapi.responses import StreamingResponse +from fastapi.encoders import jsonable_encoder from sqlalchemy.orm import Session from app.common import reponse_code as rc from app.model.crud import project_detect_crud as pdc -from app.model.schemas.project_detect_schemas import ProjectDetectPager +from app.service import project_detect_service as pds +from app.model.crud.project_train_crud import get_train +from app.model.schemas.project_detect_schemas import ProjectDetectPager, ProjectDetectIn, ProjectDetectImgPager, ProjectDetectLogIn from app.db.db_session import get_db detect = APIRouter() @@ -14,10 +17,121 @@ detect = APIRouter() @detect.post("/detect_pager") def get_pager(detect_pager: ProjectDetectPager, session: Session = Depends(get_db)): """ - 获取训练集的照片 + 获取训练集合 :param detect_pager: :param session: :return: """ pager = pdc.get_detect_pager(detect_pager, session) - return rc.response_success_pager(pager) \ No newline at end of file + return rc.response_success_pager(pager) + + +@detect.post("/add_detect") +def add_detect(detect_in: ProjectDetectIn, session: Session = Depends(get_db)): + """ + 新增训练集合 + :param detect_in: + :param session: + :return: + """ + pds.add_detect(detect_in, session) + return rc.response_success("新增成功") + + +@detect.post("/get_img_list") +def get_img_list(detect_img_pager: ProjectDetectImgPager, session: Session = Depends(get_db)): + """ + 查询训练集合中的图片列表 + :param detect_img_pager: + :param session: + :return: + """ + if detect_img_pager.pagerNum is None and detect_img_pager.pagerSize is None: + img_list = pdc.get_img_list(detect_img_pager.detect_id, session) + img_list = jsonable_encoder(img_list) + return rc.response_success(data=img_list) + else: + pager = pdc.get_img_pager(detect_img_pager, session) + return rc.response_success_pager(pager) + + +@detect.post("/upload_detect_img") +def upload_detect_img(detect_id: int = Form(...), files: List[UploadFile] = File(...), session: Session = Depends(get_db)): + """ + 上传训练集合中的照片 + :param detect_id: + :param files: + :param session: + :return: + """ + detect_out = pdc.get_detect_by_id(detect_id, session) + if detect_out is None: + return rc.response_error("训练集合查询失败,请刷新后再试") + is_check, file_name = pds.check_image_name(detect_id, files, session) + if not is_check: + return rc.response_error(msg="存在重名的图片文件:" + file_name) + pds.upload_detect_imgs(detect_out, files, session) + return rc.response_success("上传成功") + + +@detect.get("/del_detect_img/{detect_img_id}") +def del_detect_img(detect_img_id: int, session: Session = Depends(get_db)): + """ + 删除训练集合照片 + :param detect_img_id: + :param session: + :return: + """ + result = pds.del_detect_img(detect_img_id, session) + if result > 0: + return rc.response_success(msg="删除成功") + else: + return rc.response_error(msg="删除失败") + + +@detect.get("/run_detect_yolo") +def run_detect_yolo(detect_log_in: ProjectDetectLogIn, session: Session = Depends(get_db)): + """ + 开始执行训练 + :param detect_log_in: + :param session: + :return: + """ + detect = pdc.get_detect_by_id(detect_log_in.detect_id, session) + if detect is None: + return rc.response_error("训练集合不存在") + train = get_train(detect_log_in.train_id, session) + if train is None: + return rc.response_error("训练权重不存在") + detect_log = pds.run_detect_yolo(detect_log_in, detect, train, session) + return StreamingResponse(pds.run_commend(detect_log.pt_url, + detect_log.folder_url, + detect_log.detect_folder_url, + detect_log.detect_version, + detect_log.id, detect_log.detect_id, session), media_type="text/plain") + + +@detect.get("/get_log_list/{detect_id}") +def get_log_list(detect_id: int, session: Session = Depends(get_db)): + """ + 根据推理集合id获取推理记录 + :param detect_id: + :param session: + :return: + """ + result = pdc.get_log_list(detect_id, session) + result = jsonable_encoder(result) + return rc.response_success(data=result) + + +@detect.get("/get_log_imgs/{log_id}") +def get_log_imgs(log_id: int, session: Session = Depends(get_db)): + """ + 根据推理集合中的结果图片 + :param log_id: + :param session: + :return: + """ + result = pdc.get_log_imgs(log_id, session) + result = jsonable_encoder(result) + return rc.response_success(data=result) diff --git a/app/api/business/project_train_api.py b/app/api/business/project_train_api.py index bc9d263..ecae2db 100644 --- a/app/api/business/project_train_api.py +++ b/app/api/business/project_train_api.py @@ -32,7 +32,7 @@ def get_type_list(session: Session = Depends(get_db)): @project.post("/list") def project_pager(info: ProjectInfoPager, session: Session = Depends(get_db)): """ - + 项目列表 :param info: :param session: :return: diff --git a/app/api/common/view_img.py b/app/api/common/view_img.py index f6f5a04..1758742 100644 --- a/app/api/common/view_img.py +++ b/app/api/common/view_img.py @@ -4,6 +4,7 @@ from starlette.responses import FileResponse from sqlalchemy.orm import Session from app.model.crud.project_image_crud import get_img_url +from app.model.crud.project_detect_crud import get_detect_img_url from app.config.config_reader import images_url from app.db.db_session import get_db @@ -40,3 +41,38 @@ def view_thumb(image_id: int, session: Session = Depends(get_db)): if not os.path.isfile(image_path): raise HTTPException(status_code=404, detail="Image not found") return FileResponse(image_path, media_type='image/jpeg') + + +@view.get("/view_detect_img/{image_id}") +def view_detect_img(image_id: int, session: Session = Depends(get_db)): + """ + 查看图片 + :param session: + :param image_id: 图片id + :return: + """ + sour_url, thumb_url = get_detect_img_url(image_id, session) + image_path = os.path.join(images_url, sour_url) + # 检查文件是否存在以及是否是文件 + if not os.path.isfile(image_path): + raise HTTPException(status_code=404, detail="Image not found") + return FileResponse(image_path, media_type='image/jpeg') + + +@view.get("/view_detect_thumb/{image_id}") +def view_detect_thumb(image_id: int, session: Session = Depends(get_db)): + """ + 查看图片 + :param session: + :param image_id: 图片id + :return: + """ + sour_url, thumb_url = get_detect_img_url(image_id, session) + image_path = os.path.join(images_url, thumb_url) + # 检查文件是否存在以及是否是文件 + if not os.path.isfile(image_path): + raise HTTPException(status_code=404, detail="Image not found") + return FileResponse(image_path, media_type='image/jpeg') + + + diff --git a/app/application/app.py b/app/application/app.py index da5e783..dc5af44 100644 --- a/app/application/app.py +++ b/app/application/app.py @@ -8,6 +8,7 @@ from app.api.sys.login_api import login from app.api.sys.sys_user_api import user from app.api.business.project_train_api import project from app.api.common.view_img import view +from app.api.business.project_detect_api import detect my_app = FastAPI() @@ -33,5 +34,6 @@ my_app.include_router(login, prefix="/login", tags=["用户登录接口"]) my_app.include_router(upload_files, prefix="/upload", tags=["文件上传API"]) my_app.include_router(view, tags=["查看图片"]) my_app.include_router(user, prefix="/user", tags=["用户管理API"]) -my_app.include_router(project, prefix="/proj", tags=["项目管理API"]) +my_app.include_router(project, prefix="/proj", tags=["项目训练API"]) +my_app.include_router(detect, prefix="/detect", tags=["项目推理API"]) diff --git a/app/model/bussiness_model.py b/app/model/bussiness_model.py index 8f663d4..16195be 100644 --- a/app/model/bussiness_model.py +++ b/app/model/bussiness_model.py @@ -116,9 +116,11 @@ class ProjectDetectLog(DbCommon): __tablename__ = "project_detect_log" detect_id: Mapped[int] = mapped_column(Integer, nullable=False) detect_version: Mapped[str] = mapped_column(String(10)) + detect_name: Mapped[str] = mapped_column(String(64), nullable=False) train_id: Mapped[int] = mapped_column(Integer, nullable=False) train_version: Mapped[str] = mapped_column(String(10)) - pt_type: Mapped[str] = mapped_column(String(32)) + pt_type: Mapped[str] = mapped_column(String(10)) + pt_url: Mapped[str] = mapped_column(String(255)) folder_url: Mapped[str] = mapped_column(String(255)) detect_folder_url: Mapped[str] = mapped_column(String(255)) diff --git a/app/model/crud/project_detect_crud.py b/app/model/crud/project_detect_crud.py index d4fc0de..03980bd 100644 --- a/app/model/crud/project_detect_crud.py +++ b/app/model/crud/project_detect_crud.py @@ -2,10 +2,10 @@ from typing import List from sqlalchemy.orm import Session from sqlalchemy import asc, and_ from fastapi.encoders import jsonable_encoder -from fastapi import UploadFile from app.model.bussiness_model import ProjectDetect, ProjectDetectImg, ProjectDetectLog, ProjectDetectLogImg -from app.model.schemas.project_detect_schemas import ProjectDetectOut, ProjectDetectPager +from app.model.schemas.project_detect_schemas import ProjectDetectOut, ProjectDetectPager, \ + ProjectDetectImageOut, ProjectDetectImgPager, ProjectDetectLogOut, ProjectDetectLogImgOut from app.db.page_util import get_pager @@ -21,6 +21,18 @@ def add_detect(detect: ProjectDetect, session: Session): return detect +def get_detect_by_id(detect_id: int, session: Session): + """ + 根据id查询训练集合 + :param detect_id: + :param session: + :return: + """ + detect = session.query(ProjectDetect).filter_by(id=detect_id).first() + detect_out = ProjectDetectOut.from_orm(detect) + return detect_out + + def get_detect_pager(detect_pager: ProjectDetectPager, session: Session): """ 查询推理集合分页数据 @@ -39,6 +51,26 @@ def get_detect_pager(detect_pager: ProjectDetectPager, session: Session): return pager +def update_detect_status(detect_id: int, detect_status: int, session: Session): + """ + 更新项目训练状态,如果是已完成的话,train_version自动+1 + :param detect_id: + :param detect_status: 0-未运行,1-运行中,2-已完成,-1-执行失败 + :param session: + :return: + """ + if detect_status == 2: + session.query(ProjectDetect).filter_by(id=detect_id).update({ + 'detect_status': detect_status, + 'detect_version': ProjectDetect.detect_version + 1 + }) + else: + session.query(ProjectDetect).filter_by(id=detect_id).update({ + 'detect_status': detect_status + }) + session.commit() + + def check_img_name(detect_id: int, file_name: str, session: Session): """ 校验上传的图片名称是否重名 @@ -65,6 +97,40 @@ def add_detect_imgs(detect_imgs: List[ProjectDetectImg], session: Session): session.commit() +def get_img_list(detect_id: int, session: Session): + """ + 获取训练集合中的图片列表 + :param detect_id: + :param session: + :return: + """ + query = session.query(ProjectDetectImg).filter_by(detect_id=detect_id).order_by(asc(ProjectDetectImg.id)) + image_list = [ProjectDetectImageOut.from_orm(image) for image in query.all()] + return image_list + + +def get_img_pager(detect_img_pager: ProjectDetectImgPager, session: Session): + """ + 获取训练集合中的图片列表,返回pager对象 + :param detect_img_pager: + :param session: + :return: + """ + query = session.query(ProjectDetectImg).filter_by(detect_id=detect_img_pager.detect_id)\ + .order_by(asc(ProjectDetectImg.id)) + pager = get_pager(query, detect_img_pager.pagerNum, detect_img_pager.pagerSize) + pager.data = [ProjectDetectImageOut.from_orm(image) for image in pager.data] + pager.data = jsonable_encoder(pager.data) + return pager + + +def get_detect_img_url(image_id: int, session: Session): + result = session.query(ProjectDetectImg).filter_by(id=image_id).first() + sour_url = result.image_url + thumb_url = result.thumb_image_url + return sour_url, thumb_url + + def add_detect_log(detect_log: ProjectDetectLog, session: Session): """ 新增推理记录 @@ -85,4 +151,23 @@ def add_detect_log_imgs(detect_log_imgs: List[ProjectDetectLogImg], session: Ses :return: """ session.add_all(detect_log_imgs) - session.commit() \ No newline at end of file + session.commit() + + +def get_log_list(detect_id: int, session: Session): + """ + 获取推理记录 + :param detect_id: + :param session: + :return: + """ + query = session.query(ProjectDetectLog).filter_by(detect_id=detect_id).order_by(asc(ProjectDetectLog.id)) + result = [ProjectDetectLogOut.from_orm(log) for log in query.all()] + return result + + +def get_log_imgs(log_id: int, session: Session): + query = session.query(ProjectDetectLogImg).filter_by(log_id=log_id).order_by(asc(ProjectDetectLogImg.id)) + result = [ProjectDetectLogImgOut.from_orm(img) for img in query.all()] + return result + diff --git a/app/model/schemas/project_detect_schemas.py b/app/model/schemas/project_detect_schemas.py index f3cdf46..c4a05b5 100644 --- a/app/model/schemas/project_detect_schemas.py +++ b/app/model/schemas/project_detect_schemas.py @@ -19,7 +19,7 @@ class ProjectDetectPager(BaseModel): class ProjectDetectOut(BaseModel): id: Optional[int] project_id: Optional[int] - detect_name: Optional[int] + detect_name: Optional[str] detect_no: Optional[str] detect_version: Optional[int] file_type: Optional[str] @@ -33,6 +33,25 @@ class ProjectDetectOut(BaseModel): } +class ProjectDetectImgPager(BaseModel): + detect_id: Optional[int] = Field(..., description="训练集合id") + pagerNum: Optional[int] = Field(None, description="当前页码") + pagerSize: Optional[int] = Field(None, description="每页数量") + + +class ProjectDetectImageOut(BaseModel): + id: Optional[int] = Field(None, description="id") + detect_id: Optional[int] = Field(..., description="训练集合id") + file_name: Optional[str] = Field(None, description="文件名称") + create_time: Optional[datetime] = Field(None, description="上传时间") + + class Config: + orm_mode = True + json_encoders = { + datetime: lambda v: v.strftime("%Y-%m-%d %H:%M:%S") + } + + class ProjectDetectLogIn(BaseModel): detect_id: Optional[int] = Field(..., description="推理集合id") train_id: Optional[int] = Field(..., description="训练结果id") @@ -43,6 +62,7 @@ class ProjectDetectLogOut(BaseModel): id: Optional[int] detect_id: Optional[int] detect_version: Optional[str] + detect_name: Optional[str] train_id: Optional[int] train_version: Optional[int] pt_type: Optional[str] diff --git a/app/service/project_detect_service.py b/app/service/project_detect_service.py index aa3268c..9fac1bc 100644 --- a/app/service/project_detect_service.py +++ b/app/service/project_detect_service.py @@ -5,7 +5,7 @@ import subprocess from app.model.crud import project_detect_crud as pdc from app.model.schemas.project_detect_schemas import ProjectDetectIn, ProjectDetectOut, ProjectDetectLogIn -from app.model.bussiness_model import ProjectDetect, ProjectDetectImg, ProjectTrain, ProjectDetectLog +from app.model.bussiness_model import ProjectDetect, ProjectDetectImg, ProjectTrain, ProjectDetectLog, ProjectDetectLogImg from app.util.random_utils import random_str from app.config.config_reader import detect_url from app.util import os_utils as os @@ -23,6 +23,7 @@ def add_detect(detect_in: ProjectDetectIn, session: Session): detect = ProjectDetect(**detect_in.dict()) detect.detect_no = random_str(6) detect.detect_version = 0 + detect.detect_status = '0' url = os.create_folder(detect_url, detect.detect_no, 'images') detect.folder_url = url detect = pdc.add_detect(detect, session) @@ -67,6 +68,22 @@ def upload_detect_imgs(detect: ProjectDetectOut, files: List[UploadFile], sessio pdc.add_detect_imgs(images, session) +def del_detect_img(detect_img_id: int, session: Session): + """ + 删除训练集合图片 + :param detect_img_id: + :param session: + :return: + """ + detect_img = session.query(ProjectDetectImg).filter_by(id=detect_img_id).first() + if detect_img is None: + return 0 + os.delete_file_if_exists(detect_img.image_url, detect_img.thumb_image_url) + session.delete(detect_img) + session.commit() + return 1 + + def run_detect_yolo(detect_in: ProjectDetectLogIn, detect: ProjectDetect, train: ProjectTrain, session: Session): """ 开始推理 @@ -94,25 +111,35 @@ def run_detect_yolo(detect_in: ProjectDetectLogIn, detect: ProjectDetect, train: detect_log.train_id = train.id detect_log.train_version = train.train_version detect_log.pt_type = detect_in.pt_type - detect_log.folder_url = detect.folder_url + detect_log.pt_url = pt_url + detect_log.folder_url = img_url detect_log.detect_folder_url = out_url detect_log = pdc.add_detect_log(detect_log, session) return detect_log def run_commend(weights: str, source: str, project: str, name: str, - detect_log_id: int, session: Session): + log_id: int, detect_id: int, session: Session): + """ + 执行yolov5的推理 + :param weights: 权重文件 + :param source: 图片所在文件 + :param project: 推理完成的文件位置 + :param name: 版本名称 + :param log_id: 日志id + :param detect_id: 推理集合id + :param session: + :return: + """ yolo_path = os.file_path(yolo_url, 'detect.py') - yield f"stdout: 模型推理开始,请稍等。。。 \n" # 启动子进程 with subprocess.Popen( ["python", '-u', yolo_path, - "--weights =" + weights, - "--source =" + source, - "--name=" + name, - "--project=" + project, - "--view-img"], + "--weights", weights, + "--source", source, + "--name", name, + "--project", project], bufsize=1, # bufsize=0时,为不缓存;bufsize=1时,按行缓存;bufsize为其他正整数时,为按照近似该正整数的字节数缓存 shell=False, stdout=subprocess.PIPE, @@ -126,6 +153,24 @@ def run_commend(weights: str, source: str, project: str, name: str, if line != '\n': yield line + # 等待进程结束并获取返回码 + return_code = process.wait() + if return_code != 0: + pdc.update_detect_status(detect_id, -1, session) + else: + pdc.update_detect_status(detect_id, 2, session) + detect_imgs = pdc.get_img_list(detect_id, session) + detect_log_imgs = [] + for detect_img in detect_imgs: + detect_log_img = ProjectDetectLogImg() + detect_log_img.log_id = log_id + image_url = os.file_path(project, name, detect_img.file_name) + detect_log_img.image_url = image_url + detect_log_img.file_name = detect_img.file_name + detect_log_imgs.append(detect_log_img) + pdc.add_detect_imgs(detect_log_imgs, session) + + diff --git a/app/service/project_train_service.py b/app/service/project_train_service.py index d8ed200..6d1de57 100644 --- a/app/service/project_train_service.py +++ b/app/service/project_train_service.py @@ -212,7 +212,7 @@ def run_commend(data: str, project: str, train = ProjectTrain() train.project_id = project_id train.train_version = name - bast_pt_path = os.file_path(project, name, 'weights', 'bast.pt') + bast_pt_path = os.file_path(project, name, 'weights', 'best.pt') last_pt_path = os.file_path(project, name, 'weights', 'last.pt') train.best_pt = bast_pt_path train.last_pt = last_pt_path