86 lines
3.0 KiB
Python
86 lines
3.0 KiB
Python
from datetime import datetime, timedelta
|
|
from flask_sqlalchemy import SQLAlchemy
|
|
from flask_login import UserMixin
|
|
from werkzeug.security import generate_password_hash, check_password_hash
|
|
from app import db
|
|
import random
|
|
import string
|
|
|
|
class User(UserMixin, db.Model):
|
|
__tablename__ = 'users'
|
|
|
|
id = db.Column(db.Integer, primary_key=True)
|
|
email = db.Column(db.String(255), unique=True, nullable=False, index=True)
|
|
password_hash = db.Column(db.String(255), nullable=False)
|
|
name = db.Column(db.String(100), nullable=False)
|
|
age = db.Column(db.SmallInteger, nullable=False)
|
|
gender = db.Column(db.SmallInteger, nullable=False, comment='0-男, 1-女')
|
|
parent_contact = db.Column(db.String(255), nullable=True, comment='家长联系方式')
|
|
is_verified = db.Column(db.Boolean, default=False, comment='邮箱是否验证')
|
|
created_at = db.Column(db.DateTime, default=datetime.utcnow, index=True)
|
|
updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
|
|
|
def set_password(self, password):
|
|
"""设置密码"""
|
|
self.password_hash = generate_password_hash(password)
|
|
|
|
def check_password(self, password):
|
|
"""验证密码"""
|
|
return check_password_hash(self.password_hash, password)
|
|
|
|
def __repr__(self):
|
|
return f'<User {self.email}>'
|
|
|
|
class EmailVerification(db.Model):
|
|
__tablename__ = 'email_verifications'
|
|
|
|
id = db.Column(db.Integer, primary_key=True)
|
|
email = db.Column(db.String(255), nullable=False, index=True)
|
|
verification_code = db.Column(db.String(6), nullable=False, comment='6位数字验证码')
|
|
expires_at = db.Column(db.DateTime, nullable=False, index=True)
|
|
is_used = db.Column(db.Boolean, default=False, comment='是否已使用')
|
|
created_at = db.Column(db.DateTime, default=datetime.utcnow)
|
|
|
|
@classmethod
|
|
def generate_code(cls, email, expire_minutes=5):
|
|
"""生成验证码"""
|
|
# 清理过期的验证码
|
|
cls.query.filter(
|
|
cls.email == email,
|
|
cls.expires_at < datetime.utcnow()
|
|
).delete()
|
|
|
|
# 生成6位数字验证码
|
|
code = ''.join(random.choices(string.digits, k=6))
|
|
expires_at = datetime.utcnow() + timedelta(minutes=expire_minutes)
|
|
|
|
verification = cls(
|
|
email=email,
|
|
verification_code=code,
|
|
expires_at=expires_at
|
|
)
|
|
|
|
db.session.add(verification)
|
|
db.session.commit()
|
|
|
|
return code
|
|
|
|
@classmethod
|
|
def verify_code(cls, email, code):
|
|
"""验证验证码"""
|
|
verification = cls.query.filter(
|
|
cls.email == email,
|
|
cls.verification_code == code,
|
|
cls.expires_at > datetime.utcnow(),
|
|
cls.is_used == False
|
|
).first()
|
|
|
|
if verification:
|
|
verification.is_used = True
|
|
db.session.commit()
|
|
return True
|
|
return False
|
|
|
|
def __repr__(self):
|
|
return f'<EmailVerification {self.email}>'
|