fastapi-template/src/service/__init__.py

138 lines
4.0 KiB
Python

import logging
import time
from traceback import format_exc
import struct
import sqlalchemy
from sqlalchemy import text
from fastapi_sqlalchemy import db
from jinja2 import Template
from pymysql import err
from pymysql.constants import FIELD_TYPE
from pymysql.converters import conversions, convert_time
__author__ = 'Woodstock'
logger = logging.getLogger("uvicorn.error")
engine = None
conversions[FIELD_TYPE.BIT] = lambda b: struct.unpack(">Q", (bytes([0x00]) * (8 - len(b)) + b))[0]
conversions[FIELD_TYPE.TIME] = convert_time
conversions[FIELD_TYPE.DATETIME] = str
class EDatabase(Exception):
def __init__(self, *args):
super(EDatabase, self).__init__(*args)
def result(self):
return None
class DB:
CACHE = None
ACCESS_TOKEN = None
IS_LOCAL_TEST = False
version = sqlalchemy.__version__
def __init__(self, session=None):
"""
传入session时整个db对象使用同一session操作
"""
self.session = session
self.result = None
self.rowcount = -1
self.lastrowid = -1
def execute(self, sql, **kwargs):
if "{" in sql:
sql_template = Template(sql)
temp_sql = sql_template.render(**kwargs)
else:
temp_sql = sql
temp_sql = "\n".join([l.strip() for l in temp_sql.split("\n")])
# 用于输出日志
log_string = f"执行SQL: {temp_sql.strip()}\n参数: {str(kwargs)[:400]}"
t = time.time()
try:
pars = dict(**kwargs)
self.result = self.session.execute(text(temp_sql), params=pars)
except (err.OperationalError, err.InternalError, err.InternalError) as e:
orig = e
if not isinstance(orig, err.InternalError):
orig = orig.orig
if orig.args[0] in [111, 1205, 1213, 2003, 2013]:
raise
if orig.args[0] == 1044:
logger.error(f"(pymysql.err.OperationalError) (1044, Access denied for {self.session.bind.url})")
logger.error(format_exc())
raise
log_string += f"\n执行结果:\n 执行失败了 ,错误信息:{e}"
logger.error(log_string)
raise
except Exception as e:
log_string += f"\n执行结果:\n 执行失败了 ,错误信息:{e}"
logger.error(log_string)
raise
finally:
logger.info(f"{log_string}\n耗时:{time.time() - t}")
# 受影响行数
self.rowcount = self.result.rowcount
# 插入数据时返回的主键数据
self.lastrowid = self.result.lastrowid
logger.debug(f"执行结果:\n 受影响行数:{self.rowcount}")
return self
def all(self):
# 获取查询列表。把ResultProxy和RowProxy类型封装成python的list和dict类型
if self.version.startswith('1.4'):
data = [i._asdict() for i in self.result]
else:
data = [dict(zip(i.keys(), i.values())) for i in self.result]
self.result.close()
return data
def scalars(self):
"""
返回首个值的集合
:return:
"""
data = set(i[0] for i in self.result)
self.result.close()
return data
def scalar_list(self):
"""
获取首个值的列表
:return:
"""
data = [i[0] for i in self.result]
self.result.close()
return data
def first(self):
row = None
# 获取第一行数据
for i in self.result:
if self.version.startswith('1.4'):
row = i._asdict()
else:
row = dict(zip(i.keys(), i.values()))
break
self.result.close()
return row
def scalar(self):
"""获取第一行第一列的数据"""
one = None
for i in self.result:
one = i[0]
break
self.result.close()
return one
def execute_sql(sql, **kwargs):
db_instance = DB(db.session)
return db_instance.execute(sql, **kwargs)