from typing import List from fastapi import APIRouter, Depends, UploadFile, File, Form from fastapi.responses import StreamingResponse from fastapi.encoders import jsonable_encoder from sqlalchemy.orm import Session from app.common import reponse_code as rc from app.model.crud import project_detect_crud as pdc from app.service import project_detect_service as pds from app.model.crud.project_train_crud import get_train from app.model.schemas.project_detect_schemas import ProjectDetectPager, ProjectDetectIn,\ ProjectDetectImgPager, ProjectDetectLogIn, ProjectDetectLogPager from app.db.db_session import get_db detect = APIRouter() @detect.post("/detect_pager") def get_pager(detect_pager: ProjectDetectPager, session: Session = Depends(get_db)): """ 获取训练集合 :param detect_pager: :param session: :return: """ pager = pdc.get_detect_pager(detect_pager, session) return rc.response_success_pager(pager) @detect.get("/detect_list/{project_id}") def get_detect_list(project_id: int, session: Session = Depends(get_db)): """ 根据项目id获取全部推理集合 :param project_id: :param session: :return: """ return rc.response_success(data=pdc.get_detect_list(project_id, session)) @detect.post("/add_detect") def add_detect(detect_in: ProjectDetectIn, session: Session = Depends(get_db)): """ 新增训练集合 :param detect_in: :param session: :return: """ detect = pds.add_detect(detect_in, session) return rc.response_success(msg="新增成功", data=detect.id) @detect.post("/get_img_list") def get_img_list(detect_img_pager: ProjectDetectImgPager, session: Session = Depends(get_db)): """ 查询训练集合中的图片列表 :param detect_img_pager: :param session: :return: """ if detect_img_pager.pagerNum is None and detect_img_pager.pagerSize is None: img_list = pdc.get_img_list(detect_img_pager.detect_id, session) img_list = jsonable_encoder(img_list) return rc.response_success(data=img_list) else: pager = pdc.get_img_pager(detect_img_pager, session) return rc.response_success_pager(pager) @detect.post("/upload_detect_img") def upload_detect_img(detect_id: int = Form(...), files: List[UploadFile] = File(...), session: Session = Depends(get_db)): """ 上传训练集合中的照片 :param detect_id: :param files: :param session: :return: """ detect_out = pdc.get_detect_by_id(detect_id, session) if detect_out is None: return rc.response_error("训练集合查询失败,请刷新后再试") is_check, file_name = pds.check_image_name(detect_id, files, session) if not is_check: return rc.response_error(msg="存在重名的图片文件:" + file_name) pds.upload_detect_imgs(detect_out, files, session) return rc.response_success("上传成功") @detect.get("/del_detect_img/{detect_img_id}") def del_detect_img(detect_img_id: int, session: Session = Depends(get_db)): """ 删除训练集合照片 :param detect_img_id: :param session: :return: """ result = pds.del_detect_img(detect_img_id, session) if result > 0: return rc.response_success(msg="删除成功") else: return rc.response_error(msg="删除失败") @detect.post("/run_detect_yolo") def run_detect_yolo(detect_log_in: ProjectDetectLogIn, session: Session = Depends(get_db)): """ 开始执行训练 :param detect_log_in: :param session: :return: """ detect = pdc.get_detect_by_id(detect_log_in.detect_id, session) if detect is None: return rc.response_error("训练集合不存在") train = get_train(detect_log_in.train_id, session) if train is None: return rc.response_error("训练权重不存在") detect_log = pds.run_detect_yolo(detect_log_in, detect, train, session) return StreamingResponse(pds.run_commend(detect_log.pt_url, detect_log.folder_url, detect_log.detect_folder_url, detect_log.detect_version, detect_log.id, detect_log.detect_id, session), media_type="text/plain") @detect.post("/get_log_pager") def get_log_pager(detect_log_pager: ProjectDetectLogPager, session: Session = Depends(get_db)): """ 根据推理集合id获取推理记录 :param detect_log_pager: :param session: :return: """ result = pdc.get_log_pager(detect_log_pager, session) result = jsonable_encoder(result) return rc.response_success(data=result) @detect.get("/get_log_imgs/{log_id}") def get_log_imgs(log_id: int, session: Session = Depends(get_db)): """ 根据推理集合中的结果图片 :param log_id: :param session: :return: """ result = pdc.get_log_imgs(log_id, session) result = jsonable_encoder(result) return rc.response_success(data=result)