#!/usr/bin/python # -*- coding: utf-8 -*- # @version : 1.0 # @Create Time : 2025/04/03 10:25 # @File : crud.py # @IDE : PyCharm # @desc : 数据访问层 import application.settings from . import schemas, models, params from apps.vadmin.auth.utils.validation.auth import Auth from utils import os_utils as os, random_utils as ru if application.settings.DEBUG: from application.config.development import datasets_url, runs_url, detect_url, yolo_url, images_url else: from application.config.production import datasets_url, runs_url, detect_url, yolo_url, images_url from typing import Any, List from core.crud import DalBase from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, func, case, and_ class ProjectInfoDal(DalBase): def __init__(self, db: AsyncSession): super(ProjectInfoDal, self).__init__() self.db = db self.model = models.ProjectInfo self.schema = schemas.ProjectInfoOut async def get_project_pager(self, project: params.ProjectInfoParams, auth: Auth): """ 分页查询项目列表 """ # 定义子查询 subquery = ( select( models.ProjectImage.project_id, func.sum(case((models.ProjectImgLeafer.id.is_(None), 1), else_=0)).label('no_mark_count'), func.sum(case((models.ProjectImgLeafer.id.isnot(None), 1), else_=0)).label('mark_count') ) .outerjoin(models.ProjectImgLeafer, models.ProjectImage.id == models.ProjectImgLeafer.image_id) .group_by(models.ProjectImage.project_id) .subquery() ) full_query = select( models.ProjectInfo, func.ifnull(subquery.c.mark_count, 0).label("mark_count"), func.ifnull(subquery.c.no_mark_count, 0).label("no_mark_count") ).select_from(models.ProjectInfo).join( subquery, models.ProjectInfo.id == subquery.c.project_id, isouter=True ) v_where = [models.ProjectInfo.is_delete.is_(False)] if '*' in auth.dept_ids: v_where.append(models.ProjectInfo.dept_id.isnot(None)) else: v_where.append(models.ProjectInfo.dept_id.in_(auth.dept_ids)) sql = await self.filter_core( v_start_sql=full_query, v_where=v_where, v_return_sql=True, v_order=project.v_order, v_order_field=project.v_order_field ) count = await self.get_count_sql(sql) if project.limit != 0: sql = sql.offset((project.page - 1) * project.limit).limit(project.limit) queryset = await self.db.execute(sql) result = queryset.all() datas = [] for result in result: data = schemas.ProjectInfoPagerOut.model_validate(result[0]) data.mark_count = int(result[1]) data.no_mark_count = int(result[2]) datas.append(data.model_dump()) return datas, count async def check_name(self, project_name: str): """ 校验项目名称是否重名 """ count = await self.get_count(v_where=[models.ProjectInfo.project_name == project_name, models.ProjectInfo.is_delete is False]) return count > 0 async def add_project( self, project: schemas.ProjectInfoIn, auth: Auth ) -> Any: obj = self.model(**project.model_dump()) obj.user_id = auth.user.id obj.project_no = ru.random_str(6) obj.project_status = "0" obj.train_version = 0 obj.user_id = auth.user.id if '*' in auth.dept_ids: obj.dept_id = 0 else: obj.dept_id = auth.dept_ids[0] os.create_folder(datasets_url, obj.project_no) os.create_folder(runs_url, obj.project_no) await self.flush(obj) return await self.out_dict(obj, None, False, schemas.ProjectInfoOut) class ProjectImageDal(DalBase): def __init__(self, db: AsyncSession): super(ProjectImageDal, self).__init__() self.db = db self.model = models.ProjectImage self.schema = schemas.ProjectImageSimpleOut class ProjectLabelDal(DalBase): def __init__(self, db: AsyncSession): super(ProjectLabelDal, self).__init__() self.db = db self.model = models.ProjectLabel self.schema = schemas.ProjectLabel async def check_label_name( self, name: str, pro_id: int, label_id: int = None ): wheres = [ models.ProjectLabel.project_id == pro_id, models.ProjectLabel.label_name == name ] if label_id: wheres.append(models.ProjectLabel.id != label_id) count = await self.get_count(v_where=wheres) return count > 0 class ProjectImgLabelDal(DalBase): def __init__(self, db: AsyncSession): super(ProjectImgLabelDal, self).__init__() self.db = db self.model = models.ProjectImgLabel self.schema = schemas.ProjectImgLabelSimpleOut class ProjectImgLeaferDal(DalBase): def __init__(self, db: AsyncSession): super(ProjectImgLeaferDal, self).__init__() self.db = db self.model = models.ProjectImgLeafer self.schema = schemas.ProjectImgLeaferSimpleOut