完成项目训练模块的接口测试
This commit is contained in:
@ -112,6 +112,12 @@ class ProjectInfoDal(DalBase):
|
||||
await self.flush(obj)
|
||||
return await self.out_dict(obj, None, False, schemas.ProjectInfoOut)
|
||||
|
||||
async def update_version(self, data_id):
|
||||
proj = await self.get_data(data_id)
|
||||
if proj:
|
||||
proj.train_version = proj.train_version + 1
|
||||
await self.put_data(data_id=data_id, data=proj)
|
||||
|
||||
|
||||
class ProjectImageDal(DalBase):
|
||||
"""
|
||||
@ -246,22 +252,26 @@ class ProjectImageDal(DalBase):
|
||||
# 2 主查询
|
||||
query = (
|
||||
select(
|
||||
models.ProjectImage,
|
||||
models.ProjectImage.id,
|
||||
func.ifnull(subquery.c.label_count, 0).label('label_count')
|
||||
)
|
||||
.outerjoin(subquery, models.ProjectImage.id == subquery.c.image_id)
|
||||
)
|
||||
train_count_sql = await self.filter_core(
|
||||
v_start_sql=query,
|
||||
v_where=[models.ProjectImage.project_id == proj_id, models.ProjectImage.img_type == 'train'],
|
||||
v_where=[models.ProjectImage.project_id == proj_id,
|
||||
models.ProjectImage.img_type == 'train',
|
||||
subquery.c.label_count == 0],
|
||||
v_return_sql=True)
|
||||
train_count = await self.get_count(train_count_sql)
|
||||
train_count = await self.get_count_sql(train_count_sql)
|
||||
|
||||
val_count_sql = await self.filter_core(
|
||||
v_start_sql=query,
|
||||
v_where=[models.ProjectImage.project_id == proj_id, models.ProjectImage.img_type == 'val'],
|
||||
v_where=[models.ProjectImage.project_id == proj_id,
|
||||
models.ProjectImage.img_type == 'val',
|
||||
subquery.c.label_count == 0],
|
||||
v_return_sql=True)
|
||||
val_count = await self.get_count(val_count_sql)
|
||||
val_count = await self.get_count_sql(val_count_sql)
|
||||
|
||||
return train_count, val_count
|
||||
|
||||
@ -295,14 +305,15 @@ class ProjectLabelDal(DalBase):
|
||||
async def get_label_for_train(self, project_id: int):
|
||||
id_list = []
|
||||
name_list = []
|
||||
label_list = self.get_datas(
|
||||
label_list = await self.get_datas(
|
||||
v_where=[self.model.project_id == project_id],
|
||||
limit=0,
|
||||
v_order='asc',
|
||||
v_order_field='id',
|
||||
v_return_count=False)
|
||||
for label in label_list:
|
||||
id_list.append(label.id)
|
||||
name_list.append(label.label_name)
|
||||
id_list.append(label['id'])
|
||||
name_list.append(label['label_name'])
|
||||
return id_list, name_list
|
||||
|
||||
|
||||
@ -326,10 +337,12 @@ class ProjectImgLabelDal(DalBase):
|
||||
|
||||
async def get_img_label_list(self, image_id: int):
|
||||
return await self.get_datas(
|
||||
limit=0,
|
||||
v_return_count=False,
|
||||
v_where=[self.model.image_id == image_id],
|
||||
v_order="asc",
|
||||
v_order_field="id")
|
||||
v_order_field="id",
|
||||
v_return_objs=True)
|
||||
|
||||
async def del_img_label(self, label_ids: list[int]):
|
||||
img_labels = self.get_datas(v_where=[self.model.label_id.in_(label_ids)])
|
||||
|
Reference in New Issue
Block a user