项目初次提交
This commit is contained in:
0
core/__init__.py
Normal file
0
core/__init__.py
Normal file
498
core/crud.py
Normal file
498
core/crud.py
Normal file
@ -0,0 +1,498 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @version : 1.0
|
||||
# @Update Time : 2023/8/21 22:18
|
||||
# @File : crud.py
|
||||
# @IDE : PyCharm
|
||||
# @desc : 数据库 增删改查操作
|
||||
|
||||
# sqlalchemy 官方文档:https://docs.sqlalchemy.org/en/20/index.html
|
||||
# sqlalchemy 查询操作(官方文档): https://docs.sqlalchemy.org/en/20/orm/queryguide/select.html
|
||||
# sqlalchemy 增删改操作:https://docs.sqlalchemy.org/en/20/orm/queryguide/dml.html
|
||||
# sqlalchemy 1.x 语法迁移到 2.x :https://docs.sqlalchemy.org/en/20/changelog/migration_20.html#migration-20-query-usage
|
||||
|
||||
import datetime
|
||||
from fastapi import HTTPException
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from sqlalchemy import func, delete, update, BinaryExpression, ScalarResult, select, false, insert
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.strategy_options import _AbstractLoad
|
||||
from starlette import status
|
||||
from core.exception import CustomException
|
||||
from sqlalchemy.sql.selectable import Select as SelectType
|
||||
from typing import Any, Union
|
||||
|
||||
|
||||
class DalBase:
|
||||
# 倒叙
|
||||
ORDER_FIELD = ["desc", "descending"]
|
||||
|
||||
def __init__(self, db: AsyncSession = None, model: Any = None, schema: Any = None):
|
||||
self.db = db
|
||||
self.model = model
|
||||
self.schema = schema
|
||||
|
||||
async def get_data(
|
||||
self,
|
||||
data_id: int = None,
|
||||
v_start_sql: SelectType = None,
|
||||
v_select_from: list[Any] = None,
|
||||
v_join: list[Any] = None,
|
||||
v_outer_join: list[Any] = None,
|
||||
v_options: list[_AbstractLoad] = None,
|
||||
v_where: list[BinaryExpression] = None,
|
||||
v_order: str = None,
|
||||
v_order_field: str = None,
|
||||
v_return_none: bool = False,
|
||||
v_schema: Any = None,
|
||||
v_expire_all: bool = False,
|
||||
**kwargs
|
||||
) -> Any:
|
||||
"""
|
||||
获取单个数据,默认使用 ID 查询,否则使用关键词查询
|
||||
:param data_id: 数据 ID
|
||||
:param v_start_sql: 初始 sql
|
||||
:param v_select_from: 用于指定查询从哪个表开始,通常与 .join() 等方法一起使用。
|
||||
:param v_join: 创建内连接(INNER JOIN)操作,返回两个表中满足连接条件的交集。
|
||||
:param v_outer_join: 用于创建外连接(OUTER JOIN)操作,返回两个表中满足连接条件的并集,包括未匹配的行,并用 NULL 值填充。
|
||||
:param v_options: 用于为查询添加附加选项,如预加载、延迟加载等。
|
||||
:param v_where: 当前表查询条件,原始表达式
|
||||
:param v_order: 排序,默认正序,为 desc 是倒叙
|
||||
:param v_order_field: 排序字段
|
||||
:param v_return_none: 是否返回空 None,否认 抛出异常,默认抛出异常
|
||||
:param v_schema: 指定使用的序列化对象
|
||||
:param v_expire_all: 使当前会话(Session)中所有已加载的对象过期,确保您获取的是数据库中的最新数据,但可能会有性能损耗,博客:https://blog.csdn.net/k_genius/article/details/135490378。
|
||||
:param kwargs: 查询参数
|
||||
:return: 默认返回 ORM 对象,如果存在 v_schema 则会返回 v_schema 结果
|
||||
"""
|
||||
if v_expire_all:
|
||||
self.db.expire_all()
|
||||
|
||||
if not isinstance(v_start_sql, SelectType):
|
||||
v_start_sql = select(self.model).where(self.model.is_delete == false())
|
||||
|
||||
if data_id is not None:
|
||||
v_start_sql = v_start_sql.where(self.model.id == data_id)
|
||||
|
||||
queryset: ScalarResult = await self.filter_core(
|
||||
v_start_sql=v_start_sql,
|
||||
v_select_from=v_select_from,
|
||||
v_join=v_join,
|
||||
v_outer_join=v_outer_join,
|
||||
v_options=v_options,
|
||||
v_where=v_where,
|
||||
v_order=v_order,
|
||||
v_order_field=v_order_field,
|
||||
v_return_sql=False,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if v_options:
|
||||
data = queryset.unique().first()
|
||||
else:
|
||||
data = queryset.first()
|
||||
|
||||
if not data and v_return_none:
|
||||
return None
|
||||
|
||||
if data and v_schema:
|
||||
return v_schema.model_validate(data).model_dump()
|
||||
|
||||
if data:
|
||||
return data
|
||||
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="未找到此数据")
|
||||
|
||||
async def get_datas(
|
||||
self,
|
||||
page: int = 1,
|
||||
limit: int = 10,
|
||||
v_start_sql: SelectType = None,
|
||||
v_select_from: list[Any] = None,
|
||||
v_join: list[Any] = None,
|
||||
v_outer_join: list[Any] = None,
|
||||
v_options: list[_AbstractLoad] = None,
|
||||
v_where: list[BinaryExpression] = None,
|
||||
v_order: str = None,
|
||||
v_order_field: str = None,
|
||||
v_return_count: bool = False,
|
||||
v_return_scalars: bool = False,
|
||||
v_return_objs: bool = False,
|
||||
v_schema: Any = None,
|
||||
v_distinct: bool = False,
|
||||
v_expire_all: bool = False,
|
||||
**kwargs
|
||||
) -> Union[list[Any], ScalarResult, tuple]:
|
||||
"""
|
||||
获取数据列表
|
||||
:param page: 页码
|
||||
:param limit: 当前页数据量
|
||||
:param v_start_sql: 初始 sql
|
||||
:param v_select_from: 用于指定查询从哪个表开始,通常与 .join() 等方法一起使用。
|
||||
:param v_join: 创建内连接(INNER JOIN)操作,返回两个表中满足连接条件的交集。
|
||||
:param v_outer_join: 用于创建外连接(OUTER JOIN)操作,返回两个表中满足连接条件的并集,包括未匹配的行,并用 NULL 值填充。
|
||||
:param v_options: 用于为查询添加附加选项,如预加载、延迟加载等。
|
||||
:param v_where: 当前表查询条件,原始表达式
|
||||
:param v_order: 排序,默认正序,为 desc 是倒叙
|
||||
:param v_order_field: 排序字段
|
||||
:param v_return_count: 默认为 False,是否返回 count 过滤后的数据总数,不会影响其他返回结果,会一起返回为一个数组
|
||||
:param v_return_scalars: 返回scalars后的结果
|
||||
:param v_return_objs: 是否返回对象
|
||||
:param v_schema: 指定使用的序列化对象
|
||||
:param v_distinct: 是否结果去重
|
||||
:param v_expire_all: 使当前会话(Session)中所有已加载的对象过期,确保您获取的是数据库中的最新数据,但可能会有性能损耗,博客:https://blog.csdn.net/k_genius/article/details/135490378。
|
||||
:param kwargs: 查询参数,使用的是自定义表达式
|
||||
:return: 返回值优先级:v_return_scalars > v_return_objs > v_schema
|
||||
"""
|
||||
if v_expire_all:
|
||||
self.db.expire_all()
|
||||
|
||||
sql: SelectType = await self.filter_core(
|
||||
v_start_sql=v_start_sql,
|
||||
v_select_from=v_select_from,
|
||||
v_join=v_join,
|
||||
v_outer_join=v_outer_join,
|
||||
v_options=v_options,
|
||||
v_where=v_where,
|
||||
v_order=v_order,
|
||||
v_order_field=v_order_field,
|
||||
v_return_sql=True,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if v_distinct:
|
||||
sql = sql.distinct()
|
||||
|
||||
count = 0
|
||||
if v_return_count:
|
||||
count_sql = select(func.count()).select_from(sql.alias())
|
||||
count_queryset = await self.db.execute(count_sql)
|
||||
count = count_queryset.one()[0]
|
||||
|
||||
if limit != 0:
|
||||
sql = sql.offset((page - 1) * limit).limit(limit)
|
||||
|
||||
queryset = await self.db.scalars(sql)
|
||||
|
||||
if v_return_scalars:
|
||||
if v_return_count:
|
||||
return queryset, count
|
||||
return queryset
|
||||
|
||||
if v_options:
|
||||
result = queryset.unique().all()
|
||||
else:
|
||||
result = queryset.all()
|
||||
|
||||
if v_return_objs:
|
||||
if v_return_count:
|
||||
return list(result), count
|
||||
return list(result)
|
||||
|
||||
datas = [await self.out_dict(i, v_schema=v_schema) for i in result]
|
||||
if v_return_count:
|
||||
return datas, count
|
||||
return datas
|
||||
|
||||
async def get_count_sql(
|
||||
self,
|
||||
v_start_sql: SelectType
|
||||
) -> int:
|
||||
count_sql = select(func.count()).select_from(v_start_sql.alias())
|
||||
count_queryset = await self.db.execute(count_sql)
|
||||
count = count_queryset.one()[0]
|
||||
return count
|
||||
|
||||
async def get_count(
|
||||
self,
|
||||
v_select_from: list[Any] = None,
|
||||
v_join: list[Any] = None,
|
||||
v_outer_join: list[Any] = None,
|
||||
v_where: list[BinaryExpression] = None,
|
||||
**kwargs
|
||||
) -> int:
|
||||
"""
|
||||
获取数据总数
|
||||
:param v_select_from: 用于指定查询从哪个表开始,通常与 .join() 等方法一起使用。
|
||||
:param v_join: 创建内连接(INNER JOIN)操作,返回两个表中满足连接条件的交集。
|
||||
:param v_outer_join: 用于创建外连接(OUTER JOIN)操作,返回两个表中满足连接条件的并集,包括未匹配的行,并用 NULL 值填充。
|
||||
:param v_where: 当前表查询条件,原始表达式
|
||||
:param kwargs: 查询参数
|
||||
"""
|
||||
v_start_sql = select(func.count(self.model.id))
|
||||
sql = await self.filter_core(
|
||||
v_start_sql=v_start_sql,
|
||||
v_select_from=v_select_from,
|
||||
v_join=v_join,
|
||||
v_outer_join=v_outer_join,
|
||||
v_where=v_where,
|
||||
v_return_sql=True,
|
||||
**kwargs
|
||||
)
|
||||
queryset = await self.db.execute(sql)
|
||||
return queryset.one()[0]
|
||||
|
||||
async def create_data(
|
||||
self,
|
||||
data,
|
||||
v_options: list[_AbstractLoad] = None,
|
||||
v_return_obj: bool = False,
|
||||
v_schema: Any = None
|
||||
) -> Any:
|
||||
"""
|
||||
创建单个数据
|
||||
:param data: 创建数据
|
||||
:param v_options: 指示应使用select在预加载中加载给定的属性。
|
||||
:param v_schema: ,指定使用的序列化对象
|
||||
:param v_return_obj: ,是否返回对象
|
||||
"""
|
||||
if isinstance(data, dict):
|
||||
obj = self.model(**data)
|
||||
else:
|
||||
obj = self.model(**data.model_dump())
|
||||
await self.flush(obj)
|
||||
return await self.out_dict(obj, v_options, v_return_obj, v_schema)
|
||||
|
||||
async def create_datas(self, datas: list[dict]) -> None:
|
||||
"""
|
||||
批量创建数据
|
||||
SQLAlchemy 2.0 批量插入不支持 MySQL 返回值:
|
||||
https://docs.sqlalchemy.org/en/20/orm/queryguide/dml.html#getting-new-objects-with-returning
|
||||
:param datas: 字典数据列表
|
||||
"""
|
||||
await self.db.execute(insert(self.model), datas)
|
||||
await self.db.flush()
|
||||
|
||||
async def put_data(
|
||||
self,
|
||||
data_id: int,
|
||||
data: Any,
|
||||
v_options: list[_AbstractLoad] = None,
|
||||
v_return_obj: bool = False,
|
||||
v_schema: Any = None
|
||||
) -> Any:
|
||||
"""
|
||||
更新单个数据
|
||||
:param data_id: 修改行数据的 ID
|
||||
:param data: 数据内容
|
||||
:param v_options: 指示应使用select在预加载中加载给定的属性。
|
||||
:param v_return_obj: ,是否返回对象
|
||||
:param v_schema: ,指定使用的序列化对象
|
||||
"""
|
||||
obj = await self.get_data(data_id, v_options=v_options)
|
||||
obj_dict = jsonable_encoder(data)
|
||||
for key, value in obj_dict.items():
|
||||
setattr(obj, key, value)
|
||||
await self.flush(obj)
|
||||
return await self.out_dict(obj, None, v_return_obj, v_schema)
|
||||
|
||||
async def delete_datas(self, ids: list[int], v_soft: bool = False, **kwargs) -> None:
|
||||
"""
|
||||
删除多条数据
|
||||
:param ids: 数据集
|
||||
:param v_soft: 是否执行软删除
|
||||
:param kwargs: 其他更新字段
|
||||
"""
|
||||
if v_soft:
|
||||
await self.db.execute(
|
||||
update(self.model).where(self.model.id.in_(ids)).values(
|
||||
delete_datetime=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
is_delete=True,
|
||||
**kwargs
|
||||
)
|
||||
)
|
||||
else:
|
||||
await self.db.execute(delete(self.model).where(self.model.id.in_(ids)))
|
||||
await self.flush()
|
||||
|
||||
async def flush(self, obj: Any = None) -> Any:
|
||||
"""
|
||||
刷新到数据库
|
||||
:param obj:
|
||||
:return:
|
||||
"""
|
||||
if obj:
|
||||
self.db.add(obj)
|
||||
await self.db.flush()
|
||||
if obj:
|
||||
# 使用 get_data 或者 get_datas 获取到实例后如果更新了实例,并需要序列化实例,那么需要执行 refresh 刷新才能正常序列化
|
||||
await self.db.refresh(obj)
|
||||
return obj
|
||||
|
||||
async def out_dict(
|
||||
self,
|
||||
obj: Any,
|
||||
v_options: list[_AbstractLoad] = None,
|
||||
v_return_obj: bool = False,
|
||||
v_schema: Any = None
|
||||
) -> Any:
|
||||
"""
|
||||
序列化
|
||||
:param obj:
|
||||
:param v_options: 指示应使用select在预加载中加载给定的属性。
|
||||
:param v_return_obj: ,是否返回对象
|
||||
:param v_schema: ,指定使用的序列化对象
|
||||
:return:
|
||||
"""
|
||||
if v_options:
|
||||
obj = await self.get_data(obj.id, v_options=v_options)
|
||||
if v_return_obj:
|
||||
return obj
|
||||
if v_schema:
|
||||
return v_schema.model_validate(obj).model_dump()
|
||||
return self.schema.model_validate(obj).model_dump()
|
||||
|
||||
async def filter_core(
|
||||
self,
|
||||
v_start_sql: SelectType = None,
|
||||
v_select_from: list[Any] = None,
|
||||
v_join: list[Any] = None,
|
||||
v_outer_join: list[Any] = None,
|
||||
v_options: list[_AbstractLoad] = None,
|
||||
v_where: list[BinaryExpression] = None,
|
||||
v_order: str = None,
|
||||
v_order_field: str = None,
|
||||
v_return_sql: bool = False,
|
||||
**kwargs
|
||||
) -> Union[ScalarResult, SelectType]:
|
||||
"""
|
||||
数据过滤核心功能
|
||||
:param v_start_sql: 初始 sql
|
||||
:param v_select_from: 用于指定查询从哪个表开始,通常与 .join() 等方法一起使用。
|
||||
:param v_join: 创建内连接(INNER JOIN)操作,返回两个表中满足连接条件的交集。
|
||||
:param v_outer_join: 用于创建外连接(OUTER JOIN)操作,返回两个表中满足连接条件的并集,包括未匹配的行,并用 NULL 值填充。
|
||||
:param v_options: 用于为查询添加附加选项,如预加载、延迟加载等。
|
||||
:param v_where: 当前表查询条件,原始表达式
|
||||
:param v_order: 排序,默认正序,为 desc 是倒叙
|
||||
:param v_order_field: 排序字段
|
||||
:param v_return_sql: 是否直接返回 sql
|
||||
:return: 返回过滤后的总数居 或 sql
|
||||
"""
|
||||
if not isinstance(v_start_sql, SelectType):
|
||||
v_start_sql = select(self.model).where(self.model.is_delete == false())
|
||||
|
||||
sql = self.add_relation(
|
||||
v_start_sql=v_start_sql,
|
||||
v_select_from=v_select_from,
|
||||
v_join=v_join,
|
||||
v_outer_join=v_outer_join,
|
||||
v_options=v_options
|
||||
)
|
||||
|
||||
if v_where:
|
||||
sql = sql.where(*v_where)
|
||||
|
||||
sql = self.add_filter_condition(sql, **kwargs)
|
||||
|
||||
if v_order_field and (v_order in self.ORDER_FIELD):
|
||||
sql = sql.order_by(getattr(self.model, v_order_field).desc(), self.model.id.desc())
|
||||
elif v_order_field:
|
||||
sql = sql.order_by(getattr(self.model, v_order_field), self.model.id)
|
||||
elif v_order in self.ORDER_FIELD:
|
||||
sql = sql.order_by(self.model.id.desc())
|
||||
|
||||
if v_return_sql:
|
||||
return sql
|
||||
|
||||
queryset = await self.db.scalars(sql)
|
||||
|
||||
return queryset
|
||||
|
||||
def add_relation(
|
||||
self,
|
||||
v_start_sql: SelectType,
|
||||
v_select_from: list[Any] = None,
|
||||
v_join: list[Any] = None,
|
||||
v_outer_join: list[Any] = None,
|
||||
v_options: list[_AbstractLoad] = None,
|
||||
) -> SelectType:
|
||||
"""
|
||||
关系查询,关系加载
|
||||
:param v_start_sql: 初始 sql
|
||||
:param v_select_from: 用于指定查询从哪个表开始,通常与 .join() 等方法一起使用。
|
||||
:param v_join: 创建内连接(INNER JOIN)操作,返回两个表中满足连接条件的交集。
|
||||
:param v_outer_join: 用于创建外连接(OUTER JOIN)操作,返回两个表中满足连接条件的并集,包括未匹配的行,并用 NULL 值填充。
|
||||
:param v_options: 用于为查询添加附加选项,如预加载、延迟加载等。
|
||||
"""
|
||||
if v_select_from:
|
||||
v_start_sql = v_start_sql.select_from(*v_select_from)
|
||||
|
||||
if v_join:
|
||||
for relation in v_join:
|
||||
table = relation[0]
|
||||
if isinstance(table, str):
|
||||
table = getattr(self.model, table)
|
||||
if len(relation) == 2:
|
||||
v_start_sql = v_start_sql.join(table, relation[1])
|
||||
else:
|
||||
v_start_sql = v_start_sql.join(table)
|
||||
|
||||
if v_outer_join:
|
||||
for relation in v_outer_join:
|
||||
table = relation[0]
|
||||
if isinstance(table, str):
|
||||
table = getattr(self.model, table)
|
||||
if len(relation) == 2:
|
||||
v_start_sql = v_start_sql.outerjoin(table, relation[1])
|
||||
else:
|
||||
v_start_sql = v_start_sql.outerjoin(table)
|
||||
|
||||
if v_options:
|
||||
v_start_sql = v_start_sql.options(*v_options)
|
||||
|
||||
return v_start_sql
|
||||
|
||||
def add_filter_condition(self, sql: SelectType, **kwargs) -> SelectType:
|
||||
"""
|
||||
添加过滤条件
|
||||
:param sql:
|
||||
:param kwargs: 关键词参数
|
||||
"""
|
||||
conditions = self.__dict_filter(**kwargs)
|
||||
if conditions:
|
||||
sql = sql.where(*conditions)
|
||||
return sql
|
||||
|
||||
def __dict_filter(self, **kwargs) -> list[BinaryExpression]:
|
||||
"""
|
||||
字典过滤
|
||||
:param model:
|
||||
:param kwargs:
|
||||
"""
|
||||
conditions = []
|
||||
for field, value in kwargs.items():
|
||||
if value is not None and value != "":
|
||||
attr = getattr(self.model, field)
|
||||
if isinstance(value, tuple):
|
||||
if len(value) == 1:
|
||||
if value[0] == "None":
|
||||
conditions.append(attr.is_(None))
|
||||
elif value[0] == "not None":
|
||||
conditions.append(attr.isnot(None))
|
||||
else:
|
||||
raise CustomException("SQL查询语法错误")
|
||||
elif len(value) == 2 and value[1] not in [None, [], ""]:
|
||||
if value[0] == "date":
|
||||
# 根据日期查询, 关键函数是:func.time_format和func.date_format
|
||||
conditions.append(func.date_format(attr, "%Y-%m-%d") == value[1])
|
||||
elif value[0] == "like":
|
||||
conditions.append(attr.like(f"%{value[1]}%"))
|
||||
elif value[0] == "in":
|
||||
conditions.append(attr.in_(value[1]))
|
||||
elif value[0] == "between" and len(value[1]) == 2:
|
||||
conditions.append(attr.between(value[1][0], value[1][1]))
|
||||
elif value[0] == "month":
|
||||
conditions.append(func.date_format(attr, "%Y-%m") == value[1])
|
||||
elif value[0] == "!=":
|
||||
conditions.append(attr != value[1])
|
||||
elif value[0] == ">":
|
||||
conditions.append(attr > value[1])
|
||||
elif value[0] == ">=":
|
||||
conditions.append(attr >= value[1])
|
||||
elif value[0] == "<=":
|
||||
conditions.append(attr <= value[1])
|
||||
else:
|
||||
raise CustomException("SQL查询语法错误")
|
||||
else:
|
||||
conditions.append(attr == value)
|
||||
return conditions
|
121
core/data_types.py
Normal file
121
core/data_types.py
Normal file
@ -0,0 +1,121 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @version : 1.0
|
||||
# @Create Time : 2023/7/16 12:42
|
||||
# @File : data_types.py
|
||||
# @IDE : PyCharm
|
||||
# @desc : 自定义数据类型
|
||||
|
||||
"""
|
||||
自定义数据类型 - 官方文档:https://docs.pydantic.dev/dev-v2/usage/types/custom/#adding-validation-and-serialization
|
||||
"""
|
||||
import datetime
|
||||
from typing import Annotated, Any
|
||||
from bson import ObjectId
|
||||
from pydantic import AfterValidator, PlainSerializer, WithJsonSchema
|
||||
from .validator import *
|
||||
|
||||
|
||||
def datetime_str_vali(value: str | datetime.datetime | int | float | dict):
|
||||
"""
|
||||
日期时间字符串验证
|
||||
如果我传入的是字符串,那么直接返回,如果我传入的是一个日期类型,那么会转为字符串格式后返回
|
||||
因为在 pydantic 2.0 中是支持 int 或 float 自动转换类型的,所以我这里添加进去,但是在处理时会使这两种类型报错
|
||||
|
||||
官方文档:https://docs.pydantic.dev/dev-v2/usage/types/datetime/
|
||||
"""
|
||||
if isinstance(value, str):
|
||||
pattern = "%Y-%m-%d %H:%M:%S"
|
||||
try:
|
||||
datetime.datetime.strptime(value, pattern)
|
||||
return value
|
||||
except ValueError:
|
||||
pass
|
||||
elif isinstance(value, datetime.datetime):
|
||||
return value.strftime("%Y-%m-%d %H:%M:%S")
|
||||
elif isinstance(value, dict):
|
||||
# 用于处理 mongodb 日期时间数据类型
|
||||
date_str = value.get("$date")
|
||||
date_format = '%Y-%m-%dT%H:%M:%S.%fZ'
|
||||
# 将字符串转换为datetime.datetime类型
|
||||
datetime_obj = datetime.datetime.strptime(date_str, date_format)
|
||||
# 将datetime.datetime对象转换为指定的字符串格式
|
||||
return datetime_obj.strftime('%Y-%m-%d %H:%M:%S')
|
||||
raise ValueError("无效的日期时间或字符串数据")
|
||||
|
||||
|
||||
# 实现自定义一个日期时间字符串的数据类型
|
||||
DatetimeStr = Annotated[
|
||||
str | datetime.datetime | int | float | dict,
|
||||
AfterValidator(datetime_str_vali),
|
||||
PlainSerializer(lambda x: x, return_type=str),
|
||||
WithJsonSchema({'type': 'string'}, mode='serialization')
|
||||
]
|
||||
|
||||
|
||||
# 实现自定义一个手机号类型
|
||||
Telephone = Annotated[
|
||||
str,
|
||||
AfterValidator(lambda x: vali_telephone(x)),
|
||||
PlainSerializer(lambda x: x, return_type=str),
|
||||
WithJsonSchema({'type': 'string'}, mode='serialization')
|
||||
]
|
||||
|
||||
|
||||
# 实现自定义一个邮箱类型
|
||||
Email = Annotated[
|
||||
str,
|
||||
AfterValidator(lambda x: vali_email(x)),
|
||||
PlainSerializer(lambda x: x, return_type=str),
|
||||
WithJsonSchema({'type': 'string'}, mode='serialization')
|
||||
]
|
||||
|
||||
|
||||
def date_str_vali(value: str | datetime.date | int | float):
|
||||
"""
|
||||
日期字符串验证
|
||||
如果我传入的是字符串,那么直接返回,如果我传入的是一个日期类型,那么会转为字符串格式后返回
|
||||
因为在 pydantic 2.0 中是支持 int 或 float 自动转换类型的,所以我这里添加进去,但是在处理时会使这两种类型报错
|
||||
|
||||
官方文档:https://docs.pydantic.dev/dev-v2/usage/types/datetime/
|
||||
"""
|
||||
if isinstance(value, str):
|
||||
pattern = "%Y-%m-%d"
|
||||
try:
|
||||
datetime.datetime.strptime(value, pattern)
|
||||
return value
|
||||
except ValueError:
|
||||
pass
|
||||
elif isinstance(value, datetime.date):
|
||||
return value.strftime("%Y-%m-%d")
|
||||
raise ValueError("无效的日期时间或字符串数据")
|
||||
|
||||
|
||||
# 实现自定义一个日期字符串的数据类型
|
||||
DateStr = Annotated[
|
||||
str | datetime.date | int | float,
|
||||
AfterValidator(date_str_vali),
|
||||
PlainSerializer(lambda x: x, return_type=str),
|
||||
WithJsonSchema({'type': 'string'}, mode='serialization')
|
||||
]
|
||||
|
||||
|
||||
def object_id_str_vali(value: str | dict | ObjectId):
|
||||
"""
|
||||
官方文档:https://docs.pydantic.dev/dev-v2/usage/types/datetime/
|
||||
"""
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
elif isinstance(value, dict):
|
||||
return value.get("$oid")
|
||||
elif isinstance(value, ObjectId):
|
||||
return str(value)
|
||||
raise ValueError("无效的 ObjectId 数据类型")
|
||||
|
||||
|
||||
ObjectIdStr = Annotated[
|
||||
Any, # 这里不能直接使用 any,需要使用 typing.Any
|
||||
AfterValidator(object_id_str_vali),
|
||||
PlainSerializer(lambda x: x, return_type=str),
|
||||
WithJsonSchema({'type': 'string'}, mode='serialization')
|
||||
]
|
127
core/database.py
Normal file
127
core/database.py
Normal file
@ -0,0 +1,127 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @version : 1.0
|
||||
# @Update Time : 2023/8/18 9:00
|
||||
# @File : database.py
|
||||
# @IDE : PyCharm
|
||||
# @desc : SQLAlchemy 部分
|
||||
|
||||
"""
|
||||
导入 SQLAlchemy 部分
|
||||
安装: pip install sqlalchemy[asyncio]
|
||||
官方文档:https://docs.sqlalchemy.org/en/20/intro.html#installation
|
||||
"""
|
||||
from typing import AsyncGenerator
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker, AsyncAttrs
|
||||
from sqlalchemy.orm import DeclarativeBase, declared_attr
|
||||
from application.settings import SQLALCHEMY_DATABASE_URL, REDIS_DB_ENABLE, MONGO_DB_ENABLE
|
||||
from fastapi import Request
|
||||
from core.exception import CustomException
|
||||
from motor.motor_asyncio import AsyncIOMotorDatabase
|
||||
|
||||
# 官方文档:https://docs.sqlalchemy.org/en/20/orm/extensions/asyncio.html#sqlalchemy.ext.asyncio.create_async_engine
|
||||
|
||||
# database_url dialect+driver://username:password@host:port/database
|
||||
|
||||
# echo:如果为True,引擎将记录所有语句以及它们的参数列表的repr()到默认的日志处理程序,该处理程序默认为sys.stdout。如果设置为字符串"debug",
|
||||
# 结果行也将打印到标准输出。Engine的echo属性可以随时修改以打开和关闭日志记录;也可以使用标准的Python logging模块来直接控制日志记录。
|
||||
|
||||
# echo_pool=False:如果为True,连接池将记录信息性输出,如何时使连接失效以及何时将连接回收到默认的日志处理程序,该处理程序默认为sys.stdout。
|
||||
# 如果设置为字符串"debug",记录将包括池的检出和检入。也可以使用标准的Python logging模块来直接控制日志记录。
|
||||
|
||||
# pool_pre_ping:布尔值,如果为True,将启用连接池的"pre-ping"功能,该功能在每次检出时测试连接的活动性。
|
||||
|
||||
# pool_recycle=-1:此设置导致池在给定的秒数后重新使用连接。默认为-1,即没有超时。例如,将其设置为3600意味着在一小时后重新使用连接。
|
||||
# 请注意,特别是MySQL会在检测到连接8小时内没有活动时自动断开连接(尽管可以通过MySQLDB连接自身和服务器配置进行配置)。
|
||||
|
||||
# pool_size=5:在连接池内保持打开的连接数。与QueuePool以及SingletonThreadPool一起使用。
|
||||
# 对于QueuePool,pool_size设置为0表示没有限制;要禁用连接池,请将poolclass设置为NullPool。
|
||||
|
||||
# pool_timeout=30:在从池中获取连接之前等待的秒数。仅在QueuePool中使用。这可以是一个浮点数,但受Python时间函数的限制,可能在几十毫秒内不可靠
|
||||
|
||||
# max_overflow 参数用于配置连接池中允许的连接 "溢出" 数量。这个参数用于在高负载情况下处理连接请求的峰值。
|
||||
# 当连接池的所有连接都在使用中时,如果有新的连接请求到达,连接池可以创建额外的连接来满足这些请求,最多创建的数量由 max_overflow 参数决定。
|
||||
|
||||
# 创建数据库连接
|
||||
async_engine = create_async_engine(
|
||||
SQLALCHEMY_DATABASE_URL,
|
||||
echo=False,
|
||||
echo_pool=False,
|
||||
pool_pre_ping=True,
|
||||
pool_recycle=3600,
|
||||
pool_size=5,
|
||||
max_overflow=5,
|
||||
connect_args={}
|
||||
)
|
||||
|
||||
# 创建数据库会话
|
||||
session_factory = async_sessionmaker(
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
bind=async_engine,
|
||||
expire_on_commit=True,
|
||||
class_=AsyncSession
|
||||
)
|
||||
|
||||
|
||||
class Base(AsyncAttrs, DeclarativeBase):
|
||||
"""
|
||||
创建基本映射类
|
||||
稍后,我们将继承该类,创建每个 ORM 模型
|
||||
"""
|
||||
|
||||
@declared_attr.directive
|
||||
def __tablename__(cls) -> str:
|
||||
"""
|
||||
将表名改为小写
|
||||
如果有自定义表名就取自定义,没有就取小写类名
|
||||
"""
|
||||
table_name = cls.__tablename__
|
||||
if not table_name:
|
||||
model_name = cls.__name__
|
||||
ls = []
|
||||
for index, char in enumerate(model_name):
|
||||
if char.isupper() and index != 0:
|
||||
ls.append("_")
|
||||
ls.append(char)
|
||||
table_name = "".join(ls).lower()
|
||||
return table_name
|
||||
|
||||
|
||||
async def db_getter() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
获取主数据库会话
|
||||
|
||||
数据库依赖项,它将在单个请求中使用,然后在请求完成后将其关闭。
|
||||
|
||||
函数的返回类型被注解为 AsyncGenerator[int, None],其中 AsyncSession 是生成的值的类型,而 None 表示异步生成器没有终止条件。
|
||||
"""
|
||||
async with session_factory() as session:
|
||||
# 创建一个新的事务,半自动 commit
|
||||
async with session.begin():
|
||||
yield session
|
||||
|
||||
|
||||
def redis_getter(request: Request) -> Redis:
|
||||
"""
|
||||
获取 redis 数据库对象
|
||||
|
||||
全局挂载,使用一个数据库对象
|
||||
"""
|
||||
if not REDIS_DB_ENABLE:
|
||||
raise CustomException("请先配置Redis数据库链接并启用!", desc="请启用 application/settings.py: REDIS_DB_ENABLE")
|
||||
return request.app.state.redis
|
||||
|
||||
|
||||
def mongo_getter(request: Request) -> AsyncIOMotorDatabase:
|
||||
"""
|
||||
获取 mongo 数据库对象
|
||||
|
||||
全局挂载,使用一个数据库对象
|
||||
"""
|
||||
if not MONGO_DB_ENABLE:
|
||||
raise CustomException(
|
||||
msg="请先开启 MongoDB 数据库连接!",
|
||||
desc="请启用 application/settings.py: MONGO_DB_ENABLE"
|
||||
)
|
||||
return request.app.state.mongo
|
62
core/dependencies.py
Normal file
62
core/dependencies.py
Normal file
@ -0,0 +1,62 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @version : 1.0
|
||||
# @Create Time : 2022/8/8 14:18
|
||||
# @File : dependencies.py
|
||||
# @IDE : PyCharm
|
||||
# @desc : 常用依赖项
|
||||
|
||||
"""
|
||||
类依赖项-官方文档:https://fastapi.tiangolo.com/zh/tutorial/dependencies/classes-as-dependencies/
|
||||
"""
|
||||
|
||||
from fastapi import Body
|
||||
import copy
|
||||
|
||||
|
||||
class QueryParams:
|
||||
|
||||
def __init__(self, params=None):
|
||||
if params:
|
||||
self.page = params.page
|
||||
self.limit = params.limit
|
||||
self.v_order = params.v_order
|
||||
self.v_order_field = params.v_order_field
|
||||
|
||||
def dict(self, exclude: list[str] = None) -> dict:
|
||||
result = copy.deepcopy(self.__dict__)
|
||||
if exclude:
|
||||
for item in exclude:
|
||||
try:
|
||||
del result[item]
|
||||
except KeyError:
|
||||
pass
|
||||
return result
|
||||
|
||||
def to_count(self, exclude: list[str] = None) -> dict:
|
||||
params = self.dict(exclude=exclude)
|
||||
del params["page"]
|
||||
del params["limit"]
|
||||
del params["v_order"]
|
||||
del params["v_order_field"]
|
||||
return params
|
||||
|
||||
|
||||
class Paging(QueryParams):
|
||||
"""
|
||||
列表分页
|
||||
"""
|
||||
def __init__(self, page: int = 1, limit: int = 10, v_order_field: str = None, v_order: str = None):
|
||||
super().__init__()
|
||||
self.page = page
|
||||
self.limit = limit
|
||||
self.v_order = v_order
|
||||
self.v_order_field = v_order_field
|
||||
|
||||
|
||||
class IdList:
|
||||
"""
|
||||
id 列表
|
||||
"""
|
||||
def __init__(self, ids: list[int] = Body(..., title="ID 列表")):
|
||||
self.ids = ids
|
44
core/docs.py
Normal file
44
core/docs.py
Normal file
@ -0,0 +1,44 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @version : 1.0
|
||||
# @Create Time : 2023/11/16 16:44
|
||||
# @File : views.py
|
||||
# @IDE : PyCharm
|
||||
# @desc : 项目文档
|
||||
|
||||
|
||||
# 自定义接口文档静态文件:https://fastapi.tiangolo.com/how-to/custom-docs-ui-assets/
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.openapi.docs import (
|
||||
get_redoc_html,
|
||||
get_swagger_ui_html,
|
||||
get_swagger_ui_oauth2_redirect_html,
|
||||
)
|
||||
|
||||
|
||||
def custom_api_docs(app: FastAPI):
|
||||
"""
|
||||
自定义配置接口本地静态文档
|
||||
"""
|
||||
|
||||
@app.get("/docs", include_in_schema=False)
|
||||
async def custom_swagger_ui_html():
|
||||
return get_swagger_ui_html(
|
||||
openapi_url=app.openapi_url,
|
||||
title=app.title + " - Swagger UI",
|
||||
oauth2_redirect_url=app.swagger_ui_oauth2_redirect_url,
|
||||
swagger_js_url="/media/swagger_ui/swagger-ui-bundle.js",
|
||||
swagger_css_url="/media/swagger_ui/swagger-ui.css",
|
||||
)
|
||||
|
||||
@app.get(app.swagger_ui_oauth2_redirect_url, include_in_schema=False)
|
||||
async def swagger_ui_redirect():
|
||||
return get_swagger_ui_oauth2_redirect_html()
|
||||
|
||||
@app.get("/redoc", include_in_schema=False)
|
||||
async def custom_redoc_html():
|
||||
return get_redoc_html(
|
||||
openapi_url=app.openapi_url,
|
||||
title=app.title + " - ReDoc",
|
||||
redoc_js_url="/media/redoc_ui/redoc.standalone.js",
|
||||
)
|
27
core/enum.py
Normal file
27
core/enum.py
Normal file
@ -0,0 +1,27 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @version : 1.0
|
||||
# @Create Time : 2023/02/12 22:18
|
||||
# @File : enum.py
|
||||
# @IDE : PyCharm
|
||||
# @desc : 增加枚举类方法
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class SuperEnum(Enum):
|
||||
|
||||
@classmethod
|
||||
def to_dict(cls):
|
||||
"""Returns a dictionary representation of the enum."""
|
||||
return {e.name: e.value for e in cls}
|
||||
|
||||
@classmethod
|
||||
def keys(cls):
|
||||
"""Returns a list of all the enum keys."""
|
||||
return cls._member_names_
|
||||
|
||||
@classmethod
|
||||
def values(cls):
|
||||
"""Returns a list of all the enum values."""
|
||||
return list(cls._value2member_map_.keys())
|
121
core/event.py
Normal file
121
core/event.py
Normal file
@ -0,0 +1,121 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @version : 1.0
|
||||
# @Create Time : 2022/3/21 11:03
|
||||
# @File : event.py
|
||||
# @IDE : PyCharm
|
||||
# @desc : 全局事件
|
||||
|
||||
|
||||
from fastapi import FastAPI
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from application.settings import REDIS_DB_URL, MONGO_DB_URL, MONGO_DB_NAME, EVENTS
|
||||
from utils.cache import Cache
|
||||
from redis import asyncio as aioredis
|
||||
from redis.exceptions import AuthenticationError, TimeoutError, RedisError
|
||||
from contextlib import asynccontextmanager
|
||||
from utils.tools import import_modules_async
|
||||
from sqlalchemy.exc import ProgrammingError
|
||||
from core.logger import logger
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
|
||||
await import_modules_async(EVENTS, "全局事件", app=app, status=True)
|
||||
|
||||
yield
|
||||
|
||||
await import_modules_async(EVENTS, "全局事件", app=app, status=False)
|
||||
|
||||
|
||||
async def connect_redis(app: FastAPI, status: bool):
|
||||
"""
|
||||
把 redis 挂载到 app 对象上面
|
||||
|
||||
博客:https://blog.csdn.net/wgPython/article/details/107668521
|
||||
博客:https://www.cnblogs.com/emunshe/p/15761597.html
|
||||
官网:https://aioredis.readthedocs.io/en/latest/getting-started/
|
||||
Github: https://github.com/aio-libs/aioredis-py
|
||||
|
||||
aioredis.from_url(url, *, encoding=None, parser=None, decode_responses=False, db=None, password=None, ssl=None,
|
||||
connection_cls=None, loop=None, **kwargs) 方法是 aioredis 库中用于从 Redis 连接 URL 创建 Redis 连接对象的方法。
|
||||
|
||||
以下是该方法的参数说明:
|
||||
url:Redis 连接 URL。例如 redis://localhost:6379/0。
|
||||
encoding:可选参数,Redis 编码格式。默认为 utf-8。
|
||||
parser:可选参数,Redis 数据解析器。默认为 None,表示使用默认解析器。
|
||||
decode_responses:可选参数,是否将 Redis 响应解码为 Python 字符串。默认为 False。
|
||||
db:可选参数,Redis 数据库编号。默认为 None。
|
||||
password:可选参数,Redis 认证密码。默认为 None,表示无需认证。
|
||||
ssl:可选参数,是否使用 SSL/TLS 加密连接。默认为 None。
|
||||
connection_cls:可选参数,Redis 连接类。默认为 None,表示使用默认连接类。
|
||||
loop:可选参数,用于创建连接对象的事件循环。默认为 None,表示使用默认事件循环。
|
||||
**kwargs:可选参数,其他连接参数,用于传递给 Redis 连接类的构造函数。
|
||||
|
||||
aioredis.from_url() 方法的主要作用是将 Redis 连接 URL 转换为 Redis 连接对象。
|
||||
除了 URL 参数外,其他参数用于指定 Redis 连接的各种选项,例如 Redis 数据库编号、密码、SSL/TLS 加密等等。可以根据需要选择使用这些选项。
|
||||
|
||||
health_check_interval 是 aioredis.from_url() 方法中的一个可选参数,用于设置 Redis 连接的健康检查间隔时间。
|
||||
健康检查是指在 Redis 连接池中使用的连接对象会定期向 Redis 服务器发送 PING 命令来检查连接是否仍然有效。
|
||||
该参数的默认值是 0,表示不进行健康检查。如果需要启用健康检查,则可以将该参数设置为一个正整数,表示检查间隔的秒数。
|
||||
例如,如果需要每隔 5 秒对 Redis 连接进行一次健康检查,则可以将 health_check_interval 设置为 5
|
||||
:param app:
|
||||
:param status:
|
||||
:return:
|
||||
"""
|
||||
if status:
|
||||
rd = aioredis.from_url(REDIS_DB_URL, decode_responses=True, health_check_interval=1)
|
||||
app.state.redis = rd
|
||||
try:
|
||||
response = await rd.ping()
|
||||
if response:
|
||||
print("Redis 连接成功")
|
||||
else:
|
||||
print("Redis 连接失败")
|
||||
except AuthenticationError as e:
|
||||
raise AuthenticationError(f"Redis 连接认证失败,用户名或密码错误: {e}")
|
||||
except TimeoutError as e:
|
||||
raise TimeoutError(f"Redis 连接超时,地址或者端口错误: {e}")
|
||||
except RedisError as e:
|
||||
raise RedisError(f"Redis 连接失败: {e}")
|
||||
try:
|
||||
await Cache(app.state.redis).cache_tab_names()
|
||||
except ProgrammingError as e:
|
||||
logger.error(f"sqlalchemy.exc.ProgrammingError: {e}")
|
||||
print(f"sqlalchemy.exc.ProgrammingError: {e}")
|
||||
else:
|
||||
print("Redis 连接关闭")
|
||||
await app.state.redis.close()
|
||||
|
||||
|
||||
async def connect_mongo(app: FastAPI, status: bool):
|
||||
"""
|
||||
把 mongo 挂载到 app 对象上面
|
||||
|
||||
博客:https://www.cnblogs.com/aduner/p/13532504.html
|
||||
mongodb 官网:https://www.mongodb.com/docs/drivers/motor/
|
||||
motor 文档:https://motor.readthedocs.io/en/stable/
|
||||
:param app:
|
||||
:param status:
|
||||
:return:
|
||||
"""
|
||||
if status:
|
||||
client: AsyncIOMotorClient = AsyncIOMotorClient(
|
||||
MONGO_DB_URL,
|
||||
maxPoolSize=10,
|
||||
minPoolSize=10,
|
||||
serverSelectionTimeoutMS=5000
|
||||
)
|
||||
app.state.mongo_client = client
|
||||
app.state.mongo = client[MONGO_DB_NAME]
|
||||
# 尝试连接并捕获可能的超时异常
|
||||
try:
|
||||
# 触发一次服务器通信来确认连接
|
||||
data = await client.server_info()
|
||||
print("MongoDB 连接成功", data)
|
||||
except Exception as e:
|
||||
raise ValueError(f"MongoDB 连接失败: {e}")
|
||||
else:
|
||||
print("MongoDB 连接关闭")
|
||||
app.state.mongo_client.close()
|
149
core/exception.py
Normal file
149
core/exception.py
Normal file
@ -0,0 +1,149 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @version : 1.0
|
||||
# @Create Time : 2021/10/19 15:47
|
||||
# @File : exception.py
|
||||
# @IDE : PyCharm
|
||||
# @desc : 全局异常处理
|
||||
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from starlette import status
|
||||
from fastapi import Request
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi import FastAPI
|
||||
from core.logger import logger
|
||||
from application.settings import DEBUG
|
||||
|
||||
|
||||
class CustomException(Exception):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
msg: str,
|
||||
code: int = status.HTTP_400_BAD_REQUEST,
|
||||
status_code: int = status.HTTP_200_OK,
|
||||
desc: str = None
|
||||
):
|
||||
self.msg = msg
|
||||
self.code = code
|
||||
self.status_code = status_code
|
||||
self.desc = desc
|
||||
|
||||
|
||||
def register_exception(app: FastAPI):
|
||||
"""
|
||||
异常捕捉
|
||||
"""
|
||||
|
||||
@app.exception_handler(CustomException)
|
||||
async def custom_exception_handler(request: Request, exc: CustomException):
|
||||
"""
|
||||
自定义异常
|
||||
"""
|
||||
if DEBUG:
|
||||
print("请求地址", request.url.__str__())
|
||||
print("捕捉到重写CustomException异常异常:custom_exception_handler")
|
||||
print(exc.desc)
|
||||
print(exc.msg)
|
||||
# 打印栈信息,方便追踪排查异常
|
||||
logger.exception(exc)
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={"message": exc.msg, "code": exc.code},
|
||||
)
|
||||
|
||||
@app.exception_handler(StarletteHTTPException)
|
||||
async def unicorn_exception_handler(request: Request, exc: StarletteHTTPException):
|
||||
"""
|
||||
重写HTTPException异常处理器
|
||||
"""
|
||||
if DEBUG:
|
||||
print("请求地址", request.url.__str__())
|
||||
print("捕捉到重写HTTPException异常异常:unicorn_exception_handler")
|
||||
print(exc.detail)
|
||||
# 打印栈信息,方便追踪排查异常
|
||||
logger.exception(exc)
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={
|
||||
"code": exc.status_code,
|
||||
"message": exc.detail,
|
||||
}
|
||||
)
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||
"""
|
||||
重写请求验证异常处理器
|
||||
"""
|
||||
if DEBUG:
|
||||
print("请求地址", request.url.__str__())
|
||||
print("捕捉到重写请求验证异常异常:validation_exception_handler")
|
||||
print(exc.errors())
|
||||
# 打印栈信息,方便追踪排查异常
|
||||
logger.exception(exc)
|
||||
msg = exc.errors()[0].get("msg")
|
||||
if msg == "field required":
|
||||
msg = "请求失败,缺少必填项!"
|
||||
elif msg == "value is not a valid list":
|
||||
print(exc.errors())
|
||||
msg = f"类型错误,提交参数应该为列表!"
|
||||
elif msg == "value is not a valid int":
|
||||
msg = f"类型错误,提交参数应该为整数!"
|
||||
elif msg == "value could not be parsed to a boolean":
|
||||
msg = f"类型错误,提交参数应该为布尔值!"
|
||||
elif msg == "Input should be a valid list":
|
||||
msg = f"类型错误,输入应该是一个有效的列表!"
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content=jsonable_encoder(
|
||||
{
|
||||
"message": msg,
|
||||
"body": exc.body,
|
||||
"code": status.HTTP_400_BAD_REQUEST
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
@app.exception_handler(ValueError)
|
||||
async def value_exception_handler(request: Request, exc: ValueError):
|
||||
"""
|
||||
捕获值异常
|
||||
"""
|
||||
if DEBUG:
|
||||
print("请求地址", request.url.__str__())
|
||||
print("捕捉到值异常:value_exception_handler")
|
||||
print(exc.__str__())
|
||||
# 打印栈信息,方便追踪排查异常
|
||||
logger.exception(exc)
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content=jsonable_encoder(
|
||||
{
|
||||
"message": exc.__str__(),
|
||||
"code": status.HTTP_400_BAD_REQUEST
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def all_exception_handler(request: Request, exc: Exception):
|
||||
"""
|
||||
捕获全部异常
|
||||
"""
|
||||
if DEBUG:
|
||||
print("请求地址", request.url.__str__())
|
||||
print("捕捉到全局异常:all_exception_handler")
|
||||
print(exc.__str__())
|
||||
# 打印栈信息,方便追踪排查异常
|
||||
logger.exception(exc)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content=jsonable_encoder(
|
||||
{
|
||||
"message": "接口异常!",
|
||||
"code": status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
}
|
||||
),
|
||||
)
|
22
core/logger.py
Normal file
22
core/logger.py
Normal file
@ -0,0 +1,22 @@
|
||||
import os
|
||||
import time
|
||||
from loguru import logger
|
||||
from application.settings import BASE_DIR
|
||||
|
||||
"""
|
||||
# 日志简单配置
|
||||
# 具体其他配置 可自行参考 https://github.com/Delgan/loguru
|
||||
"""
|
||||
|
||||
# 移除控制台输出
|
||||
logger.remove(handler_id=None)
|
||||
|
||||
log_path = os.path.join(BASE_DIR, 'logs')
|
||||
if not os.path.exists(log_path):
|
||||
os.mkdir(log_path)
|
||||
|
||||
log_path_info = os.path.join(log_path, f'info_{time.strftime("%Y-%m-%d")}.log')
|
||||
log_path_error = os.path.join(log_path, f'error_{time.strftime("%Y-%m-%d")}.log')
|
||||
|
||||
info = logger.add(log_path_info, rotation="00:00", retention="3 days", enqueue=True, encoding="UTF-8", level="INFO")
|
||||
error = logger.add(log_path_error, rotation="00:00", retention="3 days", enqueue=True, encoding="UTF-8", level="ERROR")
|
157
core/middleware.py
Normal file
157
core/middleware.py
Normal file
@ -0,0 +1,157 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @version : 1.0
|
||||
# @Create Time : 2021/10/19 15:47
|
||||
# @File : middleware.py
|
||||
# @IDE : PyCharm
|
||||
# @desc : 中间件
|
||||
|
||||
"""
|
||||
官方文档——中间件:https://fastapi.tiangolo.com/tutorial/middleware/
|
||||
官方文档——高级中间件:https://fastapi.tiangolo.com/advanced/middleware/
|
||||
"""
|
||||
import datetime
|
||||
import json
|
||||
import time
|
||||
from fastapi import Request, Response
|
||||
from core.logger import logger
|
||||
from fastapi import FastAPI
|
||||
from fastapi.routing import APIRoute
|
||||
from user_agents import parse
|
||||
from application.settings import OPERATION_RECORD_METHOD, MONGO_DB_ENABLE, IGNORE_OPERATION_FUNCTION, \
|
||||
DEMO_WHITE_LIST_PATH, DEMO, DEMO_BLACK_LIST_PATH
|
||||
from utils.response import ErrorResponse
|
||||
from apps.vadmin.record.crud import OperationRecordDal
|
||||
from core.database import mongo_getter
|
||||
from utils import status
|
||||
|
||||
|
||||
def write_request_log(request: Request, response: Response):
|
||||
http_version = f"http/{request.scope['http_version']}"
|
||||
content_length = response.raw_headers[0][1]
|
||||
process_time = response.headers["X-Process-Time"]
|
||||
content = f"basehttp.log_message: '{request.method} {request.url} {http_version}' {response.status_code}" \
|
||||
f"{response.charset} {content_length} {process_time}"
|
||||
logger.info(content)
|
||||
|
||||
|
||||
def register_request_log_middleware(app: FastAPI):
|
||||
"""
|
||||
记录请求日志中间件
|
||||
:param app:
|
||||
:return:
|
||||
"""
|
||||
|
||||
@app.middleware("http")
|
||||
async def request_log_middleware(request: Request, call_next):
|
||||
start_time = time.time()
|
||||
response = await call_next(request)
|
||||
process_time = time.time() - start_time
|
||||
response.headers["X-Process-Time"] = str(process_time)
|
||||
write_request_log(request, response)
|
||||
return response
|
||||
|
||||
|
||||
def register_operation_record_middleware(app: FastAPI):
|
||||
"""
|
||||
操作记录中间件
|
||||
用于将使用认证的操作全部记录到 mongodb 数据库中
|
||||
:param app:
|
||||
:return:
|
||||
"""
|
||||
|
||||
@app.middleware("http")
|
||||
async def operation_record_middleware(request: Request, call_next):
|
||||
start_time = time.time()
|
||||
response = await call_next(request)
|
||||
if not MONGO_DB_ENABLE:
|
||||
return response
|
||||
telephone = request.scope.get('telephone', None)
|
||||
user_id = request.scope.get('user_id', None)
|
||||
user_name = request.scope.get('user_name', None)
|
||||
route = request.scope.get('route')
|
||||
if not telephone:
|
||||
return response
|
||||
elif request.method not in OPERATION_RECORD_METHOD:
|
||||
return response
|
||||
elif route.name in IGNORE_OPERATION_FUNCTION:
|
||||
return response
|
||||
process_time = time.time() - start_time
|
||||
user_agent = parse(request.headers.get("user-agent"))
|
||||
system = f"{user_agent.os.family} {user_agent.os.version_string}"
|
||||
browser = f"{user_agent.browser.family} {user_agent.browser.version_string}"
|
||||
query_params = dict(request.query_params.multi_items())
|
||||
path_params = request.path_params
|
||||
if isinstance(request.scope.get('body'), str):
|
||||
body = request.scope.get('body')
|
||||
else:
|
||||
body = request.scope.get('body').decode()
|
||||
if body:
|
||||
body = json.loads(body)
|
||||
params = {
|
||||
"body": body,
|
||||
"query_params": query_params if query_params else None,
|
||||
"path_params": path_params if path_params else None,
|
||||
}
|
||||
content_length = response.raw_headers[0][1]
|
||||
assert isinstance(route, APIRoute)
|
||||
document = {
|
||||
"process_time": process_time,
|
||||
"telephone": telephone,
|
||||
"user_id": user_id,
|
||||
"user_name": user_name,
|
||||
"request_api": request.url.__str__(),
|
||||
"client_ip": request.client.host,
|
||||
"system": system,
|
||||
"browser": browser,
|
||||
"request_method": request.method,
|
||||
"api_path": route.path,
|
||||
"summary": route.summary,
|
||||
"description": route.description,
|
||||
"tags": route.tags,
|
||||
"route_name": route.name,
|
||||
"status_code": response.status_code,
|
||||
"content_length": content_length,
|
||||
"create_datetime": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"params": json.dumps(params)
|
||||
}
|
||||
await OperationRecordDal(mongo_getter(request)).create_data(document)
|
||||
return response
|
||||
|
||||
|
||||
def register_demo_env_middleware(app: FastAPI):
|
||||
"""
|
||||
演示环境中间件
|
||||
:param app:
|
||||
:return:
|
||||
"""
|
||||
|
||||
@app.middleware("http")
|
||||
async def demo_env_middleware(request: Request, call_next):
|
||||
path = request.scope.get("path")
|
||||
if request.method != "GET":
|
||||
print("路由:", path, request.method)
|
||||
if DEMO and request.method != "GET":
|
||||
if path in DEMO_BLACK_LIST_PATH:
|
||||
return ErrorResponse(
|
||||
status=status.HTTP_403_FORBIDDEN,
|
||||
code=status.HTTP_403_FORBIDDEN,
|
||||
msg="演示环境,禁止操作"
|
||||
)
|
||||
elif path not in DEMO_WHITE_LIST_PATH:
|
||||
return ErrorResponse(msg="演示环境,禁止操作")
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
def register_jwt_refresh_middleware(app: FastAPI):
|
||||
"""
|
||||
JWT刷新中间件
|
||||
:param app:
|
||||
:return:
|
||||
"""
|
||||
|
||||
@app.middleware("http")
|
||||
async def jwt_refresh_middleware(request: Request, call_next):
|
||||
response = await call_next(request)
|
||||
refresh = request.scope.get('if-refresh', 0)
|
||||
response.headers["if-refresh"] = str(refresh)
|
||||
return response
|
177
core/mongo_manage.py
Normal file
177
core/mongo_manage.py
Normal file
@ -0,0 +1,177 @@
|
||||
import datetime
|
||||
import json
|
||||
from typing import Any
|
||||
from bson import ObjectId
|
||||
from bson.errors import InvalidId
|
||||
from bson.json_util import dumps
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from motor.motor_asyncio import AsyncIOMotorDatabase
|
||||
from pymongo.results import InsertOneResult, UpdateResult
|
||||
from core.exception import CustomException
|
||||
from utils import status
|
||||
|
||||
|
||||
class MongoManage:
|
||||
"""
|
||||
mongodb 数据库管理器
|
||||
博客:https://www.cnblogs.com/aduner/p/13532504.html
|
||||
mongodb 官网:https://www.mongodb.com/docs/drivers/motor/
|
||||
motor 文档:https://motor.readthedocs.io/en/stable/
|
||||
"""
|
||||
|
||||
# 倒叙
|
||||
ORDER_FIELD = ["desc", "descending"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db: AsyncIOMotorDatabase = None,
|
||||
collection: str = None,
|
||||
schema: Any = None,
|
||||
is_object_id: bool = True
|
||||
):
|
||||
"""
|
||||
初始化
|
||||
:param db:
|
||||
:param collection: 集合
|
||||
:param schema:
|
||||
:param is_object_id: _id 列是否为 ObjectId 格式
|
||||
"""
|
||||
self.db = db
|
||||
self.collection = db[collection] if collection else None
|
||||
self.schema = schema
|
||||
self.is_object_id = is_object_id
|
||||
|
||||
async def get_data(
|
||||
self,
|
||||
_id: str = None,
|
||||
v_return_none: bool = False,
|
||||
v_schema: Any = None,
|
||||
**kwargs
|
||||
) -> dict | None:
|
||||
"""
|
||||
获取单个数据,默认使用 ID 查询,否则使用关键词查询
|
||||
:param _id: 数据 ID
|
||||
:param v_return_none: 是否返回空 None,否则抛出异常,默认抛出异常
|
||||
:param v_schema: 指定使用的序列化对象
|
||||
"""
|
||||
if _id and self.is_object_id:
|
||||
kwargs["_id"] = ObjectId(_id)
|
||||
params = self.filter_condition(**kwargs)
|
||||
data = await self.collection.find_one(params)
|
||||
if not data and v_return_none:
|
||||
return None
|
||||
elif not data:
|
||||
raise CustomException("查找失败,未查找到对应数据", code=status.HTTP_404_NOT_FOUND)
|
||||
elif data and v_schema:
|
||||
return jsonable_encoder(v_schema(**data))
|
||||
return data
|
||||
|
||||
async def create_data(self, data: dict | Any) -> InsertOneResult:
|
||||
"""
|
||||
创建数据
|
||||
"""
|
||||
if not isinstance(data, dict):
|
||||
data = jsonable_encoder(data)
|
||||
data['create_datetime'] = datetime.datetime.now()
|
||||
data['update_datetime'] = datetime.datetime.now()
|
||||
result = await self.collection.insert_one(data)
|
||||
# 判断插入是否成功
|
||||
if result.acknowledged:
|
||||
return result
|
||||
else:
|
||||
raise CustomException("创建新数据失败", code=status.HTTP_ERROR)
|
||||
|
||||
async def put_data(self, _id: str, data: dict | Any) -> UpdateResult:
|
||||
"""
|
||||
更新数据
|
||||
"""
|
||||
if not isinstance(data, dict):
|
||||
data = jsonable_encoder(data)
|
||||
new_data = {'$set': data}
|
||||
result = await self.collection.update_one({'_id': ObjectId(_id) if self.is_object_id else _id}, new_data)
|
||||
|
||||
if result.matched_count > 0:
|
||||
return result
|
||||
else:
|
||||
raise CustomException("更新失败,未查找到对应数据", code=status.HTTP_404_NOT_FOUND)
|
||||
|
||||
async def delete_data(self, _id: str):
|
||||
"""
|
||||
删除数据
|
||||
"""
|
||||
result = await self.collection.delete_one({'_id': ObjectId(_id) if self.is_object_id else _id})
|
||||
|
||||
if result.deleted_count > 0:
|
||||
return True
|
||||
else:
|
||||
raise CustomException("删除失败,未查找到对应数据", code=status.HTTP_404_NOT_FOUND)
|
||||
|
||||
async def get_datas(
|
||||
self,
|
||||
page: int = 1,
|
||||
limit: int = 10,
|
||||
v_schema: Any = None,
|
||||
v_order: str = None,
|
||||
v_order_field: str = None,
|
||||
v_return_objs: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
使用 find() 要查询的一组文档。 find() 没有I / O,也不需要 await 表达式。它只是创建一个 AsyncIOMotorCursor 实例
|
||||
当您调用 to_list() 或为循环执行异步时 (async for) ,查询实际上是在服务器上执行的。
|
||||
"""
|
||||
|
||||
params = self.filter_condition(**kwargs)
|
||||
cursor = self.collection.find(params)
|
||||
|
||||
if v_order or v_order_field:
|
||||
v_order_field = v_order_field if v_order_field else 'create_datetime'
|
||||
v_order = -1 if v_order in self.ORDER_FIELD else 1
|
||||
cursor.sort(v_order_field, v_order)
|
||||
|
||||
if limit != 0:
|
||||
# 对查询应用排序(sort),跳过(skip)或限制(limit)
|
||||
cursor.skip((page - 1) * limit).limit(limit)
|
||||
|
||||
datas = []
|
||||
async for row in cursor:
|
||||
data = json.loads(dumps(row))
|
||||
datas.append(data)
|
||||
|
||||
if not datas or v_return_objs:
|
||||
return datas
|
||||
elif v_schema:
|
||||
datas = [jsonable_encoder(v_schema(**data)) for data in datas]
|
||||
elif self.schema:
|
||||
datas = [jsonable_encoder(self.schema(**data)) for data in datas]
|
||||
return datas
|
||||
|
||||
async def get_count(self, **kwargs) -> int:
|
||||
"""
|
||||
获取统计数据
|
||||
"""
|
||||
params = self.filter_condition(**kwargs)
|
||||
return await self.collection.count_documents(params)
|
||||
|
||||
@classmethod
|
||||
def filter_condition(cls, **kwargs):
|
||||
"""
|
||||
过滤条件
|
||||
"""
|
||||
params = {}
|
||||
for k, v in kwargs.items():
|
||||
if not v:
|
||||
continue
|
||||
elif isinstance(v, tuple):
|
||||
if v[0] == "like" and v[1]:
|
||||
params[k] = {'$regex': v[1]}
|
||||
elif v[0] == "between" and len(v[1]) == 2:
|
||||
params[k] = {'$gte': f"{v[1][0]} 00:00:00", '$lt': f"{v[1][1]} 23:59:59"}
|
||||
elif v[0] == "ObjectId" and v[1]:
|
||||
try:
|
||||
params[k] = ObjectId(v[1])
|
||||
except InvalidId:
|
||||
raise CustomException("任务编号格式不正确!")
|
||||
else:
|
||||
params[k] = v
|
||||
return params
|
51
core/validator.py
Normal file
51
core/validator.py
Normal file
@ -0,0 +1,51 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @version : 1.0
|
||||
# @Create Time : 2021/10/18 22:19
|
||||
# @File : validator.py
|
||||
# @IDE : PyCharm
|
||||
# @desc : pydantic 模型重用验证器
|
||||
|
||||
"""
|
||||
官方文档:https://pydantic-docs.helpmanual.io/usage/validators/#reuse-validators
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
|
||||
def vali_telephone(value: str) -> str:
|
||||
"""
|
||||
手机号验证器
|
||||
:param value: 手机号
|
||||
:return: 手机号
|
||||
"""
|
||||
if not value or len(value) != 11 or not value.isdigit():
|
||||
raise ValueError("请输入正确手机号")
|
||||
|
||||
regex = r'^1(3\d|4[4-9]|5[0-35-9]|6[67]|7[013-8]|8[0-9]|9[0-9])\d{8}$'
|
||||
|
||||
if not re.match(regex, value):
|
||||
raise ValueError("请输入正确手机号")
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def vali_email(value: str) -> str:
|
||||
"""
|
||||
邮箱地址验证器
|
||||
:param value: 邮箱
|
||||
:return: 邮箱
|
||||
"""
|
||||
if not value:
|
||||
raise ValueError("请输入邮箱地址")
|
||||
|
||||
regex = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
|
||||
|
||||
if not re.match(regex, value):
|
||||
raise ValueError("请输入正确邮箱地址")
|
||||
|
||||
return value
|
||||
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user