完成推理模块的主体功能

This commit is contained in:
2025-03-04 17:04:37 +08:00
parent 4262d3e908
commit fa6c344e84
9 changed files with 325 additions and 21 deletions

View File

@ -2,10 +2,10 @@ from typing import List
from sqlalchemy.orm import Session
from sqlalchemy import asc, and_
from fastapi.encoders import jsonable_encoder
from fastapi import UploadFile
from app.model.bussiness_model import ProjectDetect, ProjectDetectImg, ProjectDetectLog, ProjectDetectLogImg
from app.model.schemas.project_detect_schemas import ProjectDetectOut, ProjectDetectPager
from app.model.schemas.project_detect_schemas import ProjectDetectOut, ProjectDetectPager, \
ProjectDetectImageOut, ProjectDetectImgPager, ProjectDetectLogOut, ProjectDetectLogImgOut
from app.db.page_util import get_pager
@ -21,6 +21,18 @@ def add_detect(detect: ProjectDetect, session: Session):
return detect
def get_detect_by_id(detect_id: int, session: Session):
"""
根据id查询训练集合
:param detect_id:
:param session:
:return:
"""
detect = session.query(ProjectDetect).filter_by(id=detect_id).first()
detect_out = ProjectDetectOut.from_orm(detect)
return detect_out
def get_detect_pager(detect_pager: ProjectDetectPager, session: Session):
"""
查询推理集合分页数据
@ -39,6 +51,26 @@ def get_detect_pager(detect_pager: ProjectDetectPager, session: Session):
return pager
def update_detect_status(detect_id: int, detect_status: int, session: Session):
"""
更新项目训练状态,如果是已完成的话train_version自动+1
:param detect_id:
:param detect_status: 0-未运行1-运行中2-已完成,-1-执行失败
:param session:
:return:
"""
if detect_status == 2:
session.query(ProjectDetect).filter_by(id=detect_id).update({
'detect_status': detect_status,
'detect_version': ProjectDetect.detect_version + 1
})
else:
session.query(ProjectDetect).filter_by(id=detect_id).update({
'detect_status': detect_status
})
session.commit()
def check_img_name(detect_id: int, file_name: str, session: Session):
"""
校验上传的图片名称是否重名
@ -65,6 +97,40 @@ def add_detect_imgs(detect_imgs: List[ProjectDetectImg], session: Session):
session.commit()
def get_img_list(detect_id: int, session: Session):
"""
获取训练集合中的图片列表
:param detect_id:
:param session:
:return:
"""
query = session.query(ProjectDetectImg).filter_by(detect_id=detect_id).order_by(asc(ProjectDetectImg.id))
image_list = [ProjectDetectImageOut.from_orm(image) for image in query.all()]
return image_list
def get_img_pager(detect_img_pager: ProjectDetectImgPager, session: Session):
"""
获取训练集合中的图片列表,返回pager对象
:param detect_img_pager:
:param session:
:return:
"""
query = session.query(ProjectDetectImg).filter_by(detect_id=detect_img_pager.detect_id)\
.order_by(asc(ProjectDetectImg.id))
pager = get_pager(query, detect_img_pager.pagerNum, detect_img_pager.pagerSize)
pager.data = [ProjectDetectImageOut.from_orm(image) for image in pager.data]
pager.data = jsonable_encoder(pager.data)
return pager
def get_detect_img_url(image_id: int, session: Session):
result = session.query(ProjectDetectImg).filter_by(id=image_id).first()
sour_url = result.image_url
thumb_url = result.thumb_image_url
return sour_url, thumb_url
def add_detect_log(detect_log: ProjectDetectLog, session: Session):
"""
新增推理记录
@ -85,4 +151,23 @@ def add_detect_log_imgs(detect_log_imgs: List[ProjectDetectLogImg], session: Ses
:return:
"""
session.add_all(detect_log_imgs)
session.commit()
session.commit()
def get_log_list(detect_id: int, session: Session):
"""
获取推理记录
:param detect_id:
:param session:
:return:
"""
query = session.query(ProjectDetectLog).filter_by(detect_id=detect_id).order_by(asc(ProjectDetectLog.id))
result = [ProjectDetectLogOut.from_orm(log) for log in query.all()]
return result
def get_log_imgs(log_id: int, session: Session):
query = session.query(ProjectDetectLogImg).filter_by(log_id=log_id).order_by(asc(ProjectDetectLogImg.id))
result = [ProjectDetectLogImgOut.from_orm(img) for img in query.all()]
return result