优化一些细节
This commit is contained in:
@ -2,14 +2,15 @@ from app.model.crud import project_type_crud as ptc
|
||||
from app.model.crud import project_label_crud as plc
|
||||
from app.model.crud import project_info_crud as pic
|
||||
from app.model.crud import project_image_crud as pimc
|
||||
from app.service import project_service as ps
|
||||
from app.model.crud import project_train_crud as ptnc
|
||||
from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoPager
|
||||
from app.model.schemas.project_label_schemas import ProjectLabel
|
||||
from app.model.schemas.project_image_schemas import ProjectImgLeaferLabel
|
||||
from app.model.bussiness_model import ProjectLabel as pl
|
||||
from app.db.db_session import get_db
|
||||
from app.common.jwt_check import get_user_id
|
||||
from app.common import reponse_code as rc
|
||||
from app.service import project_service as ps
|
||||
from app.db.db_session import get_db
|
||||
|
||||
from typing import List
|
||||
from fastapi import APIRouter, Depends, Request, UploadFile, File, Form
|
||||
@ -35,12 +36,30 @@ def project_pager(info: ProjectInfoPager, session: Session = Depends(get_db)):
|
||||
|
||||
@project.post("/add")
|
||||
def add_project(request: Request, info: ProjectInfoIn, session: Session = Depends(get_db)):
|
||||
"""新建项目"""
|
||||
"""
|
||||
新建项目
|
||||
:param request:
|
||||
:param info:
|
||||
:param session:
|
||||
:return:
|
||||
"""
|
||||
if pic.check_project_name(info.project_name, session):
|
||||
return rc.response_error("已经存在相同名称的项目")
|
||||
user_id = get_user_id(request)
|
||||
ps.add_project(info, session, user_id)
|
||||
return rc.response_success(msg="新建成功")
|
||||
project_id = ps.add_project(info, session, user_id)
|
||||
return rc.response_success(msg="新建成功", data=project_id)
|
||||
|
||||
|
||||
@project.get("/info/{project_id}")
|
||||
def get_project(project_id: int, session: Session = Depends(get_db)):
|
||||
"""
|
||||
根据项目id获取详情
|
||||
:param project_id:
|
||||
:param session:
|
||||
:return:
|
||||
"""
|
||||
project_info = pic.get_project_by_id(project_id, session)
|
||||
return rc.response_success(data=project_info.dict())
|
||||
|
||||
|
||||
@project.get("/label_list/{project_id}")
|
||||
@ -169,3 +188,17 @@ async def run_train(project_id: int, session: Session = Depends(get_db)):
|
||||
return StreamingResponse(
|
||||
ps.run_commend(data, project_name, name, 10, project_id, session),
|
||||
media_type="text/plain")
|
||||
|
||||
|
||||
def get_train_list(project_id: int, session: Session = Depends(get_db)):
|
||||
"""
|
||||
根据项目id,获取训练列表
|
||||
:param project_id:
|
||||
:param session:
|
||||
:return:
|
||||
"""
|
||||
train_list = ptnc.get_train_list(project_id, session)
|
||||
return rc.response_success(data=train_list)
|
||||
|
||||
|
||||
|
||||
|
@ -17,7 +17,6 @@ class ProjectInfoOut(BaseModel):
|
||||
project_name: Optional[str] = Field(..., description="项目名称")
|
||||
type_code: Optional[str] = Field(..., description="项目类型编码")
|
||||
description: Optional[str] = Field(None, description="项目描述")
|
||||
user_name: Optional[str] = Field(None, description="创建人")
|
||||
train_version: Optional[int] = Field(None, description="训练版本号")
|
||||
project_status: Optional[str] = Field(None, description="项目状态")
|
||||
|
||||
|
@ -1,3 +1,5 @@
|
||||
import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
|
||||
@ -6,8 +8,7 @@ class ProjectTrainOut(BaseModel):
|
||||
"""项目训练版本信息表"""
|
||||
id: Optional[int] = Field(None, description="训练id")
|
||||
train_version: Optional[str] = Field(None, description="训练版本号")
|
||||
best_pt: Optional[str] = Field(None, description="最好")
|
||||
last_pt: Optional[str] = Field(None, description="最后")
|
||||
create_time: Optional[datetime.datetime] = Field(None, description="训练时间")
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
@ -14,7 +14,6 @@ from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
from fastapi import UploadFile
|
||||
import yaml
|
||||
import select
|
||||
import subprocess
|
||||
|
||||
|
||||
@ -33,8 +32,8 @@ def add_project(info: ProjectInfoIn, session: Session, user_id: int):
|
||||
project_info.train_version = 0
|
||||
os.create_folder(datasets_url, project_info.project_no)
|
||||
os.create_folder(runs_url, project_info.project_no)
|
||||
pic.add_project(project_info, session)
|
||||
return project_info
|
||||
project_info = pic.add_project(project_info, session)
|
||||
return project_info.id
|
||||
|
||||
|
||||
def upload_project_image(project_info: ProjectInfoOut, files: List[UploadFile], session: Session):
|
||||
|
Reference in New Issue
Block a user