Files
aicheckv2/app/api/business/project_train_api.py
2025-02-26 10:04:10 +08:00

223 lines
6.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from app.model.crud import project_type_crud as ptc
from app.model.crud import project_label_crud as plc
from app.model.crud import project_info_crud as pic
from app.model.crud import project_image_crud as pimc
from app.model.crud import project_train_crud as ptnc
from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoPager
from app.model.schemas.project_label_schemas import ProjectLabel
from app.model.schemas.project_image_schemas import ProjectImgLeaferLabel
from app.model.bussiness_model import ProjectLabel as pl
from app.common.jwt_check import get_user_id
from app.common import reponse_code as rc
from app.service import project_service as ps
from app.db.db_session import get_db
from typing import List
from fastapi import APIRouter, Depends, Request, UploadFile, File, Form
from fastapi.responses import StreamingResponse
from fastapi.encoders import jsonable_encoder
from sqlalchemy.orm import Session
"""项目管理API"""
project = APIRouter()
@project.get("/types")
def get_type_list(session: Session = Depends(get_db)):
"""获取项目类别"""
type_list = ptc.get_list(session)
return rc.response_success(data=type_list)
@project.post("/list")
def project_pager(info: ProjectInfoPager, session: Session = Depends(get_db)):
"""
:param info:
:param session:
:return:
"""
pager = pic.get_project_pager(info, session)
return rc.response_success_pager(pager)
@project.post("/add")
def add_project(request: Request, info: ProjectInfoIn, session: Session = Depends(get_db)):
"""
新建项目
:param request:
:param info:
:param session:
:return:
"""
if pic.check_project_name(info.project_name, session):
return rc.response_error("已经存在相同名称的项目")
user_id = get_user_id(request)
project_id = ps.add_project(info, session, user_id)
return rc.response_success(msg="新建成功", data=project_id)
@project.get("/info/{project_id}")
def get_project(project_id: int, session: Session = Depends(get_db)):
"""
根据项目id获取详情
:param project_id:
:param session:
:return:
"""
project_info = pic.get_project_by_id(project_id, session)
return rc.response_success(data=project_info.dict())
@project.get("/label_list/{project_id}")
def get_label_list(project_id: int, session: Session = Depends(get_db)):
"""
根据项目id查询项目标签列表
:param project_id:
:param session:
:return:
"""
label_list = plc.get_label_list(project_id, session)
return rc.response_success(msg="查询成功", data=label_list)
@project.post("/add_label")
def add_label(label: ProjectLabel, session: Session = Depends(get_db)):
"""
新增标签
:param label:
:param session:
:return:
"""
if plc.check_label_name(label.project_id, label.label_name, session):
return rc.response_error("标签名称已经存在,不能重复")
label_save = pl(**label.dict())
plc.add_label(label_save, session)
return rc.response_success(msg="保存成功")
@project.post("/up_label")
def up_label(label: ProjectLabel, session: Session = Depends(get_db)):
"""
修改标签
:param label:
:param session:
:return:
"""
if plc.check_label_name(label.project_id, label.label_name, session, label.id):
return rc.response_error("修改的标签名称已经存在,不能重复")
label_save = pl(**label.dict())
plc.update_label(label_save, session)
return rc.response_success(msg="修改成功")
@project.post("/del_label/{label_id}")
def del_label(label_id: int, session: Session = Depends(get_db)):
"""
删除标签
:param label_id:
:param session:
:return:
"""
row_del = plc.del_label(label_id, session)
if row_del > 0:
return rc.response_success(msg="删除成功")
else:
return rc.response_error("删除失败")
@project.post("/up_proj_img")
def upload_project_image(project_id: int = Form(...),
files: List[UploadFile] = File(...),
session: Session = Depends(get_db)):
"""
上传项目图片
:param files: 文件图片
:param project_id:
:param session:
:return:
"""
project_info = pic.get_project_by_id(project_id, session)
if project_info is None:
return rc.response_error("项目查询错误,请刷新页面后再试")
is_check, file_name = ps.check_image_name(project_id, files, session)
if not is_check:
return rc.response_error(msg="存在重名的图片文件:" + file_name)
ps.upload_project_image(project_info, files, session)
return rc.response_success(msg="上传成功")
@project.get("/img_list/{project_id}")
def get_image_list(project_id: int, session: Session = Depends(get_db)):
"""
获取项目图片列表
:param project_id: 项目id
:param session:
:return:
"""
image_list = pimc.get_image_list(project_id, session)
result = jsonable_encoder(image_list)
return rc.response_success(data=result)
@project.post("/save_img_label")
def save_img_label(img_leafer_label: ProjectImgLeaferLabel, session: Session = Depends(get_db)):
"""
保存图片的标签框选信息
:param img_leafer_label:
:param session:
:return:
"""
ps.save_img_label(img_leafer_label, session)
return rc.response_success(msg="保存成功")
@project.get("/get_img_leafer/{image_id}")
def get_img_leafer(image_id: int, session: Session = Depends(get_db)):
"""
根据图片id查询图片的leafer信息
:param image_id:
:param session:
:return:
"""
img_leafer_out = ps.get_img_leafer(image_id, session)
if img_leafer_out is None:
return rc.response_success()
else:
return rc.response_success(data=img_leafer_out['leafer'])
@project.get("/run_train/{project_id}")
async def run_train(project_id: int, session: Session = Depends(get_db)):
"""
执行项目训练方法
:param project_id:
:param session:
:return:
"""
project_info = pic.get_project_by_id(project_id, session)
if project_info is None:
return rc.response_error("项目查询错误")
if project_info.project_status == '1':
return rc.response_error("项目当前存在训练进程,请稍后再试")
data, project_name, name = ps.run_train_yolo(project_info, session)
return StreamingResponse(
ps.run_commend(data, project_name, name, 10, project_id, session),
media_type="text/plain")
@project.get("/get_train_list/{project_id}")
def get_train_list(project_id: int, session: Session = Depends(get_db)):
"""
根据项目id获取训练列表
:param project_id:
:param session:
:return:
"""
train_list = ptnc.get_train_list(project_id, session)
result = jsonable_encoder(train_list)
return rc.response_success(data=result)