Files
aicheckv2/app/api/business/project_train_api.py

295 lines
9.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from app.model.crud import project_type_crud as ptc
from app.model.crud import project_label_crud as plc
from app.model.crud import project_info_crud as pic
from app.model.crud import project_image_crud as pimc
from app.model.crud import project_train_crud as ptnc
from app.model.bussiness_model import ProjectLabel as pl
from app.model.schemas.project_label_schemas import ProjectLabel
from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoPager
from app.model.schemas.project_image_schemas import ProjectImgLeaferLabel, ProjectImagePager
from app.model.schemas.project_train_schemas import ProjectTrainIn
from app.common.jwt_check import get_user_id
from app.common import reponse_code as rc
from app.service import project_train_service as ps
from app.db.db_session import get_db
import threading
import asyncio
from typing import List
from fastapi import APIRouter, Depends, Request, UploadFile, File, Form
from fastapi.encoders import jsonable_encoder
from sqlalchemy.orm import Session
"""项目管理API"""
project = APIRouter()
@project.get("/types")
def get_type_list(session: Session = Depends(get_db)):
"""获取项目类别"""
type_list = ptc.get_list(session)
return rc.response_success(data=type_list)
@project.post("/list")
def project_pager(info: ProjectInfoPager, session: Session = Depends(get_db)):
"""
项目列表
:param info:
:param session:
:return:
"""
pager = pic.get_project_pager2(info, session)
return rc.response_success_pager(pager)
@project.post("/add")
def add_project(request: Request, info: ProjectInfoIn, session: Session = Depends(get_db)):
"""
新建项目
:param request:
:param info:
:param session:
:return:
"""
if pic.check_project_name(info.project_name, session):
return rc.response_error("已经存在相同名称的任务")
user_id = get_user_id(request)
project_id = ps.add_project(info, session, user_id)
return rc.response_success(msg="新建成功", data=project_id)
@project.get("/info/{project_id}")
def get_project(project_id: int, session: Session = Depends(get_db)):
"""
根据项目id获取详情
:param project_id:
:param session:
:return:
"""
project_info = pic.get_project_by_id(project_id, session)
return rc.response_success(data=project_info.dict())
@project.get("/del/{project_id}")
def del_project(project_id: int, session: Session = Depends(get_db)):
"""
删除项目,假删
:param project_id:
:param session:
:return:
"""
pic.del_project(project_id, session)
return rc.response_success(msg="删除成功")
@project.get("/label_list/{project_id}")
def get_label_list(project_id: int, session: Session = Depends(get_db)):
"""
根据项目id查询项目标签列表
:param project_id:
:param session:
:return:
"""
label_list = plc.get_label_list(project_id, session)
return rc.response_success(msg="查询成功", data=label_list)
@project.post("/add_label")
def add_label(label: ProjectLabel, session: Session = Depends(get_db)):
"""
新增标签
:param label:
:param session:
:return:
"""
if plc.check_label_name(label.project_id, label.label_name, session):
return rc.response_error("标签名称已经存在,不能重复")
label_save = pl(**label.dict())
plc.add_label(label_save, session)
return rc.response_success(msg="保存成功")
@project.post("/up_label")
def up_label(label: ProjectLabel, session: Session = Depends(get_db)):
"""
修改标签
:param label:
:param session:
:return:
"""
if plc.check_label_name(label.project_id, label.label_name, session, label.id):
return rc.response_error("修改的标签名称已经存在,不能重复")
label_save = pl(**label.dict())
plc.update_label(label_save, session)
return rc.response_success(msg="修改成功")
@project.post("/del_label/{label_id}")
def del_label(label_id: int, session: Session = Depends(get_db)):
"""
删除标签
:param label_id:
:param session:
:return:
"""
row_del = plc.del_label(label_id, session)
if row_del > 0:
return rc.response_success(msg="删除成功")
else:
return rc.response_error("删除失败")
@project.post("/up_proj_img")
def upload_project_image(project_id: int = Form(...),
files: List[UploadFile] = File(...),
img_type: str = Form(...),
session: Session = Depends(get_db)):
"""
上传项目图片
:param img_type:
:param files: 文件图片
:param project_id:
:param session:
:return:
"""
project_info = pic.get_project_by_id(project_id, session)
if project_info is None:
return rc.response_error("项目查询错误,请刷新页面后再试")
is_check, file_name = ps.check_image_name(project_id, img_type, files, session)
if not is_check:
return rc.response_error(msg="存在重名的图片文件:" + file_name)
ps.upload_project_image(project_info,img_type, files, session)
return rc.response_success(msg="上传成功")
@project.get("/del_img/{image_id}")
def del_image(image_id: int, session: Session = Depends(get_db)):
"""
删除图片
:param image_id:
:param session:
:return:
"""
ps.del_img(image_id, session)
return rc.response_success("删除成功")
@project.post("/img_list")
def get_image_list(image: ProjectImagePager, session: Session = Depends(get_db)):
"""
获取项目图片列表
:param image:
:param session:
:return:
"""
if image.pagerNum is None and image.pagerSize is None:
image_list = pimc.get_image_list(image.project_id, image.img_type, session)
result = jsonable_encoder(image_list)
return rc.response_success(data=result)
else:
pager = pimc.get_image_pager2(image, session)
return rc.response_success_pager(pager)
@project.post("/save_img_label")
def save_img_label(img_leafer_label: ProjectImgLeaferLabel, session: Session = Depends(get_db)):
"""
保存图片的标签框选信息
:param img_leafer_label:
:param session:
:return:
"""
ps.save_img_label(img_leafer_label, session)
return rc.response_success(msg="保存成功")
@project.get("/get_img_leafer/{image_id}")
def get_img_leafer(image_id: int, session: Session = Depends(get_db)):
"""
根据图片id查询图片的leafer信息
:param image_id:
:param session:
:return:
"""
img_leafer_out = ps.get_img_leafer(image_id, session)
if img_leafer_out is None:
return rc.response_success()
else:
return rc.response_success(data=img_leafer_out['leafer'])
@project.post("/run_train")
async def run_train(train_in: ProjectTrainIn, session: Session = Depends(get_db)):
"""
执行项目训练方法
:param train_in:
:param session:
:return:
"""
project_id = train_in.project_id
project_info = pic.get_project_by_id(project_id, session)
if project_info is None:
return rc.response_error("项目查询错误")
if project_info.project_status == '1':
return rc.response_error("项目当前存在训练进程,请稍后再试")
train_img_count = pimc.get_image_count(project_id, 'train', session)
if train_img_count == 0:
return rc.response_error("请先上传训练图片")
if train_img_count < 10:
return rc.response_error("训练图片少于10张请继续上传训练图片")
val_img_count = pimc.get_image_count(project_id, 'val', session)
if val_img_count == 0:
return rc.response_error("请先上传验证图片")
if val_img_count < 5:
return rc.response_error("验证图片少于5张请继续上传验证图片")
train_label_count = pimc.check_image_label(project_id, 'train', session)
if train_label_count > 0:
return rc.response_error("训练图片中存在未标注的图片")
val_label_count = pimc.check_image_label(project_id, 'val', session)
if val_label_count > 0:
return rc.response_error("验证图片中存在未标注的图片")
data, project, name = ps.run_train_yolo(project_info, train_in, session)
thread_train = threading.Thread(target=run_event_loop, args=(data, project, name, train_in,
project_id, session,))
thread_train.start()
return rc.response_success(msg="执行成功")
def run_event_loop(data: str, project: str, name: str, train_in: ProjectTrainIn,
project_id: int, session: Session):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 运行异步函数
loop.run_until_complete(ps.run_commend(data, project, name, train_in.epochs,
train_in.patience, train_in.weights_id, project_id, session))
# 可选: 关闭循环
loop.close()
@project.get("/get_train_list/{project_id}")
def get_train_list(project_id: int, session: Session = Depends(get_db)):
"""
根据项目id获取训练列表
:param project_id:
:param session:
:return:
"""
train_list = ptnc.get_train_list(project_id, session)
result = jsonable_encoder(train_list)
return rc.response_success(data=result)
@project.get("/get_train_report/{train_id}")
def get_train_report(train_id: int, session: Session = Depends(get_db)):
"""
查询训练报告
:param train_id:
:param session:
:return:
"""
result_row = ps.get_train_result(train_id, session)
if result_row is None:
return rc.response_error("查询失败")
return rc.response_success(data=result_row)