diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..5b9e7a4 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,17 @@ +FROM python:3.8.6-alpine + +RUN set -eux && sed -i 's/dl-cdn.alpinelinux.org/mirrors.ustc.edu.cn/g' /etc/apk/repositories + +RUN apk update && apk add python3-dev gcc libc-dev libffi-dev + +RUN apk add jpeg-dev zlib-dev freetype-dev lcms2-dev openjpeg-dev tiff-dev tk-dev tcl-dev + +WORKDIR /app + +COPY . . + +RUN pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ + +ENV FAST_API_ENV=dev + +CMD ["/usr/local/bin/uvicorn", "main:fast_api_app", "--reload", "--host", "0.0.0.0"] diff --git a/api/test.py b/api/test.py index 73d66ce..f2f3cbf 100644 --- a/api/test.py +++ b/api/test.py @@ -48,14 +48,18 @@ async def read_users_me(current_user: User = Depends(get_current_user)): # 数据库查询 -@router.get("/query", response_model=sqlalchemy_to_pydantic(DevicesPlace)) -def query(mid: str = Query(..., description='mid'), region_id: Optional[int] = Query(None, description="区域Id")): - if region_id is None: - record = db.session.query(DevicesPlace).filter(DevicesPlace.mid == mid).first() +@router.get("/query") +def query(mid: Optional[str] = Query(None, description='mid'), + region_id: int = Query(..., description="区域Id"), + limit: int = 20, + page_count: int = 1): + if mid is None: + records = db.session.query(DevicesPlace).filter(DevicesPlace.region_id == region_id)\ + .order_by(DevicesPlace.id).limit(limit).offset(limit*(page_count-1)).all() else: - record = db.session.query(DevicesPlace).filter(DevicesPlace.mid == mid)\ - .filter(DevicesPlace.region_id == region_id).first() - return record + records = db.session.query(DevicesPlace).filter(DevicesPlace.mid == mid)\ + .filter(DevicesPlace.region_id == region_id).order_by(DevicesPlace.id).limit(limit).offset(limit*(page_count-1)).all() + return records # Redis 缓存查询 diff --git a/requirements.txt b/requirements.txt index c1de5d1..a3c8ab8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,14 +1,21 @@ aiofiles==0.6.0 aioredis==1.3.1 +bcrypt==3.2.0 +email-validator==1.1.2 fastapi==0.63.0 +fastapi-plugins==0.6.0 FastAPI-SQLAlchemy==0.2.1 -MarkupSafe==1.1.1 pydantic==1.8.1 pydantic-sqlalchemy==0.0.8.post1 PyMySQL==0.10.1 +python-multipart==0.0.5 pytest==6.2.2 requests==2.24.0 sqlacodegen==2.3.0 SQLAlchemy==1.3.23 starlette==0.13.6 -uvicorn==0.13.4 \ No newline at end of file +uvicorn==0.13.4 +PyJWT==2.0.1 +passlib==1.7.4 +Pillow==8.0.1 +captcha==0.3 \ No newline at end of file diff --git a/utils/captcha_tools.py b/utils/captcha_tools.py new file mode 100644 index 0000000..db6fd37 --- /dev/null +++ b/utils/captcha_tools.py @@ -0,0 +1,94 @@ + +import random +import logging +import base64 +from os import urandom +from io import BytesIO +from PIL import ImageFilter +from captcha.image import ImageCaptcha, random_color + +__author__ = 'Woodstock' + +__doc__ = """ +验证码图片生成 +""" + +logger = logging.getLogger(__name__) + + +class CaptchaCode(object): + BASES = '02345689' # 1acdefhjkmnprtuvwxyACDEFGHJKLMNPQRTUVWXY' # 去除部分不易识别的字母 + + @staticmethod + def generate(code_len=4): + """ + 验证码长度 + :param code_len: + :return: + """ + code = urandom(code_len) # 验证码长度 + n = len(CaptchaCode.BASES) + return "".join([CaptchaCode.BASES[int(c) % n - 1] for c in code]) + + +class CustomCaptcha(ImageCaptcha): + def __init__(self, width=160, height=60, fonts=None, font_sizes=None): + super(CustomCaptcha, self).__init__(width, height, fonts, + font_sizes=self._scale_font(font_sizes, width, height)) + + def _scale_font(self, font_sizes, width, height): + """ + :param font_sizes: + :param width: + :param height: + :return: + """ + if not font_sizes: + r = width / 320 + height / 120 + font_sizes = set([int(fs * r) for fs in (42, 50, 56)]) + logger.debug(str(font_sizes)) + return font_sizes + + def generate_image_custom(self, chars, dots_count=30, dots_size=3, curve=1): + """ + :param chars: + :param dots_count: + :param dots_size: + :param curve: + :return: + """ + background = random_color(238, 255) + color = random_color(0, 200, random.randint(220, 255)) + im = self.create_captcha_image(chars, color, background) + if dots_count > 0 and dots_size > 0: + self.create_noise_dots(im, color, dots_size, dots_count) + + if curve: + self.create_noise_curve(im, color) + im = im.filter(ImageFilter.SMOOTH) + buf = BytesIO() + im.save(buf, format='PNG') + buf.seek(0) + return buf + + def generate_image_base64_str(self, chars, dots_count=30, dots_size=3, curve=1): + """ + 生成图片验证码二进制字符串 + :param chars: + :param dots_count: + :param dots_size: + :param curve: + :return: + """ + buf = self.generate_image_custom(chars, dots_count, dots_size, curve) + byte_data = buf.getvalue() + base64_str = str(base64.b64encode(byte_data), 'utf-8') + return base64_str + + +if __name__ == '__main__': + captcha_code = CaptchaCode.generate() + print(captcha_code) + base64_str = CustomCaptcha().generate_image_base64_str(captcha_code) + print(base64_str) + print(len(base64_str)) diff --git a/utils/sms_tools.py b/utils/sms_tools.py new file mode 100644 index 0000000..99fb3db --- /dev/null +++ b/utils/sms_tools.py @@ -0,0 +1,149 @@ + +import json +import uuid +import time +import hmac +import random +import base64 +import hashlib +import logging +import datetime +import requests +import traceback +from urllib.parse import urlencode, quote + + +__author__ = 'Woodstock' + +__doc__ = '''阿里大鱼短信接口''' +logger = logging.getLogger(__name__) + + +class AliDaYuProvider(object): + def __init__(self, key, secret, url="https://dysmsapi.aliyuncs.com/", template_code='', sign_name='水明堂', code_count=6): + """ + :param key: + :param secret: + :param url: + :param template_code: + :param sign_name: + """ + self.key = key + self.secret = secret + self.url = url + self.template_id = template_code + self.sign_name = sign_name + self.get_url = None + self.params = None + self.code_count = code_count + + def send(self, to, data): + """ + 发送模板短信 + :param to: + :param data: + :return: + """ + return self.send_template_sms(to=to, data=data, temp_id=self.template_id, sign_name=self.sign_name) + + def send_template_sms(self, to, data, temp_id, sign_name): + """ + 发送模板短信 + :param to: + :param data: 内容数据 格式为数组 例如:{'code': '123456'},如不需替换请填 '' + :param temp_id: + :param sign_name: + :return: + """ + if isinstance(data, dict): + data = json.dumps(data) + + self.params = { + 'PhoneNumbers': to, + 'SignName': sign_name, + 'TemplateCode': temp_id, + 'TemplateParam': data, + 'OutId': '123' + } + try: + self.sign('GET') + resp = self.get_response() + if resp['Code'] == 'OK': + return 0 + else: + return -1 + except: + logger.error(traceback.format_exc()) + return -1 + + def sign(self, method): + """ + 签名 + :param method: + :return: + """ + self.params.update({ + "SignatureMethod": "HMAC-SHA1", + "SignatureNonce": str(uuid.uuid4()), + "AccessKeyId": self.key, + "SignatureVersion": "1.0", + "Timestamp": datetime.datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%SZ'), # '2017-07-12T02:42:19Z', + "Format": "json", # XML + "Action": "SendSms", + "Version": "2017-05-25", + "RegionId": "cn-hangzhou" + }) + params_str = urlencode(sorted(self.params.items())) + params_str = "&".join([method, "%2F", quote(params_str)]).replace("%2B", '%2520') # 因url将空格改成加号 + print(params_str) + key = "{}&".format(self.secret) + sign_str = hmac_sha1_base64(key, params_str) + params_sign_str = "&".join([urlencode({'Signature': sign_str}), urlencode(sorted(self.params.items()))]) + self.get_url = "?".join([self.url, params_sign_str]) + + def get_response(self): + """ + 获取发送结果 + :return: + """ + start = time.time() + response = requests.get(self.get_url) + result_json = response.content.decode() + result = json.loads(result_json) + end = time.time() + print(result_json) + print("发送阿里大鱼短信:\n 耗时:{}\n 请求参数:{}\n 返回结果:{}".format(end - start, json.dumps(self.params), result_json)) + + return result + + def send_sms_code_msg(self, to, code): + sms_msg = json.dumps(dict(code=code)) + return self.send_template_sms(to, sms_msg, temp_id=self.template_id, sign_name=self.sign_name) + + +def gen_sms_code(code_count=6): + random_str = str(random.random()) + if code_count < 16: + code = random_str[-code_count:] + else: + code = random_str[-6:] + return code + + +def hmac_sha1_base64(key, content): + """ + sha1签名后base64 + :param key: + :param content: + :return: + """ + return base64.b64encode(hmac.new(key.encode("utf-8"), content.encode("utf-8"), hashlib.sha1).digest()).decode() + + +if __name__ == '__main__': + cfg = {"type": "alidayu", "config": {"key": "LTAI9pTVsFg68Tjw", "secret": "WToG39WC6eLkdnzxhKzlNYEqDV2WFd", + "template_code": "SMS_130918971", "sign_name": "水名堂"}} + api = AliDaYuProvider(**cfg['config']) + sms_code = gen_sms_code() + res = api.send_sms_code_msg("15359827092", sms_code) + print(res) \ No newline at end of file diff --git a/utils/token_tools.py b/utils/token_tools.py new file mode 100644 index 0000000..ba36a4f --- /dev/null +++ b/utils/token_tools.py @@ -0,0 +1,77 @@ +import jwt +import aioredis +from datetime import datetime, timedelta +from fastapi import status, Depends +from fastapi_plugins import depends_redis +from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm +from fastapi.exceptions import HTTPException +from utils.user_tools import get_user_by_email, authenticate_user +from common import schemas + +SECRET_KEY = "5cf366d7b06714683756ca019c6283ed440bf771edd281d8" +tokenUrl = "/token" + +oauth2_scheme = OAuth2PasswordBearer( + tokenUrl=tokenUrl, + scopes={"me": "Read information about the current user.", "items": "Read items."}, +) + + +async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), + cache: aioredis.Redis = Depends(depends_redis)): + """验证成功后返回有效的用户访问令牌""" + user = authenticate_user(form_data) + if not user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect username or password", + headers={"WWW-Authenticate": "Bearer"}, + ) + access_token_expires = timedelta(minutes=60) + access_token = create_access_token(data={"sub": user.Email}, expires_delta=access_token_expires) + # 将access_token 存到缓存里面 + await cache.setex("TOKEN:" + str(user.Id), access_token_expires.seconds, access_token) + return dict(access_token=access_token, access_type="bearer") + + +def create_access_token(*, data: dict, + expires_delta: timedelta = None, + algorithm="HS256"): + to_encode = data.copy() + if expires_delta: + expire = datetime.utcnow() + expires_delta + else: + expire = datetime.utcnow() + timedelta(minutes=15) + to_encode.update({"exp": expire}) + encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=algorithm) + return encoded_jwt + + +async def decode_access_token(token, cache: aioredis.Redis, algorithm="HS256"): + credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[algorithm]) + email: str = payload.get("sub") + if email is None: + raise credentials_exception + token_data = schemas.TokenData(email=email) + except jwt.PyJWTError: + raise credentials_exception + user = get_user_by_email(email=token_data.email) + if user is None: + raise credentials_exception + # 查询缓存 + cached_token = await cache.get("TOKEN:" + str(user.Id), encoding="utf-8") + if cached_token is None or cached_token != token: + raise credentials_exception + return user + + +async def get_current_user(token: str = Depends(oauth2_scheme), + cache: aioredis.Redis = Depends(depends_redis)): + current_user = await decode_access_token(token, cache) + return current_user diff --git a/utils/user_tools.py b/utils/user_tools.py new file mode 100644 index 0000000..8d12aa4 --- /dev/null +++ b/utils/user_tools.py @@ -0,0 +1,92 @@ +import re +from typing import Optional +from fastapi import status +from fastapi.security import OAuth2PasswordRequestForm +from fastapi.exceptions import HTTPException +from fastapi_sqlalchemy import db +from models.users import SysUser +from common.schemas import UserCreateUpdate +from common.password import verify_password, get_password_hash + + +def validate_email(email: str) -> bool: + pattern = r"[a-zA-Z0-9_-]+@[a-zA-Z0-9_-]+(?:\.[a-zA-Z0-9_-]+)" + if re.match(pattern, email) is not None: + return True + else: + return False + + +def validate_phone_number(phone_number: str) -> object: + pattern = r"1[356789]\d{9}" + if re.match(pattern, phone_number) is not None: + return True + else: + return False + + +def get_user_by_email(email: str) -> Optional[SysUser]: + if validate_email(email) is True: + user = db.session.query(SysUser).filter(SysUser.Email == email).first() + return user + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="无效的邮箱地址" + ) + + +def get_user_by_phone_number(phone_number: str) -> Optional[SysUser]: + if validate_phone_number(phone_number) is True: + user = db.session.query(SysUser).filter(SysUser.PhoneNumber == phone_number).first() + return user + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="无效的手机号码格式" + ) + + +def get_user_by_username(username: str) -> Optional[SysUser]: + user = db.session.query(SysUser).filter(SysUser.UserName == username).first() + return user + + +def authenticate_user(form_data: OAuth2PasswordRequestForm) -> Optional[SysUser]: + if validate_email(form_data.username) is True: + user = get_user_by_email(form_data.username) + else: + if validate_phone_number(form_data.username) is True: + user = get_user_by_phone_number(form_data.username) + else: + user = get_user_by_username(form_data.username) + if not user: + return None + if not verify_password(form_data.password, user.Password): + return None + return user + + +def auth_user(username: str, password: str) -> Optional[SysUser]: + if validate_email(username) is True: + user = get_user_by_email(username) + else: + if validate_phone_number(username) is True: + user = get_user_by_phone_number(username) + else: + user = get_user_by_username(username) + if not user: + return None + if not verify_password(password, user.Password): + return None + return user + + +def create_user(user_data: UserCreateUpdate): + user_data.Password = get_password_hash(user_data.Password) + user = SysUser(**user_data.dict()) + db.session.add(user) + db.session.commit() + return user + +