fix: 优化项目模板

This commit is contained in:
chenwj 2023-01-07 15:20:38 +08:00
parent db555faee8
commit 6068c6bd8f
11 changed files with 241 additions and 20 deletions

2
.dockerignore Normal file
View File

@ -0,0 +1,2 @@
.git
**/__pycache__

3
.gitignore vendored
View File

@ -140,3 +140,6 @@ dmypy.json
# Cython debug symbols # Cython debug symbols
cython_debug/ cython_debug/
logs/
files/

View File

@ -52,6 +52,7 @@ sqlacodegen.exe --tables permission_info --outfile .\Desktop\fastapi_app\models\
|--- conf-prod.ini *生产环境配置文件 |--- conf-prod.ini *生产环境配置文件
|--- log.ini *日志配置文件 |--- log.ini *日志配置文件
files *上传文件目录 files *上传文件目录
logs *日志文件目录
src *源码目录 src *源码目录
|--- api *接口目录 |--- api *接口目录
|--- biz *逻辑目录 |--- biz *逻辑目录

View File

@ -18,7 +18,7 @@ qualname=
propagate=0 propagate=0
[logger_uvicorn.error] [logger_uvicorn.error]
level=ERROR level=INFO
handlers=error_file handlers=error_file
qualname=uvicorn.error qualname=uvicorn.error
propagate=0 propagate=0
@ -44,12 +44,12 @@ propagate=1
[handler_error_file] [handler_error_file]
class=logging.handlers.RotatingFileHandler class=logging.handlers.RotatingFileHandler
formatter=error formatter=error
kwargs={"filename": "./log/error.log", "maxBytes": 1024 * 1024 * 10, "backupCount": 10, "delay": True} kwargs={"filename": "./logs/error.log", "maxBytes": 1024 * 1024 * 10, "backupCount": 10, "delay": True, "encoding": "utf8"}
[handler_access_file] [handler_access_file]
class=logging.handlers.RotatingFileHandler class=logging.handlers.RotatingFileHandler
formatter=access formatter=access
kwargs={"filename": "./log/access.log", "maxBytes": 1024 * 1024 * 10, "backupCount": 10, "delay": True} kwargs={"filename": "./logs/access.log", "maxBytes": 1024 * 1024 * 10, "backupCount": 10, "delay": True, "encoding": "utf8"}
[formatter_generic] [formatter_generic]
format=%(asctime)s %(pathname)s(%(lineno)d): %(levelname)s %(message)s format=%(asctime)s %(pathname)s(%(lineno)d): %(levelname)s %(message)s

12
main.py
View File

@ -3,9 +3,12 @@ import os
import fastapi_plugins import fastapi_plugins
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.exceptions import HTTPException, RequestValidationError
from fastapi_sqlalchemy import DBSessionMiddleware from fastapi_sqlalchemy import DBSessionMiddleware
from config import init_config from config import init_config
from src.utils.exception import (http_exception_handler,
request_validation_error_handler)
def create_app(): def create_app():
@ -13,6 +16,11 @@ def create_app():
@app.on_event("startup") @app.on_event("startup")
async def startup_event(): async def startup_event():
# 创建日志文件夹和临时文件上传文件夹
if not os.path.exists("logs"):
os.mkdir("logs")
if not os.path.exists("files"):
os.mkdir("files")
# 日志配置 # 日志配置
# logging_config.fileConfig('conf/log.ini') # logging_config.fileConfig('conf/log.ini')
# 初始化配置文件 # 初始化配置文件
@ -28,6 +36,10 @@ def create_app():
async def shutdown_event(): async def shutdown_event():
await fastapi_plugins.redis_plugin.terminate() await fastapi_plugins.redis_plugin.terminate()
# 添加异常处理
app.add_exception_handler(HTTPException, http_exception_handler)
app.add_exception_handler(RequestValidationError, request_validation_error_handler)
# 在这里添加API route # 在这里添加API route
from src.api import test from src.api import test
app.include_router(test.router) app.include_router(test.router)

View File

@ -18,4 +18,5 @@ uvicorn==0.13.4
PyJWT==2.0.1 PyJWT==2.0.1
passlib==1.7.4 passlib==1.7.4
Pillow==8.0.1 Pillow==8.0.1
captcha==0.3 captcha==0.3
jinja2==2.11.2

View File

