from typing import List from sqlalchemy.orm import Session from sqlalchemy import asc, and_ from fastapi.encoders import jsonable_encoder from app.model.bussiness_model import ProjectDetect, ProjectDetectImg, ProjectDetectLog, ProjectDetectLogImg from app.model.schemas.project_detect_schemas import ProjectDetectOut, ProjectDetectPager, \ ProjectDetectImageOut, ProjectDetectImgPager, ProjectDetectLogOut, ProjectDetectLogImgOut from app.db.page_util import get_pager def add_detect(detect: ProjectDetect, session: Session): """ 新增推理集合 :param detect: :param session: :return: """ session.add(detect) session.commit() return detect def get_detect_by_id(detect_id: int, session: Session): """ 根据id查询训练集合 :param detect_id: :param session: :return: """ detect = session.query(ProjectDetect).filter_by(id=detect_id).first() detect_out = ProjectDetectOut.from_orm(detect) return detect_out def get_detect_pager(detect_pager: ProjectDetectPager, session: Session): """ 查询推理集合分页数据 :param detect_pager: :param session: :return: """ query = session.query(ProjectDetect).order_by(asc(ProjectDetect.id)) filters = [ProjectDetect.project_id == detect_pager.project_id] if detect_pager.detect_name is not None: filters.append(ProjectDetect.detect_name.ilike(f"%{detect_pager.detect_name}%")) query = query.filter(and_(*filters)) pager = get_pager(query, detect_pager.pagerNum, detect_pager.pagerSize) pager.data = [ProjectDetectOut.from_orm(user) for user in pager.data] pager.data = jsonable_encoder(pager.data) return pager def get_detect_list(project_id: int, session: Session): """ 根据项目id查询所有集合的列表 :param project_id: :param session: :return: """ query = session.query(ProjectDetect).filter_by(project_id=project_id).order_by(asc(ProjectDetect.id)) result = jsonable_encoder(query.all()) return result def update_detect_status(detect_id: int, detect_status: int, session: Session): """ 更新项目训练状态,如果是已完成的话,train_version自动+1 :param detect_id: :param detect_status: 0-未运行,1-运行中,2-已完成,-1-执行失败 :param session: :return: """ if detect_status == 2: session.query(ProjectDetect).filter_by(id=detect_id).update({ 'detect_status': detect_status, 'detect_version': ProjectDetect.detect_version + 1 }) else: session.query(ProjectDetect).filter_by(id=detect_id).update({ 'detect_status': detect_status }) session.commit() def check_img_name(detect_id: int, file_name: str, session: Session): """ 校验上传的图片名称是否重名 :param detect_id: :param file_name: :param session: :return: """ image = session.query(ProjectDetectImg).filter_by(detect_id=detect_id).filter_by(file_name=file_name).first() if image is None: return True, None else: return False, image.file_name def add_detect_imgs(detect_imgs: List[ProjectDetectImg], session: Session): """ 添加推理集合中的图片 :param detect_imgs: :param session: :return: """ session.add_all(detect_imgs) session.commit() def get_img_list(detect_id: int, session: Session): """ 获取训练集合中的图片列表 :param detect_id: :param session: :return: """ query = session.query(ProjectDetectImg).filter_by(detect_id=detect_id).order_by(asc(ProjectDetectImg.id)) image_list = [ProjectDetectImageOut.from_orm(image) for image in query.all()] return image_list def get_img_pager(detect_img_pager: ProjectDetectImgPager, session: Session): """ 获取训练集合中的图片列表,返回pager对象 :param detect_img_pager: :param session: :return: """ query = session.query(ProjectDetectImg).filter_by(detect_id=detect_img_pager.detect_id)\ .order_by(asc(ProjectDetectImg.id)) pager = get_pager(query, detect_img_pager.pagerNum, detect_img_pager.pagerSize) pager.data = [ProjectDetectImageOut.from_orm(image) for image in pager.data] pager.data = jsonable_encoder(pager.data) return pager def get_detect_img_url(image_id: int, session: Session): result = session.query(ProjectDetectImg).filter_by(id=image_id).first() sour_url = result.image_url thumb_url = result.thumb_image_url return sour_url, thumb_url def add_detect_log(detect_log: ProjectDetectLog, session: Session): """ 新增推理记录 :param detect_log: :param session: :return: """ session.add(detect_log) session.commit() return detect_log def add_detect_log_imgs(detect_log_imgs: List[ProjectDetectLogImg], session: Session): """ 新增推理记录中的照片 :param detect_log_imgs: :param session: :return: """ session.add_all(detect_log_imgs) session.commit() def get_log_list(detect_id: int, session: Session): """ 获取推理记录 :param detect_id: :param session: :return: """ query = session.query(ProjectDetectLog).filter_by(detect_id=detect_id).order_by(asc(ProjectDetectLog.id)) result = [ProjectDetectLogOut.from_orm(log) for log in query.all()] return result def get_log_imgs(log_id: int, session: Session): query = session.query(ProjectDetectLogImg).filter_by(log_id=log_id).order_by(asc(ProjectDetectLogImg.id)) result = [ProjectDetectLogImgOut.from_orm(img) for img in query.all()] return result def get_log_img_url(img_id: int, session: Session): result = session.query(ProjectDetectLogImg).filter_by(id=img_id).first() if result is None: return None else: return result.image_url