From ef95c99f44f310a7b0e545ee099963e45ef555d6 Mon Sep 17 00:00:00 2001 From: sunyugang Date: Tue, 25 Feb 2025 09:22:37 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E4=B8=80=E4=BA=9B=E7=BB=86?= =?UTF-8?q?=E8=8A=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/business/project_api.py | 43 +++++++++++++++++++--- app/model/schemas/project_info_schemas.py | 1 - app/model/schemas/project_train_schemas.py | 5 ++- app/service/project_service.py | 5 +-- 4 files changed, 43 insertions(+), 11 deletions(-) diff --git a/app/api/business/project_api.py b/app/api/business/project_api.py index 670b849..a3da4ff 100644 --- a/app/api/business/project_api.py +++ b/app/api/business/project_api.py @@ -2,14 +2,15 @@ from app.model.crud import project_type_crud as ptc from app.model.crud import project_label_crud as plc from app.model.crud import project_info_crud as pic from app.model.crud import project_image_crud as pimc -from app.service import project_service as ps +from app.model.crud import project_train_crud as ptnc from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoPager from app.model.schemas.project_label_schemas import ProjectLabel from app.model.schemas.project_image_schemas import ProjectImgLeaferLabel from app.model.bussiness_model import ProjectLabel as pl -from app.db.db_session import get_db from app.common.jwt_check import get_user_id from app.common import reponse_code as rc +from app.service import project_service as ps +from app.db.db_session import get_db from typing import List from fastapi import APIRouter, Depends, Request, UploadFile, File, Form @@ -35,12 +36,30 @@ def project_pager(info: ProjectInfoPager, session: Session = Depends(get_db)): @project.post("/add") def add_project(request: Request, info: ProjectInfoIn, session: Session = Depends(get_db)): - """新建项目""" + """ + 新建项目 + :param request: + :param info: + :param session: + :return: + """ if pic.check_project_name(info.project_name, session): return rc.response_error("已经存在相同名称的项目") user_id = get_user_id(request) - ps.add_project(info, session, user_id) - return rc.response_success(msg="新建成功") + project_id = ps.add_project(info, session, user_id) + return rc.response_success(msg="新建成功", data=project_id) + + +@project.get("/info/{project_id}") +def get_project(project_id: int, session: Session = Depends(get_db)): + """ + 根据项目id获取详情 + :param project_id: + :param session: + :return: + """ + project_info = pic.get_project_by_id(project_id, session) + return rc.response_success(data=project_info.dict()) @project.get("/label_list/{project_id}") @@ -169,3 +188,17 @@ async def run_train(project_id: int, session: Session = Depends(get_db)): return StreamingResponse( ps.run_commend(data, project_name, name, 10, project_id, session), media_type="text/plain") + + +def get_train_list(project_id: int, session: Session = Depends(get_db)): + """ + 根据项目id,获取训练列表 + :param project_id: + :param session: + :return: + """ + train_list = ptnc.get_train_list(project_id, session) + return rc.response_success(data=train_list) + + + diff --git a/app/model/schemas/project_info_schemas.py b/app/model/schemas/project_info_schemas.py index 9392ad7..7c489b5 100644 --- a/app/model/schemas/project_info_schemas.py +++ b/app/model/schemas/project_info_schemas.py @@ -17,7 +17,6 @@ class ProjectInfoOut(BaseModel): project_name: Optional[str] = Field(..., description="项目名称") type_code: Optional[str] = Field(..., description="项目类型编码") description: Optional[str] = Field(None, description="项目描述") - user_name: Optional[str] = Field(None, description="创建人") train_version: Optional[int] = Field(None, description="训练版本号") project_status: Optional[str] = Field(None, description="项目状态") diff --git a/app/model/schemas/project_train_schemas.py b/app/model/schemas/project_train_schemas.py index 1ca73b6..43ccf27 100644 --- a/app/model/schemas/project_train_schemas.py +++ b/app/model/schemas/project_train_schemas.py @@ -1,3 +1,5 @@ +import datetime + from pydantic import BaseModel, Field from typing import Optional @@ -6,8 +8,7 @@ class ProjectTrainOut(BaseModel): """项目训练版本信息表""" id: Optional[int] = Field(None, description="训练id") train_version: Optional[str] = Field(None, description="训练版本号") - best_pt: Optional[str] = Field(None, description="最好") - last_pt: Optional[str] = Field(None, description="最后") + create_time: Optional[datetime.datetime] = Field(None, description="训练时间") class Config: orm_mode = True diff --git a/app/service/project_service.py b/app/service/project_service.py index d65bfcf..1249b9b 100644 --- a/app/service/project_service.py +++ b/app/service/project_service.py @@ -14,7 +14,6 @@ from sqlalchemy.orm import Session from typing import List from fastapi import UploadFile import yaml -import select import subprocess @@ -33,8 +32,8 @@ def add_project(info: ProjectInfoIn, session: Session, user_id: int): project_info.train_version = 0 os.create_folder(datasets_url, project_info.project_no) os.create_folder(runs_url, project_info.project_no) - pic.add_project(project_info, session) - return project_info + project_info = pic.add_project(project_info, session) + return project_info.id def upload_project_image(project_info: ProjectInfoOut, files: List[UploadFile], session: Session):