@ -0,0 +1,195 @@
# import base64
# import hashlib
# import datetime
# import hmac
# import json
# import logging
# from time import time
# from fastapi import Request
# from functools import wraps
# from traceback import format_exc
# __author__ = 'Woodstock'
# logger = logging.getLogger(__name__)
# def response(status=200, message="", data=None, err_code=0, err_detail="", target_url=""):
# result = {
# "result": data,
# "code": status,
# "error": {
# "code": status if not err_code and status != 200 else err_code,
# "message": message,
# "detail": err_detail if err_detail else message
# }
# }
# return result, status # Response(response=JSON.dumps(result), status=status, mimetype="application/json")
# def create_audit_log(func):
# """创建审计日志对象
# :param func: 函数"""
# return {
# 'ServiceName': Request.url,
# 'ServiceDesc': func.__doc__,
# 'MethodName': Request.method,
# 'ExecutionTime': datetime.datetime.now(),
# 'ExecutionDuration': time(),
# 'ClientIpAddress': Request.headers.get("X-Real-IP", request.remote_addr),
# 'AppVersion': g.Version,
# 'OS': g.OS,
# 'Parameters': get_params(),
# 'AccessToken': g.token
# }
# def save_audit_log(audit_log, status, result_type, resp=None):
# """保存审计日志对象
# :param audit_log: 日志对象
# :param status: 状态码
# :param result_type: 类型 0:执行成功1:基础校验错误, 2:权限校验错误, 3:异常
# :param resp: 返回内容
# """
# try:
# if g.TenantId != -1:
# audit_log.update({
# 'TenantId': g.TenantId,
# 'AppId': g.AppId,
# 'ClientName': g.Channel
# })
# else:
# audit_log['TenantId'] = -1
# audit_log['ReturnValue'] = status
# if status != 200 and resp:
# audit_log['ReturnData'] = json.dumps(resp[0], ensure_ascii=False)
# if result_type == 1:
# audit_log['Exception'] = '签名校验错误'
# elif result_type == 2:
# audit_log['Exception'] = '权限校验错误'
# if g.get('User') is not None:
# audit_log['UserId'] = g.User.get('Id')
# audit_log['UserName'] = g.User.get('Name')
# audit_log['ExecutionDuration'] = 1000 * (time() - audit_log['ExecutionDuration'])
# if len(audit_log.get('Parameters', '')) > 2000:
# audit_log['Parameters'] = audit_log['Parameters'][0:2000]
# except:
# pass
# def api_authorize(permission_name=None, is_check=False, may_user=False):
# """ 接口授权验证 (执行成功后,参数中默认添加用户,租户信息)
# :param permission_name: 需要权限名称, 为None标识不校验权限(允许匿名方法)
# :param is_check: 是否校验签名,默认校验参数签名
# :param may_user: 可能有用户
# """
# def api_authorize_wrapper(func):
# @wraps(func)
# def wrapper(*args, **kwargs):
# app_id = Request.headers.get('X-AppId') # 请求发起应用Id
# nonce = Request.headers.get('X-Nonce', None) # 随机数
# timestamp = Request.headers.get('X-Timestamp', None) # 请求发起时间戳
# sign_str = Request.headers.get('X-Sign', None) # 签名字符串
# token = Request.headers.get('X-Authorization', None)
# channel = Request.headers.get('X-Channel', 'Web') # 渠道
# # 是否是本地,测试用
# g.is_local = False
# g.app_id = app_id
# g.nonce = nonce
# g.timestamp = timestamp
# g.sign_str = sign_str
# g.token = token
# g.channel = channel
# g.Version = Request.headers.get('X-Version') # 软件版本
# g.OS = Request.headers.get('X-OS') # 操作系统版本
# audit_log = create_audit_log(func)
# res = check_sign(permission_name, is_check=is_check) # 校验签名
# if res[0] != 200:
# save_audit_log(audit_log, res[0], 1)
# if not args:
# return response(res[0], res[1])
# return response(res[0], res[1])
# if permission_name or may_user: # 非匿名访问, 需要判断是否登录,
# res = check_permission(permission_name, may_user) # 校验权限
# if res[0] != 200:
# save_audit_log(audit_log, res[0], 2)
# # 20190720 两种返回不一致会导致错误
# if not args:
# return response(res[0], res[1])
# return response(res[0], res[1])
# try:
# result = func(*args, **kwargs) # 实际执行的方法
# # 20190720 两种返回不一致会导致错误
# if not args:
# save_audit_log(audit_log, result.data, 0, result)
# elif isinstance(result, tuple):
# save_audit_log(audit_log, result[1], 0, result)
# return result
# except Exception as e:
# logger.warning(str(e))
# logger.error(format_exc())
# audit_log['Exception'] = format_exc()
# # save_audit_log(audit_log, res[0], 9)
# if not args:
# return response(500, str(e))
# return response(500, str(e))
# return wrapper
# return api_authorize_wrapper
# def check_sign(permission_name=None, is_check=True):
# """认证签名, 认证通过, 返回当前应用归属的租户"""
# app_id = g.app_id
# if not app_id:
# return 403, "header 请求头必须添加X-AppId 参数"
# res = common_biz.get_app_secret(app_id=app_id)
# if not res:
# g.TenantId = -1
# return 403, "获取应用服务失败,请刷新页面后重新尝试登陆"
# g.env = json.loads(res['Env']) if res['Env'] else None # 应用配置参数
# g.TenantId = res["TenantId"] # 当前组织机构 Id
# g.Application = res # 当前应用
# g.AppId = app_id # app_id
# channel = request.headers.get('X-Channel', 'Web') # 渠道
# g.Channel = channel
# code, msg = check_maintenance()
# if code != 200:
# return code, msg
# if not is_check:
# return check_encrypt() # 不需要签名认证,直接返回成功
# if channel == "Web":
# return check_encrypt() # 网页应用,不验证签名
# nonce = g.nonce
# timestamp = g.timestamp
# sign_str = g.sign_str
# if not nonce or not timestamp or not app_id or not sign_str:
# return 403, "签名认证失败,缺少签名参数"
# if time() - float(timestamp) > 300: # 超过5分钟,请求超时
# return 403, "签名认证失败,请求已超时"
# if sign_str != build_sign_str(res["AppSecret"], nonce, timestamp) and not g.is_local:
# return 403, "签名验证失败"
# return check_encrypt()
# def build_sign_str(app_secret, nonce, timestamp):
# """生成签名字符串"""
# values = [Request.method, request.path, app_secret, nonce, timestamp]
# values.sort()
# content = "".join(values) # 待签名字符串
# enc = hmac.new(app_secret.encode("utf-8"), content.encode("utf-8"), hashlib.sha256).digest()
# res = base64.b64encode(enc).decode()
# return res.replace("=", "%")
# def check_encrypt():
# """校验是否加解密"""
# return 200, 'ok'
# def get_params():
# """获取参数"""
# if Request.method in ['GET']:
# return json.dumps(request.args.to_dict(), ensure_ascii=False)
# elif request.method in ['POST', 'PUT'] and request.is_json:
# try:
# return json.dumps(request.json, ensure_ascii=False)
# except:
# return ''
# else:
# return ''

