fix: 重置开发模板

This commit is contained in:
chenwj 2023-01-06 14:39:08 +08:00
parent 34510b0686
commit 4cf61a525b
21 changed files with 663 additions and 221 deletions

162
.gitignore vendored
View File

@ -1,22 +1,142 @@
.idea
.vscode
*.pyc
# Created by .ignore support plugin (hsz.mobi)
### Python template
# Byte-compiled / optimized / DLL files
__pycache__/
conf/conf-*.ini
my.setting
server.setting
static/img/banners/*
static/img/modals/*
static/img/avatars/*
static/img/filters/*
static/img/qrcode/*
static/img/qrcode_zip/*
static/img/qrcode_temp/*
log/*.log
log/*.lock
log/
--ini
output.log
push_to_gitee.sh
/static/img/cash_out/
/venv/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
.idea/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/

25
.gitlab-ci.yml Normal file
View File

@ -0,0 +1,25 @@
variables:
PROJECT_NAME: fastapi_app_template
BASE_VERSION: 1
DOCKER_IMAGE_DOMAIN: docker.yingzhen1688.com
LATEST_VERSION: latest
before_script:
- PATH=.:$PATH
- VERSION=$BASE_VERSION"."$CI_JOB_ID
stages:
- build
deploy-staging:dep:
stage: build
only:
- master
tags:
- master
script:
- echo "===== start build =========="
- docker build -t ${DOCKER_IMAGE_DOMAIN}/library/${PROJECT_NAME}:${LATEST_VERSION} .
- docker push ${DOCKER_IMAGE_DOMAIN}/library/${PROJECT_NAME}:${LATEST_VERSION}
- docker rmi ${DOCKER_IMAGE_DOMAIN}/library/${PROJECT_NAME}:${LATEST_VERSION}
- echo "===== end build !!!!!! ====="

20
conf/conf-dev.ini Normal file
View File

@ -0,0 +1,20 @@
[common]
secret_key=5cf366d7b06714683756ca019c6283ed440bf771edd281d8
lifetime_seconds=3600
[mysql]
username=root
password=123456
host=192.168.2.94
port=3306
database=admgs_v2
[redis]
redis_host=192.168.2.94
redis_port=6379
[rabbitmq]
rabbitmq_host=192.168.2.94
rabbitmq_user=root
rabbitmq_password=123123

20
conf/conf-prod.ini Normal file
View File

@ -0,0 +1,20 @@
[common]
secret_key=5cf366d7b06714683756ca019c6283ed440bf771edd281d8
lifetime_seconds=3600
[mysql]
username=root
password=123456
host=192.168.2.94
port=3306
database=admgs_v2
[redis]
redis_host=192.168.2.94
redis_port=6379
[rabbitmq]
rabbitmq_host=192.168.2.94
rabbitmq_user=root
rabbitmq_password=123123

View File

@ -1,8 +1,9 @@
import os
from configparser import ConfigParser
from typing import Optional
from pydantic import BaseSettings
from fastapi_plugins import RedisSettings
from pydantic import BaseSettings
class ReConfigParser(ConfigParser):
@ -13,12 +14,12 @@ class ReConfigParser(ConfigParser):
def optionxform(self, optionstr):
return optionstr
# class CommonConfig(BaseSettings):
# SECRET_KEY: str = os.urandom(32)
# PROJECT_NAME: str
# API_V1_STR: str
# # 允许访问的origins
# BACKEND_CORS_ORIGINS: str
class CommonConfig(BaseSettings):
SECRET_KEY: str = os.urandom(32)
PROJECT_NAME: str
API_V1_STR: str
# 允许访问的origins
BACKEND_CORS_ORIGINS: str
class MySQLConfig(BaseSettings):
@ -42,6 +43,12 @@ class RedisConfig(RedisSettings):
redis_connection_timeout: int = 2
class RabbitmqConfig(BaseSettings):
rabbitmq_host: str
rabbitmq_user: str
rabbitmq_password: str
def init_config():
"""初始化配置文件"""
print("加载配置文件...")
@ -52,9 +59,11 @@ def init_config():
else:
config.read(os.path.join('.', 'conf', 'conf-dev.ini'), encoding='utf-8')
# common_config = CommonConfig(**dict(config.items('common')))
common_config = CommonConfig(**dict(config.items('common')))
mysql_config = MySQLConfig(**dict(config.items('mysql')))
print(mysql_config)
redis_config = RedisConfig(**dict(config.items('redis')))
rabbitmq_config = RabbitmqConfig(**dict(config.items("rabbitmq")))
# return common_config, mysql_config, redis_config, rabbitmq_config
return mysql_config, redis_config
except Exception as e:
print(e)

View File

@ -1,8 +1,10 @@
import logging.config as logging_config
import os
import fastapi_plugins
from fastapi import FastAPI
from fastapi_sqlalchemy import DBSessionMiddleware
import fastapi_plugins
import logging.config as logging_config
from config import init_config
@ -27,7 +29,7 @@ def create_app():
await fastapi_plugins.redis_plugin.terminate()
# 在这里添加API route
from api import test
from src.api import test
app.include_router(test.router)
return app

View File

@ -1,14 +1,15 @@
# Test Router 页面
from typing import Optional
from fastapi import APIRouter, Query, Depends
from fastapi.security import OAuth2PasswordBearer
import aioredis
from fastapi import APIRouter, Depends, Query
from fastapi.security import OAuth2PasswordBearer
from fastapi_plugins import depends_redis
from fastapi_sqlalchemy import db
from models.devices_place import DevicesPlace
from pydantic_sqlalchemy import sqlalchemy_to_pydantic
from pydantic import BaseModel
from pydantic_sqlalchemy import sqlalchemy_to_pydantic
from models.devices_place import DevicesPlace
router = APIRouter(prefix="/test")

43
src/dtos/__init__.py Normal file
View File

@ -0,0 +1,43 @@
from typing import Any, List, Union
from pydantic import BaseModel
class ErrorModel(BaseModel):
code: int = 0
messages: str = ""
details: str = ""
class PageItemModel(BaseModel):
totalRecords: int = 0
totalPages: int = 0
currentPage: int = 0
items: List[Any] = []
class BaseResponse(BaseModel):
result: Union[dict, list, Any] = None
error: ErrorModel = ErrorModel()
class ListResponse(BaseResponse):
result: PageItemModel = PageItemModel()
class CommonSuccess(BaseResponse):
result: dict = dict(success=True)
class SendCaptchaSuccess(BaseResponse):
result: dict = dict(success=True, token="<user_token:str>", image="<base64_str:str>")
class BadRequestError(BaseResponse):
result: Any = None
error: ErrorModel = ErrorModel(**dict(code=400, messages="BAD REQUEST", details="请求参数有误"))
class ServerInternalError(BaseResponse):
result: Any = None
error: ErrorModel = ErrorModel(**dict(code=500, messages="INTERNAL SERVER ERROR", details="服务器内部错误"))

View File

@ -1,17 +1,16 @@
import random
import logging
import base64
from os import urandom
import logging
import random
from io import BytesIO
from PIL import ImageFilter
from os import urandom
from captcha.image import ImageCaptcha, random_color
from PIL import ImageFilter
__author__ = 'Woodstock'
__doc__ = """
验证码图片生成
"""
__doc__ = """验证码图片生成"""
logger = logging.getLogger(__name__)

37
src/utils/common_tools.py Normal file
View File

@ -0,0 +1,37 @@
import random
import re
from datetime import datetime
from passlib.context import CryptContext
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def verify_password(plain_password, hashed_password):
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password):
return pwd_context.hash(password)
def create_batch():
"""创建批号"""
return datetime.now().strftime('%Y%m%d%H%M%S') + str(random.random())[-6:]
def validate_email(email: str) -> bool:
"""验证邮箱"""
pattern = r"[a-zA-Z0-9_-]+@[a-zA-Z0-9_-]+(?:\.[a-zA-Z0-9_-]+)"
if re.match(pattern, email) is not None:
return True
else:
return False
def validate_phone_number(phone_number: str) -> bool:
"""验证手机号"""
pattern = r"1[356789]\d{9}"
if re.match(pattern, phone_number) is not None:
return True
else:
return False

View File

@ -0,0 +1,63 @@
from fastapi import Request, status
from fastapi.exceptions import HTTPException, RequestValidationError
from fastapi.responses import JSONResponse
from src.dtos import BaseResponse, ErrorModel
class ErrorCode:
REGISTER_USER_ALREADY_EXISTS = "REGISTER_USER_ALREADY_EXISTS"
LOGIN_BAD_CREDENTIALS = "LOGIN_BAD_CREDENTIALS"
LOGIN_USER_NOT_VERIFIED = "LOGIN_USER_NOT_VERIFIED"
RESET_PASSWORD_BAD_TOKEN = "RESET_PASSWORD_BAD_TOKEN"
VERIFY_USER_BAD_TOKEN = "VERIFY_USER_BAD_TOKEN"
VERIFY_USER_ALREADY_VERIFIED = "VERIFY_USER_ALREADY_VERIFIED"
VERIFY_USER_TOKEN_EXPIRED = "VERIFY_USER_TOKEN_EXPIRED"
HTTP_200 = "OK"
HTTP_400 = "BAD_REQUEST"
HTTP_401 = "UNAUTHORIZED"
HTTP_403 = "FORBIDDEN"
HTTP_404 = "NOT_FOUND"
HTTP_405 = "METHOD_NOT_ALLOWED"
HTTP_500 = "INTERNAL_SERVER_ERROR"
HTTP_502 = "BAD_GATEWAY"
HTTP_503 = "SERVICE_UNAVAILABLE"
HTTP_504 = "GATEWAY_TIMEOUT"
def request_validation_error_handler(request: Request, exc: RequestValidationError):
errors = exc.errors()
print(errors)
details = []
for error in errors:
err_type = error["type"].split(".")
if len(err_type) == 2:
if err_type[0] == "value_error":
if err_type[1] == "missing":
if len(error["loc"]) == 2:
msg = f"{error['loc'][0]}: '{error['loc'][1]}' is required!"
else:
msg = "value is missing!"
else:
msg = error["msg"].replace("this value", f"'{error['loc'][1]}' value")
else:
msg = error["msg"]
details.append(msg)
else:
details.append(error["msg"])
err = dict(code=status.HTTP_400_BAD_REQUEST, messages=ErrorCode.HTTP_400, details="; ".join(details))
res = BaseResponse(**dict(result=None, error=err))
return JSONResponse(status_code=status.HTTP_400_BAD_REQUEST, content=res.dict())
def http_exception_handler(request: Request, exc: HTTPException):
if exc.status_code == 400:
err = ErrorModel(**dict(code=status.HTTP_400_BAD_REQUEST, messages=ErrorCode.HTTP_400, details=exc.detail))
elif exc.status_code == 401:
err = ErrorModel(**dict(code=status.HTTP_401_UNAUTHORIZED, messages=ErrorCode.HTTP_401, details=exc.detail))
elif exc.status_code == 500:
err = ErrorModel(**dict(code=status.HTTP_500_INTERNAL_SERVER_ERROR, messages=ErrorCode.HTTP_500, details=exc.detail))
else:
err = None # 未定义的错误消息
res = BaseResponse(**dict(result=None, error=err))
return JSONResponse(status_code=exc.status_code, content=res.dict())

View File

@ -0,0 +1,168 @@
# -*- coding: utf-8 -*-
import logging
import mimetypes
import os
import random
import shutil
import time
import traceback
import urllib.request
from fastapi import UploadFile
from .qiniu_tools import QiniuYuProvider
# http_base = "" # 外网范围基础地址
# path_base = "" # 文件保存基础地址
logger = logging.getLogger(__name__)
def init(baseurl, base_dir):
"""初始化"""
global http_base, path_base
http_base = baseurl
path_base = base_dir
def del_cloud_file(cfg, key):
"""
删除云服务文件
:param cfg:
:param key:
:return:
"""
api = QiniuYuProvider(**cfg)
return api.del_file(key=key)
def save_cloud_file(file: UploadFile, cfg, user=None):
"""
保存文件至云服务
:param file:
:param cfg:
:param user:
:return:
"""
try:
ext = get_file_ext(file.filename)
if ext not in ['jpg', 'png', 'jpeg', 'gif', 'aac', 'mp3', 'amr', 'apk', 'webm', 'mp4', 'mov', 'flv', 'mkv',
'3gp', 'avi']:
return 403, '文件格式不合法,请上传正确的文件格式', None
filename = build_file_name(ext) # 新文件名
file.filename = filename
uploaded_file_path = os.path.join("files", filename)
# flask 原有的图片保存方式
# file.save(uploaded_file_path)
# fastApi 实现的图片保存方式
with open(uploaded_file_path, "wb+") as buffer:
shutil.copyfileobj(file.file, buffer)
api = QiniuYuProvider(**cfg)
res = api.upload_file(file=uploaded_file_path, types=ext)
if res[0] == 200:
if os.path.exists(uploaded_file_path):
os.remove(uploaded_file_path)
return res
except Exception as e:
logger.error(traceback.format_exc())
return 500, '保存文件失败', None
def save_file(file, save_db=None, user=None):
"""保存文件,返回文件路径"""
try:
ext = get_file_ext(file.filename)
if ext not in ['jpg', 'png', 'jpeg', 'gif', 'aac', 'mp3', 'amr', 'apk', 'webm', 'mp4', 'mov', 'flv', 'mkv',
'3gp', 'avi']:
return 403, '文件格式不合法,请上传正确的文件格式', None
filename = build_file_name(ext) # 新文件名
date_str = time.strftime('%Y%m%d')
base_path = os.path.join(path_base, 'static', 'mr', date_str)
if not os.path.exists(base_path): # 确保当日的文件夹存在
os.makedirs(base_path)
filepath = os.path.join(base_path, filename) # 实际文件路径
file.save(filepath) # 保存文件
if save_db:
data = {
'http_base': http_base,
'file_base': path_base,
'file_name': file.filename, # 原始文件名
'file_path': '/static/mr/{date_str}/{filename}'.format(date_str=date_str, filename=filename),
'file_type': get_file_ext(file.filename), # 文件扩展名
'mime_type': mimetypes.guess_type(file.filename)[0], # 媒体类型
'file_size': get_file_size(filepath),
'file_size_str': get_file_size_str(filepath)
}
return save_db(data, user)
image_url = "{}/static/mr/{}/{}".format(http_base, date_str, filename)
return 200, "", image_url
except Exception as e:
logger.error(e)
return 500, '保存文件失败', None
def save_image(body):
"""
保存文件返回文件url
@param body 文件内容字典包括url,md5,ext,size等信息
"""
url = body.get('url', '')
ext = body.get('ext', 'jpg')
filename = build_file_name(ext) # 新文件名
if url != '':
date_str = time.strftime('%Y%m%d')
base_path = os.path.join(path_base, 'static', 'im', date_str)
if not os.path.exists(base_path): # 确保当日的文件夹存在
os.makedirs(base_path)
filepath = os.path.join(base_path, filename) # 实际文件路径
try:
urllib.request.urlretrieve(url, filepath)
url = "{}/static/im/{}/{}".format(http_base, date_str, filename)
except Exception as e:
logger.error(e)
return url # 返回图片url
def get_file_ext(filename):
"""获取文件扩展名"""
if '.' in filename:
return filename.rsplit('.', 1)[1].lower()
return None
def build_file_name(ext):
"""生成文件名"""
filename = time.strftime('%H%M%S') + str(random.random())[-6:]
return '{}.{}'.format(filename, ext) # 新文件名
def get_file_path(filename):
"""获取文件网站路径"""
return "{path_base}/{filename}".format(path_base=path_base, filename=filename)
def get_file_size(file_path):
"""获取文件大小"""
file_size = os.path.getsize(file_path)
return file_size
def get_file_size_str(file_path):
"""获取文件大小字符串"""
file_size = get_file_size(file_path)
def strofsize(integer, remainder, level):
if integer >= 1024:
remainder = integer % 1024
integer //= 1024
level += 1
return strofsize(integer, remainder, level)
else:
return integer, remainder, level
units = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']
integer, remainder, level = strofsize(file_size, 0, 0)
if level + 1 > len(units):
level = -1
return '{}.{:>03d} {}'.format(integer, remainder, units[level])

103
src/utils/qiniu_tools.py Normal file
View File

@ -0,0 +1,103 @@
import datetime
import logging
import random
import traceback
import uuid
import qiniu.config
from qiniu import Auth, BucketManager, etag, put_data, put_file
__author__ = 'Woodstock'
__doc__ = """七牛上传"""
logger = logging.getLogger(__name__)
class QiniuYuProvider(object):
def __init__(self, type, name, access_key, secret_key, bucket_name, policy, url='http://dealer-public.hyd169.com'):
"""
:param type:
:param name:
:param access_key:
:param secret_key:
:param bucket_name:
:param policy:
:param url:
"""
self.type = type
self.name = name
self.access_key = access_key
self.secret_key = secret_key
self.bucket_name = bucket_name
self.policy = policy
self.url = url
def del_file(self, key):
"""
删除已上传的文件
:param key:
:return:
"""
try:
q = Auth(self.access_key, self.secret_key)
bucket = BucketManager(q)
# 删除bucket_name 中的文件 key
ret, info = bucket.delete(self.bucket_name, key)
status_code = info.status_code
if status_code == 200:
return 200, '删除图片成功'
return 500, '删除图片失败'
except Exception as e:
logger.error(traceback.format_exc())
return 500, '删除图片异常'
def upload_file(self, file, types=None):
"""
上传文件
:param file:
:param types:
:return:
"""
try:
q = Auth(self.access_key, self.secret_key)
key = self.generate_unique_key(token=uuid.uuid4(), types=types)
if not key or key == '':
raise ValueError('Generate unique key error')
token = q.upload_token(self.bucket_name, key)
ret, info = put_file(token, key, file)
status_code = info.status_code
if status_code == 200:
assert ret['key'] == key
assert ret['hash'] == etag(file)
data = dict(
req_id=info.req_id,
key=ret['key'],
hash=ret['hash'],
url=self.url + '/' + ret['key'],
type=types,
file=file
)
return 200, '上传成功', data
return 500, '上传失败', None
except Exception as e:
logger.error(traceback.format_exc())
return 500, '上传失败', None
@staticmethod
def generate_unique_key(token, types):
"""
:param token:
:param types:
:return:
"""
import pytz
tz = pytz.timezone('Asia/Shanghai')
key = str(token) + '-' + str(datetime.datetime.now(tz).microsecond) + '-' + str(random.random()) + '.' + str(
types)
return key

View File

@ -1,17 +1,17 @@
import json
import uuid
import time
import hmac
import random
import base64
import hashlib
import logging
import datetime
import requests
import hashlib
import hmac
import json
import logging
import random
import time
import traceback
from urllib.parse import urlencode, quote
import uuid
from urllib.parse import quote, urlencode
import requests
__author__ = 'Woodstock'

View File

@ -1,4 +1,5 @@
import unittest
import requests

View File

@ -1,77 +0,0 @@
import jwt
import aioredis
from datetime import datetime, timedelta
from fastapi import status, Depends
from fastapi_plugins import depends_redis
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from fastapi.exceptions import HTTPException
from utils.user_tools import get_user_by_email, authenticate_user
from common import schemas
SECRET_KEY = "5cf366d7b06714683756ca019c6283ed440bf771edd281d8"
tokenUrl = "/token"
oauth2_scheme = OAuth2PasswordBearer(
tokenUrl=tokenUrl,
scopes={"me": "Read information about the current user.", "items": "Read items."},
)
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(),
cache: aioredis.Redis = Depends(depends_redis)):
"""验证成功后返回有效的用户访问令牌"""
user = authenticate_user(form_data)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token_expires = timedelta(minutes=60)
access_token = create_access_token(data={"sub": user.Email}, expires_delta=access_token_expires)
# 将access_token 存到缓存里面
await cache.setex("TOKEN:" + str(user.Id), access_token_expires.seconds, access_token)
return dict(access_token=access_token, access_type="bearer")
def create_access_token(*, data: dict,
expires_delta: timedelta = None,
algorithm="HS256"):
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=algorithm)
return encoded_jwt
async def decode_access_token(token, cache: aioredis.Redis, algorithm="HS256"):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[algorithm])
email: str = payload.get("sub")
if email is None:
raise credentials_exception
token_data = schemas.TokenData(email=email)
except jwt.PyJWTError:
raise credentials_exception
user = get_user_by_email(email=token_data.email)
if user is None:
raise credentials_exception
# 查询缓存
cached_token = await cache.get("TOKEN:" + str(user.Id), encoding="utf-8")
if cached_token is None or cached_token != token:
raise credentials_exception
return user
async def get_current_user(token: str = Depends(oauth2_scheme),
cache: aioredis.Redis = Depends(depends_redis)):
current_user = await decode_access_token(token, cache)
return current_user

View File

@ -1,92 +0,0 @@
import re
from typing import Optional
from fastapi import status
from fastapi.security import OAuth2PasswordRequestForm
from fastapi.exceptions import HTTPException
from fastapi_sqlalchemy import db
from models.users import SysUser
from common.schemas import UserCreateUpdate
from common.password import verify_password, get_password_hash
def validate_email(email: str) -> bool:
pattern = r"[a-zA-Z0-9_-]+@[a-zA-Z0-9_-]+(?:\.[a-zA-Z0-9_-]+)"
if re.match(pattern, email) is not None:
return True
else:
return False
def validate_phone_number(phone_number: str) -> object:
pattern = r"1[356789]\d{9}"
if re.match(pattern, phone_number) is not None:
return True
else:
return False
def get_user_by_email(email: str) -> Optional[SysUser]:
if validate_email(email) is True:
user = db.session.query(SysUser).filter(SysUser.Email == email).first()
return user
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="无效的邮箱地址"
)
def get_user_by_phone_number(phone_number: str) -> Optional[SysUser]:
if validate_phone_number(phone_number) is True:
user = db.session.query(SysUser).filter(SysUser.PhoneNumber == phone_number).first()
return user
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="无效的手机号码格式"
)
def get_user_by_username(username: str) -> Optional[SysUser]:
user = db.session.query(SysUser).filter(SysUser.UserName == username).first()
return user
def authenticate_user(form_data: OAuth2PasswordRequestForm) -> Optional[SysUser]:
if validate_email(form_data.username) is True:
user = get_user_by_email(form_data.username)
else:
if validate_phone_number(form_data.username) is True:
user = get_user_by_phone_number(form_data.username)
else:
user = get_user_by_username(form_data.username)
if not user:
return None
if not verify_password(form_data.password, user.Password):
return None
return user
def auth_user(username: str, password: str) -> Optional[SysUser]:
if validate_email(username) is True:
user = get_user_by_email(username)
else:
if validate_phone_number(username) is True:
user = get_user_by_phone_number(username)
else:
user = get_user_by_username(username)
if not user:
return None
if not verify_password(password, user.Password):
return None
return user
def create_user(user_data: UserCreateUpdate):
user_data.Password = get_password_hash(user_data.Password)
user = SysUser(**user_data.dict())
db.session.add(user)
db.session.commit()
return user