完成推理模块的主体功能
This commit is contained in:
@ -116,9 +116,11 @@ class ProjectDetectLog(DbCommon):
|
||||
__tablename__ = "project_detect_log"
|
||||
detect_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
detect_version: Mapped[str] = mapped_column(String(10))
|
||||
detect_name: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
train_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
train_version: Mapped[str] = mapped_column(String(10))
|
||||
pt_type: Mapped[str] = mapped_column(String(32))
|
||||
pt_type: Mapped[str] = mapped_column(String(10))
|
||||
pt_url: Mapped[str] = mapped_column(String(255))
|
||||
folder_url: Mapped[str] = mapped_column(String(255))
|
||||
detect_folder_url: Mapped[str] = mapped_column(String(255))
|
||||
|
||||
|
@ -2,10 +2,10 @@ from typing import List
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import asc, and_
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi import UploadFile
|
||||
|
||||
from app.model.bussiness_model import ProjectDetect, ProjectDetectImg, ProjectDetectLog, ProjectDetectLogImg
|
||||
from app.model.schemas.project_detect_schemas import ProjectDetectOut, ProjectDetectPager
|
||||
from app.model.schemas.project_detect_schemas import ProjectDetectOut, ProjectDetectPager, \
|
||||
ProjectDetectImageOut, ProjectDetectImgPager, ProjectDetectLogOut, ProjectDetectLogImgOut
|
||||
from app.db.page_util import get_pager
|
||||
|
||||
|
||||
@ -21,6 +21,18 @@ def add_detect(detect: ProjectDetect, session: Session):
|
||||
return detect
|
||||
|
||||
|
||||
def get_detect_by_id(detect_id: int, session: Session):
|
||||
"""
|
||||
根据id查询训练集合
|
||||
:param detect_id:
|
||||
:param session:
|
||||
:return:
|
||||
"""
|
||||
detect = session.query(ProjectDetect).filter_by(id=detect_id).first()
|
||||
detect_out = ProjectDetectOut.from_orm(detect)
|
||||
return detect_out
|
||||
|
||||
|
||||
def get_detect_pager(detect_pager: ProjectDetectPager, session: Session):
|
||||
"""
|
||||
查询推理集合分页数据
|
||||
@ -39,6 +51,26 @@ def get_detect_pager(detect_pager: ProjectDetectPager, session: Session):
|
||||
return pager
|
||||
|
||||
|
||||
def update_detect_status(detect_id: int, detect_status: int, session: Session):
|
||||
"""
|
||||
更新项目训练状态,如果是已完成的话,train_version自动+1
|
||||
:param detect_id:
|
||||
:param detect_status: 0-未运行,1-运行中,2-已完成,-1-执行失败
|
||||
:param session:
|
||||
:return:
|
||||
"""
|
||||
if detect_status == 2:
|
||||
session.query(ProjectDetect).filter_by(id=detect_id).update({
|
||||
'detect_status': detect_status,
|
||||
'detect_version': ProjectDetect.detect_version + 1
|
||||
})
|
||||
else:
|
||||
session.query(ProjectDetect).filter_by(id=detect_id).update({
|
||||
'detect_status': detect_status
|
||||
})
|
||||
session.commit()
|
||||
|
||||
|
||||
def check_img_name(detect_id: int, file_name: str, session: Session):
|
||||
"""
|
||||
校验上传的图片名称是否重名
|
||||
@ -65,6 +97,40 @@ def add_detect_imgs(detect_imgs: List[ProjectDetectImg], session: Session):
|
||||
session.commit()
|
||||
|
||||
|
||||
def get_img_list(detect_id: int, session: Session):
|
||||
"""
|
||||
获取训练集合中的图片列表
|
||||
:param detect_id:
|
||||
:param session:
|
||||
:return:
|
||||
"""
|
||||
query = session.query(ProjectDetectImg).filter_by(detect_id=detect_id).order_by(asc(ProjectDetectImg.id))
|
||||
image_list = [ProjectDetectImageOut.from_orm(image) for image in query.all()]
|
||||
return image_list
|
||||
|
||||
|
||||
def get_img_pager(detect_img_pager: ProjectDetectImgPager, session: Session):
|
||||
"""
|
||||
获取训练集合中的图片列表,返回pager对象
|
||||
:param detect_img_pager:
|
||||
:param session:
|
||||
:return:
|
||||
"""
|
||||
query = session.query(ProjectDetectImg).filter_by(detect_id=detect_img_pager.detect_id)\
|
||||
.order_by(asc(ProjectDetectImg.id))
|
||||
pager = get_pager(query, detect_img_pager.pagerNum, detect_img_pager.pagerSize)
|
||||
pager.data = [ProjectDetectImageOut.from_orm(image) for image in pager.data]
|
||||
pager.data = jsonable_encoder(pager.data)
|
||||
return pager
|
||||
|
||||
|
||||
def get_detect_img_url(image_id: int, session: Session):
|
||||
result = session.query(ProjectDetectImg).filter_by(id=image_id).first()
|
||||
sour_url = result.image_url
|
||||
thumb_url = result.thumb_image_url
|
||||
return sour_url, thumb_url
|
||||
|
||||
|
||||
def add_detect_log(detect_log: ProjectDetectLog, session: Session):
|
||||
"""
|
||||
新增推理记录
|
||||
@ -85,4 +151,23 @@ def add_detect_log_imgs(detect_log_imgs: List[ProjectDetectLogImg], session: Ses
|
||||
:return:
|
||||
"""
|
||||
session.add_all(detect_log_imgs)
|
||||
session.commit()
|
||||
session.commit()
|
||||
|
||||
|
||||
def get_log_list(detect_id: int, session: Session):
|
||||
"""
|
||||
获取推理记录
|
||||
:param detect_id:
|
||||
:param session:
|
||||
:return:
|
||||
"""
|
||||
query = session.query(ProjectDetectLog).filter_by(detect_id=detect_id).order_by(asc(ProjectDetectLog.id))
|
||||
result = [ProjectDetectLogOut.from_orm(log) for log in query.all()]
|
||||
return result
|
||||
|
||||
|
||||
def get_log_imgs(log_id: int, session: Session):
|
||||
query = session.query(ProjectDetectLogImg).filter_by(log_id=log_id).order_by(asc(ProjectDetectLogImg.id))
|
||||
result = [ProjectDetectLogImgOut.from_orm(img) for img in query.all()]
|
||||
return result
|
||||
|
||||
|
@ -19,7 +19,7 @@ class ProjectDetectPager(BaseModel):
|
||||
class ProjectDetectOut(BaseModel):
|
||||
id: Optional[int]
|
||||
project_id: Optional[int]
|
||||
detect_name: Optional[int]
|
||||
detect_name: Optional[str]
|
||||
detect_no: Optional[str]
|
||||
detect_version: Optional[int]
|
||||
file_type: Optional[str]
|
||||
@ -33,6 +33,25 @@ class ProjectDetectOut(BaseModel):
|
||||
}
|
||||
|
||||
|
||||
class ProjectDetectImgPager(BaseModel):
|
||||
detect_id: Optional[int] = Field(..., description="训练集合id")
|
||||
pagerNum: Optional[int] = Field(None, description="当前页码")
|
||||
pagerSize: Optional[int] = Field(None, description="每页数量")
|
||||
|
||||
|
||||
class ProjectDetectImageOut(BaseModel):
|
||||
id: Optional[int] = Field(None, description="id")
|
||||
detect_id: Optional[int] = Field(..., description="训练集合id")
|
||||
file_name: Optional[str] = Field(None, description="文件名称")
|
||||
create_time: Optional[datetime] = Field(None, description="上传时间")
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
json_encoders = {
|
||||
datetime: lambda v: v.strftime("%Y-%m-%d %H:%M:%S")
|
||||
}
|
||||
|
||||
|
||||
class ProjectDetectLogIn(BaseModel):
|
||||
detect_id: Optional[int] = Field(..., description="推理集合id")
|
||||
train_id: Optional[int] = Field(..., description="训练结果id")
|
||||
@ -43,6 +62,7 @@ class ProjectDetectLogOut(BaseModel):
|
||||
id: Optional[int]
|
||||
detect_id: Optional[int]
|
||||
detect_version: Optional[str]
|
||||
detect_name: Optional[str]
|
||||
train_id: Optional[int]
|
||||
train_version: Optional[int]
|
||||
pt_type: Optional[str]
|
||||
|
Reference in New Issue
Block a user