#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version        : 1.0
# @Create Time    : 2025/04/03 10:25
# @File           : crud.py
# @IDE            : PyCharm
# @desc           : 数据访问层
from . import schemas, models, params
from apps.vadmin.auth.utils.validation.auth import Auth
from utils import os_utils as os, random_utils as ru
from utils.huawei_obs import MyObs
from utils import status
from core.exception import CustomException
from application.settings import datasets_url, runs_url, images_url

from typing import Any, List
from core.crud import DalBase
from fastapi import UploadFile
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func, case


class ProjectInfoDal(DalBase):
    """
    项目信息
    """

    def __init__(self, db: AsyncSession):
        super(ProjectInfoDal, self).__init__()
        self.db = db
        self.model = models.ProjectInfo
        self.schema = schemas.ProjectInfoOut

    async def get_project_pager(self, project: params.ProjectInfoParams, auth: Auth):
        """
        分页查询项目列表
        """
        # 定义子查询
        subquery = (
            select(
                models.ProjectImage.project_id,
                func.sum(case((models.ProjectImgLeafer.id.is_(None), 1), else_=0)).label('no_mark_count'),
                func.sum(case((models.ProjectImgLeafer.id.isnot(None), 1), else_=0)).label('mark_count')
            )
            .outerjoin(models.ProjectImgLeafer, models.ProjectImage.id == models.ProjectImgLeafer.image_id)
            .group_by(models.ProjectImage.project_id)
            .subquery()
        )

        full_query = select(
            models.ProjectInfo,
            func.ifnull(subquery.c.mark_count, 0).label("mark_count"),
            func.ifnull(subquery.c.no_mark_count, 0).label("no_mark_count")
        ).select_from(models.ProjectInfo).join(
            subquery, models.ProjectInfo.id == subquery.c.project_id, isouter=True
        )
        v_where = [models.ProjectInfo.is_delete.is_(False)]
        if '*' in auth.dept_ids:
            v_where.append(models.ProjectInfo.dept_id.isnot(None))
        else:
            v_where.append(models.ProjectInfo.dept_id.in_(auth.dept_ids))
        sql = await self.filter_core(
            v_start_sql=full_query,
            v_where=v_where,
            v_return_sql=True,
            v_order=project.v_order,
            v_order_field=project.v_order_field
        )
        count = await self.get_count_sql(sql)
        if project.limit != 0:
            sql = sql.offset((project.page - 1) * project.limit).limit(project.limit)
        queryset = await self.db.execute(sql)
        result = queryset.all()
        datas = []
        for result in result:
            data = schemas.ProjectInfoPagerOut.model_validate(result[0])
            data.mark_count = int(result[1])
            data.no_mark_count = int(result[2])
            datas.append(data.model_dump())
        return datas, count

    async def check_name(self, project_name: str):
        """
        校验项目名称是否重名
        """
        count = await self.get_count(v_where=[models.ProjectInfo.project_name == project_name,
                                              models.ProjectInfo.is_delete is False])
        return count > 0

    async def add_project(
            self,
            project: schemas.ProjectInfoIn,
            auth: Auth
    ) -> Any:
        """
        新建项目
        """
        obj = self.model(**project.model_dump())
        obj.user_id = auth.user.id
        obj.project_no = ru.random_str(6)
        obj.project_status = "0"
        obj.train_version = 0
        obj.user_id = auth.user.id
        if '*' in auth.dept_ids:
            obj.dept_id = 0
        else:
            obj.dept_id = auth.dept_ids[0]
        # 新建数据集文件夹
        os.create_folder(datasets_url, obj.project_no)
        # 新建训练文件夹
        os.create_folder(runs_url, obj.project_no)
        await self.flush(obj)
        return await self.out_dict(obj, None, False, schemas.ProjectInfoOut)

    async def update_version(self, data_id):
        proj = await self.get_data(data_id)
        if proj:
            proj.train_version = proj.train_version + 1
            await self.put_data(data_id=data_id, data=proj)


