2025-07-04 19:07:35 +08:00

137 lines
4.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
购物车模型
"""
from datetime import datetime
from config.database import db
from app.models.product import Product, ProductInventory
class Cart(db.Model):
"""购物车模型"""
__tablename__ = 'cart'
id = db.Column(db.Integer, primary_key=True)
user_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False)
product_id = db.Column(db.Integer, db.ForeignKey('products.id'), nullable=False)
sku_code = db.Column(db.String(100))
spec_combination = db.Column(db.String(255))
quantity = db.Column(db.Integer, nullable=False, default=1)
created_at = db.Column(db.DateTime, default=datetime.utcnow)
updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# 关联关系
user = db.relationship('User', backref='cart_items')
product = db.relationship('Product', backref='cart_items')
def get_sku_info(self):
"""获取SKU信息"""
if self.sku_code:
return ProductInventory.query.filter_by(sku_code=self.sku_code).first()
else:
# 如果没有SKU返回默认库存信息
return ProductInventory.query.filter_by(
product_id=self.product_id,
is_default=1
).first()
def get_price(self):
"""获取商品价格"""
sku_info = self.get_sku_info()
if sku_info:
return sku_info.get_final_price()
return float(self.product.price) if self.product and self.product.price else 0
def get_total_price(self):
"""获取小计金额"""
return self.get_price() * self.quantity
def get_stock(self):
"""获取库存数量"""
sku_info = self.get_sku_info()
return sku_info.stock if sku_info else 0
def is_available(self):
"""检查商品是否可用"""
# 检查商品是否上架
if not self.product or self.product.status != 1:
return False
# 检查库存
if self.get_stock() < self.quantity:
return False
return True
def to_dict(self):
"""转换为字典"""
sku_info = self.get_sku_info()
return {
'id': self.id,
'user_id': self.user_id,
'product_id': self.product_id,
'product_name': self.product.name if self.product else '',
'product_image': self.product.main_image if self.product else '',
'brand': self.product.brand if self.product else '',
'sku_code': self.sku_code,
'spec_combination': self.spec_combination,
'quantity': self.quantity,
'price': self.get_price(),
'total_price': self.get_total_price(),
'stock': self.get_stock(),
'is_available': self.is_available(),
'created_at': self.created_at.isoformat() if self.created_at else None
}
@classmethod
def add_to_cart(cls, user_id, product_id, sku_code=None, spec_combination=None, quantity=1):
"""添加商品到购物车"""
# 检查是否已存在相同商品
existing_item = cls.query.filter_by(
user_id=user_id,
product_id=product_id,
sku_code=sku_code
).first()
if existing_item:
# 更新数量
existing_item.quantity += quantity
existing_item.updated_at = datetime.utcnow()
db.session.commit()
return existing_item
else:
# 创建新记录
cart_item = cls(
user_id=user_id,
product_id=product_id,
sku_code=sku_code,
spec_combination=spec_combination,
quantity=quantity
)
db.session.add(cart_item)
db.session.commit()
return cart_item
@classmethod
def get_user_cart(cls, user_id):
"""获取用户购物车"""
return cls.query.filter_by(user_id=user_id)\
.order_by(cls.created_at.desc()).all()
@classmethod
def get_cart_count(cls, user_id):
"""获取购物车商品数量"""
return cls.query.filter_by(user_id=user_id).count()
@classmethod
def get_cart_total(cls, user_id):
"""获取购物车总金额"""
cart_items = cls.get_user_cart(user_id)
total = 0
for item in cart_items:
if item.is_available():
total += item.get_total_price()
return total
def __repr__(self):
return f'<Cart {self.user_id}-{self.product_id}>'