View File

@ -8,7 +8,7 @@ from pymysql import err
__author__ = 'Woodstock' __author__ = 'Woodstock'
logger = logging.getLogger(__name__) logger = logging.getLogger("uvicorn.error")
class DB: class DB:
CACHE = None CACHE = None
@ -56,8 +56,7 @@ class DB:
logger.error(log_string) logger.error(log_string)
raise raise
finally: finally:
if not temp_sql.lower().startswith("select") and not temp_sql.lower().startswith('insert into `sysauditlogs`'): # 不记录查询语句 logger.info(f"{log_string}\n耗时:{time.time() - t}")
logger.info("{}\n耗时:{}".format(log_string, time.time() - t))
# 受影响行数 # 受影响行数
self.rowcount = self.result.rowcount self.rowcount = self.result.rowcount
# 插入数据时返回的主键数据 # 插入数据时返回的主键数据

View File

@ -4,8 +4,8 @@ from pydantic import BaseModel
class ErrorModel(BaseModel): class ErrorModel(BaseModel):
code: int = 0 code: int = 200
messages: str = "" message: str = ""
details: str = "" details: str = ""
@ -18,6 +18,7 @@ class PageItemModel(BaseModel):
class BaseResponse(BaseModel): class BaseResponse(BaseModel):
result: Union[dict, list, Any] = None result: Union[dict, list, Any] = None
code: int = 200
error: ErrorModel = ErrorModel() error: ErrorModel = ErrorModel()
@ -34,10 +35,8 @@ class SendCaptchaSuccess(BaseResponse):
class BadRequestError(BaseResponse): class BadRequestError(BaseResponse):
result: Any = None error: ErrorModel = ErrorModel(**dict(code=400, message="BAD REQUEST", details="请求参数有误"))
error: ErrorModel = ErrorModel(**dict(code=400, messages="BAD REQUEST", details="请求参数有误"))
class ServerInternalError(BaseResponse): class ServerInternalError(BaseResponse):
result: Any = None error: ErrorModel = ErrorModel(**dict(code=500, message="INTERNAL SERVER ERROR", details="服务器内部错误"))
error: ErrorModel = ErrorModel(**dict(code=500, messages="INTERNAL SERVER ERROR", details="服务器内部错误"))

