重构基本完成
This commit is contained in:
@ -59,3 +59,12 @@ class ProjectImgLabel(DbCommon):
|
||||
mark_center_y: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
mark_width: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
mark_height: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
|
||||
|
||||
class ProjectTrain(DbCommon):
|
||||
"""项目训练版本信息表"""
|
||||
__tablename__ = "project_train"
|
||||
project_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
train_version: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||
best_pt: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
last_pt: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
|
@ -20,6 +20,11 @@ def get_image_list(project_id: int, session: Session):
|
||||
return image_list
|
||||
|
||||
|
||||
def get_images(project_id: int, session: Session):
|
||||
query = session.query(piModel).filter(piModel.project_id == project_id).order_by(asc(piModel.id))
|
||||
return query.all()
|
||||
|
||||
|
||||
def add_image(image: ProjectImage, session: Session):
|
||||
session.add(image)
|
||||
session.commit()
|
||||
|
@ -16,6 +16,17 @@ def get_img_leafer(image_id: int, session: Session):
|
||||
return img_leafer
|
||||
|
||||
|
||||
def get_img_label_list(image_id, session: Session):
|
||||
"""
|
||||
根据图片id获取图片标签信息
|
||||
:param image_id:
|
||||
:param session:
|
||||
:return:
|
||||
"""
|
||||
img_label_list = session.query(ProjectImgLabel).filter_by(image_id=image_id).all()
|
||||
return img_label_list
|
||||
|
||||
|
||||
def save_img_leafer(leafer: ProjectImgLeafer, session: Session):
|
||||
leafer_saved = session.query(ProjectImgLeafer).filter_by(image_id=leafer.image_id).first()
|
||||
if leafer_saved is not None:
|
||||
|
@ -1,5 +1,5 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import desc, update
|
||||
|
||||
from app.model.bussiness_model import ProjectInfo
|
||||
from app.model.schemas.project_info_schemas import ProjectInfoOut
|
||||
@ -41,3 +41,24 @@ def check_project_name(project_name: str, session: Session):
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def update_project_status(project_id: int, project_status: str, session: Session):
|
||||
"""
|
||||
更新项目训练状态,如果是已完成的话,train_version自动+1
|
||||
:param project_id:
|
||||
:param project_status: 0-未运行,1-运行中,2-已完成,-1-执行失败
|
||||
:param session:
|
||||
:return:
|
||||
"""
|
||||
if project_status == '2':
|
||||
stmt = update(ProjectInfo).where(ProjectInfo.id == project_id).values({
|
||||
'train_status': project_status,
|
||||
'train_version': ProjectInfo.train_version + 1
|
||||
})
|
||||
session.execute(stmt)
|
||||
else:
|
||||
stmt = update(ProjectInfo).where(ProjectInfo.id == project_id).values({
|
||||
'train_status': project_status
|
||||
})
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
@ -17,6 +17,16 @@ def get_label_list(project_id: int, session: Session):
|
||||
return label_list
|
||||
|
||||
|
||||
def get_label_for_train(project_id: int, session: Session):
|
||||
id_list = []
|
||||
name_list = []
|
||||
label_list = session.query(plModel).filter(plModel.project_id == project_id).all()
|
||||
for label in label_list:
|
||||
id_list.append(label.id)
|
||||
name_list.append(label.label_name)
|
||||
return id_list, name_list
|
||||
|
||||
|
||||
def add_label(label: plModel, session: Session):
|
||||
"""
|
||||
新增标签
|
||||
|
28
app/model/crud/project_train_crud.py
Normal file
28
app/model/crud/project_train_crud.py
Normal file
@ -0,0 +1,28 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import asc
|
||||
|
||||
from app.model.bussiness_model import ProjectTrain
|
||||
from app.model.schemas.project_train_schemas import ProjectTrainOut
|
||||
|
||||
|
||||
def add_train(train: ProjectTrain, session: Session):
|
||||
"""
|
||||
新增训练结果
|
||||
:param train:
|
||||
:param session:
|
||||
:return:
|
||||
"""
|
||||
session.add(train)
|
||||
session.commit()
|
||||
|
||||
|
||||
def get_train_list(project_id: int, session: Session):
|
||||
"""
|
||||
根据项目id查询训练列表
|
||||
:param project_id:
|
||||
:param session:
|
||||
:return:
|
||||
"""
|
||||
query = session.query(ProjectTrain).filter_by(project_id=project_id).order_by(asc(ProjectTrain.id))
|
||||
train_list = [ProjectTrainOut.from_orm(train) for train in query.all()]
|
||||
return train_list
|
13
app/model/schemas/project_train_schemas.py
Normal file
13
app/model/schemas/project_train_schemas.py
Normal file
@ -0,0 +1,13 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class ProjectTrainOut(BaseModel):
|
||||
"""项目训练版本信息表"""
|
||||
id: Optional[int] = Field(None, description="训练id")
|
||||
train_version: Optional[str] = Field(None, description="训练版本号")
|
||||
best_pt: Optional[str] = Field(None, description="最好")
|
||||
last_pt: Optional[str] = Field(None, description="最后")
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
Reference in New Issue
Block a user