Files
aicheckv2/app/model/crud/project_detect_crud.py

231 lines
6.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import List
from sqlalchemy.orm import Session
from sqlalchemy import asc, and_
from fastapi.encoders import jsonable_encoder
from app.model.bussiness_model import ProjectDetect, ProjectDetectImg, ProjectDetectLog, ProjectDetectLogImg
from app.model.schemas.project_detect_schemas import ProjectDetectOut, ProjectDetectPager, \
ProjectDetectImageOut, ProjectDetectImgPager, ProjectDetectLogOut, ProjectDetectLogImgOut, \
ProjectDetectList, ProjectDetectLogPager
from app.db.page_util import get_pager
def add_detect(detect: ProjectDetect, session: Session):
"""
新增推理集合
:param detect:
:param session:
:return:
"""
session.add(detect)
session.commit()
return detect
def get_detect_by_id(detect_id: int, session: Session):
"""
根据id查询训练集合
:param detect_id:
:param session:
:return:
"""
detect = session.query(ProjectDetect).filter_by(id=detect_id).first()
detect_out = ProjectDetectOut.from_orm(detect)
return detect_out
def get_detect_pager(detect_pager: ProjectDetectPager, session: Session):
"""
查询推理集合分页数据
:param detect_pager:
:param session:
:return:
"""
query = session.query(ProjectDetect).order_by(asc(ProjectDetect.id))
filters = [ProjectDetect.project_id == detect_pager.project_id]
if detect_pager.detect_name is not None:
filters.append(ProjectDetect.detect_name.ilike(f"%{detect_pager.detect_name}%"))
query = query.filter(and_(*filters))
pager = get_pager(query, detect_pager.pagerNum, detect_pager.pagerSize)
pager.data = [ProjectDetectOut.from_orm(user) for user in pager.data]
pager.data = jsonable_encoder(pager.data)
return pager
def get_detect_list(project_id: int, session: Session):
"""
根据项目id查询所有集合的列表
:param project_id:
:param session:
:return:
"""
query = session.query(ProjectDetect).filter_by(project_id=project_id).order_by(asc(ProjectDetect.id))
result = [ProjectDetectList.from_orm(detect).dict() for detect in query.all()]
return result
def update_detect_status(detect_id: int, detect_status: int, session: Session):
"""
更新项目训练状态,如果是已完成的话train_version自动+1
:param detect_id:
:param detect_status: 0-未运行1-运行中2-已完成,-1-执行失败
:param session:
:return:
"""
if detect_status == 2:
session.query(ProjectDetect).filter_by(id=detect_id).update({
'detect_status': detect_status,
'detect_version': ProjectDetect.detect_version + 1
})
else:
session.query(ProjectDetect).filter_by(id=detect_id).update({
'detect_status': detect_status
})
session.commit()
def check_img_name(detect_id: int, file_name: str, session: Session):
"""
校验上传的图片名称是否重名
:param detect_id:
:param file_name:
:param session:
:return:
"""
image = session.query(ProjectDetectImg).filter_by(detect_id=detect_id).filter_by(file_name=file_name).first()
if image is None:
return True, None
else:
return False, image.file_name
def add_detect_imgs(detect_imgs: List[ProjectDetectImg], session: Session):
"""
添加推理集合中的图片
:param detect_imgs:
:param session:
:return:
"""
session.add_all(detect_imgs)
session.commit()
def check_detect_img(detect_id: int, session: Session):
"""
查询推理集合中图片的数量
:param detect_id:
:param session:
:return:
"""
query = session.query(ProjectDetectImg).filter_by(detect_id=detect_id)
return query.count()
def get_img_list(detect_id: int, session: Session):
"""
获取训练集合中的图片列表
:param detect_id:
:param session:
:return:
"""
query = session.query(ProjectDetectImg).filter_by(detect_id=detect_id).order_by(asc(ProjectDetectImg.id))
image_list = [ProjectDetectImageOut.from_orm(image) for image in query.all()]
return image_list
def get_img_pager(detect_img_pager: ProjectDetectImgPager, session: Session):
"""
获取训练集合中的图片列表,返回pager对象
:param detect_img_pager:
:param session:
:return:
"""
query = session.query(ProjectDetectImg).filter_by(detect_id=detect_img_pager.detect_id)\
.order_by(asc(ProjectDetectImg.id))
pager = get_pager(query, detect_img_pager.pagerNum, detect_img_pager.pagerSize)
pager.data = [ProjectDetectImageOut.from_orm(image) for image in pager.data]
pager.data = jsonable_encoder(pager.data)
return pager
def get_detect_img_url(image_id: int, session: Session):
result = session.query(ProjectDetectImg).filter_by(id=image_id).first()
sour_url = result.image_url
thumb_url = result.thumb_image_url
return sour_url, thumb_url
def add_detect_log(detect_log: ProjectDetectLog, session: Session):
"""
新增推理记录
:param detect_log:
:param session:
:return:
"""
session.add(detect_log)
session.commit()
return detect_log
def add_detect_log_imgs(detect_log_imgs: List[ProjectDetectLogImg], session: Session):
"""
新增推理记录中的照片
:param detect_log_imgs:
:param session:
:return:
"""
session.add_all(detect_log_imgs)
session.commit()
def get_log_list(detect_id: int, session: Session):
"""
获取推理记录
:param detect_id:
:param session:
:return:
"""
query = session.query(ProjectDetectLog).filter_by(detect_id=detect_id).order_by(asc(ProjectDetectLog.id))
result = [ProjectDetectLogOut.from_orm(log) for log in query.all()]
return result
def get_logs(detect_id: int, session: Session):
"""
获取推理记录
:param detect_id:
:param session:
:return:
"""
query = session.query(ProjectDetectLog).filter_by(detect_id=detect_id).order_by(asc(ProjectDetectLog.id))
return query.all()
def get_log_pager(detect_log_pager: ProjectDetectLogPager, session: Session):
"""
获取分页数据
:param detect_log_pager:
:param session:
:return:
"""
query = session.query(ProjectDetectLog).filter_by(detect_id=detect_log_pager.detect_id)\
.order_by(asc(ProjectDetectLog.id))
pager = get_pager(query, detect_log_pager.pagerNum, detect_log_pager.pagerSize)
pager.data = [ProjectDetectLogOut.from_orm(log) for log in pager.data]
pager.data = jsonable_encoder(pager.data)
return pager
def get_log_imgs(log_id: int, session: Session):
query = session.query(ProjectDetectLogImg).filter_by(log_id=log_id).order_by(asc(ProjectDetectLogImg.id))
result = [ProjectDetectLogImgOut.from_orm(img) for img in query.all()]
return result
def get_log_img_url(img_id: int, session: Session):
result = session.query(ProjectDetectLogImg).filter_by(id=img_id).first()
if result is None:
return None
else:
return result.image_url