223 lines
6.8 KiB
Python
223 lines
6.8 KiB
Python
from app.model.crud import project_type_crud as ptc
|
||
from app.model.crud import project_label_crud as plc
|
||
from app.model.crud import project_info_crud as pic
|
||
from app.model.crud import project_image_crud as pimc
|
||
from app.model.crud import project_train_crud as ptnc
|
||
from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoPager
|
||
from app.model.schemas.project_label_schemas import ProjectLabel
|
||
from app.model.schemas.project_image_schemas import ProjectImgLeaferLabel
|
||
from app.model.bussiness_model import ProjectLabel as pl
|
||
from app.common.jwt_check import get_user_id
|
||
from app.common import reponse_code as rc
|
||
from app.service import project_service as ps
|
||
from app.db.db_session import get_db
|
||
|
||
from typing import List
|
||
from fastapi import APIRouter, Depends, Request, UploadFile, File, Form
|
||
from fastapi.responses import StreamingResponse
|
||
from fastapi.encoders import jsonable_encoder
|
||
from sqlalchemy.orm import Session
|
||
|
||
"""项目管理API"""
|
||
project = APIRouter()
|
||
|
||
|
||
@project.get("/types")
|
||
def get_type_list(session: Session = Depends(get_db)):
|
||
"""获取项目类别"""
|
||
type_list = ptc.get_list(session)
|
||
return rc.response_success(data=type_list)
|
||
|
||
|
||
@project.post("/list")
|
||
def project_pager(info: ProjectInfoPager, session: Session = Depends(get_db)):
|
||
"""
|
||
|
||
:param info:
|
||
:param session:
|
||
:return:
|
||
"""
|
||
pager = pic.get_project_pager(info, session)
|
||
return rc.response_success_pager(pager)
|
||
|
||
|
||
@project.post("/add")
|
||
def add_project(request: Request, info: ProjectInfoIn, session: Session = Depends(get_db)):
|
||
"""
|
||
新建项目
|
||
:param request:
|
||
:param info:
|
||
:param session:
|
||
:return:
|
||
"""
|
||
if pic.check_project_name(info.project_name, session):
|
||
return rc.response_error("已经存在相同名称的项目")
|
||
user_id = get_user_id(request)
|
||
project_id = ps.add_project(info, session, user_id)
|
||
return rc.response_success(msg="新建成功", data=project_id)
|
||
|
||
|
||
@project.get("/info/{project_id}")
|
||
def get_project(project_id: int, session: Session = Depends(get_db)):
|
||
"""
|
||
根据项目id获取详情
|
||
:param project_id:
|
||
:param session:
|
||
:return:
|
||
"""
|
||
project_info = pic.get_project_by_id(project_id, session)
|
||
return rc.response_success(data=project_info.dict())
|
||
|
||
|
||
@project.get("/label_list/{project_id}")
|
||
def get_label_list(project_id: int, session: Session = Depends(get_db)):
|
||
"""
|
||
根据项目id查询项目标签列表
|
||
:param project_id:
|
||
:param session:
|
||
:return:
|
||
"""
|
||
label_list = plc.get_label_list(project_id, session)
|
||
return rc.response_success(msg="查询成功", data=label_list)
|
||
|
||
|
||
@project.post("/add_label")
|
||
def add_label(label: ProjectLabel, session: Session = Depends(get_db)):
|
||
"""
|
||
新增标签
|
||
:param label:
|
||
:param session:
|
||
:return:
|
||
"""
|
||
if plc.check_label_name(label.project_id, label.label_name, session):
|
||
return rc.response_error("标签名称已经存在,不能重复")
|
||
label_save = pl(**label.dict())
|
||
plc.add_label(label_save, session)
|
||
return rc.response_success(msg="保存成功")
|
||
|
||
|
||
@project.post("/up_label")
|
||
def up_label(label: ProjectLabel, session: Session = Depends(get_db)):
|
||
"""
|
||
修改标签
|
||
:param label:
|
||
:param session:
|
||
:return:
|
||
"""
|
||
if plc.check_label_name(label.project_id, label.label_name, session, label.id):
|
||
return rc.response_error("修改的标签名称已经存在,不能重复")
|
||
label_save = pl(**label.dict())
|
||
plc.update_label(label_save, session)
|
||
return rc.response_success(msg="修改成功")
|
||
|
||
|
||
@project.post("/del_label/{label_id}")
|
||
def del_label(label_id: int, session: Session = Depends(get_db)):
|
||
"""
|
||
删除标签
|
||
:param label_id:
|
||
:param session:
|
||
:return:
|
||
"""
|
||
row_del = plc.del_label(label_id, session)
|
||
if row_del > 0:
|
||
return rc.response_success(msg="删除成功")
|
||
else:
|
||
return rc.response_error("删除失败")
|
||
|
||
|
||
@project.post("/up_proj_img")
|
||
def upload_project_image(project_id: int = Form(...),
|
||
files: List[UploadFile] = File(...),
|
||
session: Session = Depends(get_db)):
|
||
"""
|
||
上传项目图片
|
||
:param files: 文件图片
|
||
:param project_id:
|
||
:param session:
|
||
:return:
|
||
"""
|
||
project_info = pic.get_project_by_id(project_id, session)
|
||
if project_info is None:
|
||
return rc.response_error("项目查询错误,请刷新页面后再试")
|
||
is_check, file_name = ps.check_image_name(project_id, files, session)
|
||
if not is_check:
|
||
return rc.response_error(msg="存在重名的图片文件:" + file_name)
|
||
ps.upload_project_image(project_info, files, session)
|
||
return rc.response_success(msg="上传成功")
|
||
|
||
|
||
@project.get("/img_list/{project_id}")
|
||
def get_image_list(project_id: int, session: Session = Depends(get_db)):
|
||
"""
|
||
获取项目图片列表
|
||
:param project_id: 项目id
|
||
:param session:
|
||
:return:
|
||
"""
|
||
image_list = pimc.get_image_list(project_id, session)
|
||
result = jsonable_encoder(image_list)
|
||
return rc.response_success(data=result)
|
||
|
||
|
||
@project.post("/save_img_label")
|
||
def save_img_label(img_leafer_label: ProjectImgLeaferLabel, session: Session = Depends(get_db)):
|
||
"""
|
||
保存图片的标签框选信息
|
||
:param img_leafer_label:
|
||
:param session:
|
||
:return:
|
||
"""
|
||
ps.save_img_label(img_leafer_label, session)
|
||
return rc.response_success(msg="保存成功")
|
||
|
||
|
||
@project.get("/get_img_leafer/{image_id}")
|
||
def get_img_leafer(image_id: int, session: Session = Depends(get_db)):
|
||
"""
|
||
根据图片id查询图片的leafer信息
|
||
:param image_id:
|
||
:param session:
|
||
:return:
|
||
"""
|
||
img_leafer_out = ps.get_img_leafer(image_id, session)
|
||
if img_leafer_out is None:
|
||
return rc.response_success()
|
||
else:
|
||
return rc.response_success(data=img_leafer_out['leafer'])
|
||
|
||
|
||
@project.get("/run_train/{project_id}")
|
||
async def run_train(project_id: int, session: Session = Depends(get_db)):
|
||
"""
|
||
执行项目训练方法
|
||
:param project_id:
|
||
:param session:
|
||
:return:
|
||
"""
|
||
project_info = pic.get_project_by_id(project_id, session)
|
||
if project_info is None:
|
||
return rc.response_error("项目查询错误")
|
||
if project_info.project_status == '1':
|
||
return rc.response_error("项目当前存在训练进程,请稍后再试")
|
||
data, project_name, name = ps.run_train_yolo(project_info, session)
|
||
return StreamingResponse(
|
||
ps.run_commend(data, project_name, name, 10, project_id, session),
|
||
media_type="text/plain")
|
||
|
||
|
||
@project.get("/get_train_list/{project_id}")
|
||
def get_train_list(project_id: int, session: Session = Depends(get_db)):
|
||
"""
|
||
根据项目id,获取训练列表
|
||
:param project_id:
|
||
:param session:
|
||
:return:
|
||
"""
|
||
train_list = ptnc.get_train_list(project_id, session)
|
||
result = jsonable_encoder(train_list)
|
||
return rc.response_success(data=result)
|
||
|
||
|
||
|