class ProjectImageDal(DalBase):
    """
    项目图片
    """

    def __init__(self, db: AsyncSession):
        super(ProjectImageDal, self).__init__()
        self.db = db
        self.model = models.ProjectImage
        self.schema = schemas.ProjectImageOut

    async def img_page(self, param: params.ProjectImageParams):
        """
        分页查询图片信息,然后关联一个图片的标签数量
        """
        subquery = (
            select(
                models.ProjectImgLabel.image_id,
                func.ifnull(func.count(models.ProjectImgLabel.id), 0).label('label_count')
            )
            .group_by(models.ProjectImgLabel.image_id)
            .subquery()
        )
        # 2 主查询
        query = (
            select(
                models.ProjectImage,
                func.ifnull(subquery.c.label_count, 0).label('label_count')
            )
            .outerjoin(subquery, models.ProjectImage.id == subquery.c.image_id)
        )
        v_where = [models.ProjectImage.project_id == param.project_id, models.ProjectImage.img_type == param.img_type]
        sql = await self.filter_core(
            v_start_sql=query,
            v_where=v_where,
            v_return_sql=True,
            v_order=param.v_order,
            v_order_field=param.v_order_field
        )
        count = await self.get_count_sql(sql)
        if param.limit != 0:
            sql = sql.offset((param.page - 1) * param.limit).limit(param.limit)
        queryset = await self.db.execute(sql)
        result = queryset.all()
        datas = []
        for result in result:
            data = schemas.ProjectImageOut.model_validate(result[0])
            data.label_count = int(result[1])
            datas.append(data.model_dump())
        return datas, count

    async def upload_imgs(self, files: List[UploadFile], pro: schemas.ProjectInfoOut, img_type: str) -> int:
        """
        上传项目图片
        """
        image_models = []
        obs = MyObs()
        for file in files:
            image = models.ProjectImage()
            image.project_id = pro.id
            image.file_name = file.filename
            image.img_type = img_type
            # 保存原图
            path = os.save_images(images_url, pro.project_no, file=file)
            image.image_url = path
            # 上传图片到obs
            object_key = pro.project_no + '/' + img_type + '/' + file.filename
            success, key, url = obs.put_file(object_key=object_key, file_path=path)
            if success:
                image.object_key = object_key
                image.thumb_image_url = url
            else:
                raise CustomException("obs上传失败", code=status.HTTP_ERROR)
            image_models.append(image)
        await self.create_models(datas=image_models)
        return len(image_models)

    async def check_img_name(self, file_name: str, project_id: int, img_type: str):
        """
        校验相同的项目,相同的文件类型是否有同名的文件
        """
        count = await self.get_count(v_where=[
            models.ProjectImage.file_name == file_name,
            models.ProjectImage.project_id == project_id,
            models.ProjectImage.img_type == img_type
        ])
        return count > 0

    async def del_img(self, ids: List[int]):
        """
        删除图片,删除数据库数据,删除本地的文件,删除obs中的文件
        """
        file_urls = []
        object_keys = []
        for img_id in ids:
            image = await self.get_data(data_id=img_id)
            if image:
                file_urls.append(image.image_url)
                object_keys.append(image.object_key)
        os.delete_file_if_exists(*file_urls)
        MyObs().del_objects(object_keys)
        await self.delete_datas(ids)

    async def get_img_count(
            self,
            proj_id: int) -> int:
        """
        查询图片数量
        """
        train_count = await self.get_count(
            v_where=[models.ProjectImage.project_id == proj_id, models.ProjectImage.img_type == 'train'])
        val_count = await self.get_count(
            v_where=[models.ProjectImage.project_id == proj_id, models.ProjectImage.img_type == 'val'])
        return train_count, val_count

    async def check_image_label(
            self,
            proj_id: int) -> int:
        """
        查询图片未标注数量
        """
        # 1 子查询
        subquery = (
            select(
                models.ProjectImgLabel.image_id,
                func.ifnull(func.count(models.ProjectImgLabel.id), 0).label('label_count')
            )
            .group_by(models.ProjectImgLabel.image_id)
            .subquery()
        )
        # 2 主查询
        query = (
            select(
                models.ProjectImage.id,
                func.ifnull(subquery.c.label_count, 0).label('label_count')
            )
            .outerjoin(subquery, models.ProjectImage.id == subquery.c.image_id)
        )
        train_count_sql = await self.filter_core(
            v_start_sql=query,
            v_where=[models.ProjectImage.project_id == proj_id,
                     models.ProjectImage.img_type == 'train',
                     subquery.c.label_count == 0],
            v_return_sql=True)
        train_count = await self.get_count_sql(train_count_sql)

        val_count_sql = await self.filter_core(
            v_start_sql=query,
            v_where=[models.ProjectImage.project_id == proj_id,
                     models.ProjectImage.img_type == 'val',
                     subquery.c.label_count == 0],
            v_return_sql=True)
        val_count = await self.get_count_sql(val_count_sql)

        return train_count, val_count


