RODY/app/core/common_utils.py
552068321@qq.com 6f7de660aa first commit
2022-11-04 17:37:08 +08:00

320 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import orjson
try:
from collections import Mapping
except: # noqa E722
from collections.abc import Mapping
import inspect
import importlib
import os
import re
import sys
from typing import Any, Dict, List, Optional
from unicodedata import normalize
from distutils.version import LooseVersion
import logging
import datetime
from flask import request
from flask_sqlalchemy import model
from sqlalchemy import UniqueConstraint
import marshmallow
from marshmallow import Schema
from .webargs import use_kwargs as base_use_kwargs, parser
from flask.json import JSONEncoder
logger = logging.getLogger(__name__)
class ParamsDict(dict):
"""Just available update func.
Example::
@use_kwargs(PageParams.update({...}))
def list_users(page, page_size, order_by):
pass
"""
def update(self, other=None):
"""Update self by other Mapping and return self.
"""
ret = ParamsDict(self.copy())
if other is not None:
for k, v in other.items() if isinstance(other, Mapping) else other:
ret[k] = v
return ret
# Function version
def row2dict(row):
return {
c.name: str(getattr(row, c.name))
for c in row.__table__.columns
}
class dict2object(dict):
"""
Dict to fake object that can use getattr.
"""
def __getattr__(self, name: str) -> Any:
if name in self.keys():
return self[name]
raise AttributeError('object has no attribute {}'.format(name))
def __setattr__(self, name: str, value: Any) -> None:
if not isinstance(name, str):
raise TypeError('key must be string type.')
self[name] = value
def secure_filename(filename: str) -> str:
"""Borrowed from werkzeug.utils.secure_filename.
Pass it a filename and it will return a secure version of it. This
filename can then safely be stored on a regular file system and passed
to :func:`os.path.join`.
On windows systems the function also makes sure that the file is not
named after one of the special device files.
>>> secure_filename(u'哈哈.zip')
'哈哈.zip'
>>> secure_filename('My cool movie.mov')
'My_cool_movie.mov'
>>> secure_filename('../../../etc/passwd')
'etc_passwd'
>>> secure_filename(u'i contain cool \xfcml\xe4uts.txt')
'i_contain_cool_umlauts.txt'
"""
for sep in os.path.sep, os.path.altsep:
if sep:
filename = filename.replace(sep, ' ')
filename = normalize('NFKD', '_'.join(filename.split()))
filename_strip_re = re.compile(u'[^A-Za-z0-9\u4e00-\u9fa5_.-]')
filename = filename_strip_re.sub('', filename).strip('._')
# on nt a couple of special files are present in each folder. We
# have to ensure that the target file is not such a filename. In
# this case we prepend an underline
windows_device_files = (
'CON', 'AUX', 'COM1', 'COM2', 'COM3', 'COM4', 'LPT1',
'LPT2', 'LPT3', 'PRN', 'NUL',
)
if os.name == 'nt' and filename and \
filename.split('.')[0].upper() in windows_device_files:
filename = '_' + filename
return filename
def _get_init_args(instance, base_class):
"""Get instance's __init__ args and it's value when __call__.
"""
getargspec = inspect.getfullargspec
argspec = getargspec(base_class.__init__)
defaults = argspec.defaults
kwargs = {}
if defaults is not None:
no_defaults = argspec.args[:-len(defaults)]
has_defaults = argspec.args[-len(defaults):]
kwargs = {k: getattr(instance, k) for k in no_defaults
if k != 'self' and hasattr(instance, k)}
kwargs.update({k: getattr(instance, k) if hasattr(instance, k) else
getattr(instance, k, defaults[i])
for i, k in enumerate(has_defaults)})
assert len(kwargs) == len(argspec.args) - 1, 'exclude `self`'
return kwargs
def use_kwargs(argmap, schema_kwargs: Optional[Dict] = None, **kwargs: Any):
"""For fix ``Schema(partial=True)`` not work when used with
``@webargs.flaskparser.use_kwargs``. More details ``see webargs.core``.
Args:
argmap (marshmallow.Schema,dict,callable): Either a
`marshmallow.Schema`, `dict` of argname ->
`marshmallow.fields.Field` pairs, or a callable that returns a
`marshmallow.Schema` instance.
schema_kwargs (dict): kwargs for argmap.
Returns:
dict: A dictionary of parsed arguments.
"""
schema_kwargs = schema_kwargs or {}
argmap = parser._get_schema(argmap, request)
if not (argmap.partial or schema_kwargs.get('partial')):
return base_use_kwargs(argmap, **kwargs)
def factory(request):
argmap_kwargs = _get_init_args(argmap, Schema)
argmap_kwargs.update(schema_kwargs)
# force set force_all=False
only = parser.parse(argmap, request).keys()
argmap_kwargs.update({
'partial': False, # fix missing=None not work
'only': only or None,
'context': {"request": request},
})
if tuple(LooseVersion(marshmallow.__version__).version)[0] < 3:
argmap_kwargs['strict'] = True
return argmap.__class__(**argmap_kwargs)
return base_use_kwargs(factory, **kwargs)
def import_subs(locals_, modules_only: bool = False) -> List[str]:
""" Auto import submodules, used in __init__.py.
Args:
locals_: `locals()`.
modules_only: Only collect modules to __all__.
Examples::
# app/models/__init__.py
from hobbit_core.utils import import_subs
__all__ = import_subs(locals())
Auto collect Model's subclass, Schema's subclass and instance.
Others objects must defined in submodule.__all__.
"""
package = locals_['__package__']
path = locals_['__path__']
top_mudule = sys.modules[package]
all_ = []
for name in os.listdir(path[0]):
if not name.endswith(('.py', '.pyc')) or name.startswith('__init__.'):
continue
module_name = name.split('.')[0]
submodule = importlib.import_module(f".{module_name}", package)
all_.append(module_name)
if modules_only:
continue
if hasattr(submodule, '__all__'):
for name in getattr(submodule, '__all__'):
if not isinstance(name, str):
raise Exception(f'Invalid object {name} in __all__, '
f'must contain only strings.')
setattr(top_mudule, name, getattr(submodule, name))
all_.append(name)
else:
for name, obj in submodule.__dict__.items():
if isinstance(obj, (model.DefaultMeta, Schema)) or \
(inspect.isclass(obj) and
(issubclass(obj, Schema) or
obj.__name__.endswith('Service'))):
setattr(top_mudule, name, obj)
all_.append(name)
return all_
def bulk_create_or_update_on_duplicate(
db, model_cls, items, updated_at='updated_at', batch_size=500):
""" Support MySQL and postgreSQL.
https://dev.mysql.com/doc/refman/8.0/en/insert-on-duplicate.html
Args:
db: Instance of `SQLAlchemy`.
model_cls: Model object.
items: List of data,[ example: `[{key: value}, {key: value}, ...]`.
updated_at: Field which recording row update time.
batch_size: Batch size is max rows per execute.
Returns:
dict: A dictionary contains rowcount and items_count.
"""
if not items:
logger.warning("bulk_create_or_update_on_duplicate save to "
f"{model_cls} failed, empty items")
return {'rowcount': 0, 'items_count': 0}
items_count = len(items)
table_name = model_cls.__tablename__
fields = list(items[0].keys())
unique_keys = [c.name for i in model_cls.__table_args__ if isinstance(
i, UniqueConstraint) for c in i]
columns = [c.name for c in model_cls.__table__.columns if c.name not in (
'id', 'created_at')]
if updated_at in columns and updated_at not in fields:
fields.append(updated_at)
updated_at_time = datetime.datetime.now()
for item in items:
item[updated_at] = updated_at_time
assert set(fields) == set(columns), \
'item fields not equal to columns in modelsnew: ' + \
f'{set(fields) - set(columns)}, delete: {set(columns) - set(fields)}'
for item in items:
for column in unique_keys:
if column in item and item[column] is None:
item[column] = ''
engine = db.get_engine(bind=getattr(model_cls, '__bind_key__', None))
if engine.name == 'postgresql':
sql_on_update = ', '.join([
f' {field} = excluded.{field}'
for field in fields if field not in unique_keys])
sql = f"""INSERT INTO {table_name} ({", ".join(fields)}) VALUES
({", ".join([f':{key}' for key in fields])})
ON CONFLICT ({", ".join(unique_keys)}) DO UPDATE SET
{sql_on_update}"""
elif engine.name == 'mysql':
sql_on_update = '`, `'.join([
f' `{field}` = new.{field}' for field in fields
if field not in unique_keys])
sql = f"""INSERT INTO {table_name} (`{"`, `".join(fields)}`) VALUES
({", ".join([f':{key}' for key in fields])}) AS new
ON DUPLICATE KEY UPDATE
{sql_on_update}"""
else:
raise Exception(f'not support db: {engine.name}')
rowcounts = 0
while len(items) > 0:
batch, items = items[:batch_size], items[batch_size:]
try:
result = db.session.execute(sql, batch, bind=engine)
except Exception as e:
logger.error(e, exc_info=True)
logger.info(sql)
raise e
rowcounts += result.rowcount
logger.info(f'{model_cls} save_data: rowcount={rowcounts}, '
f'items_count: {items_count}')
return {'rowcount': rowcounts, 'items_count': items_count}
def orjson_serializer(obj):
"""
Note that `orjson.dumps()` return byte array, while sqlalchemy expects string, thus `decode()` call.
"""
return orjson.dumps(obj, option=orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_NAIVE_UTC).decode()