增加对训练和推理中图片情况的判断

This commit is contained in:
2025-03-12 09:39:50 +08:00
parent f49a6caf10
commit 76f4a5ecd9
5 changed files with 51 additions and 1 deletions

View File

@ -117,6 +117,9 @@ def run_detect_yolo(detect_log_in: ProjectDetectLogIn, session: Session = Depend
train = get_train(detect_log_in.train_id, session)
if train is None:
return rc.response_error("训练权重不存在")
detect_img_count = pdc.check_detect_img(detect_log_in.detect_id, session)
if detect_img_count == 0:
return rc.response_error("推理集合中没有图片,请先到推理集合中上传图片")
detect_log = pds.run_detect_yolo(detect_log_in, detect, train, session)
thread_train = threading.Thread(target=run_event_loop, args=(detect_log.pt_url,
detect_log.folder_url,

View File

@ -243,6 +243,12 @@ async def run_train(train_in: ProjectTrainIn, session: Session = Depends(get_db)
return rc.response_error("请先上传验证图片")
if val_img_count < 5:
return rc.response_error("验证图片少于5张请继续上传验证图片")
train_label_count = pimc.check_image_label(project_id, 'train', session)
if train_label_count > 0:
return rc.response_error("训练图片中存在未标注的图片")
val_label_count = pimc.check_image_label(project_id, 'val', session)
if val_label_count > 0:
return rc.response_error("验证图片中存在未标注的图片")
data, project, name = ps.run_train_yolo(project_info, train_in, session)
thread_train = threading.Thread(target=run_event_loop, args=(data, project, name, train_in,
project_id, session,))

View File

@ -110,6 +110,17 @@ def add_detect_imgs(detect_imgs: List[ProjectDetectImg], session: Session):
session.commit()
def check_detect_img(detect_id: int, session: Session):
"""
查询推理集合中图片的数量
:param detect_id:
:param session:
:return:
"""
query = session.query(ProjectDetectImg).filter_by(detect_id=detect_id)
return query.count()
def get_img_list(detect_id: int, session: Session):
"""
获取训练集合中的图片列表

View File

@ -45,6 +45,36 @@ def get_image_pager2(image: ProjectImage, session: Session):
return pager
def check_image_label(project_id: int, img_type: str, session: Session):
"""
检验是否所有图片都标注
:param project_id:
:param img_type:
:param session:
:return:
"""
# 1 子查询
subquery = (
session.query(
ProjectImgLabel.image_id,
func.ifnull(func.count(ProjectImgLabel.id), 0).label('label_count')
)
.group_by(ProjectImgLabel.image_id)
.subquery()
)
# 2 主查询
query = (
session.query(
piModel,
func.ifnull(subquery.c.label_count, 0).label('label_count')
)
.outerjoin(subquery, piModel.id == subquery.c.image_id)
)
query = query.filter(piModel.project_id == project_id)\
.filter(piModel.img_type == img_type).filter(subquery.c.label_count.is_(None))
return query.count()
def check_img_name(project_id: int, img_type: str, file_name: str, session: Session):
"""
根据项目id和文件名称进行查重

View File

@ -120,7 +120,7 @@ def run_detect_yolo(detect_in: ProjectDetectLogIn, detect: ProjectDetect, train:
return detect_log
def run_commend(weights: str, source: str, project: str, name: str,
async def run_commend(weights: str, source: str, project: str, name: str,
log_id: int, detect_id: int, session: Session):
"""
执行yolov5的推理