138 lines
4.0 KiB
Python
138 lines
4.0 KiB
Python
import logging
|
|
import time
|
|
from traceback import format_exc
|
|
import struct
|
|
import sqlalchemy
|
|
from sqlalchemy import text
|
|
from fastapi_sqlalchemy import db
|
|
from jinja2 import Template
|
|
from pymysql import err
|
|
from pymysql.constants import FIELD_TYPE
|
|
from pymysql.converters import conversions, convert_time
|
|
|
|
__author__ = 'Woodstock'
|
|
|
|
logger = logging.getLogger("uvicorn.error")
|
|
|
|
engine = None
|
|
|
|
conversions[FIELD_TYPE.BIT] = lambda b: struct.unpack(">Q", (bytes([0x00]) * (8 - len(b)) + b))[0]
|
|
conversions[FIELD_TYPE.TIME] = convert_time
|
|
conversions[FIELD_TYPE.DATETIME] = str
|
|
|
|
|
|
class EDatabase(Exception):
|
|
def __init__(self, *args):
|
|
super(EDatabase, self).__init__(*args)
|
|
|
|
def result(self):
|
|
return None
|
|
|
|
|
|
class DB:
|
|
CACHE = None
|
|
ACCESS_TOKEN = None
|
|
IS_LOCAL_TEST = False
|
|
version = sqlalchemy.__version__
|
|
|
|
def __init__(self, session=None):
|
|
"""
|
|
传入session时整个db对象使用同一session操作
|
|
"""
|
|
self.session = session
|
|
self.result = None
|
|
self.rowcount = -1
|
|
self.lastrowid = -1
|
|
|
|
def execute(self, sql, **kwargs):
|
|
if "{" in sql:
|
|
sql_template = Template(sql)
|
|
temp_sql = sql_template.render(**kwargs)
|
|
else:
|
|
temp_sql = sql
|
|
temp_sql = "\n".join([l.strip() for l in temp_sql.split("\n")])
|
|
|
|
# 用于输出日志
|
|
log_string = f"执行SQL: {temp_sql.strip()}\n参数: {str(kwargs)[:400]}"
|
|
t = time.time()
|
|
try:
|
|
pars = dict(**kwargs)
|
|
self.result = self.session.execute(text(temp_sql), params=pars)
|
|
except (err.OperationalError, err.InternalError, err.InternalError) as e:
|
|
orig = e
|
|
if not isinstance(orig, err.InternalError):
|
|
orig = orig.orig
|
|
if orig.args[0] in [111, 1205, 1213, 2003, 2013]:
|
|
raise
|
|
if orig.args[0] == 1044:
|
|
logger.error(f"(pymysql.err.OperationalError) (1044, Access denied for {self.session.bind.url})")
|
|
logger.error(format_exc())
|
|
raise
|
|
log_string += f"\n执行结果:\n 执行失败了 ,错误信息:{e}"
|
|
logger.error(log_string)
|
|
raise
|
|
except Exception as e:
|
|
log_string += f"\n执行结果:\n 执行失败了 ,错误信息:{e}"
|
|
logger.error(log_string)
|
|
raise
|
|
finally:
|
|
logger.info(f"{log_string}\n耗时:{time.time() - t}")
|
|
# 受影响行数
|
|
self.rowcount = self.result.rowcount
|
|
# 插入数据时返回的主键数据
|
|
self.lastrowid = self.result.lastrowid
|
|
logger.debug(f"执行结果:\n 受影响行数:{self.rowcount}")
|
|
return self
|
|
|
|
def all(self):
|
|
# 获取查询列表。把ResultProxy和RowProxy类型封装成python的list和dict类型
|
|
if self.version.startswith('1.4'):
|
|
data = [i._asdict() for i in self.result]
|
|
else:
|
|
data = [dict(zip(i.keys(), i.values())) for i in self.result]
|
|
self.result.close()
|
|
return data
|
|
|
|
def scalars(self):
|
|
"""
|
|
返回首个值的集合
|
|
:return:
|
|
"""
|
|
data = set(i[0] for i in self.result)
|
|
self.result.close()
|
|
return data
|
|
|
|
def scalar_list(self):
|
|
"""
|
|
获取首个值的列表
|
|
:return:
|
|
"""
|
|
data = [i[0] for i in self.result]
|
|
self.result.close()
|
|
return data
|
|
|
|
def first(self):
|
|
row = None
|
|
# 获取第一行数据
|
|
for i in self.result:
|
|
if self.version.startswith('1.4'):
|
|
row = i._asdict()
|
|
else:
|
|
row = dict(zip(i.keys(), i.values()))
|
|
break
|
|
self.result.close()
|
|
return row
|
|
|
|
def scalar(self):
|
|
"""获取第一行第一列的数据"""
|
|
one = None
|
|
for i in self.result:
|
|
one = i[0]
|
|
break
|
|
self.result.close()
|
|
return one
|
|
|
|
|
|
def execute_sql(sql, **kwargs):
|
|
db_instance = DB(db.session)
|
|
return db_instance.execute(sql, **kwargs) |