diff --git a/app/api/business/project_detect_api.py b/app/api/business/project_detect_api.py index 203c098..8cf3b6e 100644 --- a/app/api/business/project_detect_api.py +++ b/app/api/business/project_detect_api.py @@ -117,6 +117,9 @@ def run_detect_yolo(detect_log_in: ProjectDetectLogIn, session: Session = Depend train = get_train(detect_log_in.train_id, session) if train is None: return rc.response_error("训练权重不存在") + detect_img_count = pdc.check_detect_img(detect_log_in.detect_id, session) + if detect_img_count == 0: + return rc.response_error("推理集合中没有图片,请先到推理集合中上传图片") detect_log = pds.run_detect_yolo(detect_log_in, detect, train, session) thread_train = threading.Thread(target=run_event_loop, args=(detect_log.pt_url, detect_log.folder_url, diff --git a/app/api/business/project_train_api.py b/app/api/business/project_train_api.py index be3d89b..b651ecd 100644 --- a/app/api/business/project_train_api.py +++ b/app/api/business/project_train_api.py @@ -243,6 +243,12 @@ async def run_train(train_in: ProjectTrainIn, session: Session = Depends(get_db) return rc.response_error("请先上传验证图片") if val_img_count < 5: return rc.response_error("验证图片少于5张,请继续上传验证图片") + train_label_count = pimc.check_image_label(project_id, 'train', session) + if train_label_count > 0: + return rc.response_error("训练图片中存在未标注的图片") + val_label_count = pimc.check_image_label(project_id, 'val', session) + if val_label_count > 0: + return rc.response_error("验证图片中存在未标注的图片") data, project, name = ps.run_train_yolo(project_info, train_in, session) thread_train = threading.Thread(target=run_event_loop, args=(data, project, name, train_in, project_id, session,)) diff --git a/app/model/crud/project_detect_crud.py b/app/model/crud/project_detect_crud.py index 27af674..d38d6c0 100644 --- a/app/model/crud/project_detect_crud.py +++ b/app/model/crud/project_detect_crud.py @@ -110,6 +110,17 @@ def add_detect_imgs(detect_imgs: List[ProjectDetectImg], session: Session): session.commit() +def check_detect_img(detect_id: int, session: Session): + """ + 查询推理集合中图片的数量 + :param detect_id: + :param session: + :return: + """ + query = session.query(ProjectDetectImg).filter_by(detect_id=detect_id) + return query.count() + + def get_img_list(detect_id: int, session: Session): """ 获取训练集合中的图片列表 diff --git a/app/model/crud/project_image_crud.py b/app/model/crud/project_image_crud.py index e8ffbfd..ae115df 100644 --- a/app/model/crud/project_image_crud.py +++ b/app/model/crud/project_image_crud.py @@ -45,6 +45,36 @@ def get_image_pager2(image: ProjectImage, session: Session): return pager +def check_image_label(project_id: int, img_type: str, session: Session): + """ + 检验是否所有图片都标注 + :param project_id: + :param img_type: + :param session: + :return: + """ + # 1 子查询 + subquery = ( + session.query( + ProjectImgLabel.image_id, + func.ifnull(func.count(ProjectImgLabel.id), 0).label('label_count') + ) + .group_by(ProjectImgLabel.image_id) + .subquery() + ) + # 2 主查询 + query = ( + session.query( + piModel, + func.ifnull(subquery.c.label_count, 0).label('label_count') + ) + .outerjoin(subquery, piModel.id == subquery.c.image_id) + ) + query = query.filter(piModel.project_id == project_id)\ + .filter(piModel.img_type == img_type).filter(subquery.c.label_count.is_(None)) + return query.count() + + def check_img_name(project_id: int, img_type: str, file_name: str, session: Session): """ 根据项目id和文件名称进行查重 diff --git a/app/service/project_detect_service.py b/app/service/project_detect_service.py index e6cf170..6177263 100644 --- a/app/service/project_detect_service.py +++ b/app/service/project_detect_service.py @@ -120,7 +120,7 @@ def run_detect_yolo(detect_in: ProjectDetectLogIn, detect: ProjectDetect, train: return detect_log -def run_commend(weights: str, source: str, project: str, name: str, +async def run_commend(weights: str, source: str, project: str, name: str, log_id: int, detect_id: int, session: Session): """ 执行yolov5的推理