first commit
This commit is contained in:
319
app/core/common_utils.py
Normal file
319
app/core/common_utils.py
Normal file
@ -0,0 +1,319 @@
|
||||
import orjson
|
||||
|
||||
try:
|
||||
from collections import Mapping
|
||||
except: # noqa E722
|
||||
from collections.abc import Mapping
|
||||
import inspect
|
||||
import importlib
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from typing import Any, Dict, List, Optional
|
||||
from unicodedata import normalize
|
||||
from distutils.version import LooseVersion
|
||||
import logging
|
||||
import datetime
|
||||
|
||||
from flask import request
|
||||
from flask_sqlalchemy import model
|
||||
from sqlalchemy import UniqueConstraint
|
||||
import marshmallow
|
||||
from marshmallow import Schema
|
||||
|
||||
from .webargs import use_kwargs as base_use_kwargs, parser
|
||||
|
||||
from flask.json import JSONEncoder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ParamsDict(dict):
|
||||
"""Just available update func.
|
||||
|
||||
Example::
|
||||
|
||||
@use_kwargs(PageParams.update({...}))
|
||||
def list_users(page, page_size, order_by):
|
||||
pass
|
||||
|
||||
"""
|
||||
|
||||
def update(self, other=None):
|
||||
"""Update self by other Mapping and return self.
|
||||
"""
|
||||
ret = ParamsDict(self.copy())
|
||||
if other is not None:
|
||||
for k, v in other.items() if isinstance(other, Mapping) else other:
|
||||
ret[k] = v
|
||||
return ret
|
||||
|
||||
|
||||
# Function version
|
||||
def row2dict(row):
|
||||
return {
|
||||
c.name: str(getattr(row, c.name))
|
||||
for c in row.__table__.columns
|
||||
}
|
||||
|
||||
|
||||
class dict2object(dict):
|
||||
"""
|
||||
Dict to fake object that can use getattr.
|
||||
"""
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
if name in self.keys():
|
||||
return self[name]
|
||||
raise AttributeError('object has no attribute {}'.format(name))
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
if not isinstance(name, str):
|
||||
raise TypeError('key must be string type.')
|
||||
self[name] = value
|
||||
|
||||
|
||||
def secure_filename(filename: str) -> str:
|
||||
"""Borrowed from werkzeug.utils.secure_filename.
|
||||
|
||||
Pass it a filename and it will return a secure version of it. This
|
||||
filename can then safely be stored on a regular file system and passed
|
||||
to :func:`os.path.join`.
|
||||
|
||||
On windows systems the function also makes sure that the file is not
|
||||
named after one of the special device files.
|
||||
|
||||
>>> secure_filename(u'哈哈.zip')
|
||||
'哈哈.zip'
|
||||
>>> secure_filename('My cool movie.mov')
|
||||
'My_cool_movie.mov'
|
||||
>>> secure_filename('../../../etc/passwd')
|
||||
'etc_passwd'
|
||||
>>> secure_filename(u'i contain cool \xfcml\xe4uts.txt')
|
||||
'i_contain_cool_umlauts.txt'
|
||||
|
||||
"""
|
||||
|
||||
for sep in os.path.sep, os.path.altsep:
|
||||
if sep:
|
||||
filename = filename.replace(sep, ' ')
|
||||
|
||||
filename = normalize('NFKD', '_'.join(filename.split()))
|
||||
|
||||
filename_strip_re = re.compile(u'[^A-Za-z0-9\u4e00-\u9fa5_.-]')
|
||||
filename = filename_strip_re.sub('', filename).strip('._')
|
||||
|
||||
# on nt a couple of special files are present in each folder. We
|
||||
# have to ensure that the target file is not such a filename. In
|
||||
# this case we prepend an underline
|
||||
windows_device_files = (
|
||||
'CON', 'AUX', 'COM1', 'COM2', 'COM3', 'COM4', 'LPT1',
|
||||
'LPT2', 'LPT3', 'PRN', 'NUL',
|
||||
)
|
||||
if os.name == 'nt' and filename and \
|
||||
filename.split('.')[0].upper() in windows_device_files:
|
||||
filename = '_' + filename
|
||||
|
||||
return filename
|
||||
|
||||
|
||||
def _get_init_args(instance, base_class):
|
||||
"""Get instance's __init__ args and it's value when __call__.
|
||||
"""
|
||||
getargspec = inspect.getfullargspec
|
||||
argspec = getargspec(base_class.__init__)
|
||||
|
||||
defaults = argspec.defaults
|
||||
kwargs = {}
|
||||
if defaults is not None:
|
||||
no_defaults = argspec.args[:-len(defaults)]
|
||||
has_defaults = argspec.args[-len(defaults):]
|
||||
|
||||
kwargs = {k: getattr(instance, k) for k in no_defaults
|
||||
if k != 'self' and hasattr(instance, k)}
|
||||
kwargs.update({k: getattr(instance, k) if hasattr(instance, k) else
|
||||
getattr(instance, k, defaults[i])
|
||||
for i, k in enumerate(has_defaults)})
|
||||
assert len(kwargs) == len(argspec.args) - 1, 'exclude `self`'
|
||||
return kwargs
|
||||
|
||||
|
||||
def use_kwargs(argmap, schema_kwargs: Optional[Dict] = None, **kwargs: Any):
|
||||
"""For fix ``Schema(partial=True)`` not work when used with
|
||||
``@webargs.flaskparser.use_kwargs``. More details ``see webargs.core``.
|
||||
|
||||
Args:
|
||||
|
||||
argmap (marshmallow.Schema,dict,callable): Either a
|
||||
`marshmallow.Schema`, `dict` of argname ->
|
||||
`marshmallow.fields.Field` pairs, or a callable that returns a
|
||||
`marshmallow.Schema` instance.
|
||||
schema_kwargs (dict): kwargs for argmap.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary of parsed arguments.
|
||||
|
||||
"""
|
||||
schema_kwargs = schema_kwargs or {}
|
||||
|
||||
argmap = parser._get_schema(argmap, request)
|
||||
|
||||
if not (argmap.partial or schema_kwargs.get('partial')):
|
||||
return base_use_kwargs(argmap, **kwargs)
|
||||
|
||||
def factory(request):
|
||||
argmap_kwargs = _get_init_args(argmap, Schema)
|
||||
argmap_kwargs.update(schema_kwargs)
|
||||
|
||||
# force set force_all=False
|
||||
only = parser.parse(argmap, request).keys()
|
||||
|
||||
argmap_kwargs.update({
|
||||
'partial': False, # fix missing=None not work
|
||||
'only': only or None,
|
||||
'context': {"request": request},
|
||||
})
|
||||
if tuple(LooseVersion(marshmallow.__version__).version)[0] < 3:
|
||||
argmap_kwargs['strict'] = True
|
||||
|
||||
return argmap.__class__(**argmap_kwargs)
|
||||
|
||||
return base_use_kwargs(factory, **kwargs)
|
||||
|
||||
|
||||
def import_subs(locals_, modules_only: bool = False) -> List[str]:
|
||||
""" Auto import submodules, used in __init__.py.
|
||||
|
||||
Args:
|
||||
|
||||
locals_: `locals()`.
|
||||
modules_only: Only collect modules to __all__.
|
||||
|
||||
Examples::
|
||||
|
||||
# app/models/__init__.py
|
||||
from hobbit_core.utils import import_subs
|
||||
|
||||
__all__ = import_subs(locals())
|
||||
|
||||
Auto collect Model's subclass, Schema's subclass and instance.
|
||||
Others objects must defined in submodule.__all__.
|
||||
"""
|
||||
package = locals_['__package__']
|
||||
path = locals_['__path__']
|
||||
top_mudule = sys.modules[package]
|
||||
|
||||
all_ = []
|
||||
for name in os.listdir(path[0]):
|
||||
if not name.endswith(('.py', '.pyc')) or name.startswith('__init__.'):
|
||||
continue
|
||||
|
||||
module_name = name.split('.')[0]
|
||||
submodule = importlib.import_module(f".{module_name}", package)
|
||||
all_.append(module_name)
|
||||
|
||||
if modules_only:
|
||||
continue
|
||||
|
||||
if hasattr(submodule, '__all__'):
|
||||
for name in getattr(submodule, '__all__'):
|
||||
if not isinstance(name, str):
|
||||
raise Exception(f'Invalid object {name} in __all__, '
|
||||
f'must contain only strings.')
|
||||
setattr(top_mudule, name, getattr(submodule, name))
|
||||
all_.append(name)
|
||||
else:
|
||||
for name, obj in submodule.__dict__.items():
|
||||
if isinstance(obj, (model.DefaultMeta, Schema)) or \
|
||||
(inspect.isclass(obj) and
|
||||
(issubclass(obj, Schema) or
|
||||
obj.__name__.endswith('Service'))):
|
||||
setattr(top_mudule, name, obj)
|
||||
all_.append(name)
|
||||
return all_
|
||||
|
||||
|
||||
def bulk_create_or_update_on_duplicate(
|
||||
db, model_cls, items, updated_at='updated_at', batch_size=500):
|
||||
""" Support MySQL and postgreSQL.
|
||||
https://dev.mysql.com/doc/refman/8.0/en/insert-on-duplicate.html
|
||||
|
||||
Args:
|
||||
|
||||
db: Instance of `SQLAlchemy`.
|
||||
model_cls: Model object.
|
||||
items: List of data,[ example: `[{key: value}, {key: value}, ...]`.
|
||||
updated_at: Field which recording row update time.
|
||||
batch_size: Batch size is max rows per execute.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary contains rowcount and items_count.
|
||||
"""
|
||||
if not items:
|
||||
logger.warning("bulk_create_or_update_on_duplicate save to "
|
||||
f"{model_cls} failed, empty items")
|
||||
return {'rowcount': 0, 'items_count': 0}
|
||||
|
||||
items_count = len(items)
|
||||
table_name = model_cls.__tablename__
|
||||
fields = list(items[0].keys())
|
||||
unique_keys = [c.name for i in model_cls.__table_args__ if isinstance(
|
||||
i, UniqueConstraint) for c in i]
|
||||
columns = [c.name for c in model_cls.__table__.columns if c.name not in (
|
||||
'id', 'created_at')]
|
||||
|
||||
if updated_at in columns and updated_at not in fields:
|
||||
fields.append(updated_at)
|
||||
updated_at_time = datetime.datetime.now()
|
||||
for item in items:
|
||||
item[updated_at] = updated_at_time
|
||||
|
||||
assert set(fields) == set(columns), \
|
||||
'item fields not equal to columns in models:new: ' + \
|
||||
f'{set(fields) - set(columns)}, delete: {set(columns) - set(fields)}'
|
||||
|
||||
for item in items:
|
||||
for column in unique_keys:
|
||||
if column in item and item[column] is None:
|
||||
item[column] = ''
|
||||
|
||||
engine = db.get_engine(bind=getattr(model_cls, '__bind_key__', None))
|
||||
if engine.name == 'postgresql':
|
||||
sql_on_update = ', '.join([
|
||||
f' {field} = excluded.{field}'
|
||||
for field in fields if field not in unique_keys])
|
||||
sql = f"""INSERT INTO {table_name} ({", ".join(fields)}) VALUES
|
||||
({", ".join([f':{key}' for key in fields])})
|
||||
ON CONFLICT ({", ".join(unique_keys)}) DO UPDATE SET
|
||||
{sql_on_update}"""
|
||||
elif engine.name == 'mysql':
|
||||
sql_on_update = '`, `'.join([
|
||||
f' `{field}` = new.{field}' for field in fields
|
||||
if field not in unique_keys])
|
||||
sql = f"""INSERT INTO {table_name} (`{"`, `".join(fields)}`) VALUES
|
||||
({", ".join([f':{key}' for key in fields])}) AS new
|
||||
ON DUPLICATE KEY UPDATE
|
||||
{sql_on_update}"""
|
||||
else:
|
||||
raise Exception(f'not support db: {engine.name}')
|
||||
|
||||
rowcounts = 0
|
||||
while len(items) > 0:
|
||||
batch, items = items[:batch_size], items[batch_size:]
|
||||
try:
|
||||
result = db.session.execute(sql, batch, bind=engine)
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
logger.info(sql)
|
||||
raise e
|
||||
rowcounts += result.rowcount
|
||||
logger.info(f'{model_cls} save_data: rowcount={rowcounts}, '
|
||||
f'items_count: {items_count}')
|
||||
return {'rowcount': rowcounts, 'items_count': items_count}
|
||||
|
||||
|
||||
def orjson_serializer(obj):
|
||||
"""
|
||||
Note that `orjson.dumps()` return byte array, while sqlalchemy expects string, thus `decode()` call.
|
||||
"""
|
||||
return orjson.dumps(obj, option=orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_NAIVE_UTC).decode()
|
Reference in New Issue
Block a user