diff --git a/app/api/business/project_detect_api.py b/app/api/business/project_detect_api.py new file mode 100644 index 0000000..c15a556 --- /dev/null +++ b/app/api/business/project_detect_api.py @@ -0,0 +1,23 @@ +from typing import List +from fastapi import APIRouter, Depends, Request, UploadFile, File, Form +from fastapi.responses import StreamingResponse +from sqlalchemy.orm import Session + +from app.common import reponse_code as rc +from app.model.crud import project_detect_crud as pdc +from app.model.schemas.project_detect_schemas import ProjectDetectPager +from app.db.db_session import get_db + +detect = APIRouter() + + +@detect.post("/detect_pager") +def get_pager(detect_pager: ProjectDetectPager, session: Session = Depends(get_db)): + """ + 获取训练集的照片 + :param detect_pager: + :param session: + :return: + """ + pager = pdc.get_detect_pager(detect_pager, session) + return rc.response_success_pager(pager) \ No newline at end of file diff --git a/app/api/business/project_train_api.py b/app/api/business/project_train_api.py index acd2dc0..bc9d263 100644 --- a/app/api/business/project_train_api.py +++ b/app/api/business/project_train_api.py @@ -9,7 +9,7 @@ from app.model.schemas.project_image_schemas import ProjectImgLeaferLabel, Proje from app.model.bussiness_model import ProjectLabel as pl from app.common.jwt_check import get_user_id from app.common import reponse_code as rc -from app.service import project_service as ps +from app.service import project_train_service as ps from app.db.db_session import get_db from typing import List @@ -37,7 +37,7 @@ def project_pager(info: ProjectInfoPager, session: Session = Depends(get_db)): :param session: :return: """ - pager = pic.get_project_pager(info, session) + pager = pic.get_project_pager2(info, session) return rc.response_success_pager(pager) @@ -51,7 +51,7 @@ def add_project(request: Request, info: ProjectInfoIn, session: Session = Depend :return: """ if pic.check_project_name(info.project_name, session): - return rc.response_error("已经存在相同名称的项目") + return rc.response_error("已经存在相同名称的任务") user_id = get_user_id(request) project_id = ps.add_project(info, session, user_id) return rc.response_success(msg="新建成功", data=project_id) @@ -69,6 +69,18 @@ def get_project(project_id: int, session: Session = Depends(get_db)): return rc.response_success(data=project_info.dict()) +@project.get("/del/{project_id}") +def del_project(project_id: int, session: Session = Depends(get_db)): + """ + 删除项目,假删 + :param project_id: + :param session: + :return: + """ + pic.del_project(project_id, session) + return rc.response_success(msg="删除成功") + + @project.get("/label_list/{project_id}") def get_label_list(project_id: int, session: Session = Depends(get_db)): """ @@ -147,7 +159,19 @@ def upload_project_image(project_id: int = Form(...), return rc.response_success(msg="上传成功") -@project.get("/img_list") +@project.get("/del_img/{image_id}") +def del_image(image_id: int, session: Session = Depends(get_db)): + """ + 删除图片 + :param image_id: + :param session: + :return: + """ + ps.del_img(image_id, session) + return rc.response_success("删除成功") + + +@project.post("/img_list") def get_image_list(image: ProjectImagePager, session: Session = Depends(get_db)): """ 获取项目图片列表 @@ -160,7 +184,7 @@ def get_image_list(image: ProjectImagePager, session: Session = Depends(get_db)) result = jsonable_encoder(image_list) return rc.response_success(data=result) else: - pager = pimc.get_image_pager(image, session) + pager = pimc.get_image_pager2(image, session) return rc.response_success_pager(pager) @@ -206,7 +230,7 @@ async def run_train(project_id: int, session: Session = Depends(get_db)): return rc.response_error("项目当前存在训练进程,请稍后再试") data, project_name, name = ps.run_train_yolo(project_info, session) return StreamingResponse( - ps.run_commend(data, project_name, name, 10, project_id, session), + ps.run_commend(data, project_name, name, 100, project_id, session), media_type="text/plain") diff --git a/app/config/application_config_dev.ini b/app/config/application_config_dev.ini index 3857c15..dd71f0e 100644 --- a/app/config/application_config_dev.ini +++ b/app/config/application_config_dev.ini @@ -1,5 +1,5 @@ [mysql] -database_url = mysql+pymysql://root:Aicheck2025@1.92.105.242:3306/aicheckv2 +database_url = mysql+pymysql://root:root@localhost:3306/aicheckv2 [redis] host = localhost @@ -13,6 +13,7 @@ dir = D:\syg\workspace\logs [yolo] datasets_url = D:\syg\yolov5\datasets runs_url = D:\syg\yolov5\runs +detect_url = D:\syg\yolov5\detect yolo_url = D:\syg\workspace\aicheckv2\yolov5 [images] diff --git a/app/config/application_config_prod.ini b/app/config/application_config_prod.ini index 1e876f5..2c30229 100644 --- a/app/config/application_config_prod.ini +++ b/app/config/application_config_prod.ini @@ -13,6 +13,7 @@ dir = /home/aicheckv2/logs [yolo] datasets_url = /home/aicheckv2/yolov5/datasets runs_url = /home/aicheckv2/yolov5/runs +detect_url = /home/aicheckv2/yolov5/detect yolo_url = /home/aicheckv2/backend/yolov5 [images] diff --git a/app/config/config_reader.py b/app/config/config_reader.py index 5f14c8a..db27cec 100644 --- a/app/config/config_reader.py +++ b/app/config/config_reader.py @@ -24,6 +24,7 @@ log_dir = config.get('log', 'dir') datasets_url = config.get('yolo', 'datasets_url') runs_url = config.get('yolo', 'runs_url') +detect_url = config.get('yolo', 'detect_url') yolo_url = config.get('yolo', 'yolo_url') images_url = config.get('images', 'image_url') diff --git a/app/model/bussiness_model.py b/app/model/bussiness_model.py index cffd8b7..8f663d4 100644 --- a/app/model/bussiness_model.py +++ b/app/model/bussiness_model.py @@ -27,6 +27,7 @@ class ProjectInfo(DbCommon): project_status: Mapped[str] = mapped_column(String(10)) user_id: Mapped[int] = mapped_column(Integer) train_version: Mapped[int] = mapped_column(Integer) + del_flag: Mapped[int] = mapped_column(Integer) class ProjectLabel(DbCommon): @@ -73,7 +74,9 @@ class ProjectImgLabel(DbCommon): class ProjectTrain(DbCommon): - """项目训练版本信息表""" + """ + 项目训练版本信息表 + """ __tablename__ = "project_train" project_id: Mapped[int] = mapped_column(Integer, nullable=False) train_version: Mapped[str] = mapped_column(String(32), nullable=False) @@ -83,9 +86,48 @@ class ProjectTrain(DbCommon): class ProjectDetect(DbCommon): """ - 训练推理集合 + 项目推理集合 """ __tablename__ = "project_detect" - project_id: Mapped[str] = mapped_column(Integer, nullable=False) + project_id: Mapped[int] = mapped_column(Integer, nullable=False) + detect_name: Mapped[str] = mapped_column(String(64), nullable=False) + detect_version: Mapped[int] = mapped_column(Integer) + detect_no: Mapped[str] = mapped_column(String(32)) + detect_status: Mapped[int] = mapped_column(Integer) + file_type: Mapped[str] = mapped_column(String(10)) + folder_url: Mapped[str] = mapped_column(String(255)) +class ProjectDetectImg(DbCommon): + """ + 推理之前的图片 + """ + __tablename__ = "project_detect_img" + detect_id: Mapped[int] = mapped_column(Integer, nullable=False) + file_name: Mapped[str] = mapped_column(String(64), nullable=False) + image_url: Mapped[str] = mapped_column(String(255), nullable=False) + thumb_image_url: Mapped[str] = mapped_column(String(255), nullable=False) + + +class ProjectDetectLog(DbCommon): + """ + 项目推理记录 + """ + __tablename__ = "project_detect_log" + detect_id: Mapped[int] = mapped_column(Integer, nullable=False) + detect_version: Mapped[str] = mapped_column(String(10)) + train_id: Mapped[int] = mapped_column(Integer, nullable=False) + train_version: Mapped[str] = mapped_column(String(10)) + pt_type: Mapped[str] = mapped_column(String(32)) + folder_url: Mapped[str] = mapped_column(String(255)) + detect_folder_url: Mapped[str] = mapped_column(String(255)) + + +class ProjectDetectLogImg(DbCommon): + """ + 推理完成的图片 + """ + __tablename__ = "project_detect_log_img" + log_id: Mapped[int] = mapped_column(Integer, nullable=False) + file_name: Mapped[str] = mapped_column(String(64), nullable=False) + image_url: Mapped[str] = mapped_column(String(255), nullable=False) diff --git a/app/model/crud/project_detect_crud.py b/app/model/crud/project_detect_crud.py new file mode 100644 index 0000000..d4fc0de --- /dev/null +++ b/app/model/crud/project_detect_crud.py @@ -0,0 +1,88 @@ +from typing import List +from sqlalchemy.orm import Session +from sqlalchemy import asc, and_ +from fastapi.encoders import jsonable_encoder +from fastapi import UploadFile + +from app.model.bussiness_model import ProjectDetect, ProjectDetectImg, ProjectDetectLog, ProjectDetectLogImg +from app.model.schemas.project_detect_schemas import ProjectDetectOut, ProjectDetectPager +from app.db.page_util import get_pager + + +def add_detect(detect: ProjectDetect, session: Session): + """ + 新增推理集合 + :param detect: + :param session: + :return: + """ + session.add(detect) + session.commit() + return detect + + +def get_detect_pager(detect_pager: ProjectDetectPager, session: Session): + """ + 查询推理集合分页数据 + :param detect_pager: + :param session: + :return: + """ + query = session.query(ProjectDetect).order_by(asc(ProjectDetect.id)) + filters = [ProjectDetect.project_id == detect_pager.project_id] + if detect_pager.detect_name is not None: + filters.append(ProjectDetect.detect_name.ilike(f"%{detect_pager.detect_name}%")) + query = query.filter(and_(*filters)) + pager = get_pager(query, detect_pager.pagerNum, detect_pager.pagerSize) + pager.data = [ProjectDetectOut.from_orm(user) for user in pager.data] + pager.data = jsonable_encoder(pager.data) + return pager + + +def check_img_name(detect_id: int, file_name: str, session: Session): + """ + 校验上传的图片名称是否重名 + :param detect_id: + :param file_name: + :param session: + :return: + """ + image = session.query(ProjectDetectImg).filter_by(detect_id=detect_id).filter_by(file_name=file_name).first() + if image is None: + return True, None + else: + return False, image.file_name + + +def add_detect_imgs(detect_imgs: List[ProjectDetectImg], session: Session): + """ + 添加推理集合中的图片 + :param detect_imgs: + :param session: + :return: + """ + session.add_all(detect_imgs) + session.commit() + + +def add_detect_log(detect_log: ProjectDetectLog, session: Session): + """ + 新增推理记录 + :param detect_log: + :param session: + :return: + """ + session.add(detect_log) + session.commit() + return detect_log + + +def add_detect_log_imgs(detect_log_imgs: List[ProjectDetectLogImg], session: Session): + """ + 新增推理记录中的照片 + :param detect_log_imgs: + :param session: + :return: + """ + session.add_all(detect_log_imgs) + session.commit() \ No newline at end of file diff --git a/app/model/crud/project_image_crud.py b/app/model/crud/project_image_crud.py index 22a90e8..aeb8316 100644 --- a/app/model/crud/project_image_crud.py +++ b/app/model/crud/project_image_crud.py @@ -1,9 +1,10 @@ from sqlalchemy.orm import Session -from sqlalchemy import asc +from sqlalchemy import asc, func from typing import List +from fastapi.encoders import jsonable_encoder -from app.model.bussiness_model import ProjectImage as piModel -from app.model.schemas.project_image_schemas import ProjectImage, ProjectImagePager +from app.model.bussiness_model import ProjectImage as piModel, ProjectImgLabel +from app.model.schemas.project_image_schemas import ProjectImage, ProjectImagePager, ProjectImageOut from app.db.page_util import get_pager @@ -14,6 +15,35 @@ def get_image_pager(image: ProjectImagePager, session: Session): return pager +def get_image_pager2(image: ProjectImage, session: Session): + # 1 子查询 + subquery = ( + session.query( + ProjectImgLabel.image_id, + func.ifnull(func.count(ProjectImgLabel.id), 0).label('label_count') + ) + .group_by(ProjectImgLabel.image_id) + .subquery() + ) + # 2 主查询 + query = ( + session.query( + piModel, + func.ifnull(subquery.c.label_count, 0).label('label_count') + ) + .outerjoin(subquery, piModel.id == subquery.c.image_id) + ) + query = query.filter(piModel.project_id == image.project_id).order_by(asc(piModel.id)) + pager = get_pager(query, image.pagerNum, image.pagerSize) + datas = [] + for result in pager.data: + data = ProjectImageOut.from_orm(result[0]) + data.label_count = result[1] + datas.append(data) + pager.data = jsonable_encoder(datas) + return pager + + def check_img_name(project_id: int, file_name: str, session: Session): """ 根据项目id和文件名称进行查重 diff --git a/app/model/crud/project_img_leafer_label_crud.py b/app/model/crud/project_img_leafer_label_crud.py index bd1ce28..5887618 100644 --- a/app/model/crud/project_img_leafer_label_crud.py +++ b/app/model/crud/project_img_leafer_label_crud.py @@ -23,7 +23,7 @@ def get_img_label_list(image_id, session: Session): :param session: :return: """ - img_label_list = session.query(ProjectImgLabel).filter_by(image_id=image_id).all() + img_label_list = session.query(ProjectImgLabel).filter_by(image_id=image_id).order_by(ProjectImgLabel.label_id).all() return img_label_list diff --git a/app/model/crud/project_info_crud.py b/app/model/crud/project_info_crud.py index 2804ccd..2904b6e 100644 --- a/app/model/crud/project_info_crud.py +++ b/app/model/crud/project_info_crud.py @@ -1,33 +1,72 @@ from sqlalchemy.orm import Session -from sqlalchemy import desc, update +from sqlalchemy import desc, and_, func, case +from fastapi.encoders import jsonable_encoder -from app.model.bussiness_model import ProjectInfo -from app.model.schemas.project_info_schemas import ProjectInfoOut -from app.model.schemas.project_info_schemas import ProjectInfoPager +from app.model.bussiness_model import ProjectInfo, ProjectImage, ProjectImgLeafer +from app.model.schemas.project_info_schemas import ProjectInfoOut, ProjectInfoPager, ProjectInfoPagerOut from app.db.page_util import get_pager def get_project_pager(info: ProjectInfoPager, session: Session): """分页查询项目信息""" query = session.query(ProjectInfo).order_by(desc(ProjectInfo.id)) - filters = [] - if info.project_name is not None: + filters = [ProjectInfo.del_flag == 0] + if info.project_name is not None and info.project_name != '': filters.append(ProjectInfo.project_name.ilike(f"%{info.project_name}%")) if len(filters) > 0: - query.filter(*filters) + query = query.filter(and_(*filters)) pager = get_pager(query, info.pagerNum, info.pagerSize) pager.data = [ProjectInfoOut.from_orm(info).dict() for info in pager.data] return pager +def get_project_pager2(info: ProjectInfoPager, session: Session): + # 1. 定义子查询 + subquery = ( + session.query( + ProjectImage.project_id, + func.sum(case((ProjectImgLeafer.id.is_(None), 1), else_=0)).label('no_mark_count'), + func.sum(case((ProjectImgLeafer.id.isnot(None), 1), else_=0)).label('mark_count') + ) + .outerjoin(ProjectImgLeafer, ProjectImage.id == ProjectImgLeafer.image_id) + .group_by(ProjectImage.project_id) + .subquery() + ) + + # 2. 主查询 + query = ( + session.query( + ProjectInfo, + func.ifnull(subquery.c.mark_count, 0).label('mark_count'), + func.ifnull(subquery.c.no_mark_count, 0).label('no_mark_count') + ) + .outerjoin(subquery, ProjectInfo.id == subquery.c.project_id) + ) + query = query.order_by(desc(ProjectInfo.id)) + filters = [ProjectInfo.del_flag == 0] + if info.project_name is not None and info.project_name != '': + filters.append(ProjectInfo.project_name.ilike(f"%{info.project_name}%")) + query = query.filter(and_(*filters)) + pager = get_pager(query, info.pagerNum, info.pagerSize) + datas = [] + for result in pager.data: + data = ProjectInfoPagerOut.from_orm(result[0]) + data.mark_count = result[1] + data.no_mark_count = result[2] + datas.append(data) + pager.data = jsonable_encoder(datas) + return pager + + def get_project_by_id(project_id: str, session: Session): - info = session.query(ProjectInfo).filter_by(id=project_id).first() + info = session.query(ProjectInfo).filter_by(id=project_id).filter_by(del_flag=0).first() info_out = ProjectInfoOut.from_orm(info) return info_out def add_project(info: ProjectInfo, session: Session): """新建项目,并在对应文件夹下面创建文件夹""" + info.del_flag = 0 session.add(info) session.commit() return info @@ -35,7 +74,8 @@ def add_project(info: ProjectInfo, session: Session): def check_project_name(project_name: str, session: Session): """检验是否存在重名的项目名称""" - count = session.query(ProjectInfo).filter(ProjectInfo.project_name == project_name).count() + count = session.query(ProjectInfo).filter(ProjectInfo.project_name == project_name)\ + .filter(ProjectInfo.del_flag==0).count() if count > 0: return True else: @@ -60,3 +100,16 @@ def update_project_status(project_id: int, project_status: str, session: Session 'project_status': project_status }) session.commit() + + +def del_project(project_id: int, session: Session): + """ + 删除项目 + :param project_id: + :param session: + :return: + """ + session.query(ProjectInfo).filter_by(id=project_id).update({ + 'del_flag': 1 + }) + session.commit() diff --git a/app/model/crud/project_label_crud.py b/app/model/crud/project_label_crud.py index 45e4970..4583cc8 100644 --- a/app/model/crud/project_label_crud.py +++ b/app/model/crud/project_label_crud.py @@ -20,7 +20,7 @@ def get_label_list(project_id: int, session: Session): def get_label_for_train(project_id: int, session: Session): id_list = [] name_list = [] - label_list = session.query(plModel).filter(plModel.project_id == project_id).all() + label_list = session.query(plModel).filter(plModel.project_id == project_id).order_by(plModel.id).all() for label in label_list: id_list.append(label.id) name_list.append(label.label_name) diff --git a/app/model/crud/project_train_crud.py b/app/model/crud/project_train_crud.py index 41ee35b..30346cb 100644 --- a/app/model/crud/project_train_crud.py +++ b/app/model/crud/project_train_crud.py @@ -26,3 +26,14 @@ def get_train_list(project_id: int, session: Session): query = session.query(ProjectTrain).filter_by(project_id=project_id).order_by(asc(ProjectTrain.id)) train_list = [ProjectTrainOut.from_orm(train) for train in query.all()] return train_list + + +def get_train(train_id: int, session: Session): + """ + 根据id查询训练信息 + :param train_id: + :param session: + :return: + """ + train = session.query(ProjectTrain).filter_by(id=train_id).first() + return train diff --git a/app/model/crud/project_type_crud.py b/app/model/crud/project_type_crud.py index 291687f..eeae2fa 100644 --- a/app/model/crud/project_type_crud.py +++ b/app/model/crud/project_type_crud.py @@ -7,7 +7,6 @@ from sqlalchemy.orm import Session def get_list(session: Session): """获取项目类型列表""" - query = session.query(ProjectType).order_by(asc(ProjectType.id)) - query.filter(ProjectType.type_status == "0") + query = session.query(ProjectType).filter(ProjectType.type_status == "0").order_by(asc(ProjectType.id)) result_list = [ProjectTypeOut.from_orm(project_type).dict() for project_type in query.all()] return result_list diff --git a/app/model/schemas/project_detect_schemas.py b/app/model/schemas/project_detect_schemas.py new file mode 100644 index 0000000..f3cdf46 --- /dev/null +++ b/app/model/schemas/project_detect_schemas.py @@ -0,0 +1,69 @@ +from pydantic import BaseModel, Field +from typing import Optional +from datetime import datetime + + +class ProjectDetectIn(BaseModel): + project_id: Optional[int] = Field(..., description="项目id") + file_type: Optional[str] = Field('img', description="推理集合文件类别") + detect_name: Optional[str] = Field(..., description="推理集合名称") + + +class ProjectDetectPager(BaseModel): + project_id: Optional[int] = Field(..., description="项目id") + detect_name: Optional[str] = Field(None, description="推理集合名称") + pagerNum: Optional[int] = Field(1, description="当前页码") + pagerSize: Optional[int] = Field(10, description="每页数量") + + +class ProjectDetectOut(BaseModel): + id: Optional[int] + project_id: Optional[int] + detect_name: Optional[int] + detect_no: Optional[str] + detect_version: Optional[int] + file_type: Optional[str] + folder_url: Optional[str] + create_time: Optional[datetime] + + class Config: + orm_mode = True + json_encoders = { + datetime: lambda v: v.strftime("%Y-%m-%d %H:%M:%S") + } + + +class ProjectDetectLogIn(BaseModel): + detect_id: Optional[int] = Field(..., description="推理集合id") + train_id: Optional[int] = Field(..., description="训练结果id") + pt_type: Optional[str] = Field('best', description="权重文件类型") + + +class ProjectDetectLogOut(BaseModel): + id: Optional[int] + detect_id: Optional[int] + detect_version: Optional[str] + train_id: Optional[int] + train_version: Optional[int] + pt_type: Optional[str] + create_time: Optional[datetime] + + class Config: + orm_mode = True + json_encoders = { + datetime: lambda v: v.strftime("%Y-%m-%d %H:%M:%S") + } + + +class ProjectDetectLogImgOut(BaseModel): + id: Optional[int] + file_name: Optional[str] + create_time: Optional[datetime] + + class Config: + orm_mode = True + json_encoders = { + datetime: lambda v: v.strftime("%Y-%m-%d %H:%M:%S") + } + + diff --git a/app/model/schemas/project_image_schemas.py b/app/model/schemas/project_image_schemas.py index 2c117c1..bcc8e79 100644 --- a/app/model/schemas/project_image_schemas.py +++ b/app/model/schemas/project_image_schemas.py @@ -16,10 +16,24 @@ class ProjectImage(BaseModel): } +class ProjectImageOut(BaseModel): + id: Optional[int] = Field(None, description="id") + project_id: Optional[int] = Field(..., description="项目id") + file_name: Optional[str] = Field(None, description="文件名称") + create_time: Optional[datetime] = Field(None, description="上传时间") + label_count: Optional[int] + + class Config: + orm_mode = True + json_encoders = { + datetime: lambda v: v.strftime("%Y-%m-%d %H:%M:%S") + } + + class ProjectImagePager(BaseModel): project_id: Optional[int] = Field(..., description="项目id") - pagerNum: Optional[int] = Field(1, description="当前页码") - pagerSize: Optional[int] = Field(10, description="每页数量") + pagerNum: Optional[int] = Field(None, description="当前页码") + pagerSize: Optional[int] = Field(None, description="每页数量") class ProjectImgLabelIn(BaseModel): diff --git a/app/model/schemas/project_info_schemas.py b/app/model/schemas/project_info_schemas.py index 7c489b5..64c37d7 100644 --- a/app/model/schemas/project_info_schemas.py +++ b/app/model/schemas/project_info_schemas.py @@ -24,6 +24,22 @@ class ProjectInfoOut(BaseModel): orm_mode = True +class ProjectInfoPagerOut(BaseModel): + """项目信息输出""" + id: Optional[int] = Field(None, description="项目id") + project_no: Optional[str] = Field(None, description="项目编号") + project_name: Optional[str] = Field(None, description="项目名称") + type_code: Optional[str] = Field(None, description="项目类型编码") + description: Optional[str] = Field(None, description="项目描述") + train_version: Optional[int] = Field(None, description="训练版本号") + project_status: Optional[str] = Field(None, description="项目状态") + mark_count: Optional[int] + no_mark_count: Optional[int] + + class Config: + orm_mode = True + + class ProjectInfoPager(BaseModel): project_name: Optional[str] = Field(None, description="项目名称") pagerNum: Optional[int] = Field(1, description="当前页码") diff --git a/app/model/schemas/project_train_schemas.py b/app/model/schemas/project_train_schemas.py index 7f38c77..e2a736a 100644 --- a/app/model/schemas/project_train_schemas.py +++ b/app/model/schemas/project_train_schemas.py @@ -13,4 +13,4 @@ class ProjectTrainOut(BaseModel): orm_mode = True json_encoders = { datetime: lambda v: v.strftime("%Y-%m-%d %H:%M:%S") - } \ No newline at end of file + } diff --git a/app/service/project_detect_service.py b/app/service/project_detect_service.py new file mode 100644 index 0000000..aa3268c --- /dev/null +++ b/app/service/project_detect_service.py @@ -0,0 +1,132 @@ +from sqlalchemy.orm import Session +from typing import List +from fastapi import UploadFile +import subprocess + +from app.model.crud import project_detect_crud as pdc +from app.model.schemas.project_detect_schemas import ProjectDetectIn, ProjectDetectOut, ProjectDetectLogIn +from app.model.bussiness_model import ProjectDetect, ProjectDetectImg, ProjectTrain, ProjectDetectLog +from app.util.random_utils import random_str +from app.config.config_reader import detect_url +from app.util import os_utils as os +from app.util import random_utils as ru +from app.config.config_reader import yolo_url + + +def add_detect(detect_in: ProjectDetectIn, session: Session): + """ + 新增训练集合信息,并创建文件夹 + :param detect_in: + :param session: + :return: + """ + detect = ProjectDetect(**detect_in.dict()) + detect.detect_no = random_str(6) + detect.detect_version = 0 + url = os.create_folder(detect_url, detect.detect_no, 'images') + detect.folder_url = url + detect = pdc.add_detect(detect, session) + return detect + + +def check_image_name(detect_id: int, files: List[UploadFile], session: Session): + """ + 校验上传的文件名称是否重复 + :param detect_id: + :param files: + :param session: + :return: + """ + for file in files: + if not pdc.check_img_name(detect_id, file.filename, session): + return False, file.filename + return True, None + + +def upload_detect_imgs(detect: ProjectDetectOut, files: List[UploadFile], session: Session): + """ + 上传推理集合的照片,保存原图,并生成缩略图 + :param detect: + :param files: + :param session: + :return: + """ + images = [] + for file in files: + image = ProjectDetectImg() + image.detect_id = detect.id + image.file_name = file.filename + # 保存原图 + path = os.save_images(detect.folder_url, file=file) + image.image_url = path + # 生成缩略图 + thumb_image_url = os.file_path(detect.folder_url, 'thumb', ru.random_str(10) + ".jpg") + os.create_thumbnail(path, thumb_image_url) + image.thumb_image_url = thumb_image_url + images.append(image) + pdc.add_detect_imgs(images, session) + + +def run_detect_yolo(detect_in: ProjectDetectLogIn, detect: ProjectDetect, train: ProjectTrain, session: Session): + """ + 开始推理 + :param detect: + :param detect_in: + :param train: + :param session: + :return: + """ + # 推理版本 + version_path = 'v' + str(detect.detect_version + 1) + + # 权重文件 + pt_url = train.best_pt if detect_in.pt_type == 'best' else train.last_pt + + # 推理集合文件路径 + img_url = detect.folder_url + + out_url = os.file_path(detect_url, detect.detect_no, 'detect') + + # 构建推理记录数据 + detect_log = ProjectDetectLog() + detect_log.detect_id = detect.id + detect_log.detect_version = version_path + detect_log.train_id = train.id + detect_log.train_version = train.train_version + detect_log.pt_type = detect_in.pt_type + detect_log.folder_url = detect.folder_url + detect_log.detect_folder_url = out_url + detect_log = pdc.add_detect_log(detect_log, session) + return detect_log + + +def run_commend(weights: str, source: str, project: str, name: str, + detect_log_id: int, session: Session): + yolo_path = os.file_path(yolo_url, 'detect.py') + + yield f"stdout: 模型推理开始,请稍等。。。 \n" + # 启动子进程 + with subprocess.Popen( + ["python", '-u', yolo_path, + "--weights =" + weights, + "--source =" + source, + "--name=" + name, + "--project=" + project, + "--view-img"], + bufsize=1, # bufsize=0时,为不缓存;bufsize=1时,按行缓存;bufsize为其他正整数时,为按照近似该正整数的字节数缓存 + shell=False, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, # 这里可以显示yolov5训练过程中出现的进度条等信息 + text=True, # 缓存内容为文本,避免后续编码显示问题 + encoding='utf-8', + ) as process: + while process.poll() is None: + line = process.stdout.readline() + process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死 + if line != '\n': + yield line + + + + + diff --git a/app/service/project_service.py b/app/service/project_train_service.py similarity index 93% rename from app/service/project_service.py rename to app/service/project_train_service.py index 05b1c21..d8ed200 100644 --- a/app/service/project_service.py +++ b/app/service/project_train_service.py @@ -67,6 +67,22 @@ def upload_project_image(project_info: ProjectInfoOut, files: List[UploadFile], pimc.add_image_batch(images, session) +def del_img(image_id: int, session: Session): + """ + 删除图片,并删除文件 + :param image_id: + :param session: + :return: + """ + image = session.query(ProjectImage).filter_by(id=image_id).first() + if image is None: + return 0 + os.delete_file_if_exists(image.image_url, image.thumb_image_url) + session.delete(image) + session.commit() + + + def save_img_label(img_leafer_label: ProjectImgLeaferLabel, session: Session): """ 保存图片的标签框选信息,每次保存都会针对图片的信息全部删除,然后重新保存 @@ -164,7 +180,8 @@ def run_commend(data: str, project: str, name: str, epochs: int, project_id: int, session: Session): yolo_path = os.file_path(yolo_url, 'train.py') - yield f"stdout: 模型训练开始,请稍等。。。" + + yield f"stdout: 模型训练开始,请稍等。。。\n" # 启动子进程 with subprocess.Popen( ["python", '-u', yolo_path, @@ -195,8 +212,8 @@ def run_commend(data: str, project: str, train = ProjectTrain() train.project_id = project_id train.train_version = name - bast_pt_path = os.file_path(project, name, 'weight', 'bast.pt') - last_pt_path = os.file_path(project, name, 'weight', 'last.pt') + bast_pt_path = os.file_path(project, name, 'weights', 'bast.pt') + last_pt_path = os.file_path(project, name, 'weights', 'last.pt') train.best_pt = bast_pt_path train.last_pt = last_pt_path ptc.add_train(train, session) diff --git a/app/util/os_utils.py b/app/util/os_utils.py index 0788c33..378f0c9 100644 --- a/app/util/os_utils.py +++ b/app/util/os_utils.py @@ -79,3 +79,14 @@ def copy_and_rename_file(src_file_path, dst_dir, new_name): # 复制文件到目标位置并重命名 shutil.copy(src_file_path, dst_file_path) + + +def delete_file_if_exists(*file_paths: str): + """ + 删除文件 + :param file_path: + :return: + """ + for path in file_paths: + if os.path.exists(path): # 检查文件是否存在 + os.remove(path) # 删除文件