182 lines
		
	
	
		
			6.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			182 lines
		
	
	
		
			6.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
from flask import Blueprint, render_template, request, redirect, url_for, flash, session, jsonify
 | 
						||
from werkzeug.security import generate_password_hash, check_password_hash
 | 
						||
from app.models.user import User, db
 | 
						||
from app.utils.email import send_verification_email, generate_verification_code
 | 
						||
import logging
 | 
						||
from functools import wraps
 | 
						||
import time
 | 
						||
from datetime import datetime, timedelta
 | 
						||
 | 
						||
# 创建蓝图
 | 
						||
user_bp = Blueprint('user', __name__)
 | 
						||
 | 
						||
 | 
						||
# 使用内存字典代替Redis存储验证码
 | 
						||
class VerificationStore:
 | 
						||
    def __init__(self):
 | 
						||
        self.codes = {}  # 存储格式: {email: {'code': code, 'expires': timestamp}}
 | 
						||
 | 
						||
    def setex(self, email, seconds, code):
 | 
						||
        """设置验证码并指定过期时间"""
 | 
						||
        expiry = datetime.now() + timedelta(seconds=seconds)
 | 
						||
        self.codes[email] = {'code': code, 'expires': expiry}
 | 
						||
        return True
 | 
						||
 | 
						||
    def get(self, email):
 | 
						||
        """获取验证码,如果过期则返回None"""
 | 
						||
        if email not in self.codes:
 | 
						||
            return None
 | 
						||
 | 
						||
        data = self.codes[email]
 | 
						||
        if datetime.now() > data['expires']:
 | 
						||
            # 验证码已过期,删除它
 | 
						||
            self.delete(email)
 | 
						||
            return None
 | 
						||
 | 
						||
        return data['code']
 | 
						||
 | 
						||
    def delete(self, email):
 | 
						||
        """删除验证码"""
 | 
						||
        if email in self.codes:
 | 
						||
            del self.codes[email]
 | 
						||
        return True
 | 
						||
 | 
						||
 | 
						||
# 使用内存存储验证码
 | 
						||
verification_codes = VerificationStore()
 | 
						||
 | 
						||
 | 
						||
def login_required(f):
 | 
						||
    @wraps(f)
 | 
						||
    def decorated_function(*args, **kwargs):
 | 
						||
        if 'user_id' not in session:
 | 
						||
            return redirect(url_for('user.login'))
 | 
						||
        return f(*args, **kwargs)
 | 
						||
 | 
						||
    return decorated_function
 | 
						||
 | 
						||
 | 
						||
@user_bp.route('/login', methods=['GET', 'POST'])
 | 
						||
def login():
 | 
						||
    # 保持原代码不变
 | 
						||
    if request.method == 'POST':
 | 
						||
        username = request.form.get('username')
 | 
						||
        password = request.form.get('password')
 | 
						||
        remember_me = request.form.get('remember_me') == 'on'
 | 
						||
 | 
						||
        if not username or not password:
 | 
						||
            return render_template('login.html', error='用户名和密码不能为空')
 | 
						||
 | 
						||
        # 检查用户是否存在
 | 
						||
        user = User.query.filter((User.username == username) | (User.email == username)).first()
 | 
						||
 | 
						||
        if not user or not user.check_password(password):
 | 
						||
            return render_template('login.html', error='用户名或密码错误')
 | 
						||
 | 
						||
        if user.status == 0:
 | 
						||
            return render_template('login.html', error='账号已被禁用,请联系管理员')
 | 
						||
 | 
						||
        # 登录成功,保存用户信息到会话
 | 
						||
        session['user_id'] = user.id
 | 
						||
        session['username'] = user.username
 | 
						||
        session['role_id'] = user.role_id
 | 
						||
 | 
						||
        if remember_me:
 | 
						||
            # 设置会话过期时间为7天
 | 
						||
            session.permanent = True
 | 
						||
 | 
						||
        # 记录登录日志(可选)
 | 
						||
        # log_user_action('用户登录')
 | 
						||
 | 
						||
        # 重定向到首页
 | 
						||
        return redirect(url_for('index'))
 | 
						||
 | 
						||
    return render_template('login.html')
 | 
						||
 | 
						||
 | 
						||