class ProjectLabelDal(DalBase):
    """
    项目标签
    """

    def __init__(self, db: AsyncSession):
        super(ProjectLabelDal, self).__init__()
        self.db = db
        self.model = models.ProjectLabel
        self.schema = schemas.ProjectLabel

    async def check_label_name(
            self,
            name: str,
            pro_id: int,
            label_id: int = None
    ):
        wheres = [
            self.model.project_id == pro_id,
            self.model.label_name == name
        ]
        if label_id:
            wheres.append(self.model.id != label_id)
        count = await self.get_count(v_where=wheres)
        return count > 0

    async def get_label_for_train(self, project_id: int):
        id_list = []
        name_list = []
        label_list = await self.get_datas(
            v_where=[self.model.project_id == project_id],
            limit=0,
            v_order='asc',
            v_order_field='id',
            v_return_count=False)
        for label in label_list:
            id_list.append(label['id'])
            name_list.append(label['label_name'])
        return id_list, name_list


class ProjectImgLabelDal(DalBase):
    """
    图片标签信息
    """
    def __init__(self, db: AsyncSession):
        super(ProjectImgLabelDal, self).__init__()
        self.db = db
        self.model = models.ProjectImgLabel

    async def add_img_label(self, img_label_in: schemas.ProjectImgLeaferLabel):
        # 先把历史数据都删掉,然后再保存
        image_id = img_label_in.image_id
        await self.delete_datas(image_id=image_id)
        img_labels = [self.model(**i.model_dump()) for i in img_label_in.label_infos]
        for img in img_labels:
            img.image_id = image_id
        await self.create_datas(img_labels)

    async def get_img_label_list(self, image_id: int):
        return await self.get_datas(
            limit=0,
            v_return_count=False,
            v_where=[self.model.image_id == image_id],
            v_order="asc",
            v_order_field="id",
            v_return_objs=True)

    async def del_img_label(self, label_ids: list[int]):
        img_labels = self.get_datas(v_where=[self.model.label_id.in_(label_ids)])
        img_label_ids = [i.id for i in img_labels]
        self.delete_datas(ids=img_label_ids)


class ProjectImgLeaferDal(DalBase):
    """
    图片标注信息-leafer.js
    """

    def __init__(self, db: AsyncSession):
        super(ProjectImgLeaferDal, self).__init__()
        self.db = db
        self.model = models.ProjectImgLeafer
        self.schema = schemas.ProjectImgLeaferOut

    async def get_leafer(self, image_id: int):
        img_label = self.get_data(v_where=[self.model.image_id == image_id])
        return img_label.leafer

    async def add_leafer(self, img_label_in: schemas.ProjectImgLeaferLabel):
        # 先把历史数据都删掉,然后再保存
        image_id = img_label_in.image_id
        await self.delete_datas(image_id=image_id)
        await self.create_data(data=self.model(**img_label_in.model_dump()))