完成训练模块的转移
This commit is contained in:
@ -5,23 +5,19 @@
|
||||
# @File : crud.py
|
||||
# @IDE : PyCharm
|
||||
# @desc : 数据访问层
|
||||
import application.settings
|
||||
from . import schemas, models, params
|
||||
from apps.vadmin.auth.utils.validation.auth import Auth
|
||||
from utils import os_utils as os, random_utils as ru
|
||||
from utils.huawei_obs import ObsClient
|
||||
from utils import status
|
||||
from core.exception import CustomException
|
||||
if application.settings.DEBUG:
|
||||
from application.config.development import datasets_url, runs_url, images_url
|
||||
else:
|
||||
from application.config.production import datasets_url, runs_url, images_url
|
||||
from application.settings import datasets_url, runs_url, images_url
|
||||
|
||||
from typing import Any, List
|
||||
from core.crud import DalBase
|
||||
from fastapi import UploadFile
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func, case, and_
|
||||
from sqlalchemy import select, func, case
|
||||
|
||||
|
||||
class ProjectInfoDal(DalBase):
|
||||
@ -96,6 +92,9 @@ class ProjectInfoDal(DalBase):
|
||||
project: schemas.ProjectInfoIn,
|
||||
auth: Auth
|
||||
) -> Any:
|
||||
"""
|
||||
新建项目
|
||||
"""
|
||||
obj = self.model(**project.model_dump())
|
||||
obj.user_id = auth.user.id
|
||||
obj.project_no = ru.random_str(6)
|
||||
@ -106,7 +105,9 @@ class ProjectInfoDal(DalBase):
|
||||
obj.dept_id = 0
|
||||
else:
|
||||
obj.dept_id = auth.dept_ids[0]
|
||||
# 新建数据集文件夹
|
||||
os.create_folder(datasets_url, obj.project_no)
|
||||
# 新建训练文件夹
|
||||
os.create_folder(runs_url, obj.project_no)
|
||||
await self.flush(obj)
|
||||
return await self.out_dict(obj, None, False, schemas.ProjectInfoOut)
|
||||
@ -214,6 +215,55 @@ class ProjectImageDal(DalBase):
|
||||
ObsClient.del_objects(object_keys)
|
||||
await self.delete_datas(ids)
|
||||
|
||||
async def get_img_count(
|
||||
self,
|
||||
proj_id: int) -> int:
|
||||
"""
|
||||
查询图片数量
|
||||
"""
|
||||
train_count = await self.get_count(
|
||||
v_where=[models.ProjectImage.project_id == proj_id, models.ProjectImage.img_type == 'train'])
|
||||
val_count = await self.get_count(
|
||||
v_where=[models.ProjectImage.project_id == proj_id, models.ProjectImage.img_type == 'val'])
|
||||
return train_count, val_count
|
||||
|
||||
async def check_image_label(
|
||||
self,
|
||||
proj_id: int) -> int:
|
||||
"""
|
||||
查询图片未标注数量
|
||||
"""
|
||||
# 1 子查询
|
||||
subquery = (
|
||||
select(
|
||||
models.ProjectImgLabel.image_id,
|
||||
func.ifnull(func.count(models.ProjectImgLabel.id), 0).label('label_count')
|
||||
)
|
||||
.group_by(models.ProjectImgLabel.image_id)
|
||||
.subquery()
|
||||
)
|
||||
# 2 主查询
|
||||
query = (
|
||||
select(
|
||||
models.ProjectImage,
|
||||
func.ifnull(subquery.c.label_count, 0).label('label_count')
|
||||
)
|
||||
.outerjoin(subquery, models.ProjectImage.id == subquery.c.image_id)
|
||||
)
|
||||
train_count_sql = await self.filter_core(
|
||||
v_start_sql=query,
|
||||
v_where=[models.ProjectImage.project_id == proj_id, models.ProjectImage.img_type == 'train'],
|
||||
v_return_sql=True)
|
||||
train_count = await self.get_count(train_count_sql)
|
||||
|
||||
val_count_sql = await self.filter_core(
|
||||
v_start_sql=query,
|
||||
v_where=[models.ProjectImage.project_id == proj_id, models.ProjectImage.img_type == 'val'],
|
||||
v_return_sql=True)
|
||||
val_count = await self.get_count(val_count_sql)
|
||||
|
||||
return train_count, val_count
|
||||
|
||||
|
||||
class ProjectLabelDal(DalBase):
|
||||
"""
|
||||
@ -233,14 +283,27 @@ class ProjectLabelDal(DalBase):
|
||||
label_id: int = None
|
||||
):
|
||||
wheres = [
|
||||
models.ProjectLabel.project_id == pro_id,
|
||||
models.ProjectLabel.label_name == name
|
||||
self.model.project_id == pro_id,
|
||||
self.model.label_name == name
|
||||
]
|
||||
if label_id:
|
||||
wheres.append(models.ProjectLabel.id != label_id)
|
||||
wheres.append(self.model.id != label_id)
|
||||
count = await self.get_count(v_where=wheres)
|
||||
return count > 0
|
||||
|
||||
async def get_label_for_train(self, project_id: int):
|
||||
id_list = []
|
||||
name_list = []
|
||||
label_list = self.get_datas(
|
||||
v_where=[self.model.project_id == project_id],
|
||||
v_order='asc',
|
||||
v_order_field='id',
|
||||
v_return_count=False)
|
||||
for label in label_list:
|
||||
id_list.append(label.id)
|
||||
name_list.append(label.label_name)
|
||||
return id_list, name_list
|
||||
|
||||
|
||||
class ProjectImgLabelDal(DalBase):
|
||||
"""
|
||||
@ -260,6 +323,13 @@ class ProjectImgLabelDal(DalBase):
|
||||
img.image_id = image_id
|
||||
await self.create_datas(img_labels)
|
||||
|
||||
async def get_img_label_list(self, image_id: int):
|
||||
return await self.get_datas(
|
||||
v_return_count=False,
|
||||
v_where=[self.model.image_id == image_id],
|
||||
v_order="asc",
|
||||
v_order_field="id")
|
||||
|
||||
|
||||
class ProjectImgLeaferDal(DalBase):
|
||||
"""
|
||||
|
@ -15,14 +15,10 @@ class ProjectInfoParams(QueryParams):
|
||||
self,
|
||||
project_name: str | None = Query(None, title="项目名称"),
|
||||
type_code: str | None = Query(None, title="项目类别"),
|
||||
dept_id: str | None = Query(None, title="部门id"),
|
||||
user_id: str | None = Query(None, title="用户id"),
|
||||
params: Paging = Depends()
|
||||
):
|
||||
super().__init__(params)
|
||||
self.project_name = ("like", project_name)
|
||||
self.type_code = type_code
|
||||
self.dept_id = dept_id
|
||||
self.user_id = user_id
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user