from sqlalchemy.orm import Session from sqlalchemy import desc, and_, func, case from fastapi.encoders import jsonable_encoder from app.model.bussiness_model import ProjectInfo, ProjectImage, ProjectImgLeafer from app.model.schemas.project_info_schemas import ProjectInfoOut, ProjectInfoPager, ProjectInfoPagerOut from app.db.page_util import get_pager def get_project_pager(info: ProjectInfoPager, session: Session): """分页查询项目信息""" query = session.query(ProjectInfo).order_by(desc(ProjectInfo.id)) filters = [ProjectInfo.del_flag == 0] if info.project_name is not None and info.project_name != '': filters.append(ProjectInfo.project_name.ilike(f"%{info.project_name}%")) if len(filters) > 0: query = query.filter(and_(*filters)) pager = get_pager(query, info.pagerNum, info.pagerSize) pager.data = [ProjectInfoOut.from_orm(info).dict() for info in pager.data] return pager def get_project_pager2(info: ProjectInfoPager, session: Session): # 1. 定义子查询 subquery = ( session.query( ProjectImage.project_id, func.sum(case((ProjectImgLeafer.id.is_(None), 1), else_=0)).label('no_mark_count'), func.sum(case((ProjectImgLeafer.id.isnot(None), 1), else_=0)).label('mark_count') ) .outerjoin(ProjectImgLeafer, ProjectImage.id == ProjectImgLeafer.image_id) .group_by(ProjectImage.project_id) .subquery() ) # 2. 主查询 query = ( session.query( ProjectInfo, func.ifnull(subquery.c.mark_count, 0).label('mark_count'), func.ifnull(subquery.c.no_mark_count, 0).label('no_mark_count') ) .outerjoin(subquery, ProjectInfo.id == subquery.c.project_id) ) query = query.order_by(desc(ProjectInfo.id)) filters = [ProjectInfo.del_flag == 0] if info.project_name is not None and info.project_name != '': filters.append(ProjectInfo.project_name.ilike(f"%{info.project_name}%")) query = query.filter(and_(*filters)) pager = get_pager(query, info.pagerNum, info.pagerSize) datas = [] for result in pager.data: data = ProjectInfoPagerOut.from_orm(result[0]) data.mark_count = result[1] data.no_mark_count = result[2] datas.append(data) pager.data = jsonable_encoder(datas) return pager def get_project_by_id(project_id: str, session: Session): info = session.query(ProjectInfo).filter_by(id=project_id).filter_by(del_flag=0).first() info_out = ProjectInfoOut.from_orm(info) return info_out def add_project(info: ProjectInfo, session: Session): """新建项目,并在对应文件夹下面创建文件夹""" info.del_flag = 0 session.add(info) session.commit() return info def check_project_name(project_name: str, session: Session): """检验是否存在重名的项目名称""" count = session.query(ProjectInfo).filter(ProjectInfo.project_name == project_name)\ .filter(ProjectInfo.del_flag==0).count() if count > 0: return True else: return False def update_project_status(project_id: int, project_status: str, session: Session): """ 更新项目训练状态,如果是已完成的话,train_version自动+1 :param project_id: :param project_status: 0-未运行,1-运行中,2-已完成,-1-执行失败 :param session: :return: """ if project_status == '2': session.query(ProjectInfo).filter_by(id=project_id).update({ 'project_status': project_status, 'train_version': ProjectInfo.train_version + 1 }) else: session.query(ProjectInfo).filter_by(id=project_id).update({ 'project_status': project_status }) session.commit() def del_project(project_id: int, session: Session): """ 删除项目 :param project_id: :param session: :return: """ session.query(ProjectInfo).filter_by(id=project_id).update({ 'del_flag': 1 }) session.commit()