完善相关问题,并增加推理的部分代码
This commit is contained in:
23
app/api/business/project_detect_api.py
Normal file
23
app/api/business/project_detect_api.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from typing import List
|
||||||
|
from fastapi import APIRouter, Depends, Request, UploadFile, File, Form
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.common import reponse_code as rc
|
||||||
|
from app.model.crud import project_detect_crud as pdc
|
||||||
|
from app.model.schemas.project_detect_schemas import ProjectDetectPager
|
||||||
|
from app.db.db_session import get_db
|
||||||
|
|
||||||
|
detect = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@detect.post("/detect_pager")
|
||||||
|
def get_pager(detect_pager: ProjectDetectPager, session: Session = Depends(get_db)):
|
||||||
|
"""
|
||||||
|
获取训练集的照片
|
||||||
|
:param detect_pager:
|
||||||
|
:param session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
pager = pdc.get_detect_pager(detect_pager, session)
|
||||||
|
return rc.response_success_pager(pager)
|
@ -9,7 +9,7 @@ from app.model.schemas.project_image_schemas import ProjectImgLeaferLabel, Proje
|
|||||||
from app.model.bussiness_model import ProjectLabel as pl
|
from app.model.bussiness_model import ProjectLabel as pl
|
||||||
from app.common.jwt_check import get_user_id
|
from app.common.jwt_check import get_user_id
|
||||||
from app.common import reponse_code as rc
|
from app.common import reponse_code as rc
|
||||||
from app.service import project_service as ps
|
from app.service import project_train_service as ps
|
||||||
from app.db.db_session import get_db
|
from app.db.db_session import get_db
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
@ -37,7 +37,7 @@ def project_pager(info: ProjectInfoPager, session: Session = Depends(get_db)):
|
|||||||
:param session:
|
:param session:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
pager = pic.get_project_pager(info, session)
|
pager = pic.get_project_pager2(info, session)
|
||||||
return rc.response_success_pager(pager)
|
return rc.response_success_pager(pager)
|
||||||
|
|
||||||
|
|
||||||
@ -51,7 +51,7 @@ def add_project(request: Request, info: ProjectInfoIn, session: Session = Depend
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if pic.check_project_name(info.project_name, session):
|
if pic.check_project_name(info.project_name, session):
|
||||||
return rc.response_error("已经存在相同名称的项目")
|
return rc.response_error("已经存在相同名称的任务")
|
||||||
user_id = get_user_id(request)
|
user_id = get_user_id(request)
|
||||||
project_id = ps.add_project(info, session, user_id)
|
project_id = ps.add_project(info, session, user_id)
|
||||||
return rc.response_success(msg="新建成功", data=project_id)
|
return rc.response_success(msg="新建成功", data=project_id)
|
||||||
@ -69,6 +69,18 @@ def get_project(project_id: int, session: Session = Depends(get_db)):
|
|||||||
return rc.response_success(data=project_info.dict())
|
return rc.response_success(data=project_info.dict())
|
||||||
|
|
||||||
|
|
||||||
|
@project.get("/del/{project_id}")
|
||||||
|
def del_project(project_id: int, session: Session = Depends(get_db)):
|
||||||
|
"""
|
||||||
|
删除项目,假删
|
||||||
|
:param project_id:
|
||||||
|
:param session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
pic.del_project(project_id, session)
|
||||||
|
return rc.response_success(msg="删除成功")
|
||||||
|
|
||||||
|
|
||||||
@project.get("/label_list/{project_id}")
|
@project.get("/label_list/{project_id}")
|
||||||
def get_label_list(project_id: int, session: Session = Depends(get_db)):
|
def get_label_list(project_id: int, session: Session = Depends(get_db)):
|
||||||
"""
|
"""
|
||||||
@ -147,7 +159,19 @@ def upload_project_image(project_id: int = Form(...),
|
|||||||
return rc.response_success(msg="上传成功")
|
return rc.response_success(msg="上传成功")
|
||||||
|
|
||||||
|
|
||||||
@project.get("/img_list")
|
@project.get("/del_img/{image_id}")
|
||||||
|
def del_image(image_id: int, session: Session = Depends(get_db)):
|
||||||
|
"""
|
||||||
|
删除图片
|
||||||
|
:param image_id:
|
||||||
|
:param session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
ps.del_img(image_id, session)
|
||||||
|
return rc.response_success("删除成功")
|
||||||
|
|
||||||
|
|
||||||
|
@project.post("/img_list")
|
||||||
def get_image_list(image: ProjectImagePager, session: Session = Depends(get_db)):
|
def get_image_list(image: ProjectImagePager, session: Session = Depends(get_db)):
|
||||||
"""
|
"""
|
||||||
获取项目图片列表
|
获取项目图片列表
|
||||||
@ -160,7 +184,7 @@ def get_image_list(image: ProjectImagePager, session: Session = Depends(get_db))
|
|||||||
result = jsonable_encoder(image_list)
|
result = jsonable_encoder(image_list)
|
||||||
return rc.response_success(data=result)
|
return rc.response_success(data=result)
|
||||||
else:
|
else:
|
||||||
pager = pimc.get_image_pager(image, session)
|
pager = pimc.get_image_pager2(image, session)
|
||||||
return rc.response_success_pager(pager)
|
return rc.response_success_pager(pager)
|
||||||
|
|
||||||
|
|
||||||
@ -206,7 +230,7 @@ async def run_train(project_id: int, session: Session = Depends(get_db)):
|
|||||||
return rc.response_error("项目当前存在训练进程,请稍后再试")
|
return rc.response_error("项目当前存在训练进程,请稍后再试")
|
||||||
data, project_name, name = ps.run_train_yolo(project_info, session)
|
data, project_name, name = ps.run_train_yolo(project_info, session)
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
ps.run_commend(data, project_name, name, 10, project_id, session),
|
ps.run_commend(data, project_name, name, 100, project_id, session),
|
||||||
media_type="text/plain")
|
media_type="text/plain")
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
[mysql]
|
[mysql]
|
||||||
database_url = mysql+pymysql://root:Aicheck2025@1.92.105.242:3306/aicheckv2
|
database_url = mysql+pymysql://root:root@localhost:3306/aicheckv2
|
||||||
|
|
||||||
[redis]
|
[redis]
|
||||||
host = localhost
|
host = localhost
|
||||||
@ -13,6 +13,7 @@ dir = D:\syg\workspace\logs
|
|||||||
[yolo]
|
[yolo]
|
||||||
datasets_url = D:\syg\yolov5\datasets
|
datasets_url = D:\syg\yolov5\datasets
|
||||||
runs_url = D:\syg\yolov5\runs
|
runs_url = D:\syg\yolov5\runs
|
||||||
|
detect_url = D:\syg\yolov5\detect
|
||||||
yolo_url = D:\syg\workspace\aicheckv2\yolov5
|
yolo_url = D:\syg\workspace\aicheckv2\yolov5
|
||||||
|
|
||||||
[images]
|
[images]
|
||||||
|
@ -13,6 +13,7 @@ dir = /home/aicheckv2/logs
|
|||||||
[yolo]
|
[yolo]
|
||||||
datasets_url = /home/aicheckv2/yolov5/datasets
|
datasets_url = /home/aicheckv2/yolov5/datasets
|
||||||
runs_url = /home/aicheckv2/yolov5/runs
|
runs_url = /home/aicheckv2/yolov5/runs
|
||||||
|
detect_url = /home/aicheckv2/yolov5/detect
|
||||||
yolo_url = /home/aicheckv2/backend/yolov5
|
yolo_url = /home/aicheckv2/backend/yolov5
|
||||||
|
|
||||||
[images]
|
[images]
|
||||||
|
@ -24,6 +24,7 @@ log_dir = config.get('log', 'dir')
|
|||||||
|
|
||||||
datasets_url = config.get('yolo', 'datasets_url')
|
datasets_url = config.get('yolo', 'datasets_url')
|
||||||
runs_url = config.get('yolo', 'runs_url')
|
runs_url = config.get('yolo', 'runs_url')
|
||||||
|
detect_url = config.get('yolo', 'detect_url')
|
||||||
yolo_url = config.get('yolo', 'yolo_url')
|
yolo_url = config.get('yolo', 'yolo_url')
|
||||||
|
|
||||||
images_url = config.get('images', 'image_url')
|
images_url = config.get('images', 'image_url')
|
||||||
|
@ -27,6 +27,7 @@ class ProjectInfo(DbCommon):
|
|||||||
project_status: Mapped[str] = mapped_column(String(10))
|
project_status: Mapped[str] = mapped_column(String(10))
|
||||||
user_id: Mapped[int] = mapped_column(Integer)
|
user_id: Mapped[int] = mapped_column(Integer)
|
||||||
train_version: Mapped[int] = mapped_column(Integer)
|
train_version: Mapped[int] = mapped_column(Integer)
|
||||||
|
del_flag: Mapped[int] = mapped_column(Integer)
|
||||||
|
|
||||||
|
|
||||||
class ProjectLabel(DbCommon):
|
class ProjectLabel(DbCommon):
|
||||||
@ -73,7 +74,9 @@ class ProjectImgLabel(DbCommon):
|
|||||||
|
|
||||||
|
|
||||||
class ProjectTrain(DbCommon):
|
class ProjectTrain(DbCommon):
|
||||||
"""项目训练版本信息表"""
|
"""
|
||||||
|
项目训练版本信息表
|
||||||
|
"""
|
||||||
__tablename__ = "project_train"
|
__tablename__ = "project_train"
|
||||||
project_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
project_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
train_version: Mapped[str] = mapped_column(String(32), nullable=False)
|
train_version: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||||
@ -83,9 +86,48 @@ class ProjectTrain(DbCommon):
|
|||||||
|
|
||||||
class ProjectDetect(DbCommon):
|
class ProjectDetect(DbCommon):
|
||||||
"""
|
"""
|
||||||
训练推理集合
|
项目推理集合
|
||||||
"""
|
"""
|
||||||
__tablename__ = "project_detect"
|
__tablename__ = "project_detect"
|
||||||
project_id: Mapped[str] = mapped_column(Integer, nullable=False)
|
project_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
|
detect_name: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||||
|
detect_version: Mapped[int] = mapped_column(Integer)
|
||||||
|
detect_no: Mapped[str] = mapped_column(String(32))
|
||||||
|
detect_status: Mapped[int] = mapped_column(Integer)
|
||||||
|
file_type: Mapped[str] = mapped_column(String(10))
|
||||||
|
folder_url: Mapped[str] = mapped_column(String(255))
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectDetectImg(DbCommon):
|
||||||
|
"""
|
||||||
|
推理之前的图片
|
||||||
|
"""
|
||||||
|
__tablename__ = "project_detect_img"
|
||||||
|
detect_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
|
file_name: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||||
|
image_url: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
thumb_image_url: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectDetectLog(DbCommon):
|
||||||
|
"""
|
||||||
|
项目推理记录
|
||||||
|
"""
|
||||||
|
__tablename__ = "project_detect_log"
|
||||||
|
detect_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
|
detect_version: Mapped[str] = mapped_column(String(10))
|
||||||
|
train_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
|
train_version: Mapped[str] = mapped_column(String(10))
|
||||||
|
pt_type: Mapped[str] = mapped_column(String(32))
|
||||||
|
folder_url: Mapped[str] = mapped_column(String(255))
|
||||||
|
detect_folder_url: Mapped[str] = mapped_column(String(255))
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectDetectLogImg(DbCommon):
|
||||||
|
"""
|
||||||
|
推理完成的图片
|
||||||
|
"""
|
||||||
|
__tablename__ = "project_detect_log_img"
|
||||||
|
log_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
|
file_name: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||||
|
image_url: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
88
app/model/crud/project_detect_crud.py
Normal file
88
app/model/crud/project_detect_crud.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
from typing import List
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from sqlalchemy import asc, and_
|
||||||
|
from fastapi.encoders import jsonable_encoder
|
||||||
|
from fastapi import UploadFile
|
||||||
|
|
||||||
|
from app.model.bussiness_model import ProjectDetect, ProjectDetectImg, ProjectDetectLog, ProjectDetectLogImg
|
||||||
|
from app.model.schemas.project_detect_schemas import ProjectDetectOut, ProjectDetectPager
|
||||||
|
from app.db.page_util import get_pager
|
||||||
|
|
||||||
|
|
||||||
|
def add_detect(detect: ProjectDetect, session: Session):
|
||||||
|
"""
|
||||||
|
新增推理集合
|
||||||
|
:param detect:
|
||||||
|
:param session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
session.add(detect)
|
||||||
|
session.commit()
|
||||||
|
return detect
|
||||||
|
|
||||||
|
|
||||||
|
def get_detect_pager(detect_pager: ProjectDetectPager, session: Session):
|
||||||
|
"""
|
||||||
|
查询推理集合分页数据
|
||||||
|
:param detect_pager:
|
||||||
|
:param session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
query = session.query(ProjectDetect).order_by(asc(ProjectDetect.id))
|
||||||
|
filters = [ProjectDetect.project_id == detect_pager.project_id]
|
||||||
|
if detect_pager.detect_name is not None:
|
||||||
|
filters.append(ProjectDetect.detect_name.ilike(f"%{detect_pager.detect_name}%"))
|
||||||
|
query = query.filter(and_(*filters))
|
||||||
|
pager = get_pager(query, detect_pager.pagerNum, detect_pager.pagerSize)
|
||||||
|
pager.data = [ProjectDetectOut.from_orm(user) for user in pager.data]
|
||||||
|
pager.data = jsonable_encoder(pager.data)
|
||||||
|
return pager
|
||||||
|
|
||||||
|
|
||||||
|
def check_img_name(detect_id: int, file_name: str, session: Session):
|
||||||
|
"""
|
||||||
|
校验上传的图片名称是否重名
|
||||||
|
:param detect_id:
|
||||||
|
:param file_name:
|
||||||
|
:param session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
image = session.query(ProjectDetectImg).filter_by(detect_id=detect_id).filter_by(file_name=file_name).first()
|
||||||
|
if image is None:
|
||||||
|
return True, None
|
||||||
|
else:
|
||||||
|
return False, image.file_name
|
||||||
|
|
||||||
|
|
||||||
|
def add_detect_imgs(detect_imgs: List[ProjectDetectImg], session: Session):
|
||||||
|
"""
|
||||||
|
添加推理集合中的图片
|
||||||
|
:param detect_imgs:
|
||||||
|
:param session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
session.add_all(detect_imgs)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def add_detect_log(detect_log: ProjectDetectLog, session: Session):
|
||||||
|
"""
|
||||||
|
新增推理记录
|
||||||
|
:param detect_log:
|
||||||
|
:param session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
session.add(detect_log)
|
||||||
|
session.commit()
|
||||||
|
return detect_log
|
||||||
|
|
||||||
|
|
||||||
|
def add_detect_log_imgs(detect_log_imgs: List[ProjectDetectLogImg], session: Session):
|
||||||
|
"""
|
||||||
|
新增推理记录中的照片
|
||||||
|
:param detect_log_imgs:
|
||||||
|
:param session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
session.add_all(detect_log_imgs)
|
||||||
|
session.commit()
|
@ -1,9 +1,10 @@
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from sqlalchemy import asc
|
from sqlalchemy import asc, func
|
||||||
from typing import List
|
from typing import List
|
||||||
|
from fastapi.encoders import jsonable_encoder
|
||||||
|
|
||||||
from app.model.bussiness_model import ProjectImage as piModel
|
from app.model.bussiness_model import ProjectImage as piModel, ProjectImgLabel
|
||||||
from app.model.schemas.project_image_schemas import ProjectImage, ProjectImagePager
|
from app.model.schemas.project_image_schemas import ProjectImage, ProjectImagePager, ProjectImageOut
|
||||||
from app.db.page_util import get_pager
|
from app.db.page_util import get_pager
|
||||||
|
|
||||||
|
|
||||||
@ -14,6 +15,35 @@ def get_image_pager(image: ProjectImagePager, session: Session):
|
|||||||
return pager
|
return pager
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_pager2(image: ProjectImage, session: Session):
|
||||||
|
# 1 子查询
|
||||||
|
subquery = (
|
||||||
|
session.query(
|
||||||
|
ProjectImgLabel.image_id,
|
||||||
|
func.ifnull(func.count(ProjectImgLabel.id), 0).label('label_count')
|
||||||
|
)
|
||||||
|
.group_by(ProjectImgLabel.image_id)
|
||||||
|
.subquery()
|
||||||
|
)
|
||||||
|
# 2 主查询
|
||||||
|
query = (
|
||||||
|
session.query(
|
||||||
|
piModel,
|
||||||
|
func.ifnull(subquery.c.label_count, 0).label('label_count')
|
||||||
|
)
|
||||||
|
.outerjoin(subquery, piModel.id == subquery.c.image_id)
|
||||||
|
)
|
||||||
|
query = query.filter(piModel.project_id == image.project_id).order_by(asc(piModel.id))
|
||||||
|
pager = get_pager(query, image.pagerNum, image.pagerSize)
|
||||||
|
datas = []
|
||||||
|
for result in pager.data:
|
||||||
|
data = ProjectImageOut.from_orm(result[0])
|
||||||
|
data.label_count = result[1]
|
||||||
|
datas.append(data)
|
||||||
|
pager.data = jsonable_encoder(datas)
|
||||||
|
return pager
|
||||||
|
|
||||||
|
|
||||||
def check_img_name(project_id: int, file_name: str, session: Session):
|
def check_img_name(project_id: int, file_name: str, session: Session):
|
||||||
"""
|
"""
|
||||||
根据项目id和文件名称进行查重
|
根据项目id和文件名称进行查重
|
||||||
|
@ -23,7 +23,7 @@ def get_img_label_list(image_id, session: Session):
|
|||||||
:param session:
|
:param session:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
img_label_list = session.query(ProjectImgLabel).filter_by(image_id=image_id).all()
|
img_label_list = session.query(ProjectImgLabel).filter_by(image_id=image_id).order_by(ProjectImgLabel.label_id).all()
|
||||||
return img_label_list
|
return img_label_list
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,33 +1,72 @@
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from sqlalchemy import desc, update
|
from sqlalchemy import desc, and_, func, case
|
||||||
|
from fastapi.encoders import jsonable_encoder
|
||||||
|
|
||||||
from app.model.bussiness_model import ProjectInfo
|
from app.model.bussiness_model import ProjectInfo, ProjectImage, ProjectImgLeafer
|
||||||
from app.model.schemas.project_info_schemas import ProjectInfoOut
|
from app.model.schemas.project_info_schemas import ProjectInfoOut, ProjectInfoPager, ProjectInfoPagerOut
|
||||||
from app.model.schemas.project_info_schemas import ProjectInfoPager
|
|
||||||
from app.db.page_util import get_pager
|
from app.db.page_util import get_pager
|
||||||
|
|
||||||
|
|
||||||
def get_project_pager(info: ProjectInfoPager, session: Session):
|
def get_project_pager(info: ProjectInfoPager, session: Session):
|
||||||
"""分页查询项目信息"""
|
"""分页查询项目信息"""
|
||||||
query = session.query(ProjectInfo).order_by(desc(ProjectInfo.id))
|
query = session.query(ProjectInfo).order_by(desc(ProjectInfo.id))
|
||||||
filters = []
|
filters = [ProjectInfo.del_flag == 0]
|
||||||
if info.project_name is not None:
|
if info.project_name is not None and info.project_name != '':
|
||||||
filters.append(ProjectInfo.project_name.ilike(f"%{info.project_name}%"))
|
filters.append(ProjectInfo.project_name.ilike(f"%{info.project_name}%"))
|
||||||
if len(filters) > 0:
|
if len(filters) > 0:
|
||||||
query.filter(*filters)
|
query = query.filter(and_(*filters))
|
||||||
pager = get_pager(query, info.pagerNum, info.pagerSize)
|
pager = get_pager(query, info.pagerNum, info.pagerSize)
|
||||||
pager.data = [ProjectInfoOut.from_orm(info).dict() for info in pager.data]
|
pager.data = [ProjectInfoOut.from_orm(info).dict() for info in pager.data]
|
||||||
return pager
|
return pager
|
||||||
|
|
||||||
|
|
||||||
|
def get_project_pager2(info: ProjectInfoPager, session: Session):
|
||||||
|
# 1. 定义子查询
|
||||||
|
subquery = (
|
||||||
|
session.query(
|
||||||
|
ProjectImage.project_id,
|
||||||
|
func.sum(case((ProjectImgLeafer.id.is_(None), 1), else_=0)).label('no_mark_count'),
|
||||||
|
func.sum(case((ProjectImgLeafer.id.isnot(None), 1), else_=0)).label('mark_count')
|
||||||
|
)
|
||||||
|
.outerjoin(ProjectImgLeafer, ProjectImage.id == ProjectImgLeafer.image_id)
|
||||||
|
.group_by(ProjectImage.project_id)
|
||||||
|
.subquery()
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. 主查询
|
||||||
|
query = (
|
||||||
|
session.query(
|
||||||
|
ProjectInfo,
|
||||||
|
func.ifnull(subquery.c.mark_count, 0).label('mark_count'),
|
||||||
|
func.ifnull(subquery.c.no_mark_count, 0).label('no_mark_count')
|
||||||
|
)
|
||||||
|
.outerjoin(subquery, ProjectInfo.id == subquery.c.project_id)
|
||||||
|
)
|
||||||
|
query = query.order_by(desc(ProjectInfo.id))
|
||||||
|
filters = [ProjectInfo.del_flag == 0]
|
||||||
|
if info.project_name is not None and info.project_name != '':
|
||||||
|
filters.append(ProjectInfo.project_name.ilike(f"%{info.project_name}%"))
|
||||||
|
query = query.filter(and_(*filters))
|
||||||
|
pager = get_pager(query, info.pagerNum, info.pagerSize)
|
||||||
|
datas = []
|
||||||
|
for result in pager.data:
|
||||||
|
data = ProjectInfoPagerOut.from_orm(result[0])
|
||||||
|
data.mark_count = result[1]
|
||||||
|
data.no_mark_count = result[2]
|
||||||
|
datas.append(data)
|
||||||
|
pager.data = jsonable_encoder(datas)
|
||||||
|
return pager
|
||||||
|
|
||||||
|
|
||||||
def get_project_by_id(project_id: str, session: Session):
|
def get_project_by_id(project_id: str, session: Session):
|
||||||
info = session.query(ProjectInfo).filter_by(id=project_id).first()
|
info = session.query(ProjectInfo).filter_by(id=project_id).filter_by(del_flag=0).first()
|
||||||
info_out = ProjectInfoOut.from_orm(info)
|
info_out = ProjectInfoOut.from_orm(info)
|
||||||
return info_out
|
return info_out
|
||||||
|
|
||||||
|
|
||||||
def add_project(info: ProjectInfo, session: Session):
|
def add_project(info: ProjectInfo, session: Session):
|
||||||
"""新建项目,并在对应文件夹下面创建文件夹"""
|
"""新建项目,并在对应文件夹下面创建文件夹"""
|
||||||
|
info.del_flag = 0
|
||||||
session.add(info)
|
session.add(info)
|
||||||
session.commit()
|
session.commit()
|
||||||
return info
|
return info
|
||||||
@ -35,7 +74,8 @@ def add_project(info: ProjectInfo, session: Session):
|
|||||||
|
|
||||||
def check_project_name(project_name: str, session: Session):
|
def check_project_name(project_name: str, session: Session):
|
||||||
"""检验是否存在重名的项目名称"""
|
"""检验是否存在重名的项目名称"""
|
||||||
count = session.query(ProjectInfo).filter(ProjectInfo.project_name == project_name).count()
|
count = session.query(ProjectInfo).filter(ProjectInfo.project_name == project_name)\
|
||||||
|
.filter(ProjectInfo.del_flag==0).count()
|
||||||
if count > 0:
|
if count > 0:
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
@ -60,3 +100,16 @@ def update_project_status(project_id: int, project_status: str, session: Session
|
|||||||
'project_status': project_status
|
'project_status': project_status
|
||||||
})
|
})
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def del_project(project_id: int, session: Session):
|
||||||
|
"""
|
||||||
|
删除项目
|
||||||
|
:param project_id:
|
||||||
|
:param session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
session.query(ProjectInfo).filter_by(id=project_id).update({
|
||||||
|
'del_flag': 1
|
||||||
|
})
|
||||||
|
session.commit()
|
||||||
|
@ -20,7 +20,7 @@ def get_label_list(project_id: int, session: Session):
|
|||||||
def get_label_for_train(project_id: int, session: Session):
|
def get_label_for_train(project_id: int, session: Session):
|
||||||
id_list = []
|
id_list = []
|
||||||
name_list = []
|
name_list = []
|
||||||
label_list = session.query(plModel).filter(plModel.project_id == project_id).all()
|
label_list = session.query(plModel).filter(plModel.project_id == project_id).order_by(plModel.id).all()
|
||||||
for label in label_list:
|
for label in label_list:
|
||||||
id_list.append(label.id)
|
id_list.append(label.id)
|
||||||
name_list.append(label.label_name)
|
name_list.append(label.label_name)
|
||||||
|
@ -26,3 +26,14 @@ def get_train_list(project_id: int, session: Session):
|
|||||||
query = session.query(ProjectTrain).filter_by(project_id=project_id).order_by(asc(ProjectTrain.id))
|
query = session.query(ProjectTrain).filter_by(project_id=project_id).order_by(asc(ProjectTrain.id))
|
||||||
train_list = [ProjectTrainOut.from_orm(train) for train in query.all()]
|
train_list = [ProjectTrainOut.from_orm(train) for train in query.all()]
|
||||||
return train_list
|
return train_list
|
||||||
|
|
||||||
|
|
||||||
|
def get_train(train_id: int, session: Session):
|
||||||
|
"""
|
||||||
|
根据id查询训练信息
|
||||||
|
:param train_id:
|
||||||
|
:param session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
train = session.query(ProjectTrain).filter_by(id=train_id).first()
|
||||||
|
return train
|
||||||
|
@ -7,7 +7,6 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
def get_list(session: Session):
|
def get_list(session: Session):
|
||||||
"""获取项目类型列表"""
|
"""获取项目类型列表"""
|
||||||
query = session.query(ProjectType).order_by(asc(ProjectType.id))
|
query = session.query(ProjectType).filter(ProjectType.type_status == "0").order_by(asc(ProjectType.id))
|
||||||
query.filter(ProjectType.type_status == "0")
|
|
||||||
result_list = [ProjectTypeOut.from_orm(project_type).dict() for project_type in query.all()]
|
result_list = [ProjectTypeOut.from_orm(project_type).dict() for project_type in query.all()]
|
||||||
return result_list
|
return result_list
|
||||||
|
69
app/model/schemas/project_detect_schemas.py
Normal file
69
app/model/schemas/project_detect_schemas.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import Optional
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectDetectIn(BaseModel):
|
||||||
|
project_id: Optional[int] = Field(..., description="项目id")
|
||||||
|
file_type: Optional[str] = Field('img', description="推理集合文件类别")
|
||||||
|
detect_name: Optional[str] = Field(..., description="推理集合名称")
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectDetectPager(BaseModel):
|
||||||
|
project_id: Optional[int] = Field(..., description="项目id")
|
||||||
|
detect_name: Optional[str] = Field(None, description="推理集合名称")
|
||||||
|
pagerNum: Optional[int] = Field(1, description="当前页码")
|
||||||
|
pagerSize: Optional[int] = Field(10, description="每页数量")
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectDetectOut(BaseModel):
|
||||||
|
id: Optional[int]
|
||||||
|
project_id: Optional[int]
|
||||||
|
detect_name: Optional[int]
|
||||||
|
detect_no: Optional[str]
|
||||||
|
detect_version: Optional[int]
|
||||||
|
file_type: Optional[str]
|
||||||
|
folder_url: Optional[str]
|
||||||
|
create_time: Optional[datetime]
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
orm_mode = True
|
||||||
|
json_encoders = {
|
||||||
|
datetime: lambda v: v.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectDetectLogIn(BaseModel):
|
||||||
|
detect_id: Optional[int] = Field(..., description="推理集合id")
|
||||||
|
train_id: Optional[int] = Field(..., description="训练结果id")
|
||||||
|
pt_type: Optional[str] = Field('best', description="权重文件类型")
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectDetectLogOut(BaseModel):
|
||||||
|
id: Optional[int]
|
||||||
|
detect_id: Optional[int]
|
||||||
|
detect_version: Optional[str]
|
||||||
|
train_id: Optional[int]
|
||||||
|
train_version: Optional[int]
|
||||||
|
pt_type: Optional[str]
|
||||||
|
create_time: Optional[datetime]
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
orm_mode = True
|
||||||
|
json_encoders = {
|
||||||
|
datetime: lambda v: v.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectDetectLogImgOut(BaseModel):
|
||||||
|
id: Optional[int]
|
||||||
|
file_name: Optional[str]
|
||||||
|
create_time: Optional[datetime]
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
orm_mode = True
|
||||||
|
json_encoders = {
|
||||||
|
datetime: lambda v: v.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
@ -16,10 +16,24 @@ class ProjectImage(BaseModel):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectImageOut(BaseModel):
|
||||||
|
id: Optional[int] = Field(None, description="id")
|
||||||
|
project_id: Optional[int] = Field(..., description="项目id")
|
||||||
|
file_name: Optional[str] = Field(None, description="文件名称")
|
||||||
|
create_time: Optional[datetime] = Field(None, description="上传时间")
|
||||||
|
label_count: Optional[int]
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
orm_mode = True
|
||||||
|
json_encoders = {
|
||||||
|
datetime: lambda v: v.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class ProjectImagePager(BaseModel):
|
class ProjectImagePager(BaseModel):
|
||||||
project_id: Optional[int] = Field(..., description="项目id")
|
project_id: Optional[int] = Field(..., description="项目id")
|
||||||
pagerNum: Optional[int] = Field(1, description="当前页码")
|
pagerNum: Optional[int] = Field(None, description="当前页码")
|
||||||
pagerSize: Optional[int] = Field(10, description="每页数量")
|
pagerSize: Optional[int] = Field(None, description="每页数量")
|
||||||
|
|
||||||
|
|
||||||
class ProjectImgLabelIn(BaseModel):
|
class ProjectImgLabelIn(BaseModel):
|
||||||
|
@ -24,6 +24,22 @@ class ProjectInfoOut(BaseModel):
|
|||||||
orm_mode = True
|
orm_mode = True
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectInfoPagerOut(BaseModel):
|
||||||
|
"""项目信息输出"""
|
||||||
|
id: Optional[int] = Field(None, description="项目id")
|
||||||
|
project_no: Optional[str] = Field(None, description="项目编号")
|
||||||
|
project_name: Optional[str] = Field(None, description="项目名称")
|
||||||
|
type_code: Optional[str] = Field(None, description="项目类型编码")
|
||||||
|
description: Optional[str] = Field(None, description="项目描述")
|
||||||
|
train_version: Optional[int] = Field(None, description="训练版本号")
|
||||||
|
project_status: Optional[str] = Field(None, description="项目状态")
|
||||||
|
mark_count: Optional[int]
|
||||||
|
no_mark_count: Optional[int]
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
orm_mode = True
|
||||||
|
|
||||||
|
|
||||||
class ProjectInfoPager(BaseModel):
|
class ProjectInfoPager(BaseModel):
|
||||||
project_name: Optional[str] = Field(None, description="项目名称")
|
project_name: Optional[str] = Field(None, description="项目名称")
|
||||||
pagerNum: Optional[int] = Field(1, description="当前页码")
|
pagerNum: Optional[int] = Field(1, description="当前页码")
|
||||||
|
132
app/service/project_detect_service.py
Normal file
132
app/service/project_detect_service.py
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from typing import List
|
||||||
|
from fastapi import UploadFile
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
from app.model.crud import project_detect_crud as pdc
|
||||||
|
from app.model.schemas.project_detect_schemas import ProjectDetectIn, ProjectDetectOut, ProjectDetectLogIn
|
||||||
|
from app.model.bussiness_model import ProjectDetect, ProjectDetectImg, ProjectTrain, ProjectDetectLog
|
||||||
|
from app.util.random_utils import random_str
|
||||||
|
from app.config.config_reader import detect_url
|
||||||
|
from app.util import os_utils as os
|
||||||
|
from app.util import random_utils as ru
|
||||||
|
from app.config.config_reader import yolo_url
|
||||||
|
|
||||||
|
|
||||||
|
def add_detect(detect_in: ProjectDetectIn, session: Session):
|
||||||
|
"""
|
||||||
|
新增训练集合信息,并创建文件夹
|
||||||
|
:param detect_in:
|
||||||
|
:param session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
detect = ProjectDetect(**detect_in.dict())
|
||||||
|
detect.detect_no = random_str(6)
|
||||||
|
detect.detect_version = 0
|
||||||
|
url = os.create_folder(detect_url, detect.detect_no, 'images')
|
||||||
|
detect.folder_url = url
|
||||||
|
detect = pdc.add_detect(detect, session)
|
||||||
|
return detect
|
||||||
|
|
||||||
|
|
||||||
|
def check_image_name(detect_id: int, files: List[UploadFile], session: Session):
|
||||||
|
"""
|
||||||
|
校验上传的文件名称是否重复
|
||||||
|
:param detect_id:
|
||||||
|
:param files:
|
||||||
|
:param session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
for file in files:
|
||||||
|
if not pdc.check_img_name(detect_id, file.filename, session):
|
||||||
|
return False, file.filename
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
|
||||||
|
def upload_detect_imgs(detect: ProjectDetectOut, files: List[UploadFile], session: Session):
|
||||||
|
"""
|
||||||
|
上传推理集合的照片,保存原图,并生成缩略图
|
||||||
|
:param detect:
|
||||||
|
:param files:
|
||||||
|
:param session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
images = []
|
||||||
|
for file in files:
|
||||||
|
image = ProjectDetectImg()
|
||||||
|
image.detect_id = detect.id
|
||||||
|
image.file_name = file.filename
|
||||||
|
# 保存原图
|
||||||
|
path = os.save_images(detect.folder_url, file=file)
|
||||||
|
image.image_url = path
|
||||||
|
# 生成缩略图
|
||||||
|
thumb_image_url = os.file_path(detect.folder_url, 'thumb', ru.random_str(10) + ".jpg")
|
||||||
|
os.create_thumbnail(path, thumb_image_url)
|
||||||
|
image.thumb_image_url = thumb_image_url
|
||||||
|
images.append(image)
|
||||||
|
pdc.add_detect_imgs(images, session)
|
||||||
|
|
||||||
|
|
||||||
|
def run_detect_yolo(detect_in: ProjectDetectLogIn, detect: ProjectDetect, train: ProjectTrain, session: Session):
|
||||||
|
"""
|
||||||
|
开始推理
|
||||||
|
:param detect:
|
||||||
|
:param detect_in:
|
||||||
|
:param train:
|
||||||
|
:param session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# 推理版本
|
||||||
|
version_path = 'v' + str(detect.detect_version + 1)
|
||||||
|
|
||||||
|
# 权重文件
|
||||||
|
pt_url = train.best_pt if detect_in.pt_type == 'best' else train.last_pt
|
||||||
|
|
||||||
|
# 推理集合文件路径
|
||||||
|
img_url = detect.folder_url
|
||||||
|
|
||||||
|
out_url = os.file_path(detect_url, detect.detect_no, 'detect')
|
||||||
|
|
||||||
|
# 构建推理记录数据
|
||||||
|
detect_log = ProjectDetectLog()
|
||||||
|
detect_log.detect_id = detect.id
|
||||||
|
detect_log.detect_version = version_path
|
||||||
|
detect_log.train_id = train.id
|
||||||
|
detect_log.train_version = train.train_version
|
||||||
|
detect_log.pt_type = detect_in.pt_type
|
||||||
|
detect_log.folder_url = detect.folder_url
|
||||||
|
detect_log.detect_folder_url = out_url
|
||||||
|
detect_log = pdc.add_detect_log(detect_log, session)
|
||||||
|
return detect_log
|
||||||
|
|
||||||
|
|
||||||
|
def run_commend(weights: str, source: str, project: str, name: str,
|
||||||
|
detect_log_id: int, session: Session):
|
||||||
|
yolo_path = os.file_path(yolo_url, 'detect.py')
|
||||||
|
|
||||||
|
yield f"stdout: 模型推理开始,请稍等。。。 \n"
|
||||||
|
# 启动子进程
|
||||||
|
with subprocess.Popen(
|
||||||
|
["python", '-u', yolo_path,
|
||||||
|
"--weights =" + weights,
|
||||||
|
"--source =" + source,
|
||||||
|
"--name=" + name,
|
||||||
|
"--project=" + project,
|
||||||
|
"--view-img"],
|
||||||
|
bufsize=1, # bufsize=0时,为不缓存;bufsize=1时,按行缓存;bufsize为其他正整数时,为按照近似该正整数的字节数缓存
|
||||||
|
shell=False,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.STDOUT, # 这里可以显示yolov5训练过程中出现的进度条等信息
|
||||||
|
text=True, # 缓存内容为文本,避免后续编码显示问题
|
||||||
|
encoding='utf-8',
|
||||||
|
) as process:
|
||||||
|
while process.poll() is None:
|
||||||
|
line = process.stdout.readline()
|
||||||
|
process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死
|
||||||
|
if line != '\n':
|
||||||
|
yield line
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -67,6 +67,22 @@ def upload_project_image(project_info: ProjectInfoOut, files: List[UploadFile],
|
|||||||
pimc.add_image_batch(images, session)
|
pimc.add_image_batch(images, session)
|
||||||
|
|
||||||
|
|
||||||
|
def del_img(image_id: int, session: Session):
|
||||||
|
"""
|
||||||
|
删除图片,并删除文件
|
||||||
|
:param image_id:
|
||||||
|
:param session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
image = session.query(ProjectImage).filter_by(id=image_id).first()
|
||||||
|
if image is None:
|
||||||
|
return 0
|
||||||
|
os.delete_file_if_exists(image.image_url, image.thumb_image_url)
|
||||||
|
session.delete(image)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def save_img_label(img_leafer_label: ProjectImgLeaferLabel, session: Session):
|
def save_img_label(img_leafer_label: ProjectImgLeaferLabel, session: Session):
|
||||||
"""
|
"""
|
||||||
保存图片的标签框选信息,每次保存都会针对图片的信息全部删除,然后重新保存
|
保存图片的标签框选信息,每次保存都会针对图片的信息全部删除,然后重新保存
|
||||||
@ -164,7 +180,8 @@ def run_commend(data: str, project: str,
|
|||||||
name: str, epochs: int,
|
name: str, epochs: int,
|
||||||
project_id: int, session: Session):
|
project_id: int, session: Session):
|
||||||
yolo_path = os.file_path(yolo_url, 'train.py')
|
yolo_path = os.file_path(yolo_url, 'train.py')
|
||||||
yield f"stdout: 模型训练开始,请稍等。。。"
|
|
||||||
|
yield f"stdout: 模型训练开始,请稍等。。。\n"
|
||||||
# 启动子进程
|
# 启动子进程
|
||||||
with subprocess.Popen(
|
with subprocess.Popen(
|
||||||
["python", '-u', yolo_path,
|
["python", '-u', yolo_path,
|
||||||
@ -195,8 +212,8 @@ def run_commend(data: str, project: str,
|
|||||||
train = ProjectTrain()
|
train = ProjectTrain()
|
||||||
train.project_id = project_id
|
train.project_id = project_id
|
||||||
train.train_version = name
|
train.train_version = name
|
||||||
bast_pt_path = os.file_path(project, name, 'weight', 'bast.pt')
|
bast_pt_path = os.file_path(project, name, 'weights', 'bast.pt')
|
||||||
last_pt_path = os.file_path(project, name, 'weight', 'last.pt')
|
last_pt_path = os.file_path(project, name, 'weights', 'last.pt')
|
||||||
train.best_pt = bast_pt_path
|
train.best_pt = bast_pt_path
|
||||||
train.last_pt = last_pt_path
|
train.last_pt = last_pt_path
|
||||||
ptc.add_train(train, session)
|
ptc.add_train(train, session)
|
@ -79,3 +79,14 @@ def copy_and_rename_file(src_file_path, dst_dir, new_name):
|
|||||||
|
|
||||||
# 复制文件到目标位置并重命名
|
# 复制文件到目标位置并重命名
|
||||||
shutil.copy(src_file_path, dst_file_path)
|
shutil.copy(src_file_path, dst_file_path)
|
||||||
|
|
||||||
|
|
||||||
|
def delete_file_if_exists(*file_paths: str):
|
||||||
|
"""
|
||||||
|
删除文件
|
||||||
|
:param file_path:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
for path in file_paths:
|
||||||
|
if os.path.exists(path): # 检查文件是否存在
|
||||||
|
os.remove(path) # 删除文件
|
||||||
|
Reference in New Issue
Block a user