diff --git a/.gitignore b/.gitignore index 44b67bd..ff87c95 100644 --- a/.gitignore +++ b/.gitignore @@ -1,22 +1,142 @@ -.idea -.vscode -*.pyc +# Created by .ignore support plugin (hsz.mobi) +### Python template +# Byte-compiled / optimized / DLL files __pycache__/ -conf/conf-*.ini -my.setting -server.setting -static/img/banners/* -static/img/modals/* -static/img/avatars/* -static/img/filters/* -static/img/qrcode/* -static/img/qrcode_zip/* -static/img/qrcode_temp/* -log/*.log -log/*.lock -log/ ---ini -output.log -push_to_gitee.sh -/static/img/cash_out/ -/venv/ \ No newline at end of file +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +.idea/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml new file mode 100644 index 0000000..bd2c766 --- /dev/null +++ b/.gitlab-ci.yml @@ -0,0 +1,25 @@ +variables: + PROJECT_NAME: fastapi_app_template + BASE_VERSION: 1 + DOCKER_IMAGE_DOMAIN: docker.yingzhen1688.com + LATEST_VERSION: latest + +before_script: + - PATH=.:$PATH + - VERSION=$BASE_VERSION"."$CI_JOB_ID + +stages: + - build + +deploy-staging:dep: + stage: build + only: + - master + tags: + - master + script: + - echo "===== start build ==========" + - docker build -t ${DOCKER_IMAGE_DOMAIN}/library/${PROJECT_NAME}:${LATEST_VERSION} . + - docker push ${DOCKER_IMAGE_DOMAIN}/library/${PROJECT_NAME}:${LATEST_VERSION} + - docker rmi ${DOCKER_IMAGE_DOMAIN}/library/${PROJECT_NAME}:${LATEST_VERSION} + - echo "===== end build !!!!!! =====" diff --git a/conf/conf-dev.ini b/conf/conf-dev.ini new file mode 100644 index 0000000..4c97a07 --- /dev/null +++ b/conf/conf-dev.ini @@ -0,0 +1,20 @@ +[common] +secret_key=5cf366d7b06714683756ca019c6283ed440bf771edd281d8 +lifetime_seconds=3600 + +[mysql] +username=root +password=123456 +host=192.168.2.94 +port=3306 +database=admgs_v2 + + +[redis] +redis_host=192.168.2.94 +redis_port=6379 + +[rabbitmq] +rabbitmq_host=192.168.2.94 +rabbitmq_user=root +rabbitmq_password=123123 \ No newline at end of file diff --git a/conf/conf-prod.ini b/conf/conf-prod.ini new file mode 100644 index 0000000..4c97a07 --- /dev/null +++ b/conf/conf-prod.ini @@ -0,0 +1,20 @@ +[common] +secret_key=5cf366d7b06714683756ca019c6283ed440bf771edd281d8 +lifetime_seconds=3600 + +[mysql] +username=root +password=123456 +host=192.168.2.94 +port=3306 +database=admgs_v2 + + +[redis] +redis_host=192.168.2.94 +redis_port=6379 + +[rabbitmq] +rabbitmq_host=192.168.2.94 +rabbitmq_user=root +rabbitmq_password=123123 \ No newline at end of file diff --git a/config.py b/config.py index 00b8310..0be5f33 100644 --- a/config.py +++ b/config.py @@ -1,8 +1,9 @@ import os from configparser import ConfigParser from typing import Optional -from pydantic import BaseSettings + from fastapi_plugins import RedisSettings +from pydantic import BaseSettings class ReConfigParser(ConfigParser): @@ -13,12 +14,12 @@ class ReConfigParser(ConfigParser): def optionxform(self, optionstr): return optionstr -# class CommonConfig(BaseSettings): -# SECRET_KEY: str = os.urandom(32) -# PROJECT_NAME: str -# API_V1_STR: str -# # 允许访问的origins -# BACKEND_CORS_ORIGINS: str +class CommonConfig(BaseSettings): + SECRET_KEY: str = os.urandom(32) + PROJECT_NAME: str + API_V1_STR: str + # 允许访问的origins + BACKEND_CORS_ORIGINS: str class MySQLConfig(BaseSettings): @@ -42,6 +43,12 @@ class RedisConfig(RedisSettings): redis_connection_timeout: int = 2 +class RabbitmqConfig(BaseSettings): + rabbitmq_host: str + rabbitmq_user: str + rabbitmq_password: str + + def init_config(): """初始化配置文件""" print("加载配置文件...") @@ -52,9 +59,11 @@ def init_config(): else: config.read(os.path.join('.', 'conf', 'conf-dev.ini'), encoding='utf-8') # common_config = CommonConfig(**dict(config.items('common'))) + common_config = CommonConfig(**dict(config.items('common'))) mysql_config = MySQLConfig(**dict(config.items('mysql'))) - print(mysql_config) redis_config = RedisConfig(**dict(config.items('redis'))) + rabbitmq_config = RabbitmqConfig(**dict(config.items("rabbitmq"))) + # return common_config, mysql_config, redis_config, rabbitmq_config return mysql_config, redis_config except Exception as e: print(e) diff --git a/main.py b/main.py index d1297a3..1cce106 100644 --- a/main.py +++ b/main.py @@ -1,8 +1,10 @@ +import logging.config as logging_config import os + +import fastapi_plugins from fastapi import FastAPI from fastapi_sqlalchemy import DBSessionMiddleware -import fastapi_plugins -import logging.config as logging_config + from config import init_config @@ -27,7 +29,7 @@ def create_app(): await fastapi_plugins.redis_plugin.terminate() # 在这里添加API route - from api import test + from src.api import test app.include_router(test.router) return app diff --git a/api/__init__.py b/src/api/__init__.py similarity index 100% rename from api/__init__.py rename to src/api/__init__.py diff --git a/api/test.py b/src/api/test.py similarity index 97% rename from api/test.py rename to src/api/test.py index f2f3cbf..ca29949 100644 --- a/api/test.py +++ b/src/api/test.py @@ -1,14 +1,15 @@ # Test Router 页面 from typing import Optional -from fastapi import APIRouter, Query, Depends -from fastapi.security import OAuth2PasswordBearer import aioredis +from fastapi import APIRouter, Depends, Query +from fastapi.security import OAuth2PasswordBearer from fastapi_plugins import depends_redis from fastapi_sqlalchemy import db -from models.devices_place import DevicesPlace -from pydantic_sqlalchemy import sqlalchemy_to_pydantic from pydantic import BaseModel +from pydantic_sqlalchemy import sqlalchemy_to_pydantic + +from models.devices_place import DevicesPlace router = APIRouter(prefix="/test") diff --git a/src/dtos/__init__.py b/src/dtos/__init__.py new file mode 100644 index 0000000..7e84842 --- /dev/null +++ b/src/dtos/__init__.py @@ -0,0 +1,43 @@ +from typing import Any, List, Union + +from pydantic import BaseModel + + +class ErrorModel(BaseModel): + code: int = 0 + messages: str = "" + details: str = "" + + +class PageItemModel(BaseModel): + totalRecords: int = 0 + totalPages: int = 0 + currentPage: int = 0 + items: List[Any] = [] + + +class BaseResponse(BaseModel): + result: Union[dict, list, Any] = None + error: ErrorModel = ErrorModel() + + +class ListResponse(BaseResponse): + result: PageItemModel = PageItemModel() + + +class CommonSuccess(BaseResponse): + result: dict = dict(success=True) + + +class SendCaptchaSuccess(BaseResponse): + result: dict = dict(success=True, token="", image="") + + +class BadRequestError(BaseResponse): + result: Any = None + error: ErrorModel = ErrorModel(**dict(code=400, messages="BAD REQUEST", details="请求参数有误")) + + +class ServerInternalError(BaseResponse): + result: Any = None + error: ErrorModel = ErrorModel(**dict(code=500, messages="INTERNAL SERVER ERROR", details="服务器内部错误")) \ No newline at end of file diff --git a/models/__init__.py b/src/models/__init__.py similarity index 100% rename from models/__init__.py rename to src/models/__init__.py diff --git a/models/devices_place.py b/src/models/devices_place.py similarity index 100% rename from models/devices_place.py rename to src/models/devices_place.py diff --git a/utils/__init__.py b/src/utils/__init__.py similarity index 100% rename from utils/__init__.py rename to src/utils/__init__.py diff --git a/utils/captcha_tools.py b/src/utils/captcha_tools.py similarity index 98% rename from utils/captcha_tools.py rename to src/utils/captcha_tools.py index db6fd37..ed6ef97 100644 --- a/utils/captcha_tools.py +++ b/src/utils/captcha_tools.py @@ -1,17 +1,16 @@ -import random -import logging import base64 -from os import urandom +import logging +import random from io import BytesIO -from PIL import ImageFilter +from os import urandom + from captcha.image import ImageCaptcha, random_color +from PIL import ImageFilter __author__ = 'Woodstock' -__doc__ = """ -验证码图片生成 -""" +__doc__ = """验证码图片生成""" logger = logging.getLogger(__name__) diff --git a/src/utils/common_tools.py b/src/utils/common_tools.py new file mode 100644 index 0000000..f146a4b --- /dev/null +++ b/src/utils/common_tools.py @@ -0,0 +1,37 @@ +import random +import re +from datetime import datetime + +from passlib.context import CryptContext + +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + + +def verify_password(plain_password, hashed_password): + return pwd_context.verify(plain_password, hashed_password) + + +def get_password_hash(password): + return pwd_context.hash(password) + + +def create_batch(): + """创建批号""" + return datetime.now().strftime('%Y%m%d%H%M%S') + str(random.random())[-6:] + +def validate_email(email: str) -> bool: + """验证邮箱""" + pattern = r"[a-zA-Z0-9_-]+@[a-zA-Z0-9_-]+(?:\.[a-zA-Z0-9_-]+)" + if re.match(pattern, email) is not None: + return True + else: + return False + +def validate_phone_number(phone_number: str) -> bool: + """验证手机号""" + pattern = r"1[356789]\d{9}" + if re.match(pattern, phone_number) is not None: + return True + else: + return False + diff --git a/src/utils/exception_tools.py b/src/utils/exception_tools.py new file mode 100644 index 0000000..170e280 --- /dev/null +++ b/src/utils/exception_tools.py @@ -0,0 +1,63 @@ +from fastapi import Request, status +from fastapi.exceptions import HTTPException, RequestValidationError +from fastapi.responses import JSONResponse + +from src.dtos import BaseResponse, ErrorModel + + +class ErrorCode: + REGISTER_USER_ALREADY_EXISTS = "REGISTER_USER_ALREADY_EXISTS" + LOGIN_BAD_CREDENTIALS = "LOGIN_BAD_CREDENTIALS" + LOGIN_USER_NOT_VERIFIED = "LOGIN_USER_NOT_VERIFIED" + RESET_PASSWORD_BAD_TOKEN = "RESET_PASSWORD_BAD_TOKEN" + VERIFY_USER_BAD_TOKEN = "VERIFY_USER_BAD_TOKEN" + VERIFY_USER_ALREADY_VERIFIED = "VERIFY_USER_ALREADY_VERIFIED" + VERIFY_USER_TOKEN_EXPIRED = "VERIFY_USER_TOKEN_EXPIRED" + HTTP_200 = "OK" + HTTP_400 = "BAD_REQUEST" + HTTP_401 = "UNAUTHORIZED" + HTTP_403 = "FORBIDDEN" + HTTP_404 = "NOT_FOUND" + HTTP_405 = "METHOD_NOT_ALLOWED" + HTTP_500 = "INTERNAL_SERVER_ERROR" + HTTP_502 = "BAD_GATEWAY" + HTTP_503 = "SERVICE_UNAVAILABLE" + HTTP_504 = "GATEWAY_TIMEOUT" + + +def request_validation_error_handler(request: Request, exc: RequestValidationError): + errors = exc.errors() + print(errors) + details = [] + for error in errors: + err_type = error["type"].split(".") + if len(err_type) == 2: + if err_type[0] == "value_error": + if err_type[1] == "missing": + if len(error["loc"]) == 2: + msg = f"{error['loc'][0]}: '{error['loc'][1]}' is required!" + else: + msg = "value is missing!" + else: + msg = error["msg"].replace("this value", f"'{error['loc'][1]}' value") + else: + msg = error["msg"] + details.append(msg) + else: + details.append(error["msg"]) + err = dict(code=status.HTTP_400_BAD_REQUEST, messages=ErrorCode.HTTP_400, details="; ".join(details)) + res = BaseResponse(**dict(result=None, error=err)) + return JSONResponse(status_code=status.HTTP_400_BAD_REQUEST, content=res.dict()) + + +def http_exception_handler(request: Request, exc: HTTPException): + if exc.status_code == 400: + err = ErrorModel(**dict(code=status.HTTP_400_BAD_REQUEST, messages=ErrorCode.HTTP_400, details=exc.detail)) + elif exc.status_code == 401: + err = ErrorModel(**dict(code=status.HTTP_401_UNAUTHORIZED, messages=ErrorCode.HTTP_401, details=exc.detail)) + elif exc.status_code == 500: + err = ErrorModel(**dict(code=status.HTTP_500_INTERNAL_SERVER_ERROR, messages=ErrorCode.HTTP_500, details=exc.detail)) + else: + err = None # 未定义的错误消息 + res = BaseResponse(**dict(result=None, error=err)) + return JSONResponse(status_code=exc.status_code, content=res.dict()) \ No newline at end of file diff --git a/src/utils/file_upload_tools.py b/src/utils/file_upload_tools.py new file mode 100644 index 0000000..431a4d2 --- /dev/null +++ b/src/utils/file_upload_tools.py @@ -0,0 +1,168 @@ +# -*- coding: utf-8 -*- +import logging +import mimetypes +import os +import random +import shutil +import time +import traceback +import urllib.request + +from fastapi import UploadFile + +from .qiniu_tools import QiniuYuProvider + +# http_base = "" # 外网范围基础地址 +# path_base = "" # 文件保存基础地址 + +logger = logging.getLogger(__name__) + + +def init(baseurl, base_dir): + """初始化""" + global http_base, path_base + http_base = baseurl + path_base = base_dir + + +def del_cloud_file(cfg, key): + """ + 删除云服务文件 + :param cfg: + :param key: + :return: + """ + api = QiniuYuProvider(**cfg) + return api.del_file(key=key) + + +def save_cloud_file(file: UploadFile, cfg, user=None): + """ + 保存文件至云服务 + :param file: + :param cfg: + :param user: + :return: + """ + try: + ext = get_file_ext(file.filename) + if ext not in ['jpg', 'png', 'jpeg', 'gif', 'aac', 'mp3', 'amr', 'apk', 'webm', 'mp4', 'mov', 'flv', 'mkv', + '3gp', 'avi']: + return 403, '文件格式不合法,请上传正确的文件格式', None + filename = build_file_name(ext) # 新文件名 + file.filename = filename + uploaded_file_path = os.path.join("files", filename) + # flask 原有的图片保存方式 + # file.save(uploaded_file_path) + # fastApi 实现的图片保存方式 + with open(uploaded_file_path, "wb+") as buffer: + shutil.copyfileobj(file.file, buffer) + api = QiniuYuProvider(**cfg) + res = api.upload_file(file=uploaded_file_path, types=ext) + if res[0] == 200: + if os.path.exists(uploaded_file_path): + os.remove(uploaded_file_path) + return res + except Exception as e: + logger.error(traceback.format_exc()) + return 500, '保存文件失败', None + + +def save_file(file, save_db=None, user=None): + """保存文件,返回文件路径""" + try: + ext = get_file_ext(file.filename) + if ext not in ['jpg', 'png', 'jpeg', 'gif', 'aac', 'mp3', 'amr', 'apk', 'webm', 'mp4', 'mov', 'flv', 'mkv', + '3gp', 'avi']: + return 403, '文件格式不合法,请上传正确的文件格式', None + filename = build_file_name(ext) # 新文件名 + date_str = time.strftime('%Y%m%d') + base_path = os.path.join(path_base, 'static', 'mr', date_str) + if not os.path.exists(base_path): # 确保当日的文件夹存在 + os.makedirs(base_path) + filepath = os.path.join(base_path, filename) # 实际文件路径 + file.save(filepath) # 保存文件 + if save_db: + data = { + 'http_base': http_base, + 'file_base': path_base, + 'file_name': file.filename, # 原始文件名 + 'file_path': '/static/mr/{date_str}/{filename}'.format(date_str=date_str, filename=filename), + 'file_type': get_file_ext(file.filename), # 文件扩展名 + 'mime_type': mimetypes.guess_type(file.filename)[0], # 媒体类型 + 'file_size': get_file_size(filepath), + 'file_size_str': get_file_size_str(filepath) + } + return save_db(data, user) + image_url = "{}/static/mr/{}/{}".format(http_base, date_str, filename) + return 200, "", image_url + except Exception as e: + logger.error(e) + return 500, '保存文件失败', None + + +def save_image(body): + """ + 保存文件返回文件url + @param body 文件内容字典,包括url,md5,ext,size等信息 + """ + url = body.get('url', '') + ext = body.get('ext', 'jpg') + filename = build_file_name(ext) # 新文件名 + if url != '': + date_str = time.strftime('%Y%m%d') + base_path = os.path.join(path_base, 'static', 'im', date_str) + if not os.path.exists(base_path): # 确保当日的文件夹存在 + os.makedirs(base_path) + filepath = os.path.join(base_path, filename) # 实际文件路径 + try: + urllib.request.urlretrieve(url, filepath) + + url = "{}/static/im/{}/{}".format(http_base, date_str, filename) + except Exception as e: + logger.error(e) + return url # 返回图片url + + +def get_file_ext(filename): + """获取文件扩展名""" + if '.' in filename: + return filename.rsplit('.', 1)[1].lower() + return None + + +def build_file_name(ext): + """生成文件名""" + filename = time.strftime('%H%M%S') + str(random.random())[-6:] + return '{}.{}'.format(filename, ext) # 新文件名 + + +def get_file_path(filename): + """获取文件网站路径""" + return "{path_base}/{filename}".format(path_base=path_base, filename=filename) + + +def get_file_size(file_path): + """获取文件大小""" + file_size = os.path.getsize(file_path) + return file_size + + +def get_file_size_str(file_path): + """获取文件大小字符串""" + file_size = get_file_size(file_path) + + def strofsize(integer, remainder, level): + if integer >= 1024: + remainder = integer % 1024 + integer //= 1024 + level += 1 + return strofsize(integer, remainder, level) + else: + return integer, remainder, level + + units = ['B', 'KB', 'MB', 'GB', 'TB', 'PB'] + integer, remainder, level = strofsize(file_size, 0, 0) + if level + 1 > len(units): + level = -1 + return '{}.{:>03d} {}'.format(integer, remainder, units[level]) diff --git a/src/utils/qiniu_tools.py b/src/utils/qiniu_tools.py new file mode 100644 index 0000000..5b9db9c --- /dev/null +++ b/src/utils/qiniu_tools.py @@ -0,0 +1,103 @@ + +import datetime +import logging +import random +import traceback +import uuid + +import qiniu.config +from qiniu import Auth, BucketManager, etag, put_data, put_file + +__author__ = 'Woodstock' + +__doc__ = """七牛上传""" + +logger = logging.getLogger(__name__) + + +class QiniuYuProvider(object): + + def __init__(self, type, name, access_key, secret_key, bucket_name, policy, url='http://dealer-public.hyd169.com'): + """ + :param type: + :param name: + :param access_key: + :param secret_key: + :param bucket_name: + :param policy: + :param url: + """ + self.type = type + self.name = name + self.access_key = access_key + self.secret_key = secret_key + self.bucket_name = bucket_name + self.policy = policy + self.url = url + + def del_file(self, key): + """ + 删除已上传的文件 + :param key: + :return: + """ + try: + q = Auth(self.access_key, self.secret_key) + bucket = BucketManager(q) + + # 删除bucket_name 中的文件 key + ret, info = bucket.delete(self.bucket_name, key) + status_code = info.status_code + if status_code == 200: + return 200, '删除图片成功' + return 500, '删除图片失败' + except Exception as e: + logger.error(traceback.format_exc()) + return 500, '删除图片异常' + + def upload_file(self, file, types=None): + """ + 上传文件 + :param file: + :param types: + :return: + """ + try: + q = Auth(self.access_key, self.secret_key) + + key = self.generate_unique_key(token=uuid.uuid4(), types=types) + if not key or key == '': + raise ValueError('Generate unique key error') + + token = q.upload_token(self.bucket_name, key) + ret, info = put_file(token, key, file) + status_code = info.status_code + if status_code == 200: + assert ret['key'] == key + assert ret['hash'] == etag(file) + data = dict( + req_id=info.req_id, + key=ret['key'], + hash=ret['hash'], + url=self.url + '/' + ret['key'], + type=types, + file=file + ) + return 200, '上传成功', data + return 500, '上传失败', None + except Exception as e: + logger.error(traceback.format_exc()) + return 500, '上传失败', None + + @staticmethod + def generate_unique_key(token, types): + """ + :param token: + :param types: + :return: + """ + import pytz + tz = pytz.timezone('Asia/Shanghai') + key = str(token) + '-' + str(datetime.datetime.now(tz).microsecond) + '-' + str(random.random()) + '.' + str( + types) + return key diff --git a/utils/sms_tools.py b/src/utils/sms_tools.py similarity index 98% rename from utils/sms_tools.py rename to src/utils/sms_tools.py index 99fb3db..378ded6 100644 --- a/utils/sms_tools.py +++ b/src/utils/sms_tools.py @@ -1,17 +1,17 @@ -import json -import uuid -import time -import hmac -import random import base64 -import hashlib -import logging import datetime -import requests +import hashlib +import hmac +import json +import logging +import random +import time import traceback -from urllib.parse import urlencode, quote +import uuid +from urllib.parse import quote, urlencode +import requests __author__ = 'Woodstock' diff --git a/test/test_api.py b/test/test_api.py index 1878f5c..d2a45ee 100644 --- a/test/test_api.py +++ b/test/test_api.py @@ -1,4 +1,5 @@ import unittest + import requests diff --git a/utils/token_tools.py b/utils/token_tools.py deleted file mode 100644 index ba36a4f..0000000 --- a/utils/token_tools.py +++ /dev/null @@ -1,77 +0,0 @@ -import jwt -import aioredis -from datetime import datetime, timedelta -from fastapi import status, Depends -from fastapi_plugins import depends_redis -from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm -from fastapi.exceptions import HTTPException -from utils.user_tools import get_user_by_email, authenticate_user -from common import schemas - -SECRET_KEY = "5cf366d7b06714683756ca019c6283ed440bf771edd281d8" -tokenUrl = "/token" - -oauth2_scheme = OAuth2PasswordBearer( - tokenUrl=tokenUrl, - scopes={"me": "Read information about the current user.", "items": "Read items."}, -) - - -async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), - cache: aioredis.Redis = Depends(depends_redis)): - """验证成功后返回有效的用户访问令牌""" - user = authenticate_user(form_data) - if not user: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Incorrect username or password", - headers={"WWW-Authenticate": "Bearer"}, - ) - access_token_expires = timedelta(minutes=60) - access_token = create_access_token(data={"sub": user.Email}, expires_delta=access_token_expires) - # 将access_token 存到缓存里面 - await cache.setex("TOKEN:" + str(user.Id), access_token_expires.seconds, access_token) - return dict(access_token=access_token, access_type="bearer") - - -def create_access_token(*, data: dict, - expires_delta: timedelta = None, - algorithm="HS256"): - to_encode = data.copy() - if expires_delta: - expire = datetime.utcnow() + expires_delta - else: - expire = datetime.utcnow() + timedelta(minutes=15) - to_encode.update({"exp": expire}) - encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=algorithm) - return encoded_jwt - - -async def decode_access_token(token, cache: aioredis.Redis, algorithm="HS256"): - credentials_exception = HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Could not validate credentials", - headers={"WWW-Authenticate": "Bearer"}, - ) - try: - payload = jwt.decode(token, SECRET_KEY, algorithms=[algorithm]) - email: str = payload.get("sub") - if email is None: - raise credentials_exception - token_data = schemas.TokenData(email=email) - except jwt.PyJWTError: - raise credentials_exception - user = get_user_by_email(email=token_data.email) - if user is None: - raise credentials_exception - # 查询缓存 - cached_token = await cache.get("TOKEN:" + str(user.Id), encoding="utf-8") - if cached_token is None or cached_token != token: - raise credentials_exception - return user - - -async def get_current_user(token: str = Depends(oauth2_scheme), - cache: aioredis.Redis = Depends(depends_redis)): - current_user = await decode_access_token(token, cache) - return current_user diff --git a/utils/user_tools.py b/utils/user_tools.py deleted file mode 100644 index 8d12aa4..0000000 --- a/utils/user_tools.py +++ /dev/null @@ -1,92 +0,0 @@ -import re -from typing import Optional -from fastapi import status -from fastapi.security import OAuth2PasswordRequestForm -from fastapi.exceptions import HTTPException -from fastapi_sqlalchemy import db -from models.users import SysUser -from common.schemas import UserCreateUpdate -from common.password import verify_password, get_password_hash - - -def validate_email(email: str) -> bool: - pattern = r"[a-zA-Z0-9_-]+@[a-zA-Z0-9_-]+(?:\.[a-zA-Z0-9_-]+)" - if re.match(pattern, email) is not None: - return True - else: - return False - - -def validate_phone_number(phone_number: str) -> object: - pattern = r"1[356789]\d{9}" - if re.match(pattern, phone_number) is not None: - return True - else: - return False - - -def get_user_by_email(email: str) -> Optional[SysUser]: - if validate_email(email) is True: - user = db.session.query(SysUser).filter(SysUser.Email == email).first() - return user - else: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="无效的邮箱地址" - ) - - -def get_user_by_phone_number(phone_number: str) -> Optional[SysUser]: - if validate_phone_number(phone_number) is True: - user = db.session.query(SysUser).filter(SysUser.PhoneNumber == phone_number).first() - return user - else: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="无效的手机号码格式" - ) - - -def get_user_by_username(username: str) -> Optional[SysUser]: - user = db.session.query(SysUser).filter(SysUser.UserName == username).first() - return user - - -def authenticate_user(form_data: OAuth2PasswordRequestForm) -> Optional[SysUser]: - if validate_email(form_data.username) is True: - user = get_user_by_email(form_data.username) - else: - if validate_phone_number(form_data.username) is True: - user = get_user_by_phone_number(form_data.username) - else: - user = get_user_by_username(form_data.username) - if not user: - return None - if not verify_password(form_data.password, user.Password): - return None - return user - - -def auth_user(username: str, password: str) -> Optional[SysUser]: - if validate_email(username) is True: - user = get_user_by_email(username) - else: - if validate_phone_number(username) is True: - user = get_user_by_phone_number(username) - else: - user = get_user_by_username(username) - if not user: - return None - if not verify_password(password, user.Password): - return None - return user - - -def create_user(user_data: UserCreateUpdate): - user_data.Password = get_password_hash(user_data.Password) - user = SysUser(**user_data.dict()) - db.session.add(user) - db.session.commit() - return user - -