From 76f4a5ecd965c78122ffb9f861824adb2d6c35e4 Mon Sep 17 00:00:00 2001 From: sunyugang Date: Wed, 12 Mar 2025 09:39:50 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=AF=B9=E8=AE=AD=E7=BB=83?= =?UTF-8?q?=E5=92=8C=E6=8E=A8=E7=90=86=E4=B8=AD=E5=9B=BE=E7=89=87=E6=83=85?= =?UTF-8?q?=E5=86=B5=E7=9A=84=E5=88=A4=E6=96=AD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/business/project_detect_api.py | 3 +++ app/api/business/project_train_api.py | 6 ++++++ app/model/crud/project_detect_crud.py | 11 ++++++++++ app/model/crud/project_image_crud.py | 30 ++++++++++++++++++++++++++ app/service/project_detect_service.py | 2 +- 5 files changed, 51 insertions(+), 1 deletion(-) diff --git a/app/api/business/project_detect_api.py b/app/api/business/project_detect_api.py index 203c098..8cf3b6e 100644 --- a/app/api/business/project_detect_api.py +++ b/app/api/business/project_detect_api.py @@ -117,6 +117,9 @@ def run_detect_yolo(detect_log_in: ProjectDetectLogIn, session: Session = Depend train = get_train(detect_log_in.train_id, session) if train is None: return rc.response_error("训练权重不存在") + detect_img_count = pdc.check_detect_img(detect_log_in.detect_id, session) + if detect_img_count == 0: + return rc.response_error("推理集合中没有图片,请先到推理集合中上传图片") detect_log = pds.run_detect_yolo(detect_log_in, detect, train, session) thread_train = threading.Thread(target=run_event_loop, args=(detect_log.pt_url, detect_log.folder_url, diff --git a/app/api/business/project_train_api.py b/app/api/business/project_train_api.py index be3d89b..b651ecd 100644 --- a/app/api/business/project_train_api.py +++ b/app/api/business/project_train_api.py @@ -243,6 +243,12 @@ async def run_train(train_in: ProjectTrainIn, session: Session = Depends(get_db) return rc.response_error("请先上传验证图片") if val_img_count < 5: return rc.response_error("验证图片少于5张,请继续上传验证图片") + train_label_count = pimc.check_image_label(project_id, 'train', session) + if train_label_count > 0: + return rc.response_error("训练图片中存在未标注的图片") + val_label_count = pimc.check_image_label(project_id, 'val', session) + if val_label_count > 0: + return rc.response_error("验证图片中存在未标注的图片") data, project, name = ps.run_train_yolo(project_info, train_in, session) thread_train = threading.Thread(target=run_event_loop, args=(data, project, name, train_in, project_id, session,)) diff --git a/app/model/crud/project_detect_crud.py b/app/model/crud/project_detect_crud.py index 27af674..d38d6c0 100644 --- a/app/model/crud/project_detect_crud.py +++ b/app/model/crud/project_detect_crud.py @@ -110,6 +110,17 @@ def add_detect_imgs(detect_imgs: List[ProjectDetectImg], session: Session): session.commit() +def check_detect_img(detect_id: int, session: Session): + """ + 查询推理集合中图片的数量 + :param detect_id: + :param session: + :return: + """ + query = session.query(ProjectDetectImg).filter_by(detect_id=detect_id) + return query.count() + + def get_img_list(detect_id: int, session: Session): """ 获取训练集合中的图片列表 diff --git a/app/model/crud/project_image_crud.py b/app/model/crud/project_image_crud.py index e8ffbfd..ae115df 100644 --- a/app/model/crud/project_image_crud.py +++ b/app/model/crud/project_image_crud.py @@ -45,6 +45,36 @@ def get_image_pager2(image: ProjectImage, session: Session): return pager +def check_image_label(project_id: int, img_type: str, session: Session): + """ + 检验是否所有图片都标注 + :param project_id: + :param img_type: + :param session: + :return: + """ + # 1 子查询 + subquery = ( + session.query( + ProjectImgLabel.image_id, + func.ifnull(func.count(ProjectImgLabel.id), 0).label('label_count') + ) + .group_by(ProjectImgLabel.image_id) + .subquery() + ) + # 2 主查询 + query = ( + session.query( + piModel, + func.ifnull(subquery.c.label_count, 0).label('label_count') + ) + .outerjoin(subquery, piModel.id == subquery.c.image_id) + ) + query = query.filter(piModel.project_id == project_id)\ + .filter(piModel.img_type == img_type).filter(subquery.c.label_count.is_(None)) + return query.count() + + def check_img_name(project_id: int, img_type: str, file_name: str, session: Session): """ 根据项目id和文件名称进行查重 diff --git a/app/service/project_detect_service.py b/app/service/project_detect_service.py index e6cf170..6177263 100644 --- a/app/service/project_detect_service.py +++ b/app/service/project_detect_service.py @@ -120,7 +120,7 @@ def run_detect_yolo(detect_in: ProjectDetectLogIn, detect: ProjectDetect, train: return detect_log -def run_commend(weights: str, source: str, project: str, name: str, +async def run_commend(weights: str, source: str, project: str, name: str, log_id: int, detect_id: int, session: Session): """ 执行yolov5的推理