from app.model.crud import project_type_crud as ptc from app.model.crud import project_label_crud as plc from app.model.crud import project_info_crud as pic from app.model.crud import project_image_crud as pimc from app.model.crud import project_train_crud as ptnc from app.model.bussiness_model import ProjectLabel as pl from app.model.schemas.project_label_schemas import ProjectLabel from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoPager from app.model.schemas.project_image_schemas import ProjectImgLeaferLabel, ProjectImagePager from app.model.schemas.project_train_schemas import ProjectTrainIn from app.common.jwt_check import get_user_id from app.common import reponse_code as rc from app.service import project_train_service as ps from app.db.db_session import get_db import threading import asyncio from typing import List from fastapi import APIRouter, Depends, Request, UploadFile, File, Form from fastapi.encoders import jsonable_encoder from sqlalchemy.orm import Session """项目管理API""" project = APIRouter() @project.get("/types") def get_type_list(session: Session = Depends(get_db)): """获取项目类别""" type_list = ptc.get_list(session) return rc.response_success(data=type_list) @project.post("/list") def project_pager(info: ProjectInfoPager, session: Session = Depends(get_db)): """ 项目列表 :param info: :param session: :return: """ pager = pic.get_project_pager2(info, session) return rc.response_success_pager(pager) @project.post("/add") def add_project(request: Request, info: ProjectInfoIn, session: Session = Depends(get_db)): """ 新建项目 :param request: :param info: :param session: :return: """ if pic.check_project_name(info.project_name, session): return rc.response_error("已经存在相同名称的任务") user_id = get_user_id(request) project_id = ps.add_project(info, session, user_id) return rc.response_success(msg="新建成功", data=project_id) @project.get("/info/{project_id}") def get_project(project_id: int, session: Session = Depends(get_db)): """ 根据项目id获取详情 :param project_id: :param session: :return: """ project_info = pic.get_project_by_id(project_id, session) return rc.response_success(data=project_info.dict()) @project.get("/del/{project_id}") def del_project(project_id: int, session: Session = Depends(get_db)): """ 删除项目,假删 :param project_id: :param session: :return: """ pic.del_project(project_id, session) return rc.response_success(msg="删除成功") @project.get("/label_list/{project_id}") def get_label_list(project_id: int, session: Session = Depends(get_db)): """ 根据项目id查询项目标签列表 :param project_id: :param session: :return: """ label_list = plc.get_label_list(project_id, session) return rc.response_success(msg="查询成功", data=label_list) @project.post("/add_label") def add_label(label: ProjectLabel, session: Session = Depends(get_db)): """ 新增标签 :param label: :param session: :return: """ if plc.check_label_name(label.project_id, label.label_name, session): return rc.response_error("标签名称已经存在,不能重复") label_save = pl(**label.dict()) plc.add_label(label_save, session) return rc.response_success(msg="保存成功") @project.post("/up_label") def up_label(label: ProjectLabel, session: Session = Depends(get_db)): """ 修改标签 :param label: :param session: :return: """ if plc.check_label_name(label.project_id, label.label_name, session, label.id): return rc.response_error("修改的标签名称已经存在,不能重复") label_save = pl(**label.dict()) plc.update_label(label_save, session) return rc.response_success(msg="修改成功") @project.post("/del_label/{label_id}") def del_label(label_id: int, session: Session = Depends(get_db)): """ 删除标签 :param label_id: :param session: :return: """ row_del = plc.del_label(label_id, session) if row_del > 0: return rc.response_success(msg="删除成功") else: return rc.response_error("删除失败") @project.post("/up_proj_img") def upload_project_image(project_id: int = Form(...), files: List[UploadFile] = File(...), img_type: str = Form(...), session: Session = Depends(get_db)): """ 上传项目图片 :param img_type: :param files: 文件图片 :param project_id: :param session: :return: """ project_info = pic.get_project_by_id(project_id, session) if project_info is None: return rc.response_error("项目查询错误,请刷新页面后再试") is_check, file_name = ps.check_image_name(project_id, img_type, files, session) if not is_check: return rc.response_error(msg="存在重名的图片文件:" + file_name) ps.upload_project_image(project_info,img_type, files, session) return rc.response_success(msg="上传成功") @project.get("/del_img/{image_id}") def del_image(image_id: int, session: Session = Depends(get_db)): """ 删除图片 :param image_id: :param session: :return: """ ps.del_img(image_id, session) return rc.response_success("删除成功") @project.post("/img_list") def get_image_list(image: ProjectImagePager, session: Session = Depends(get_db)): """ 获取项目图片列表 :param image: :param session: :return: """ if image.pagerNum is None and image.pagerSize is None: image_list = pimc.get_image_list(image.project_id, image.img_type, session) result = jsonable_encoder(image_list) return rc.response_success(data=result) else: pager = pimc.get_image_pager2(image, session) return rc.response_success_pager(pager) @project.post("/save_img_label") def save_img_label(img_leafer_label: ProjectImgLeaferLabel, session: Session = Depends(get_db)): """ 保存图片的标签框选信息 :param img_leafer_label: :param session: :return: """ ps.save_img_label(img_leafer_label, session) return rc.response_success(msg="保存成功") @project.get("/get_img_leafer/{image_id}") def get_img_leafer(image_id: int, session: Session = Depends(get_db)): """ 根据图片id查询图片的leafer信息 :param image_id: :param session: :return: """ img_leafer_out = ps.get_img_leafer(image_id, session) if img_leafer_out is None: return rc.response_success() else: return rc.response_success(data=img_leafer_out['leafer']) @project.post("/run_train") async def run_train(train_in: ProjectTrainIn, session: Session = Depends(get_db)): """ 执行项目训练方法 :param train_in: :param session: :return: """ project_id = train_in.project_id project_info = pic.get_project_by_id(project_id, session) if project_info is None: return rc.response_error("项目查询错误") if project_info.project_status == '1': return rc.response_error("项目当前存在训练进程,请稍后再试") train_img_count = pimc.get_image_count(project_id, 'train', session) if train_img_count == 0: return rc.response_error("请先上传训练图片") if train_img_count < 10: return rc.response_error("训练图片少于10张,请继续上传训练图片") val_img_count = pimc.get_image_count(project_id, 'val', session) if val_img_count == 0: return rc.response_error("请先上传验证图片") if val_img_count < 5: return rc.response_error("验证图片少于5张,请继续上传验证图片") train_label_count = pimc.check_image_label(project_id, 'train', session) if train_label_count > 0: return rc.response_error("训练图片中存在未标注的图片") val_label_count = pimc.check_image_label(project_id, 'val', session) if val_label_count > 0: return rc.response_error("验证图片中存在未标注的图片") data, project, name = ps.run_train_yolo(project_info, train_in, session) thread_train = threading.Thread(target=run_event_loop, args=(data, project, name, train_in, project_id, session,)) thread_train.start() return rc.response_success(msg="执行成功") def run_event_loop(data: str, project: str, name: str, train_in: ProjectTrainIn, project_id: int, session: Session): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) # 运行异步函数 loop.run_until_complete(ps.run_commend(data, project, name, train_in.epochs, train_in.patience, train_in.weights_id, project_id, session)) # 可选: 关闭循环 loop.close() @project.get("/get_train_list/{project_id}") def get_train_list(project_id: int, session: Session = Depends(get_db)): """ 根据项目id,获取训练列表 :param project_id: :param session: :return: """ train_list = ptnc.get_train_list(project_id, session) result = jsonable_encoder(train_list) return rc.response_success(data=result) @project.get("/get_train_report/{train_id}") def get_train_report(train_id: int, session: Session = Depends(get_db)): """ 查询训练报告 :param train_id: :param session: :return: """ result_row = ps.get_train_result(train_id, session) if result_row is None: return rc.response_error("查询失败") return rc.response_success(data=result_row)