重构基本完成

This commit is contained in:
2025-02-21 11:35:53 +08:00
parent cdd05e95ba
commit 85d0a8fadc
16 changed files with 346 additions and 8 deletions

View File

@ -59,3 +59,12 @@ class ProjectImgLabel(DbCommon):
mark_center_y: Mapped[str] = mapped_column(String(64), nullable=False)
mark_width: Mapped[str] = mapped_column(String(64), nullable=False)
mark_height: Mapped[str] = mapped_column(String(64), nullable=False)
class ProjectTrain(DbCommon):
"""项目训练版本信息表"""
__tablename__ = "project_train"
project_id: Mapped[int] = mapped_column(Integer, nullable=False)
train_version: Mapped[str] = mapped_column(String(32), nullable=False)
best_pt: Mapped[str] = mapped_column(String(255), nullable=False)
last_pt: Mapped[str] = mapped_column(String(255), nullable=False)

View File

@ -20,6 +20,11 @@ def get_image_list(project_id: int, session: Session):
return image_list
def get_images(project_id: int, session: Session):
query = session.query(piModel).filter(piModel.project_id == project_id).order_by(asc(piModel.id))
return query.all()
def add_image(image: ProjectImage, session: Session):
session.add(image)
session.commit()

View File

@ -16,6 +16,17 @@ def get_img_leafer(image_id: int, session: Session):
return img_leafer
def get_img_label_list(image_id, session: Session):
"""
根据图片id获取图片标签信息
:param image_id:
:param session:
:return:
"""
img_label_list = session.query(ProjectImgLabel).filter_by(image_id=image_id).all()
return img_label_list
def save_img_leafer(leafer: ProjectImgLeafer, session: Session):
leafer_saved = session.query(ProjectImgLeafer).filter_by(image_id=leafer.image_id).first()
if leafer_saved is not None:

View File

@ -1,5 +1,5 @@
from sqlalchemy.orm import Session
from sqlalchemy import desc
from sqlalchemy import desc, update
from app.model.bussiness_model import ProjectInfo
from app.model.schemas.project_info_schemas import ProjectInfoOut
@ -41,3 +41,24 @@ def check_project_name(project_name: str, session: Session):
else:
return False
def update_project_status(project_id: int, project_status: str, session: Session):
"""
更新项目训练状态,如果是已完成的话train_version自动+1
:param project_id:
:param project_status: 0-未运行1-运行中2-已完成,-1-执行失败
:param session:
:return:
"""
if project_status == '2':
stmt = update(ProjectInfo).where(ProjectInfo.id == project_id).values({
'train_status': project_status,
'train_version': ProjectInfo.train_version + 1
})
session.execute(stmt)
else:
stmt = update(ProjectInfo).where(ProjectInfo.id == project_id).values({
'train_status': project_status
})
session.execute(stmt)
session.commit()

View File

@ -17,6 +17,16 @@ def get_label_list(project_id: int, session: Session):
return label_list
def get_label_for_train(project_id: int, session: Session):
id_list = []
name_list = []
label_list = session.query(plModel).filter(plModel.project_id == project_id).all()
for label in label_list:
id_list.append(label.id)
name_list.append(label.label_name)
return id_list, name_list
def add_label(label: plModel, session: Session):
"""
新增标签

View File

@ -0,0 +1,28 @@
from sqlalchemy.orm import Session
from sqlalchemy import asc
from app.model.bussiness_model import ProjectTrain
from app.model.schemas.project_train_schemas import ProjectTrainOut
def add_train(train: ProjectTrain, session: Session):
"""
新增训练结果
:param train:
:param session:
:return:
"""
session.add(train)
session.commit()
def get_train_list(project_id: int, session: Session):
"""
根据项目id查询训练列表
:param project_id:
:param session:
:return:
"""
query = session.query(ProjectTrain).filter_by(project_id=project_id).order_by(asc(ProjectTrain.id))
train_list = [ProjectTrainOut.from_orm(train) for train in query.all()]
return train_list

View File

@ -0,0 +1,13 @@
from pydantic import BaseModel, Field
from typing import Optional
class ProjectTrainOut(BaseModel):
"""项目训练版本信息表"""
id: Optional[int] = Field(None, description="训练id")
train_version: Optional[str] = Field(None, description="训练版本号")
best_pt: Optional[str] = Field(None, description="最好")
last_pt: Optional[str] = Field(None, description="最后")
class Config:
orm_mode = True