完善相关问题,并增加推理的部分代码
This commit is contained in:
@ -27,6 +27,7 @@ class ProjectInfo(DbCommon):
|
||||
project_status: Mapped[str] = mapped_column(String(10))
|
||||
user_id: Mapped[int] = mapped_column(Integer)
|
||||
train_version: Mapped[int] = mapped_column(Integer)
|
||||
del_flag: Mapped[int] = mapped_column(Integer)
|
||||
|
||||
|
||||
class ProjectLabel(DbCommon):
|
||||
@ -73,7 +74,9 @@ class ProjectImgLabel(DbCommon):
|
||||
|
||||
|
||||
class ProjectTrain(DbCommon):
|
||||
"""项目训练版本信息表"""
|
||||
"""
|
||||
项目训练版本信息表
|
||||
"""
|
||||
__tablename__ = "project_train"
|
||||
project_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
train_version: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||
@ -83,9 +86,48 @@ class ProjectTrain(DbCommon):
|
||||
|
||||
class ProjectDetect(DbCommon):
|
||||
"""
|
||||
训练推理集合
|
||||
项目推理集合
|
||||
"""
|
||||
__tablename__ = "project_detect"
|
||||
project_id: Mapped[str] = mapped_column(Integer, nullable=False)
|
||||
project_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
detect_name: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
detect_version: Mapped[int] = mapped_column(Integer)
|
||||
detect_no: Mapped[str] = mapped_column(String(32))
|
||||
detect_status: Mapped[int] = mapped_column(Integer)
|
||||
file_type: Mapped[str] = mapped_column(String(10))
|
||||
folder_url: Mapped[str] = mapped_column(String(255))
|
||||
|
||||
|
||||
class ProjectDetectImg(DbCommon):
|
||||
"""
|
||||
推理之前的图片
|
||||
"""
|
||||
__tablename__ = "project_detect_img"
|
||||
detect_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
file_name: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
image_url: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
thumb_image_url: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
|
||||
|
||||
class ProjectDetectLog(DbCommon):
|
||||
"""
|
||||
项目推理记录
|
||||
"""
|
||||
__tablename__ = "project_detect_log"
|
||||
detect_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
detect_version: Mapped[str] = mapped_column(String(10))
|
||||
train_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
train_version: Mapped[str] = mapped_column(String(10))
|
||||
pt_type: Mapped[str] = mapped_column(String(32))
|
||||
folder_url: Mapped[str] = mapped_column(String(255))
|
||||
detect_folder_url: Mapped[str] = mapped_column(String(255))
|
||||
|
||||
|
||||
class ProjectDetectLogImg(DbCommon):
|
||||
"""
|
||||
推理完成的图片
|
||||
"""
|
||||
__tablename__ = "project_detect_log_img"
|
||||
log_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
file_name: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
image_url: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
|
88
app/model/crud/project_detect_crud.py
Normal file
88
app/model/crud/project_detect_crud.py
Normal file
@ -0,0 +1,88 @@
|
||||
from typing import List
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import asc, and_
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi import UploadFile
|
||||
|
||||
from app.model.bussiness_model import ProjectDetect, ProjectDetectImg, ProjectDetectLog, ProjectDetectLogImg
|
||||
from app.model.schemas.project_detect_schemas import ProjectDetectOut, ProjectDetectPager
|
||||
from app.db.page_util import get_pager
|
||||
|
||||
|
||||
def add_detect(detect: ProjectDetect, session: Session):
|
||||
"""
|
||||
新增推理集合
|
||||
:param detect:
|
||||
:param session:
|
||||
:return:
|
||||
"""
|
||||
session.add(detect)
|
||||
session.commit()
|
||||
return detect
|
||||
|
||||
|
||||
def get_detect_pager(detect_pager: ProjectDetectPager, session: Session):
|
||||
"""
|
||||
查询推理集合分页数据
|
||||
:param detect_pager:
|
||||
:param session:
|
||||
:return:
|
||||
"""
|
||||
query = session.query(ProjectDetect).order_by(asc(ProjectDetect.id))
|
||||
filters = [ProjectDetect.project_id == detect_pager.project_id]
|
||||
if detect_pager.detect_name is not None:
|
||||
filters.append(ProjectDetect.detect_name.ilike(f"%{detect_pager.detect_name}%"))
|
||||
query = query.filter(and_(*filters))
|
||||
pager = get_pager(query, detect_pager.pagerNum, detect_pager.pagerSize)
|
||||
pager.data = [ProjectDetectOut.from_orm(user) for user in pager.data]
|
||||
pager.data = jsonable_encoder(pager.data)
|
||||
return pager
|
||||
|
||||
|
||||
def check_img_name(detect_id: int, file_name: str, session: Session):
|
||||
"""
|
||||
校验上传的图片名称是否重名
|
||||
:param detect_id:
|
||||
:param file_name:
|
||||
:param session:
|
||||
:return:
|
||||
"""
|
||||
image = session.query(ProjectDetectImg).filter_by(detect_id=detect_id).filter_by(file_name=file_name).first()
|
||||
if image is None:
|
||||
return True, None
|
||||
else:
|
||||
return False, image.file_name
|
||||
|
||||
|
||||
def add_detect_imgs(detect_imgs: List[ProjectDetectImg], session: Session):
|
||||
"""
|
||||
添加推理集合中的图片
|
||||
:param detect_imgs:
|
||||
:param session:
|
||||
:return:
|
||||
"""
|
||||
session.add_all(detect_imgs)
|
||||
session.commit()
|
||||
|
||||
|
||||
def add_detect_log(detect_log: ProjectDetectLog, session: Session):
|
||||
"""
|
||||
新增推理记录
|
||||
:param detect_log:
|
||||
:param session:
|
||||
:return:
|
||||
"""
|
||||
session.add(detect_log)
|
||||
session.commit()
|
||||
return detect_log
|
||||
|
||||
|
||||
def add_detect_log_imgs(detect_log_imgs: List[ProjectDetectLogImg], session: Session):
|
||||
"""
|
||||
新增推理记录中的照片
|
||||
:param detect_log_imgs:
|
||||
:param session:
|
||||
:return:
|
||||
"""
|
||||
session.add_all(detect_log_imgs)
|
||||
session.commit()
|
@ -1,9 +1,10 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import asc
|
||||
from sqlalchemy import asc, func
|
||||
from typing import List
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
|
||||
from app.model.bussiness_model import ProjectImage as piModel
|
||||
from app.model.schemas.project_image_schemas import ProjectImage, ProjectImagePager
|
||||
from app.model.bussiness_model import ProjectImage as piModel, ProjectImgLabel
|
||||
from app.model.schemas.project_image_schemas import ProjectImage, ProjectImagePager, ProjectImageOut
|
||||
from app.db.page_util import get_pager
|
||||
|
||||
|
||||
@ -14,6 +15,35 @@ def get_image_pager(image: ProjectImagePager, session: Session):
|
||||
return pager
|
||||
|
||||
|
||||
def get_image_pager2(image: ProjectImage, session: Session):
|
||||
# 1 子查询
|
||||
subquery = (
|
||||
session.query(
|
||||
ProjectImgLabel.image_id,
|
||||
func.ifnull(func.count(ProjectImgLabel.id), 0).label('label_count')
|
||||
)
|
||||
.group_by(ProjectImgLabel.image_id)
|
||||
.subquery()
|
||||
)
|
||||
# 2 主查询
|
||||
query = (
|
||||
session.query(
|
||||
piModel,
|
||||
func.ifnull(subquery.c.label_count, 0).label('label_count')
|
||||
)
|
||||
.outerjoin(subquery, piModel.id == subquery.c.image_id)
|
||||
)
|
||||
query = query.filter(piModel.project_id == image.project_id).order_by(asc(piModel.id))
|
||||
pager = get_pager(query, image.pagerNum, image.pagerSize)
|
||||
datas = []
|
||||
for result in pager.data:
|
||||
data = ProjectImageOut.from_orm(result[0])
|
||||
data.label_count = result[1]
|
||||
datas.append(data)
|
||||
pager.data = jsonable_encoder(datas)
|
||||
return pager
|
||||
|
||||
|
||||
def check_img_name(project_id: int, file_name: str, session: Session):
|
||||
"""
|
||||
根据项目id和文件名称进行查重
|
||||
|
@ -23,7 +23,7 @@ def get_img_label_list(image_id, session: Session):
|
||||
:param session:
|
||||
:return:
|
||||
"""
|
||||
img_label_list = session.query(ProjectImgLabel).filter_by(image_id=image_id).all()
|
||||
img_label_list = session.query(ProjectImgLabel).filter_by(image_id=image_id).order_by(ProjectImgLabel.label_id).all()
|
||||
return img_label_list
|
||||
|
||||
|
||||
|
@ -1,33 +1,72 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import desc, update
|
||||
from sqlalchemy import desc, and_, func, case
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
|
||||
from app.model.bussiness_model import ProjectInfo
|
||||
from app.model.schemas.project_info_schemas import ProjectInfoOut
|
||||
from app.model.schemas.project_info_schemas import ProjectInfoPager
|
||||
from app.model.bussiness_model import ProjectInfo, ProjectImage, ProjectImgLeafer
|
||||
from app.model.schemas.project_info_schemas import ProjectInfoOut, ProjectInfoPager, ProjectInfoPagerOut
|
||||
from app.db.page_util import get_pager
|
||||
|
||||
|
||||
def get_project_pager(info: ProjectInfoPager, session: Session):
|
||||
"""分页查询项目信息"""
|
||||
query = session.query(ProjectInfo).order_by(desc(ProjectInfo.id))
|
||||
filters = []
|
||||
if info.project_name is not None:
|
||||
filters = [ProjectInfo.del_flag == 0]
|
||||
if info.project_name is not None and info.project_name != '':
|
||||
filters.append(ProjectInfo.project_name.ilike(f"%{info.project_name}%"))
|
||||
if len(filters) > 0:
|
||||
query.filter(*filters)
|
||||
query = query.filter(and_(*filters))
|
||||
pager = get_pager(query, info.pagerNum, info.pagerSize)
|
||||
pager.data = [ProjectInfoOut.from_orm(info).dict() for info in pager.data]
|
||||
return pager
|
||||
|
||||
|
||||
def get_project_pager2(info: ProjectInfoPager, session: Session):
|
||||
# 1. 定义子查询
|
||||
subquery = (
|
||||
session.query(
|
||||
ProjectImage.project_id,
|
||||
func.sum(case((ProjectImgLeafer.id.is_(None), 1), else_=0)).label('no_mark_count'),
|
||||
func.sum(case((ProjectImgLeafer.id.isnot(None), 1), else_=0)).label('mark_count')
|
||||
)
|
||||
.outerjoin(ProjectImgLeafer, ProjectImage.id == ProjectImgLeafer.image_id)
|
||||
.group_by(ProjectImage.project_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
# 2. 主查询
|
||||
query = (
|
||||
session.query(
|
||||
ProjectInfo,
|
||||
func.ifnull(subquery.c.mark_count, 0).label('mark_count'),
|
||||
func.ifnull(subquery.c.no_mark_count, 0).label('no_mark_count')
|
||||
)
|
||||
.outerjoin(subquery, ProjectInfo.id == subquery.c.project_id)
|
||||
)
|
||||
query = query.order_by(desc(ProjectInfo.id))
|
||||
filters = [ProjectInfo.del_flag == 0]
|
||||
if info.project_name is not None and info.project_name != '':
|
||||
filters.append(ProjectInfo.project_name.ilike(f"%{info.project_name}%"))
|
||||
query = query.filter(and_(*filters))
|
||||
pager = get_pager(query, info.pagerNum, info.pagerSize)
|
||||
datas = []
|
||||
for result in pager.data:
|
||||
data = ProjectInfoPagerOut.from_orm(result[0])
|
||||
data.mark_count = result[1]
|
||||
data.no_mark_count = result[2]
|
||||
datas.append(data)
|
||||
pager.data = jsonable_encoder(datas)
|
||||
return pager
|
||||
|
||||
|
||||
def get_project_by_id(project_id: str, session: Session):
|
||||
info = session.query(ProjectInfo).filter_by(id=project_id).first()
|
||||
info = session.query(ProjectInfo).filter_by(id=project_id).filter_by(del_flag=0).first()
|
||||
info_out = ProjectInfoOut.from_orm(info)
|
||||
return info_out
|
||||
|
||||
|
||||
def add_project(info: ProjectInfo, session: Session):
|
||||
"""新建项目,并在对应文件夹下面创建文件夹"""
|
||||
info.del_flag = 0
|
||||
session.add(info)
|
||||
session.commit()
|
||||
return info
|
||||
@ -35,7 +74,8 @@ def add_project(info: ProjectInfo, session: Session):
|
||||
|
||||
def check_project_name(project_name: str, session: Session):
|
||||
"""检验是否存在重名的项目名称"""
|
||||
count = session.query(ProjectInfo).filter(ProjectInfo.project_name == project_name).count()
|
||||
count = session.query(ProjectInfo).filter(ProjectInfo.project_name == project_name)\
|
||||
.filter(ProjectInfo.del_flag==0).count()
|
||||
if count > 0:
|
||||
return True
|
||||
else:
|
||||
@ -60,3 +100,16 @@ def update_project_status(project_id: int, project_status: str, session: Session
|
||||
'project_status': project_status
|
||||
})
|
||||
session.commit()
|
||||
|
||||
|
||||
def del_project(project_id: int, session: Session):
|
||||
"""
|
||||
删除项目
|
||||
:param project_id:
|
||||
:param session:
|
||||
:return:
|
||||
"""
|
||||
session.query(ProjectInfo).filter_by(id=project_id).update({
|
||||
'del_flag': 1
|
||||
})
|
||||
session.commit()
|
||||
|
@ -20,7 +20,7 @@ def get_label_list(project_id: int, session: Session):
|
||||
def get_label_for_train(project_id: int, session: Session):
|
||||
id_list = []
|
||||
name_list = []
|
||||
label_list = session.query(plModel).filter(plModel.project_id == project_id).all()
|
||||
label_list = session.query(plModel).filter(plModel.project_id == project_id).order_by(plModel.id).all()
|
||||
for label in label_list:
|
||||
id_list.append(label.id)
|
||||
name_list.append(label.label_name)
|
||||
|
@ -26,3 +26,14 @@ def get_train_list(project_id: int, session: Session):
|
||||
query = session.query(ProjectTrain).filter_by(project_id=project_id).order_by(asc(ProjectTrain.id))
|
||||
train_list = [ProjectTrainOut.from_orm(train) for train in query.all()]
|
||||
return train_list
|
||||
|
||||
|
||||
def get_train(train_id: int, session: Session):
|
||||
"""
|
||||
根据id查询训练信息
|
||||
:param train_id:
|
||||
:param session:
|
||||
:return:
|
||||
"""
|
||||
train = session.query(ProjectTrain).filter_by(id=train_id).first()
|
||||
return train
|
||||
|
@ -7,7 +7,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
def get_list(session: Session):
|
||||
"""获取项目类型列表"""
|
||||
query = session.query(ProjectType).order_by(asc(ProjectType.id))
|
||||
query.filter(ProjectType.type_status == "0")
|
||||
query = session.query(ProjectType).filter(ProjectType.type_status == "0").order_by(asc(ProjectType.id))
|
||||
result_list = [ProjectTypeOut.from_orm(project_type).dict() for project_type in query.all()]
|
||||
return result_list
|
||||
|
69
app/model/schemas/project_detect_schemas.py
Normal file
69
app/model/schemas/project_detect_schemas.py
Normal file
@ -0,0 +1,69 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class ProjectDetectIn(BaseModel):
|
||||
project_id: Optional[int] = Field(..., description="项目id")
|
||||
file_type: Optional[str] = Field('img', description="推理集合文件类别")
|
||||
detect_name: Optional[str] = Field(..., description="推理集合名称")
|
||||
|
||||
|
||||
class ProjectDetectPager(BaseModel):
|
||||
project_id: Optional[int] = Field(..., description="项目id")
|
||||
detect_name: Optional[str] = Field(None, description="推理集合名称")
|
||||
pagerNum: Optional[int] = Field(1, description="当前页码")
|
||||
pagerSize: Optional[int] = Field(10, description="每页数量")
|
||||
|
||||
|
||||
class ProjectDetectOut(BaseModel):
|
||||
id: Optional[int]
|
||||
project_id: Optional[int]
|
||||
detect_name: Optional[int]
|
||||
detect_no: Optional[str]
|
||||
detect_version: Optional[int]
|
||||
file_type: Optional[str]
|
||||
folder_url: Optional[str]
|
||||
create_time: Optional[datetime]
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
json_encoders = {
|
||||
datetime: lambda v: v.strftime("%Y-%m-%d %H:%M:%S")
|
||||
}
|
||||
|
||||
|
||||
class ProjectDetectLogIn(BaseModel):
|
||||
detect_id: Optional[int] = Field(..., description="推理集合id")
|
||||
train_id: Optional[int] = Field(..., description="训练结果id")
|
||||
pt_type: Optional[str] = Field('best', description="权重文件类型")
|
||||
|
||||
|
||||
class ProjectDetectLogOut(BaseModel):
|
||||
id: Optional[int]
|
||||
detect_id: Optional[int]
|
||||
detect_version: Optional[str]
|
||||
train_id: Optional[int]
|
||||
train_version: Optional[int]
|
||||
pt_type: Optional[str]
|
||||
create_time: Optional[datetime]
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
json_encoders = {
|
||||
datetime: lambda v: v.strftime("%Y-%m-%d %H:%M:%S")
|
||||
}
|
||||
|
||||
|
||||
class ProjectDetectLogImgOut(BaseModel):
|
||||
id: Optional[int]
|
||||
file_name: Optional[str]
|
||||
create_time: Optional[datetime]
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
json_encoders = {
|
||||
datetime: lambda v: v.strftime("%Y-%m-%d %H:%M:%S")
|
||||
}
|
||||
|
||||
|
@ -16,10 +16,24 @@ class ProjectImage(BaseModel):
|
||||
}
|
||||
|
||||
|
||||
class ProjectImageOut(BaseModel):
|
||||
id: Optional[int] = Field(None, description="id")
|
||||
project_id: Optional[int] = Field(..., description="项目id")
|
||||
file_name: Optional[str] = Field(None, description="文件名称")
|
||||
create_time: Optional[datetime] = Field(None, description="上传时间")
|
||||
label_count: Optional[int]
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
json_encoders = {
|
||||
datetime: lambda v: v.strftime("%Y-%m-%d %H:%M:%S")
|
||||
}
|
||||
|
||||
|
||||
class ProjectImagePager(BaseModel):
|
||||
project_id: Optional[int] = Field(..., description="项目id")
|
||||
pagerNum: Optional[int] = Field(1, description="当前页码")
|
||||
pagerSize: Optional[int] = Field(10, description="每页数量")
|
||||
pagerNum: Optional[int] = Field(None, description="当前页码")
|
||||
pagerSize: Optional[int] = Field(None, description="每页数量")
|
||||
|
||||
|
||||
class ProjectImgLabelIn(BaseModel):
|
||||
|
@ -24,6 +24,22 @@ class ProjectInfoOut(BaseModel):
|
||||
orm_mode = True
|
||||
|
||||
|
||||
class ProjectInfoPagerOut(BaseModel):
|
||||
"""项目信息输出"""
|
||||
id: Optional[int] = Field(None, description="项目id")
|
||||
project_no: Optional[str] = Field(None, description="项目编号")
|
||||
project_name: Optional[str] = Field(None, description="项目名称")
|
||||
type_code: Optional[str] = Field(None, description="项目类型编码")
|
||||
description: Optional[str] = Field(None, description="项目描述")
|
||||
train_version: Optional[int] = Field(None, description="训练版本号")
|
||||
project_status: Optional[str] = Field(None, description="项目状态")
|
||||
mark_count: Optional[int]
|
||||
no_mark_count: Optional[int]
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
|
||||
class ProjectInfoPager(BaseModel):
|
||||
project_name: Optional[str] = Field(None, description="项目名称")
|
||||
pagerNum: Optional[int] = Field(1, description="当前页码")
|
||||
|
@ -13,4 +13,4 @@ class ProjectTrainOut(BaseModel):
|
||||
orm_mode = True
|
||||
json_encoders = {
|
||||
datetime: lambda v: v.strftime("%Y-%m-%d %H:%M:%S")
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user