import logging import time from traceback import format_exc import struct import sqlalchemy from sqlalchemy import text from fastapi_sqlalchemy import db from jinja2 import Template from pymysql import err from pymysql.constants import FIELD_TYPE from pymysql.converters import conversions, convert_time __author__ = 'Woodstock' logger = logging.getLogger("uvicorn.error") engine = None conversions[FIELD_TYPE.BIT] = lambda b: struct.unpack(">Q", (bytes([0x00]) * (8 - len(b)) + b))[0] conversions[FIELD_TYPE.TIME] = convert_time conversions[FIELD_TYPE.DATETIME] = str class EDatabase(Exception): def __init__(self, *args): super(EDatabase, self).__init__(*args) def result(self): return None class DB: CACHE = None ACCESS_TOKEN = None IS_LOCAL_TEST = False version = sqlalchemy.__version__ def __init__(self, session=None): """ 传入session时整个db对象使用同一session操作 """ self.session = session self.result = None self.rowcount = -1 self.lastrowid = -1 def execute(self, sql, **kwargs): if "{" in sql: sql_template = Template(sql) temp_sql = sql_template.render(**kwargs) else: temp_sql = sql temp_sql = "\n".join([l.strip() for l in temp_sql.split("\n")]) # 用于输出日志 log_string = f"执行SQL: {temp_sql.strip()}\n参数: {str(kwargs)[:400]}" t = time.time() try: pars = dict(**kwargs) self.result = self.session.execute(text(temp_sql), params=pars) except (err.OperationalError, err.InternalError, err.InternalError) as e: orig = e if not isinstance(orig, err.InternalError): orig = orig.orig if orig.args[0] in [111, 1205, 1213, 2003, 2013]: raise if orig.args[0] == 1044: logger.error(f"(pymysql.err.OperationalError) (1044, Access denied for {self.session.bind.url})") logger.error(format_exc()) raise log_string += f"\n执行结果:\n 执行失败了 ,错误信息:{e}" logger.error(log_string) raise except Exception as e: log_string += f"\n执行结果:\n 执行失败了 ,错误信息:{e}" logger.error(log_string) raise finally: logger.info(f"{log_string}\n耗时:{time.time() - t}") # 受影响行数 self.rowcount = self.result.rowcount # 插入数据时返回的主键数据 self.lastrowid = self.result.lastrowid logger.debug(f"执行结果:\n 受影响行数:{self.rowcount}") return self def all(self): # 获取查询列表。把ResultProxy和RowProxy类型封装成python的list和dict类型 if self.version.startswith('1.4'): data = [i._asdict() for i in self.result] else: data = [dict(zip(i.keys(), i.values())) for i in self.result] self.result.close() return data def scalars(self): """ 返回首个值的集合 :return: """ data = set(i[0] for i in self.result) self.result.close() return data def scalar_list(self): """ 获取首个值的列表 :return: """ data = [i[0] for i in self.result] self.result.close() return data def first(self): row = None # 获取第一行数据 for i in self.result: if self.version.startswith('1.4'): row = i._asdict() else: row = dict(zip(i.keys(), i.values())) break self.result.close() return row def scalar(self): """获取第一行第一列的数据""" one = None for i in self.result: one = i[0] break self.result.close() return one def execute_sql(sql, **kwargs): db_instance = DB(db.session) return db_instance.execute(sql, **kwargs)