当推荐系统或风控模型的在线推理(Online Inference)请求 QPS 从一万攀升到十万时,最大的瓶颈往往不再是模型计算本身,而是实时特征的获取延迟。批处理生成的特征时效性太差,无法捕捉用户最新的意图;而一个为在线服务设计的、能够支撑高并发读写、且能在毫秒级响应的实时特征存储系统,便成为决定业务成败的关键基础设施。
摆在我们面前的核心挑战是:设计一个系统,它能持续不断地接收海量用户行为事件流,近乎实时地计算出特征,并提供一个稳定、低延迟的查询接口,供下游的 PyTorch 推理服务在每次预测前调用。具体指标要求是:特征写入到可查询的延迟在5秒内,特征查询接口的 p99 延迟低于 15ms。
方案 A: 以内存缓存为核心的快速迭代方案
一个直观的架构是利用内存数据库的极致读取性能。这个方案的逻辑流如下:
- 事件源: 用户行为事件被发送到 Google Cloud Pub/Sub。
- 处理层: 一个无服务器函数(如 Cloud Functions)或一个独立的消费者服务订阅 Pub/Sub 主题,进行简单的特征计算。
- 存储与服务: 将计算出的特征写入一个高速内存缓存,如 Redis 或 Memcached。这个缓存层直接面向 PyTorch 推理服务,提供特征查询。
- 持久化: 为防止数据丢失,处理层会异步地将特征数据备份到 BigQuery 或 Google Cloud Storage 等持久化存储中。
graph TD subgraph "事件源" A[用户行为] --> B(Google Cloud Pub/Sub) end subgraph "处理与存储" B --> C{Cloud Function/Consumer} C -- "写入特征" --> D[Redis Cluster] C -- "异步备份" --> E[BigQuery] end subgraph "在线服务" F[PyTorch Inference Service] -- "查询特征 (p99 < 5ms)" --> D end
优势分析:
- 极低的读取延迟: Redis 的内存操作能轻松将 p99 读取延迟控制在 5ms 以内,完全满足我们的性能指标。
- 实现简单: Pub/Sub 到 Cloud Function 再到 Redis 的链路相对直接,开发和部署速度快。
劣势分析:
- 数据易失性与恢复复杂性: Redis 是内存数据库,尽管有 AOF/RDB 持久化机制,但在故障场景下,数据丢失的风险依然存在。从 BigQuery 进行数据恢复以预热(Warm-up)一个大规模 Redis 集群,将是一个极其缓慢且复杂的操作,期间服务不可用。
- 成本高昂: 特征数据量往往非常庞大。如果每个用户需要存储 1KB 的特征,那么一亿用户就需要 100GB 的存储空间。完全依赖内存,其成本将随着用户量和特征维度的增长而线性飙升,很快变得难以承受。
- 一致性问题: 异步备份到 BigQuery 的机制,意味着 Redis 和 BigQuery 之间存在数据不一致的窗口。在进行模型分析或特征回溯时,这种不一致性会带来麻烦。
- 扩展性限制: 虽然 Redis Cluster 支持水平扩展,但当数据量达到 TB 级别时,集群的管理和运维成本会急剧增加。
在真实项目中,这种架构更适合作为最终存储层的前置缓存,而非核心存储。将其作为唯一的数据源,对于需要高可靠性和成本效益的长期项目来说,技术债务过高。
方案 B: 以 Cassandra 为核心的持久化存储方案
考虑到方案 A 在持久化、成本和扩展性上的根本缺陷,我们设计了第二套方案,将核心存储替换为一个为高并发写和水平扩展而生的分布式 NoSQL 数据库:Apache Cassandra。
- 事件源: 同样使用 Google Cloud Pub/Sub 接收事件流。
- 处理层: 一个健壮的消费者集群(例如,部署在 GKE 上的 Python 服务)负责处理消息,保证至少一次(At-least-once)的消费语义。
- 核心存储: 处理后的特征直接写入一个自管或托管的 Cassandra 集群。Cassandra 的数据模型将被精心设计以优化读取性能。
- 服务层: 一个轻量级的 Flask 应用作为特征服务 API。它提供一个高性能的 RESTful 或 gRPC 接口供 PyTorch 推理服务调用。同时,该应用内嵌一个简单的 Server-Side Rendering (SSR) 监控面板,供团队内部监控特征状态。
graph TD subgraph "事件源" A[用户行为] --> B(Google Cloud Pub/Sub) end subgraph "处理层 (GKE)" B --> C[Python Consumer Pool] end subgraph "存储与服务" C -- "写入特征 (高吞吐)" --> D[Cassandra Cluster] F[PyTorch Inference Service] -- "查询特征 (p99 < 15ms)" --> E D -- "读取" --> E(Flask Feature Service) end subgraph "内部监控" G[工程师] -- "访问" --> H(SSR Dashboard) E -- "包含" --> H end
优势分析:
- 水平扩展与高可用: Cassandra 的无主(Masterless)架构使其能够通过简单地增加节点来线性扩展读写吞吐量。其内置的副本机制和跨可用区部署能力保证了数据的高可用性。
- 写性能优异: Cassandra 基于 LSM-Tree 的存储引擎对写入操作极其友好,非常适合我们这种持续不断的事件流写入场景。
- 成本效益: Cassandra 主要使用磁盘存储(配合内存缓存),相比纯内存方案,单位存储成本大幅降低,能够经济地存储海量历史特征。
- 数据持久化: 数据写入 Cassandra 后即被持久化,并根据副本策略同步到多个节点,数据安全性远高于方案 A。
- 可调一致性: Cassandra 允许在每次读写请求中指定一致性级别(如
ONE
,QUORUM
),让我们可以根据业务场景在性能和数据一致性之间做灵活权衡。
劣势分析:
- 读取延迟相对较高: 相比纯内存的 Redis,Cassandra 的 p99 读取延迟通常在 10-20ms 范围。这需要通过精细的数据建模和调优来确保其满足我们的 15ms 指标。
- 数据建模复杂: Cassandra 的性能严重依赖于其数据模型。查询模式必须在表设计阶段就已确定,不支持灵活的 ad-hoc 查询。错误的 partition key 设计会导致热点问题,严重影响性能。
- 运维复杂度: 无论是自建还是使用托管服务,Cassandra 集群的运维、监控和调优都比 Redis 更为复杂。
最终决策与理由
我们最终选择了方案 B。尽管方案 A 在读取延迟上看似更有优势,但其在数据安全、成本和长期可扩展性方面的巨大缺陷使其在生产环境中风险过高。一个健壮的系统必须为未来的增长和潜在的故障做好准备。
方案 B 的核心是接受一个略高的读取延迟(15ms vs 5ms),来换取系统的整体韧性、可扩展性和成本可控性。通过精心设计 Cassandra 的数据模型和优化服务层,我们有信心将延迟控制在可接受的范围内。而 Flask 服务内嵌的 SSR 监控面板,则是一个务实的选择,它允许我们用最小的开发成本快速构建一个内部运维工具,而无需引入复杂的前端技术栈。
核心实现概览
以下是方案 B 关键组件的生产级代码实现。
1. Cassandra 数据模型设计
这是整个系统的基石。我们的查询模式非常单一:根据 user_id
获取该用户的所有实时特征。因此,user_id
是最理想的分区键(Partition Key)。
-- cqlsh
CREATE KEYSPACE IF NOT EXISTS feature_store
WITH REPLICATION = { 'class' : 'NetworkTopologyStrategy', 'datacenter1' : 3 };
USE feature_store;
CREATE TABLE IF NOT EXISTS user_realtime_features (
user_id text,
feature_name text,
feature_value blob,
updated_at timestamp,
PRIMARY KEY (user_id, feature_name)
) WITH CLUSTERING ORDER BY (feature_name ASC)
AND compaction = { 'class' : 'TimeWindowCompactionStrategy', 'compaction_window_unit' : 'DAYS', 'compaction_window_size' : 1 };
-- 说明:
-- 1. Partition Key: `user_id`。所有属于同一个用户的特征都存储在同一个分区中,确保一次查询即可获取全部。
-- 2. Clustering Key: `feature_name`。这使得分区内的数据按特征名排序,虽然我们当前是获取全部,但为未来按名称范围查询提供了可能。
-- 3. `feature_value` 使用 blob 类型,可以存储序列化后的任意数据结构(如 Numpy array, JSON string)。
-- 4. `TimeWindowCompactionStrategy` (TWCS) 适用于时间序列数据,因为新特征会不断覆盖旧特征,TWCS 可以高效地丢弃过期数据,减少磁盘空间占用和读放大。
2. Pub/Sub 消费者
这个服务从 Pub/Sub 拉取消息,执行特征转换,并写入 Cassandra。
# feature_consumer/main.py
import os
import json
import logging
from concurrent.futures import TimeoutError
from typing import Callable
from google.cloud import pubsub_v1
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
from cassandra.query import BatchStatement, ConsistencyLevel
# --- 配置 ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
PROJECT_ID = os.environ.get("GCP_PROJECT_ID")
SUBSCRIPTION_ID = os.environ.get("PUBSUB_SUBSCRIPTION_ID")
CASSANDRA_HOSTS = os.environ.get("CASSANDRA_HOSTS", "127.0.0.1").split(',')
CASSANDRA_PORT = int(os.environ.get("CASSANDRA_PORT", 9042))
CASSANDRA_USER = os.environ.get("CASSANDRA_USER")
CASSANDRA_PASSWORD = os.environ.get("CASSANDRA_PASSWORD")
KEYSPACE = "feature_store"
TABLE = "user_realtime_features"
# --- Cassandra 连接 ---
auth_provider = PlainTextAuthProvider(username=CASSANDRA_USER, password=CASSANDRA_PASSWORD)
cluster = Cluster(CASSANDRA_HOSTS, port=CASSANDRA_PORT, auth_provider=auth_provider)
session = cluster.connect(KEYSPACE)
# 预编译语句以提高性能
insert_stmt = session.prepare(
f"INSERT INTO {TABLE} (user_id, feature_name, feature_value, updated_at) VALUES (?, ?, ?, toTimestamp(now()))"
)
def process_event(message_data: dict) -> list:
"""
一个简化的特征工程函数。
在真实项目中,这里会有更复杂的逻辑。
"""
user_id = message_data.get("user_id")
event_type = message_data.get("event_type")
if not user_id or not event_type:
return []
# 示例:计算用户最近点击的商品类别
if event_type == "product_click":
category = message_data.get("category", "unknown")
# 将字符串编码为 bytes
return [(user_id, "last_clicked_category", category.encode('utf-8'))]
# 可以添加更多事件类型的处理
return []
def callback(message: pubsub_v1.subscriber.message.Message) -> None:
try:
data = json.loads(message.data.decode("utf-8"))
logging.info(f"Received message: {data}")
features = process_event(data)
if not features:
message.ack()
return
# 使用 BatchStatement 批量写入一个用户的所有更新特征
batch = BatchStatement(consistency_level=ConsistencyLevel.LOCAL_QUORUM)
for user_id, feature_name, feature_value in features:
batch.add(insert_stmt, (user_id, feature_name, feature_value))
session.execute(batch)
logging.info(f"Successfully wrote {len(features)} features for user {features[0][0]}")
message.ack()
except json.JSONDecodeError:
logging.error("Failed to decode message data.")
message.nack()
except Exception as e:
logging.error(f"An error occurred: {e}", exc_info=True)
# Nack a message so it can be re-delivered.
message.nack()
def main():
subscriber = pubsub_v1.SubscriberClient()
subscription_path = subscriber.subscription_path(PROJECT_ID, SUBSCRIPTION_ID)
streaming_pull_future = subscriber.subscribe(subscription_path, callback=callback)
logging.info(f"Listening for messages on {subscription_path}..")
with subscriber:
try:
# 当 future.result() 被调用时,它会阻塞直到订阅被取消。
streaming_pull_future.result()
except TimeoutError:
streaming_pull_future.cancel()
streaming_pull_future.result()
except KeyboardInterrupt:
streaming_pull_future.cancel()
if __name__ == "__main__":
# 单元测试思路:
# 1. Mock `session.execute` and `message.ack/nack`
# 2. 测试 `callback` 函数在不同消息格式下的行为(正常、JSON错误、处理异常)
# 3. 测试 `process_event` 的特征转换逻辑是否正确
main()
3. Flask 特征服务 (API + SSR Dashboard)
这个 Flask 应用是系统的查询入口。
# feature_service/app.py
import os
import logging
from datetime import datetime
from flask import Flask, jsonify, render_template
from cassandra.cluster import Cluster, Session
from cassandra.auth import PlainTextAuthProvider
from cassandra.query import dict_factory
# --- 配置 ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
app = Flask(__name__)
CASSANDRA_HOSTS = os.environ.get("CASSANDRA_HOSTS", "127.0.0.1").split(',')
CASSANDRA_PORT = int(os.environ.get("CASSANDRA_PORT", 9042))
CASSANDRA_USER = os.environ.get("CASSANDRA_USER")
CASSANDRA_PASSWORD = os.environ.get("CASSANDRA_PASSWORD")
KEYSPACE = "feature_store"
TABLE = "user_realtime_features"
# --- 全局 Cassandra Session ---
# 在生产环境中,应该使用 gunicorn 等 WSGI 服务器,
# 连接会在 fork 的 worker 进程中被正确处理。
session: Session = None
def get_cassandra_session():
global session
if session is None:
try:
auth_provider = PlainTextAuthProvider(username=CASSANDRA_USER, password=CASSANDRA_PASSWORD)
cluster = Cluster(CASSANDRA_HOSTS, port=CASSANDRA_PORT, auth_provider=auth_provider)
session = cluster.connect(KEYSPACE)
session.row_factory = dict_factory # 返回字典而不是元组,更易于处理
logging.info("Cassandra session initialized.")
except Exception as e:
logging.error(f"Failed to connect to Cassandra: {e}", exc_info=True)
# 允许应用启动,但在请求时会失败,便于 Kubernetes 的 liveness probe
return session
@app.before_first_request
def initialize_connection():
get_cassandra_session()
# 预编译查询
select_stmt = None
def get_select_stmt():
global select_stmt
if select_stmt is None and session:
select_stmt = session.prepare(f"SELECT feature_name, feature_value, updated_at FROM {TABLE} WHERE user_id = ?")
return select_stmt
# --- API 端点 ---
@app.route('/features/v1/user/<string:user_id>', methods=['GET'])
def get_user_features(user_id):
"""
为机器学习模型提供特征的高性能端点。
"""
if not user_id:
return jsonify({"error": "user_id is required"}), 400
cass_session = get_cassandra_session()
if not cass_session:
return jsonify({"error": "database connection not available"}), 503
try:
stmt = get_select_stmt()
if not stmt:
return jsonify({"error": "database statement not prepared"}), 503
rows = cass_session.execute(stmt, (user_id,))
# 这里的关键是序列化。blob 类型需要解码。
# 假设所有特征值都是 utf-8 编码的字符串。
features = {row['feature_name']: row['feature_value'].decode('utf-8') for row in rows}
return jsonify(features)
except Exception as e:
logging.error(f"Error fetching features for user {user_id}: {e}", exc_info=True)
return jsonify({"error": "internal server error"}), 500
# --- SSR 监控面板 ---
@app.route('/internal/dashboard', methods=['GET'])
def dashboard():
"""
一个简单的 SSR 页面,用于内部监控。
"""
# 这是一个示例,展示如何从 Cassandra 读取一些聚合信息。
# 注意:在 Cassandra 中做聚合通常是反模式,这里仅用于小规模的内部仪表盘。
# 更好的做法是用 Spark 等工具计算统计信息并写入另一张表。
cass_session = get_cassandra_session()
if not cass_session:
return "Database connection not available", 503
try:
# 仅获取最近更新的几个用户作为样本展示
# 这是一个低效的查询,严禁在生产 API 中使用!
rows = list(cass_session.execute(f"SELECT user_id, updated_at FROM {TABLE} LIMIT 10"))
stats = {
"service_status": "OK",
"db_connected": True,
"last_updated_samples": [
{"user_id": r['user_id'], "last_update": r['updated_at'].strftime('%Y-%m-%d %H:%M:%S UTC')}
for r in rows
],
"timestamp": datetime.utcnow().isoformat()
}
except Exception as e:
stats = {
"service_status": "Error",
"db_connected": False,
"error_message": str(e),
"timestamp": datetime.utcnow().isoformat()
}
# 使用 Jinja2 模板进行服务器端渲染
return render_template('dashboard.html', stats=stats)
if __name__ == '__main__':
# 仅用于本地开发
app.run(host='0.0.0.0', port=8080)
对应的 SSR 模板文件:
<!-- feature_service/templates/dashboard.html -->
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>Real-time Feature Store Dashboard</title>
<style>
body { font-family: sans-serif; line-height: 1.6; padding: 2em; }
.container { max-width: 800px; margin: 0 auto; }
.status { padding: 1em; border-radius: 5px; }
.ok { background-color: #d4edda; color: #155724; }
.error { background-color: #f8d7da; color: #721c24; }
table { width: 100%; border-collapse: collapse; margin-top: 1em; }
th, td { border: 1px solid #ddd; padding: 8px; text-align: left; }
</style>
</head>
<body>
<div class="container">
<h1>Real-time Feature Store Dashboard</h1>
<div class="status {{ 'ok' if stats.service_status == 'OK' else 'error' }}">
<strong>Service Status:</strong> {{ stats.service_status }} <br>
<strong>Report Time (UTC):</strong> {{ stats.timestamp }}
</div>
{% if stats.service_status == 'OK' %}
<h2>Recent Feature Updates (Sample)</h2>
<p>Showing a small sample of recently updated user features. This is not a comprehensive list.</p>
<table>
<thead>
<tr>
<th>User ID</th>
<th>Last Update Time (UTC)</th>
</tr>
</thead>
<tbody>
{% for sample in stats.last_updated_samples %}
<tr>
<td>{{ sample.user_id }}</td>
<td>{{ sample.last_update }}</td>
</tr>
{% else %}
<tr>
<td colspan="2">No recent updates found.</td>
</tr>
{% endfor %}
</tbody>
</table>
{% else %}
<h2>Error Details</h2>
<pre>{{ stats.error_message }}</pre>
{% endif %}
</div>
</body>
</html>
4. PyTorch 推理服务调用示例
这个片段展示了下游服务如何使用这个特征 API。
# pytorch_inference/client.py
import requests
import torch
import time
# 假设这是一个已加载的 PyTorch 模型
# model = torch.load('my_model.pth')
# model.eval()
FEATURE_SERVICE_URL = "http://feature-service.default.svc.cluster.local/features/v1/user/{}"
def predict(user_id: str):
"""
执行一次完整的在线推理流程。
"""
# 1. 从特征服务获取实时特征
start_time = time.time()
try:
response = requests.get(FEATURE_SERVICE_URL.format(user_id), timeout=0.05) # 50ms 超时
response.raise_for_status()
features = response.json()
except requests.exceptions.RequestException as e:
print(f"Failed to get features for {user_id}: {e}")
# 在真实项目中,这里应该有回退逻辑,例如使用默认值或缓存的特征
return None
fetch_latency = (time.time() - start_time) * 1000
print(f"Feature fetch latency for user {user_id}: {fetch_latency:.2f} ms")
# 2. 特征预处理 (这是一个伪代码,真实逻辑会更复杂)
# last_category = features.get("last_clicked_category", "unknown")
# category_tensor = torch.tensor([category_to_id[last_category]])
# 3. 模型推理
# with torch.no_grad():
# output = model(category_tensor)
# prediction = torch.argmax(output, dim=1).item()
# return prediction
return features # 仅返回获取到的特征用于演示
if __name__ == "__main__":
test_user_id = "user-12345" # 假设这个用户的数据已被消费者写入
result = predict(test_user_id)
if result:
print(f"Features retrieved for {test_user_id}: {result}")
架构的扩展性与局限性
当前架构在满足核心需求的同时,也存在明确的边界和未来演进的方向。
一个显著的局限性是特征工程逻辑较为简单,直接耦合在消费者服务中。当特征逻辑变得复杂,例如需要基于时间窗口进行滚动聚合(如“用户过去1小时内的点击次数”),当前的架构将力不从心。这类场景需要引入真正的流处理引擎,如 Apache Flink 或 Spark Streaming,来替代简单的消费者。演进后的架构将是 Pub/Sub -> Flink/Spark -> Cassandra。
其次,对于某些极端热门的用户(例如网红或机器人账户),可能会在 Cassandra 中产生热点分区。虽然 Cassandra 在处理不均匀负载方面比许多其他数据库要好,但这仍然是一个需要监控和应对的问题。解决方案可能包括在 user_id
后面加上一个时间桶(如 user_id_20231027
)来分散写入压力,但这会增加读取时的复杂性。
最后,SSR 监控面板功能有限。它无法提供详细的性能指标、延迟分布或历史趋势图。一个成熟的系统需要将 Flask 服务、消费者和 Cassandra 的关键指标对接到 Prometheus,并使用 Grafana 进行可视化和告警,这才是可观测性的完整解决方案。当前内嵌的 SSR 仪表盘,其价值在于提供了一个“零成本”的、即时可用的内部状态检查窗口,满足了项目初期的快速迭代需求。