项目初次提交
This commit is contained in:
177
core/mongo_manage.py
Normal file
177
core/mongo_manage.py
Normal file
@ -0,0 +1,177 @@
|
||||
import datetime
|
||||
import json
|
||||
from typing import Any
|
||||
from bson import ObjectId
|
||||
from bson.errors import InvalidId
|
||||
from bson.json_util import dumps
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from motor.motor_asyncio import AsyncIOMotorDatabase
|
||||
from pymongo.results import InsertOneResult, UpdateResult
|
||||
from core.exception import CustomException
|
||||
from utils import status
|
||||
|
||||
|
||||
class MongoManage:
|
||||
"""
|
||||
mongodb 数据库管理器
|
||||
博客:https://www.cnblogs.com/aduner/p/13532504.html
|
||||
mongodb 官网:https://www.mongodb.com/docs/drivers/motor/
|
||||
motor 文档:https://motor.readthedocs.io/en/stable/
|
||||
"""
|
||||
|
||||
# 倒叙
|
||||
ORDER_FIELD = ["desc", "descending"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db: AsyncIOMotorDatabase = None,
|
||||
collection: str = None,
|
||||
schema: Any = None,
|
||||
is_object_id: bool = True
|
||||
):
|
||||
"""
|
||||
初始化
|
||||
:param db:
|
||||
:param collection: 集合
|
||||
:param schema:
|
||||
:param is_object_id: _id 列是否为 ObjectId 格式
|
||||
"""
|
||||
self.db = db
|
||||
self.collection = db[collection] if collection else None
|
||||
self.schema = schema
|
||||
self.is_object_id = is_object_id
|
||||
|
||||
async def get_data(
|
||||
self,
|
||||
_id: str = None,
|
||||
v_return_none: bool = False,
|
||||
v_schema: Any = None,
|
||||
**kwargs
|
||||
) -> dict | None:
|
||||
"""
|
||||
获取单个数据,默认使用 ID 查询,否则使用关键词查询
|
||||
:param _id: 数据 ID
|
||||
:param v_return_none: 是否返回空 None,否则抛出异常,默认抛出异常
|
||||
:param v_schema: 指定使用的序列化对象
|
||||
"""
|
||||
if _id and self.is_object_id:
|
||||
kwargs["_id"] = ObjectId(_id)
|
||||
params = self.filter_condition(**kwargs)
|
||||
data = await self.collection.find_one(params)
|
||||
if not data and v_return_none:
|
||||
return None
|
||||
elif not data:
|
||||
raise CustomException("查找失败,未查找到对应数据", code=status.HTTP_404_NOT_FOUND)
|
||||
elif data and v_schema:
|
||||
return jsonable_encoder(v_schema(**data))
|
||||
return data
|
||||
|
||||
async def create_data(self, data: dict | Any) -> InsertOneResult:
|
||||
"""
|
||||
创建数据
|
||||
"""
|
||||
if not isinstance(data, dict):
|
||||
data = jsonable_encoder(data)
|
||||
data['create_datetime'] = datetime.datetime.now()
|
||||
data['update_datetime'] = datetime.datetime.now()
|
||||
result = await self.collection.insert_one(data)
|
||||
# 判断插入是否成功
|
||||
if result.acknowledged:
|
||||
return result
|
||||
else:
|
||||
raise CustomException("创建新数据失败", code=status.HTTP_ERROR)
|
||||
|
||||
async def put_data(self, _id: str, data: dict | Any) -> UpdateResult:
|
||||
"""
|
||||
更新数据
|
||||
"""
|
||||
if not isinstance(data, dict):
|
||||
data = jsonable_encoder(data)
|
||||
new_data = {'$set': data}
|
||||
result = await self.collection.update_one({'_id': ObjectId(_id) if self.is_object_id else _id}, new_data)
|
||||
|
||||
if result.matched_count > 0:
|
||||
return result
|
||||
else:
|
||||
raise CustomException("更新失败,未查找到对应数据", code=status.HTTP_404_NOT_FOUND)
|
||||
|
||||
async def delete_data(self, _id: str):
|
||||
"""
|
||||
删除数据
|
||||
"""
|
||||
result = await self.collection.delete_one({'_id': ObjectId(_id) if self.is_object_id else _id})
|
||||
|
||||
if result.deleted_count > 0:
|
||||
return True
|
||||
else:
|
||||
raise CustomException("删除失败,未查找到对应数据", code=status.HTTP_404_NOT_FOUND)
|
||||
|
||||
async def get_datas(
|
||||
self,
|
||||
page: int = 1,
|
||||
limit: int = 10,
|
||||
v_schema: Any = None,
|
||||
v_order: str = None,
|
||||
v_order_field: str = None,
|
||||
v_return_objs: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
使用 find() 要查询的一组文档。 find() 没有I / O,也不需要 await 表达式。它只是创建一个 AsyncIOMotorCursor 实例
|
||||
当您调用 to_list() 或为循环执行异步时 (async for) ,查询实际上是在服务器上执行的。
|
||||
"""
|
||||
|
||||
params = self.filter_condition(**kwargs)
|
||||
cursor = self.collection.find(params)
|
||||
|
||||
if v_order or v_order_field:
|
||||
v_order_field = v_order_field if v_order_field else 'create_datetime'
|
||||
v_order = -1 if v_order in self.ORDER_FIELD else 1
|
||||
cursor.sort(v_order_field, v_order)
|
||||
|
||||
if limit != 0:
|
||||
# 对查询应用排序(sort),跳过(skip)或限制(limit)
|
||||
cursor.skip((page - 1) * limit).limit(limit)
|
||||
|
||||
datas = []
|
||||
async for row in cursor:
|
||||
data = json.loads(dumps(row))
|
||||
datas.append(data)
|
||||
|
||||
if not datas or v_return_objs:
|
||||
return datas
|
||||
elif v_schema:
|
||||
datas = [jsonable_encoder(v_schema(**data)) for data in datas]
|
||||
elif self.schema:
|
||||
datas = [jsonable_encoder(self.schema(**data)) for data in datas]
|
||||
return datas
|
||||
|
||||
async def get_count(self, **kwargs) -> int:
|
||||
"""
|
||||
获取统计数据
|
||||
"""
|
||||
params = self.filter_condition(**kwargs)
|
||||
return await self.collection.count_documents(params)
|
||||
|
||||
@classmethod
|
||||
def filter_condition(cls, **kwargs):
|
||||
"""
|
||||
过滤条件
|
||||
"""
|
||||
params = {}
|
||||
for k, v in kwargs.items():
|
||||
if not v:
|
||||
continue
|
||||
elif isinstance(v, tuple):
|
||||
if v[0] == "like" and v[1]:
|
||||
params[k] = {'$regex': v[1]}
|
||||
elif v[0] == "between" and len(v[1]) == 2:
|
||||
params[k] = {'$gte': f"{v[1][0]} 00:00:00", '$lt': f"{v[1][1]} 23:59:59"}
|
||||
elif v[0] == "ObjectId" and v[1]:
|
||||
try:
|
||||
params[k] = ObjectId(v[1])
|
||||
except InvalidId:
|
||||
raise CustomException("任务编号格式不正确!")
|
||||
else:
|
||||
params[k] = v
|
||||
return params
|
Reference in New Issue
Block a user