@user_bp.route('/register', methods=['GET', 'POST'])
 | 
						||
def register():
 | 
						||
    if request.method == 'POST':
 | 
						||
        username = request.form.get('username')
 | 
						||
        email = request.form.get('email')
 | 
						||
        password = request.form.get('password')
 | 
						||
        confirm_password = request.form.get('confirm_password')
 | 
						||
        verification_code = request.form.get('verification_code')
 | 
						||
 | 
						||
        # 验证表单数据
 | 
						||
        if not username or not email or not password or not confirm_password or not verification_code:
 | 
						||
            return render_template('register.html', error='所有字段都是必填项')
 | 
						||
 | 
						||
        if password != confirm_password:
 | 
						||
            return render_template('register.html', error='两次输入的密码不匹配')
 | 
						||
 | 
						||
        # 检查用户名和邮箱是否已存在
 | 
						||
        if User.query.filter_by(username=username).first():
 | 
						||
            return render_template('register.html', error='用户名已存在')
 | 
						||
 | 
						||
        if User.query.filter_by(email=email).first():
 | 
						||
            return render_template('register.html', error='邮箱已被注册')
 | 
						||
 | 
						||
        # 验证验证码
 | 
						||
        stored_code = verification_codes.get(email)
 | 
						||
        if not stored_code or stored_code != verification_code:
 | 
						||
            return render_template('register.html', error='验证码无效或已过期')
 | 
						||
 | 
						||
        # 创建新用户
 | 
						||
        try:
 | 
						||
            new_user = User(
 | 
						||
                username=username,
 | 
						||
                password=password,  # 密码会在模型中自动哈希
 | 
						||
                email=email,
 | 
						||
                nickname=username  # 默认昵称与用户名相同
 | 
						||
            )
 | 
						||
            db.session.add(new_user)
 | 
						||
            db.session.commit()
 | 
						||
 | 
						||
            # 清除验证码
 | 
						||
            verification_codes.delete(email)
 | 
						||
 | 
						||
            flash('注册成功,请登录', 'success')
 | 
						||
            return redirect(url_for('user.login'))
 | 
						||
        except Exception as e:
 | 
						||
            db.session.rollback()
 | 
						||
            logging.error(f"User registration failed: {str(e)}")
 | 
						||
            return render_template('register.html', error='注册失败,请稍后重试')
 | 
						||
 | 
						||
    return render_template('register.html')
 | 
						||
 | 
						||
 | 
						||
@user_bp.route('/logout')
 | 
						||
def logout():
 | 
						||
    # 清除会话数据
 | 
						||
    session.pop('user_id', None)
 | 
						||
    session.pop('username', None)
 | 
						||
    session.pop('role_id', None)
 | 
						||
    return redirect(url_for('user.login'))
 | 
						||
 | 
						||
 | 
						||
@user_bp.route('/send_verification_code', methods=['POST'])
 | 
						||
def send_verification_code():
 | 
						||
    data = request.get_json()
 | 
						||
    email = data.get('email')
 | 
						||
 | 
						||
    if not email:
 | 
						||
        return jsonify({'success': False, 'message': '请提供邮箱地址'})
 | 
						||
 | 
						||
    # 检查邮箱格式
 | 
						||
    import re
 | 
						||
    if not re.match(r"[^@]+@[^@]+\.[^@]+", email):
 | 
						||
        return jsonify({'success': False, 'message': '邮箱格式不正确'})
 | 
						||
 | 
						||
    # 生成验证码
 | 
						||
    code = generate_verification_code()
 | 
						||
 | 
						||
    # 存储验证码(10分钟有效)
 | 
						||
    verification_codes.setex(email, 600, code)  # 10分钟过期
 | 
						||
 | 
						||
    # 发送验证码邮件
 | 
						||
    if send_verification_email(email, code):
 | 
						||
        return jsonify({'success': True, 'message': '验证码已发送'})
 | 
						||
    else:
 | 
						||
        return jsonify({'success': False, 'message': '邮件发送失败,请稍后重试'})
 |