From 51e370d017735120dc9775d3ba80188356e90e49 Mon Sep 17 00:00:00 2001 From: chenwj <654891551@qq.com> Date: Fri, 6 Jan 2023 17:44:25 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E5=A2=9E=E5=8A=A0=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E5=BA=93SQL=E7=9B=B8=E5=85=B3=E6=93=8D=E4=BD=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.py | 12 ++--- src/biz/__init__.py | 113 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 118 insertions(+), 7 deletions(-) diff --git a/config.py b/config.py index 0be5f33..42bba39 100644 --- a/config.py +++ b/config.py @@ -14,12 +14,11 @@ class ReConfigParser(ConfigParser): def optionxform(self, optionstr): return optionstr -class CommonConfig(BaseSettings): - SECRET_KEY: str = os.urandom(32) - PROJECT_NAME: str - API_V1_STR: str - # 允许访问的origins - BACKEND_CORS_ORIGINS: str +# class CommonConfig(BaseSettings): +# SECRET_KEY: str = os.urandom(32) +# PROJECT_NAME: str +# API_V1_STR: str +# BACKEND_CORS_ORIGINS: str class MySQLConfig(BaseSettings): @@ -59,7 +58,6 @@ def init_config(): else: config.read(os.path.join('.', 'conf', 'conf-dev.ini'), encoding='utf-8') # common_config = CommonConfig(**dict(config.items('common'))) - common_config = CommonConfig(**dict(config.items('common'))) mysql_config = MySQLConfig(**dict(config.items('mysql'))) redis_config = RedisConfig(**dict(config.items('redis'))) rabbitmq_config = RabbitmqConfig(**dict(config.items("rabbitmq"))) diff --git a/src/biz/__init__.py b/src/biz/__init__.py index e69de29..bc41aa8 100644 --- a/src/biz/__init__.py +++ b/src/biz/__init__.py @@ -0,0 +1,113 @@ +import logging +import time +from traceback import format_exc + +from fastapi_sqlalchemy import db +from jinja2 import Template +from pymysql import err + +__author__ = 'Woodstock' + +logger = logging.getLogger(__name__) + +class DB: + CACHE = None + ACCESS_TOKEN = None + IS_LOCAL_TEST = False + + def __init__(self, session=None): + """ + 传入session时整个db对象使用同一session操作 + """ + self.session = session + self.result = None + self.rowcount = -1 + self.lastrowid = -1 + + def execute(self, sql, **kwargs): + if "{" in sql: + sql_template = Template(sql) + temp_sql = sql_template.render(**kwargs) + else: + temp_sql = sql + temp_sql = "\n".join([l.strip() for l in temp_sql.split("\n")]) + + # 用于输出日志 + log_string = "执行SQL: {}\n参数: {}".format(temp_sql.strip(), str(kwargs)[:400]) + t = time.time() + try: + pars = dict(**kwargs) + self.result = self.session.execute(temp_sql, params=pars) + except (err.OperationalError, err.InternalError, err.InternalError) as e: + orig = e + if not isinstance(orig, err.InternalError): + orig = orig.orig + if orig.args[0] in [111, 1205, 1213, 2003, 2013]: + raise + if orig.args[0] == 1044: + logger.error("(pymysql.err.OperationalError) (1044, Access denied for {})".format(self.session.bind.url)) + logger.error(format_exc()) + raise + log_string += "\n执行结果:\n 执行失败了 ,错误信息:" + str(e) + logger.error(log_string) + raise + except Exception as e: + log_string += "\n执行结果:\n 执行失败了 ,错误信息:" + str(e) + logger.error(log_string) + raise + finally: + if not temp_sql.lower().startswith("select") and not temp_sql.lower().startswith('insert into `sysauditlogs`'): # 不记录查询语句 + logger.info("{}\n耗时:{}".format(log_string, time.time() - t)) + # 受影响行数 + self.rowcount = self.result.rowcount + # 插入数据时返回的主键数据 + self.lastrowid = self.result.lastrowid + logger.debug("执行结果:\n 受影响行数:%d" % self.rowcount) + return self + + def all(self): + # 获取查询列表。把ResultProxy和RowProxy类型封装成python的list和dict类型 + data = [ i._asdict() for i in self.result] + self.result.close() + return data + + def scalars(self): + """ + 返回首个值的集合 + :return: + """ + data = set(i[0] for i in self.result) + self.result.close() + return data + + def scalar_list(self): + """ + 获取首个值的列表 + :return: + """ + data = [i[0] for i in self.result] + self.result.close() + return data + + def first(self): + row = None + # 获取第一行数据 + for i in self.result: + row = i._asdict() + break + self.result.close() + return row + + def scalar(self): + """获取第一行第一列的数据""" + one = None + for i in self.result: + one = i[0] + break + self.result.close() + return one + + +def execute_sql(sql, **kwargs): + db_instance = DB(db.session) + return db_instance.execute(sql, **kwargs) \ No newline at end of file