View File

@ -10,7 +10,6 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def verify_password(plain_password, hashed_password): def verify_password(plain_password, hashed_password):
return pwd_context.verify(plain_password, hashed_password) return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password): def get_password_hash(password):
return pwd_context.hash(password) return pwd_context.hash(password)
@ -22,7 +21,17 @@ def get_total_pages(total, limit):
:return: :return:
""" """
s = total / limit if total % limit == 0 else total / limit + 1 s = total / limit if total % limit == 0 else total / limit + 1
return round(s) return round(s)
def get_page_result(page, limit, rows, data):
"""获取分页结果
page 查询页
limit 每页大小
rows 总条数
data 当前页数据
"""
return {"totalRecords": rows, "totalPages": get_total_pages(rows, limit), "currentPage": page,
"items": data} # data
def create_batch(): def create_batch():
"""创建批号""" """创建批号"""

View File

@ -5,7 +5,7 @@ from fastapi.responses import JSONResponse
from src.dtos import BaseResponse, ErrorModel from src.dtos import BaseResponse, ErrorModel
class ErrorCode: class HttpCodeMsg:
REGISTER_USER_ALREADY_EXISTS = "REGISTER_USER_ALREADY_EXISTS" REGISTER_USER_ALREADY_EXISTS = "REGISTER_USER_ALREADY_EXISTS"
LOGIN_BAD_CREDENTIALS = "LOGIN_BAD_CREDENTIALS" LOGIN_BAD_CREDENTIALS = "LOGIN_BAD_CREDENTIALS"
LOGIN_USER_NOT_VERIFIED = "LOGIN_USER_NOT_VERIFIED" LOGIN_USER_NOT_VERIFIED = "LOGIN_USER_NOT_VERIFIED"
@ -45,18 +45,18 @@ def request_validation_error_handler(request: Request, exc: RequestValidationErr
details.append(msg) details.append(msg)
else: else:
details.append(error["msg"]) details.append(error["msg"])
err = dict(code=status.HTTP_400_BAD_REQUEST, messages=ErrorCode.HTTP_400, details="; ".join(details)) err = dict(code=status.HTTP_400_BAD_REQUEST, messages=HttpCodeMsg.HTTP_400, details="; ".join(details))
res = BaseResponse(**dict(result=None, error=err)) res = BaseResponse(**dict(result=None, error=err))
return JSONResponse(status_code=status.HTTP_400_BAD_REQUEST, content=res.dict()) return JSONResponse(status_code=status.HTTP_400_BAD_REQUEST, content=res.dict())
def http_exception_handler(request: Request, exc: HTTPException): def http_exception_handler(request: Request, exc: HTTPException):
if exc.status_code == 400: if exc.status_code == 400:
err = ErrorModel(**dict(code=status.HTTP_400_BAD_REQUEST, messages=ErrorCode.HTTP_400, details=exc.detail)) err = ErrorModel(**dict(code=status.HTTP_400_BAD_REQUEST, messages=HttpCodeMsg.HTTP_400, details=exc.detail))
elif exc.status_code == 401: elif exc.status_code == 401:
err = ErrorModel(**dict(code=status.HTTP_401_UNAUTHORIZED, messages=ErrorCode.HTTP_401, details=exc.detail)) err = ErrorModel(**dict(code=status.HTTP_401_UNAUTHORIZED, messages=HttpCodeMsg.HTTP_401, details=exc.detail))
elif exc.status_code == 500: elif exc.status_code == 500:
err = ErrorModel(**dict(code=status.HTTP_500_INTERNAL_SERVER_ERROR, messages=ErrorCode.HTTP_500, details=exc.detail)) err = ErrorModel(**dict(code=status.HTTP_500_INTERNAL_SERVER_ERROR, messages=HttpCodeMsg.HTTP_500, details=exc.detail))
else: else:
err = None # 未定义的错误消息 err = None # 未定义的错误消息
res = BaseResponse(**dict(result=None, error=err)) res = BaseResponse(**dict(result=None, error=err))