209 lines
6.4 KiB
Python
209 lines
6.4 KiB
Python
from typing import List
|
||
from sqlalchemy.orm import Session
|
||
from sqlalchemy import asc, and_
|
||
from fastapi.encoders import jsonable_encoder
|
||
|
||
from app.model.bussiness_model import ProjectDetect, ProjectDetectImg, ProjectDetectLog, ProjectDetectLogImg
|
||
from app.model.schemas.project_detect_schemas import ProjectDetectOut, ProjectDetectPager, \
|
||
ProjectDetectImageOut, ProjectDetectImgPager, ProjectDetectLogOut, ProjectDetectLogImgOut, \
|
||
ProjectDetectList, ProjectDetectLogPager
|
||
from app.db.page_util import get_pager
|
||
|
||
|
||
def add_detect(detect: ProjectDetect, session: Session):
|
||
"""
|
||
新增推理集合
|
||
:param detect:
|
||
:param session:
|
||
:return:
|
||
"""
|
||
session.add(detect)
|
||
session.commit()
|
||
return detect
|
||
|
||
|
||
def get_detect_by_id(detect_id: int, session: Session):
|
||
"""
|
||
根据id查询训练集合
|
||
:param detect_id:
|
||
:param session:
|
||
:return:
|
||
"""
|
||
detect = session.query(ProjectDetect).filter_by(id=detect_id).first()
|
||
detect_out = ProjectDetectOut.from_orm(detect)
|
||
return detect_out
|
||
|
||
|
||
def get_detect_pager(detect_pager: ProjectDetectPager, session: Session):
|
||
"""
|
||
查询推理集合分页数据
|
||
:param detect_pager:
|
||
:param session:
|
||
:return:
|
||
"""
|
||
query = session.query(ProjectDetect).order_by(asc(ProjectDetect.id))
|
||
filters = [ProjectDetect.project_id == detect_pager.project_id]
|
||
if detect_pager.detect_name is not None:
|
||
filters.append(ProjectDetect.detect_name.ilike(f"%{detect_pager.detect_name}%"))
|
||
query = query.filter(and_(*filters))
|
||
pager = get_pager(query, detect_pager.pagerNum, detect_pager.pagerSize)
|
||
pager.data = [ProjectDetectOut.from_orm(user) for user in pager.data]
|
||
pager.data = jsonable_encoder(pager.data)
|
||
return pager
|
||
|
||
|
||
def get_detect_list(project_id: int, session: Session):
|
||
"""
|
||
根据项目id查询所有集合的列表
|
||
:param project_id:
|
||
:param session:
|
||
:return:
|
||
"""
|
||
query = session.query(ProjectDetect).filter_by(project_id=project_id).order_by(asc(ProjectDetect.id))
|
||
result = [ProjectDetectList.from_orm(detect).dict() for detect in query.all()]
|
||
return result
|
||
|
||
|
||
def update_detect_status(detect_id: int, detect_status: int, session: Session):
|
||
"""
|
||
更新项目训练状态,如果是已完成的话,train_version自动+1
|
||
:param detect_id:
|
||
:param detect_status: 0-未运行,1-运行中,2-已完成,-1-执行失败
|
||
:param session:
|
||
:return:
|
||
"""
|
||
if detect_status == 2:
|
||
session.query(ProjectDetect).filter_by(id=detect_id).update({
|
||
'detect_status': detect_status,
|
||
'detect_version': ProjectDetect.detect_version + 1
|
||
})
|
||
else:
|
||
session.query(ProjectDetect).filter_by(id=detect_id).update({
|
||
'detect_status': detect_status
|
||
})
|
||
session.commit()
|
||
|
||
|
||
def check_img_name(detect_id: int, file_name: str, session: Session):
|
||
"""
|
||
校验上传的图片名称是否重名
|
||
:param detect_id:
|
||
:param file_name:
|
||
:param session:
|
||
:return:
|
||
"""
|
||
image = session.query(ProjectDetectImg).filter_by(detect_id=detect_id).filter_by(file_name=file_name).first()
|
||
if image is None:
|
||
return True, None
|
||
else:
|
||
return False, image.file_name
|
||
|
||
|
||
def add_detect_imgs(detect_imgs: List[ProjectDetectImg], session: Session):
|
||
"""
|
||
添加推理集合中的图片
|
||
:param detect_imgs:
|
||
:param session:
|
||
:return:
|
||
"""
|
||
session.add_all(detect_imgs)
|
||
session.commit()
|
||
|
||
|
||
def get_img_list(detect_id: int, session: Session):
|
||
"""
|
||
获取训练集合中的图片列表
|
||
:param detect_id:
|
||
:param session:
|
||
:return:
|
||
"""
|
||
query = session.query(ProjectDetectImg).filter_by(detect_id=detect_id).order_by(asc(ProjectDetectImg.id))
|
||
image_list = [ProjectDetectImageOut.from_orm(image) for image in query.all()]
|
||
return image_list
|
||
|
||
|
||
def get_img_pager(detect_img_pager: ProjectDetectImgPager, session: Session):
|
||
"""
|
||
获取训练集合中的图片列表,返回pager对象
|
||
:param detect_img_pager:
|
||
:param session:
|
||
:return:
|
||
"""
|
||
query = session.query(ProjectDetectImg).filter_by(detect_id=detect_img_pager.detect_id)\
|
||
.order_by(asc(ProjectDetectImg.id))
|
||
pager = get_pager(query, detect_img_pager.pagerNum, detect_img_pager.pagerSize)
|
||
pager.data = [ProjectDetectImageOut.from_orm(image) for image in pager.data]
|
||
pager.data = jsonable_encoder(pager.data)
|
||
return pager
|
||
|
||
|
||
def get_detect_img_url(image_id: int, session: Session):
|
||
result = session.query(ProjectDetectImg).filter_by(id=image_id).first()
|
||
sour_url = result.image_url
|
||
thumb_url = result.thumb_image_url
|
||
return sour_url, thumb_url
|
||
|
||
|
||
def add_detect_log(detect_log: ProjectDetectLog, session: Session):
|
||
"""
|
||
新增推理记录
|
||
:param detect_log:
|
||
:param session:
|
||
:return:
|
||
"""
|
||
session.add(detect_log)
|
||
session.commit()
|
||
return detect_log
|
||
|
||
|
||
def add_detect_log_imgs(detect_log_imgs: List[ProjectDetectLogImg], session: Session):
|
||
"""
|
||
新增推理记录中的照片
|
||
:param detect_log_imgs:
|
||
:param session:
|
||
:return:
|
||
"""
|
||
session.add_all(detect_log_imgs)
|
||
session.commit()
|
||
|
||
|
||
def get_log_list(detect_id: int, session: Session):
|
||
"""
|
||
获取推理记录
|
||
:param detect_id:
|
||
:param session:
|
||
:return:
|
||
"""
|
||
query = session.query(ProjectDetectLog).filter_by(detect_id=detect_id).order_by(asc(ProjectDetectLog.id))
|
||
result = [ProjectDetectLogOut.from_orm(log) for log in query.all()]
|
||
return result
|
||
|
||
|
||
def get_log_pager(detect_log_pager: ProjectDetectLogPager, session: Session):
|
||
"""
|
||
获取分页数据
|
||
:param detect_log_pager:
|
||
:param session:
|
||
:return:
|
||
"""
|
||
query = session.query(ProjectDetectLog).filter_by(detect_id=detect_log_pager.detect_id)\
|
||
.order_by(asc(ProjectDetectLog.id))
|
||
pager = get_pager(query, detect_log_pager.pagerNum, detect_log_pager.pagerSize)
|
||
pager.data = [ProjectDetectLogOut.from_orm(log) for log in pager.data]
|
||
pager.data = jsonable_encoder(pager.data)
|
||
return pager
|
||
|
||
|
||
def get_log_imgs(log_id: int, session: Session):
|
||
query = session.query(ProjectDetectLogImg).filter_by(log_id=log_id).order_by(asc(ProjectDetectLogImg.id))
|
||
result = [ProjectDetectLogImgOut.from_orm(img) for img in query.all()]
|
||
return result
|
||
|
||
|
||
def get_log_img_url(img_id: int, session: Session):
|
||
result = session.query(ProjectDetectLogImg).filter_by(id=img_id).first()
|
||
if result is None:
|
||
return None
|
||
else:
|
||
return result.image_url
|