Files
aicheckv2/app/model/crud/project_info_crud.py

116 lines
4.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from sqlalchemy.orm import Session
from sqlalchemy import desc, and_, func, case
from fastapi.encoders import jsonable_encoder
from app.model.bussiness_model import ProjectInfo, ProjectImage, ProjectImgLeafer
from app.model.schemas.project_info_schemas import ProjectInfoOut, ProjectInfoPager, ProjectInfoPagerOut
from app.db.page_util import get_pager
def get_project_pager(info: ProjectInfoPager, session: Session):
"""分页查询项目信息"""
query = session.query(ProjectInfo).order_by(desc(ProjectInfo.id))
filters = [ProjectInfo.del_flag == 0]
if info.project_name is not None and info.project_name != '':
filters.append(ProjectInfo.project_name.ilike(f"%{info.project_name}%"))
if len(filters) > 0:
query = query.filter(and_(*filters))
pager = get_pager(query, info.pagerNum, info.pagerSize)
pager.data = [ProjectInfoOut.from_orm(info).dict() for info in pager.data]
return pager
def get_project_pager2(info: ProjectInfoPager, session: Session):
# 1. 定义子查询
subquery = (
session.query(
ProjectImage.project_id,
func.sum(case((ProjectImgLeafer.id.is_(None), 1), else_=0)).label('no_mark_count'),
func.sum(case((ProjectImgLeafer.id.isnot(None), 1), else_=0)).label('mark_count')
)
.outerjoin(ProjectImgLeafer, ProjectImage.id == ProjectImgLeafer.image_id)
.group_by(ProjectImage.project_id)
.subquery()
)
# 2. 主查询
query = (
session.query(
ProjectInfo,
func.ifnull(subquery.c.mark_count, 0).label('mark_count'),
func.ifnull(subquery.c.no_mark_count, 0).label('no_mark_count')
)
.outerjoin(subquery, ProjectInfo.id == subquery.c.project_id)
)
query = query.order_by(desc(ProjectInfo.id))
filters = [ProjectInfo.del_flag == 0]
if info.project_name is not None and info.project_name != '':
filters.append(ProjectInfo.project_name.ilike(f"%{info.project_name}%"))
query = query.filter(and_(*filters))
pager = get_pager(query, info.pagerNum, info.pagerSize)
datas = []
for result in pager.data:
data = ProjectInfoPagerOut.from_orm(result[0])
data.mark_count = result[1]
data.no_mark_count = result[2]
datas.append(data)
pager.data = jsonable_encoder(datas)
return pager
def get_project_by_id(project_id: str, session: Session):
info = session.query(ProjectInfo).filter_by(id=project_id).filter_by(del_flag=0).first()
info_out = ProjectInfoOut.from_orm(info)
return info_out
def add_project(info: ProjectInfo, session: Session):
"""新建项目,并在对应文件夹下面创建文件夹"""
info.del_flag = 0
session.add(info)
session.commit()
return info
def check_project_name(project_name: str, session: Session):
"""检验是否存在重名的项目名称"""
count = session.query(ProjectInfo).filter(ProjectInfo.project_name == project_name)\
.filter(ProjectInfo.del_flag==0).count()
if count > 0:
return True
else:
return False
def update_project_status(project_id: int, project_status: str, session: Session):
"""
更新项目训练状态,如果是已完成的话train_version自动+1
:param project_id:
:param project_status: 0-未运行1-运行中2-已完成,-1-执行失败
:param session:
:return:
"""
if project_status == '2':
session.query(ProjectInfo).filter_by(id=project_id).update({
'project_status': project_status,
'train_version': ProjectInfo.train_version + 1
})
else:
session.query(ProjectInfo).filter_by(id=project_id).update({
'project_status': project_status
})
session.commit()
def del_project(project_id: int, session: Session):
"""
删除项目
:param project_id:
:param session:
:return:
"""
session.query(ProjectInfo).filter_by(id=project_id).update({
'del_flag': 1
})
session.commit()