import datetime
import json
from typing import Any
from bson import ObjectId
from bson.errors import InvalidId
from bson.json_util import dumps
from fastapi.encoders import jsonable_encoder
from motor.motor_asyncio import AsyncIOMotorDatabase
from pymongo.results import InsertOneResult, UpdateResult
from core.exception import CustomException
from utils import status


class MongoManage:
    """
    mongodb 数据库管理器
    博客:https://www.cnblogs.com/aduner/p/13532504.html
    mongodb 官网:https://www.mongodb.com/docs/drivers/motor/
    motor 文档:https://motor.readthedocs.io/en/stable/
    """

    # 倒叙
    ORDER_FIELD = ["desc", "descending"]

    def __init__(
            self,
            db: AsyncIOMotorDatabase = None,
            collection: str = None,
            schema: Any = None,
            is_object_id: bool = True
    ):
        """
        初始化
        :param db:
        :param collection: 集合
        :param schema:
        :param is_object_id: _id 列是否为 ObjectId 格式
        """
        self.db = db
        self.collection = db[collection] if collection else None
        self.schema = schema
        self.is_object_id = is_object_id

    async def get_data(
            self,
            _id: str = None,
            v_return_none: bool = False,
            v_schema: Any = None,
            **kwargs
    ) -> dict | None:
        """
        获取单个数据,默认使用 ID 查询,否则使用关键词查询
        :param _id: 数据 ID
        :param v_return_none: 是否返回空 None,否则抛出异常,默认抛出异常
        :param v_schema: 指定使用的序列化对象
        """
        if _id and self.is_object_id:
            kwargs["_id"] = ObjectId(_id)
        params = self.filter_condition(**kwargs)
        data = await self.collection.find_one(params)
        if not data and v_return_none:
            return None
        elif not data:
            raise CustomException("查找失败,未查找到对应数据", code=status.HTTP_404_NOT_FOUND)
        elif data and v_schema:
            return jsonable_encoder(v_schema(**data))
        return data

    async def create_data(self, data: dict | Any) -> InsertOneResult:
        """
        创建数据
        """
        if not isinstance(data, dict):
            data = jsonable_encoder(data)
        data['create_datetime'] = datetime.datetime.now()
        data['update_datetime'] = datetime.datetime.now()
        result = await self.collection.insert_one(data)
        # 判断插入是否成功
        if result.acknowledged:
            return result
        else:
            raise CustomException("创建新数据失败", code=status.HTTP_ERROR)

    async def put_data(self, _id: str, data: dict | Any) -> UpdateResult:
        """
        更新数据
        """
        if not isinstance(data, dict):
            data = jsonable_encoder(data)
        new_data = {'$set': data}
        result = await self.collection.update_one({'_id': ObjectId(_id) if self.is_object_id else _id}, new_data)

        if result.matched_count > 0:
            return result
        else:
            raise CustomException("更新失败,未查找到对应数据", code=status.HTTP_404_NOT_FOUND)

    async def delete_data(self, _id: str):
        """
        删除数据
        """
        result = await self.collection.delete_one({'_id': ObjectId(_id) if self.is_object_id else _id})

        if result.deleted_count > 0:
            return True
        else:
            raise CustomException("删除失败,未查找到对应数据", code=status.HTTP_404_NOT_FOUND)

    async def get_datas(
            self,
            page: int = 1,
            limit: int = 10,
            v_schema: Any = None,
            v_order: str = None,
            v_order_field: str = None,
            v_return_objs: bool = False,
            **kwargs
    ):
        """
        使用 find() 要查询的一组文档。 find() 没有I / O,也不需要 await 表达式。它只是创建一个 AsyncIOMotorCursor 实例
        当您调用 to_list() 或为循环执行异步时 (async for) ,查询实际上是在服务器上执行的。
        """

        params = self.filter_condition(**kwargs)
        cursor = self.collection.find(params)

        if v_order or v_order_field:
            v_order_field = v_order_field if v_order_field else 'create_datetime'
            v_order = -1 if v_order in self.ORDER_FIELD else 1
            cursor.sort(v_order_field, v_order)

        if limit != 0:
            # 对查询应用排序(sort),跳过(skip)或限制(limit)
            cursor.skip((page - 1) * limit).limit(limit)

        datas = []
        async for row in cursor:
            data = json.loads(dumps(row))
            datas.append(data)

        if not datas or v_return_objs:
            return datas
        elif v_schema:
            datas = [jsonable_encoder(v_schema(**data)) for data in datas]
        elif self.schema:
            datas = [jsonable_encoder(self.schema(**data)) for data in datas]
        return datas

    async def get_count(self, **kwargs) -> int:
        """
        获取统计数据
        """
        params = self.filter_condition(**kwargs)
        return await self.collection.count_documents(params)

    @classmethod
    def filter_condition(cls, **kwargs):
        """
        过滤条件
        """
        params = {}
        for k, v in kwargs.items():
            if not v:
                continue
            elif isinstance(v, tuple):
                if v[0] == "like" and v[1]:
                    params[k] = {'$regex': v[1]}
                elif v[0] == "between" and len(v[1]) == 2:
                    params[k] = {'$gte': f"{v[1][0]} 00:00:00", '$lt': f"{v[1][1]} 23:59:59"}
                elif v[0] == "ObjectId" and v[1]:
                    try:
                        params[k] = ObjectId(v[1])
                    except InvalidId:
                        raise CustomException("任务编号格式不正确!")
            else:
                params[k] = v
        return params