ggaaooppeenngg

为什么计算机科学是无限的但生命是有限的

概述

SGLang 的 Pipeline Parallelism (PP) 模式下,HiCache 负责 KV cache 的异步 prefetch(Host→GPU)和 backup(GPU→Host)。本文从 PP 调度器的外层事件循环出发,追踪 Load 和 Write 的完整时序,揭示 write_ack 比 load_ack 多延迟的根本原因。

一、PP 事件循环

每个 iteration 执行以下 4 步:

1
2
3
4
5
iter=N:
① check_hicache_events() ← 查询 HiCache 异步事件(load_ack / write_ack)
② get_next_batch_to_run() ← 选下一批 batch(prefix match → eviction → load)
③ _pp_launch_batch() ← launch forward
④ _pp_process_batch_result() ← 处理上一批 batch 的结果(insert → write)

核心设计process_batch_result 处理的是 mbs[next_mb_id]——上一轮 launch 的 batch。同一个 iter 内,调度器同时在处理两个不同 batch 的生命周期阶段。

二、Load 时序(Host→GPU)

触发时机

iter=N 的 get_next_batch_to_run —— prefill 阶段 prefix match 发现 host_hit,发起 load_back。

完整时序

步骤 发生在 说明
1. load 发起 iter=N get_next_batch_to_run match_prefix → host_hit → load_back → start_loading
2. CUDA copy 启动 iter=N get_next_batch_to_run GPU 从 Host 异步拉取 KV cache,CUDA event 入队
3. forward 逐层等待 iter=N _pp_launch_batch forward 通过 consumer_index 逐层等待对应 layer 的 load 完成
4. load_ack 消费 iter≥N+1 check_hicache_events loading_check → event.query()=True → 消费 ack

关键:load 是 prefetch——在 forward 之前触发,CUDA copy 和 forward 可以重叠(逐层等待、逐层执行)。

时序图

1
2
3
4
5
6
7
8
9
10
11
12
iter=N:
① check_hicache_events()
└─ 消费更早的 load_ack
② get_next_batch_to_run()
└─ match_prefix → host_hit → load_back() → start_loading()
└─ CUDA copy 启动,event 入队
③ _pp_launch_batch()
└─ forward 逐层等待 load 完成

iter=N+1:
① check_hicache_events()
└─ loading_check() → event.query()=True → load_ack ✓

延迟:load_ack 比 load 发起晚 1 iter

三、Write 时序(GPU→Host)

触发时机

iter=N 的 process_batch_result —— 处理上一轮 launch 的 batch 结果,insert 时触发 write_backup。

完整时序

步骤 发生在 说明
1. forward 执行 iter=N-1 _pp_launch_batch Prefill batch 在 GPU 上计算,生成 KV cache
2. write 发起 iter=N process_batch_result 处理 iter=N-1 的 batch → insert → write_backup → start_writing
3. CUDA copy 启动 iter=N process_batch_result GPU 异步写回 Host,CUDA event 入队
4. write_ack 消费 iter≥N+1 check_hicache_events writing_check → event.query()=True → 消费 ack

关键:write 是 post-write——必须在 forward 算完、拿到完整 KV cache 后才能触发。

时序图

1
2
3
4
5
6
7
8
9
10
11
12
13
iter=N-1:
② get_next_batch_to_run() → 选出 Prefill A
③ _pp_launch_batch() → forward(Prefill A)

iter=N:
④ _pp_process_batch_result()
└─ 处理 iter=N-1 的 Prefill A
└─ insert() → write_backup() → start_writing()
└─ CUDA copy 启动,event 入队

iter=N+1:
① check_hicache_events()
└─ writing_check() → event.query()=True → write_ack ✓

延迟:write_ack 比 write 发起晚 1 iter,但 write 本身比 forward 晚 1 iter(因为 process_batch_result 处理上一批 batch)。

四、延迟对比

以同一个 Prefill batch 为基准

追踪一个 Prefill batch 从 launch 到 write_ack 的完整生命周期:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
iter=1:
② get_next_batch_to_run()
└─ 选出 Prefill A
└─ match_prefix → host_hit → load_back() → start_loading()
③ _pp_launch_batch() → forward(Prefill A)

iter=2:
① check_hicache_events()
└─ loading_check() → load_ack ✓
↑ load_ack: 1 iter delay

iter=3:
④ process_batch_result()
└─ 处理 iter=1 的 Prefill A
└─ insert() → write_backup() → start_writing()
① check_hicache_events()
└─ (write_ack 还没完成)

iter=4:
① check_hicache_events()
└─ writing_check() → write_ack ✓
↑ write_ack: 3 iter delay from Prefill launch

对比表

事件 触发时机 ack 入队时机 ack 消费时机 相对于 Prefill launch 的延迟
load iter=1 get_next_batch_to_run iter=1 iter=2 1 iter
write iter=3 process_batch_result(处理 iter=1 的 batch) iter=3 iter=4 3 iter

为什么 write_ack 多延迟?

两个因素叠加:

1. process_batch_result 滞后一轮

它处理 mbs[next_mb_id](上一轮 launch 的 batch),以 PP2 为例偏移 2 iter。所以 write 比 load 晚 2 iter 才触发。

2. CUDA event query 需要等下一轮 check

无论 load 还是 write,ack 入队后最早等下一轮 check_hicache_events 才能消费。两者各加 1 iter。

综合

1
2
Load:  iter=1 发起 → iter=1 ack 入队 → iter=2 消费
Write: iter=1 forward → iter=3 process_result 发起 → iter=3 ack 入队 → iter=4 消费

Load 是 prefetch(forward 之前),Write 是 post-write(forward 之后)。这个架构差异决定了 Write 必然比 Load 多延迟。

五、PP 间同步问题

PP0 和 PP1 共享同一个 HiCache 实例(radix tree + host memory)。由于 output relay 延迟,PP0 和 PP1 的 iter 进度存在 1-2 iter 的偏移,需要三层同步机制保证一致性。

5.1 逻辑时钟保证重放一致性

PP1 的 writing_check/loading_check 不再直接消费 ack,而是将事件通过 Gloo P2P 通道 replay 给 PP0。PP0 作为唯一的事件消费者,按逻辑时钟顺序处理所有 ack,确保 PP0 和 PP1 的 radix tree 操作顺序一致。

问题:早期实现中 PP1 的 writing_check() 绕过 check_hicache_events guard,直接消费 write_ack(即 ack theft),导致 PP0 端 pending event 永远无法完成,radix tree 分叉。

修复:将 PP rank 分支移入 writing_check/loading_check 内部,用 PPHiCacheEventsReq 控制请求替代 dict wrapper,强制 PP1 replay 事件给 PP0。

5.2 Count Sync 保证 CP 一致性

PP0 比 PP1 早 1-2 iter 积累 write ack(output relay 延迟),导致 ack_write_queue 积累差异。PR #22878 通过 piggybacking write-ack consumption counts 在 PP ranks 间同步:

1
2
3
iter=N:
PP0: writing_check() → 消费 3 个 write_ack → count=3
PP1: 等待 PP0 的 count → count=3 → radix tree 操作对齐

Count sync 确保 PP0 和 PP1 对同一个 radix tree node 的 checkpoint(CP)操作一致,避免一个 stage 认为 node 已 backup、另一个 stage 还在等待的情况。

5.3 PP1 同步消费 PP0 的 ack

PP1 不再独立消费 ack,而是通过同步机制确保 PP0 消费 ack 后,PP1 的 radix tree 状态与 PP0 对齐:

1
2
3
4
5
PP0: writing_check() → event.query()=True → 消费 write_ack
→ radix tree: insert() → finalize() → node 状态更新
↓ TP all_reduce sync
PP1: 等待 PP0 完成 → radix tree: 同步执行 finalize()
→ node 状态与 PP0 一致

三层保障

同步层 机制 保证什么
逻辑时钟 Event Replay(Gloo P2P) PP0/PP1 事件处理顺序一致
Count Sync PR #22878 piggyback counts PP0/PP1 checkpoint 状态一致
Ack 同步 PP1 replay → PP0 消费 → TP all_reduce radix tree 节点状态一致

核心目标:无论 PP0 和 PP1 的 iter 偏移多少,radix tree 的结构和节点状态在两个 stage 上始终保持一致。

六、总结

维度 Load Write
触发函数 get_next_batch_to_run process_batch_result
触发契机 prefix match 发现 host_hit insert → _inc_hit_count
与 forward 关系 forward 之前(prefetch) forward 之后(post-write)
CUDA copy 与 forward 可重叠(逐层等待) 串行(forward 算完才触发)
ack 消费延迟 +1 iter +3 iter(含 process_result 偏移)

核心结论

  1. PP 调度器每个 iter 同时做两件事:get_next_batch_to_run 选下一批 batch,process_batch_result 处理上一批 batch 的结果
  2. Load 是 prefetch(forward 之前触发),Write 是 post-write(forward 之后触发)
  3. process_batch_result 处理 mbs[next_mb_id] 的 iter 偏移是 Write 延迟的根本原因
  4. PP 间需要 count sync 补偿 output relay 带来的时序差异

本文基于 SGLang 源码分析,涉及文件:hiradix_cache.pycache_controller.py

Pipeline Parallelism(PP)是多卡推理中的核心技术之一。SGLang 的 PP 实现有一套独立的事件循环和调度机制,和普通的 single-batch 路径完全不同。本文将从代码出发,深入分析 SGLang 的 PP 实现。

为什么需要 Pipeline Parallelism

单机推理的瓶颈

单卡推理时,GPU 利用率受限于模型大小和显存。一个 70B 模型单卡放不下,必须拆分到多卡。

模型并行方案对比

方案 原理 优点 缺点
TP (Tensor Parallelism) 按层切分权重 通信少,延迟低 受限于单节点卡数
PP (Pipeline Parallelism) 按层切分模型 支持任意层数 存在 pipeline bubble
DP (Data Parallelism) 每卡完整模型 简单,吞吐高 显存需求不变

SGLang 通常 TP + PP 组合使用:节点内 TP,跨节点 PP。

SGLang PP 架构概览

PP 路径的独立事件循环

SGLang 的调度器有三条事件循环路径:

  • event_loop_normal — 普通单 batch 调度
  • event_loop_overlap — 计算和通信 overlap 的调度
  • event_loop_pp — PP 路径的独立事件循环

PP 路径有自己独立的事件循环,定义在 scheduler_pp_mixin.pySchedulerPPMixin 类中。它和 normal/overlap 路径完全隔离,因为 PP 的通信模式(send/recv proxy tensors 跨 stage)和单机调度逻辑差异太大,硬塞进同一个 event loop 会很脏。

event_loop_pp 核心流程

1
2
3
4
5
6
7
8
9
10
11
12
# scheduler_pp_mixin.py 第 72-145 行
while True:
for mb_id in range(pp_loop_size): # 遍历 microbatch
1. recv_requests() # 接收第 i 个请求(从上一个 PP stage 或 tokenizer)
2. _pp_send_pyobj_to_next_stage() # 立即把第 i 个请求转发给下一个 PP rank(async,提前发送)
3. get_next_batch_to_run() # 调度当前 microbatch 的 batch
4. _pp_recv_proxy_tensors() # 接收第 i 个 proxy(hidden states,来自上一个 PP stage)
5. _pp_commit_send_output_work_and_preprocess...() # 接收并预处理上一个 microbatch 的 output tensors
6. _pp_launch_batch() # 执行第 i 个 batch 的 forward
7. _pp_send_dict_to_next_stage(msg_type="proxy") # 发送第 i 个 proxy(hidden states)给下一个 PP stage
8. _pp_send_dict_to_next_stage(msg_type="output") # 发送当前 stage 的 output 给下一个 stage(延迟/暂存)
9. _pp_process_batch_result() # 处理上一个 microbatch 的 batch result

核心概念:commit = 等待上一轮的异步操作完成

所有 _pp_commit_comm_work(work) 做的事情就是:

1
2
3
4
def _pp_commit_comm_work(self: Scheduler, work: List[P2PWork]) -> None:
for p2p_work in work:
p2p_work.work.wait() # 阻塞等待 NCCL/Gloo 异步通信完成
work.clear()

即:commit 是一个同步屏障,确保上一轮发起的异步 send 已经完成,发送缓冲区可以安全复用。这是一种延迟等待模式——发送时不阻塞,到下一轮再等待完成,从而让发送和 GPU 计算重叠。

Microbatch 索引计算

1
2
next_mb_id = (mb_id + 1) % self.pp_loop_size              # 当前 rank 下一个要处理的 microbatch 槽位
next_first_rank_mb_id = (mb_id + self.pp_size) % self.pp_loop_size # first rank 对应的"下一轮"槽位
  • next_mb_id:当前 rank 下一个要处理的 microbatch 槽位(偏移 1,因为每次迭代处理一个 mb)
  • next_first_rank_mb_id:first rank 对应的”下一轮”microbatch 槽位(偏移 pp_size,因为流水线深度——first rank 产出的 batch 要经过 pp_size 个迭代才到 last rank 完成)

步骤与函数的对应关系

步骤 函数 作用
recv ith req recv_requests() 接收请求(PP0 从 zmq,PP>0 从前一个 rank P2P)
send ith req to next _pp_send_pyobj_to_next_stage() 把请求异步转发给下一个 PP rank(提前发送,隐藏通信延迟)
get_next_batch_to_run get_next_batch_to_run() 从 waiting_queue 调度请求组成 batch
recv ith proxy _pp_recv_proxy_tensors() 接收上游 stage 的 hidden_states + residual(forward 输入)
recv prev outputs + preprocess _pp_commit_send_output_work_and_preprocess_output_tensors() 接收上一个 microbatch 的 output、GPU→CPU 拷贝、解包
run ith batch _pp_launch_batch() 执行当前 microbatch 的 forward 计算
send ith proxy to next _pp_send_dict_to_next_stage(msg_type="proxy") 发送 forward 输出的 hidden_states 给下一个 stage
send outputs to next (stashed) _pp_send_dict_to_next_stage(msg_type="output") 发送最终输出(next_token_ids),last rank→rank 0
process prev batch result _pp_process_batch_result() 更新请求状态、判断是否结束、流式输出给 tokenizer

关键设计

  • send req 提前:收到请求后立即转发给下一个 rank(step 2),不等 forward 完成,隐藏 P2P 通信延迟
  • output 延迟/暂存:output tensors 的发送和上一个 mb 的 result 处理是交叉的(step 5 处理 prev,step 8 发送 current),形成 pipeline 重叠
  • 延迟等待异步通信_pp_commit_comm_work 在下一轮迭代中调用 .wait() 确保上一轮的 async send 已完成,让发送和 GPU 计算重叠
  • 每个 microbatch 槽位独立管理状态
  • async send + recv 实现计算和通信的 overlap
  • 主循环轮转处理每个 microbatch,填充 pipeline bubble

完整流程示例(pp_size=2, pp_async_batch_depth=0)

以 pp_size=2, pp_async_batch_depth=0 为例,pp_loop_size=2,只有槽位 0 和 1:

1
2
3
4
5
6
参数:
pp_size = 2(PP0 = first rank, PP1 = last rank)
pp_async_batch_depth = 0
pp_loop_size = 2(槽位:mb_id ∈ {0, 1})
next_mb_id = (mb_id + 1) % 2 # 另一个槽位
next_first_rank_mb_id = (mb_id + 2) % 2 = mb_id # 注意:next_first_rank_mb_id == mb_id

单次迭代(mb_id = N)的完整步骤

步骤 1:恢复上下文

1
2
3
4
running_batch = running_mbs[N]
last_batch = last_mbs[N]
next_first_rank_mb_id = N # 因为 (N + 2) % 2 = N
next_mb_id = (N + 1) % 2 # 另一个槽位

步骤 2:接收并处理请求

1
2
recv_reqs = recv_requests()
process_input_requests(recv_reqs)
  • PP0:从 tokenizer_manager / detokenizer 收请求(zmq 通道)
  • PP1:从 PP0 收转发过来的请求(P2P,含 PPHiCacheEventsReq 控制消息)

步骤 3:HiCache 事件同步 + 转发请求

1
2
3
4
5
6
if enable_hierarchical_cache:
tree_cache.check_hicache_events()

if not is_last_rank: # 即 PP0
_pp_commit_comm_work(send_req_work) # ← 等待上一轮的 send_req 异步操作完成
send_req_work = async_send(pp_send_payload) → PP1

commit 含义:上一轮(槽位 (N+1)%2)发起的 send_req_work 可能还没完成,这里阻塞等它完成后,才能安全地复用发送缓冲区发送本轮的请求。

步骤 4:调度决策

1
2
3
mbs[N] = get_next_batch_to_run()
running_mbs[N] = running_batch
cur_batch = mbs[N]

步骤 5:接收 Proxy 张量

1
2
3
4
if cur_batch is not None:
pp_proxy_tensors = _pp_recv_proxy_tensors()
# PP0: 直接返回 None(第一个 stage 不需要接收)
# PP1: 阻塞等待从 PP0 接收 hidden_states + residual (msg_type="proxy")

步骤 6:等待上一轮 Proxy 发送完成

1
_pp_commit_comm_work(send_proxy_work)  # ← 等待上一轮的 proxy 异步发送完成

commit 含义:上一轮(槽位 (N+1)%2)PP0 发起的 proxy 异步发送可能还没完成,这里确保它完成,释放发送缓冲区。

步骤 7:GPU Forward 计算

1
2
3
4
5
if cur_batch:
result, launch_event = _pp_launch_batch(N, pp_proxy_tensors, ...)
# 在 forward_stream 上执行 run_batch
# 记录 CUDA event
# PP1(last rank): 将 output 张量 push 到 last_rank_comm_queue

步骤 8:发送 Output + 接收 Output + 预处理

因为 pp_async_batch_depth == 0,这一步在 forward 之后执行(无法与 GPU 计算 overlap)。

1
2
3
4
5
6
7
8
9
10
11
12
_pp_commit_comm_work(send_output_work)  # ← 等待上一轮的 output 异步发送完成

# 发送 output:
# PP1(last rank): 如果 mbs[next_first_rank_mb_id=N] 不为空,
# 从 last_rank_comm_queue 取出 output,异步发送给 PP0 (msg_type="output")
# PP0(非 last rank): 如果有上一轮暂存的 pp_outputs,转发给 PP1

# 接收 output + 预处理:
# 如果 mbs[next_mb_id] 不为空:
# 阻塞接收 output 张量 (msg_type="output")
# 在 copy_stream 上执行 _pp_prep_batch_result()(组装 GenerationBatchResult + D2H 拷贝)
# 记录 d2h_event

commit 含义:上一轮发起的 output 异步发送可能还没完成,这里确保完成后才能发起新的 output 发送。

步骤 9:后处理上一轮的 Batch 结果

1
2
3
4
5
if mbs[next_mb_id] is not None:
d2h_event.synchronize() # 等待 copy_stream 上的 D2H 拷贝完成
_pp_process_batch_result(mbs[next_mb_id], next_batch_result)
# 更新请求状态、判断 EOS、发送给 detokenizer
last_mbs[next_mb_id] = mbs[next_mb_id]

步骤 10:发送 Proxy 张量(仅 PP0)

1
2
3
if not is_last_rank and cur_batch:
torch.cuda.current_stream().wait_event(launch_event) # 等 forward 完成
send_proxy_work = async_send(hidden_states + residual) → PP1 # msg_type="proxy"

步骤 11:保存状态

1
pp_outputs = next_pp_outputs  # 暂存本轮接收到的 output,供下一轮步骤 8 转发

完整时序(稳态,两个槽位交替)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
┌─────────────────────────────────────────────────────────────────┐
│ mb_id=0 │
├─────────────────────────────────────────────────────────────────┤
│ 1. recv_requests + process │
│ 2. commit(send_req_work[上轮mb=1的]) → async_send(reqs) │
│ 3. get_next_batch_to_run → cur_batch │
│ 4. recv_proxy (PP1阻塞等PP0; PP0跳过) │
│ 5. commit(send_proxy_work[上轮mb=1的]) │
│ 6. GPU forward (forward_stream) │
│ 7. commit(send_output_work[上轮mb=1的]) │
│ send_output → recv_output → prep_batch_result (copy_stream) │
│ 8. d2h_event.sync → process_batch_result(mbs[1]的结果) │
│ 9. PP0: wait(launch_event) → async_send(proxy) → PP1 │
│10. pp_outputs = next_pp_outputs │
├─────────────────────────────────────────────────────────────────┤
│ mb_id=1 │
├─────────────────────────────────────────────────────────────────┤
│ 1. recv_requests + process │
│ 2. commit(send_req_work[上轮mb=0的]) → async_send(reqs) │
│ 3. get_next_batch_to_run → cur_batch │
│ 4. recv_proxy (PP1阻塞等PP0; PP0跳过) │
│ 5. commit(send_proxy_work[上轮mb=0的]) │
│ 6. GPU forward (forward_stream) │
│ 7. commit(send_output_work[上轮mb=0的]) │
│ send_output → recv_output → prep_batch_result (copy_stream) │
│ 8. d2h_event.sync → process_batch_result(mbs[0]的结果) │
│ 9. PP0: wait(launch_event) → async_send(proxy) → PP1 │
│10. pp_outputs = next_pp_outputs │
└─────────────────────────────────────────────────────────────────┘

PP0 与 PP1 的交互时序(稳态)

1
2
3
4
5
6
7
8
9
10
11
12
时间轴 →

PP0 (mb_id=0):
recv_req → commit(send_req) → send_req→PP1 → schedule → [skip proxy]
→ commit(send_proxy) → FORWARD → commit(send_output)
→ send_output(转发) → recv_output(从PP1) → prep → process_result
→ send_proxy→PP1

PP1 (mb_id=0):
recv_req(从PP0) → process → schedule → recv_proxy(阻塞等PP0)
→ commit(send_proxy) → FORWARD → commit(send_output)
→ send_output→PP0 → recv_output(从PP0转发) → prep → process_result

关键依赖链

  • PP1 的 recv_proxy 必须等 PP0 的 send_proxy(上一轮的)完成
  • PP0 的 recv_output 必须等 PP1 的 send_output 完成
  • 因为 pp_async_batch_depth=0,output 的发送/接收在 forward 之后,无法与 GPU 计算重叠

所有 commit 汇总

commit 调用 等待的是什么 目的
commit(send_req_work) 上一轮槽位发起的 reqs 异步发送 确保发送完成,可以安全复用缓冲区发送本轮 reqs
commit(send_proxy_work) 上一轮槽位发起的 proxy 异步发送 确保 hidden_states 张量已发送完毕,可以释放/复用
commit(send_output_work) 上一轮槽位发起的 output 异步发送 确保 output 张量已发送完毕,可以发起新的 output 发送

模式统一:每个 async_send 在下一轮同一类型操作之前被 commit,形成”发起→(做其他事)→等待完成→再发起”的流水线模式。即使 pp_async_batch_depth=0 没有计算-通信 overlap,这种 commit 模式仍然保证了同类型通信操作之间不会冲突。

Microbatch:填充 Pipeline Bubble

问题:Pipeline Bubble

PP2(2 个 stage)如果只用一个 batch,会出现大量空闲:

1
2
3
4
时间 →
Stage 0: [ batch 0 ] [ batch 1 ]
Stage 1: [ batch 0 ] [ batch 1 ]
↑ 空闲 ↑ 空闲

Stage 0 做完 batch 0 后必须等 Stage 1 处理完才能做 batch 1(因为只有一个 batch 在 pipeline 里流转),一半的时间都在空闲。这就是 pipeline bubble

解法:Microbatch

把多个 batch 同时塞进 pipeline,这些 batch 就叫 microbatch:

1
2
3
时间 →
Stage 0: [ mb 0 ][ mb 1 ][ mb 2 ][ mb 3 ]
Stage 1: [ mb 0 ][ mb 1 ][ mb 2 ][ mb 3 ]

Stage 0 做完 mb0 不用等,立刻做 mb1;Stage 1 也紧接着处理,bubble 大幅减少。

Microbatch 数量的确定

1
2
# scheduler_pp_mixin.py 第 514 行
self.pp_loop_size = self.pp_size + self.server_args.pp_async_batch_depth
  • pp_size = 2(PP2)→ 至少 2 个 microbatch 槽位
  • pp_async_batch_depth → 额外的 buffer 深度,进一步隐藏延迟

每个 microbatch 都有独立的状态:

1
2
3
self.mbs = [None] * self.pp_loop_size          # 当前 batch
self.last_mbs = [None] * self.pp_loop_size # 上一个 batch
self.running_mbs = [ScheduleBatch(...) for _ in range(self.pp_loop_size)] # running 状态

Microbatch vs 正常 Batch

维度 正常 batch Microbatch(PP)
数量 同时 1 个 同时 pp_size + async_depth 个
目的 普通调度 填充 pipeline bubble,提升利用率
状态管理 单个 running_batch 数组 running_mbs[mb_id] 各自独立
本质 就是一个 batch 也是普通 batch,只是多个在 pipeline 中交替执行

Microbatch 和 batch 在数据结构上没有区别(都是 ScheduleBatch),区别只在于 PP 需要多个 batch 同时在不同 stage 中流转,所以给它们编号叫 microbatch。

Microbatch 数量不是越多越好

1
pp_loop_size = pp_size + pp_async_batch_depth
  • 太少(= pp_size):pipeline 容易断流,一个 stage 稍慢就导致下游空闲
  • 太多:显存占用线性增长(每个 microbatch 都要有自己的 KV cache + activation),而且可能因为显存压力导致 batch size 被迫缩小,反而吞吐下降

调优经验

  • pp_async_batch_depth = 1~2 通常足够,再深收益递减
  • 显存充裕时可以试更大的 depth,但要注意 activation memory 的开销
  • 如果 single batch 的 token 数很大(长上下文),microbatch 多了容易 OOM

PP Rank 的请求接收路径

recv_requests() 根据 PP rank 的不同,有两条完全不同的接收路径。接收的请求包括两类:

  • 外部 API 请求:tokenize 后的推理请求(TokenizedGenerateReqInputTokenizedEmbeddingReqInput 等)
  • 内部控制请求:PP 同步用的控制消息(如 PPHiCacheEventsReq 等)

PP Rank 0(第一个 stage)

从上游组件直接接收用户请求,走两条 zmq 通道:

1
2
3
4
5
# 1. 从 tokenizer 接收(zmq 通道)
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)

# 2. 从 rpc 接收(zmq 通道)
recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)

接收的是 tokenize 后的请求对象,类型为:

  • TokenizedGenerateReqInput — 生成请求
  • TokenizedEmbeddingReqInput — embedding 请求
  • 其他控制类消息(flush、abort 等)

PP Rank > 0(后续 stage)

不直接从 tokenizer 接收,而是从前一个 PP rank P2P 转发过来:

1
2
3
4
5
6
7
recv_reqs = point_to_point_pyobj(
[],
self.pp_rank * self.tp_size + dp_offset, # 本 rank
self.world_group.cpu_group,
(self.pp_rank - 1) * self.tp_size + dp_offset, # src: 前一个 PP rank
self.pp_rank * self.tp_size + dp_offset, # dst: 本 rank
)

对应 event_loop_pp 中的流转

1
2
3
4
5
6
7
8
9
# event_loop_pp 第 80-88 行
recv_reqs = self.recv_requests() # PP0 从 zmq 收,PP1 从 PP0 的 P2P 收
self.process_input_requests(recv_reqs) # 本地处理

if not self.pp_group.is_last_rank:
# 步骤 2: send reqs to next PP stage
self.send_req_work = self._pp_send_pyobj_to_next_stage( # 转发给下一个 PP rank
recv_reqs, async_send=True,
)

流转全图

1
2
3
4
5
6
7
Tokenizer / RPC
│ zmq

PP Rank 0: recv_requests() ── point_to_point_pyobj (async) ──→ PP Rank 1: recv_requests()
│ │
▼ ▼
process_input_requests() process_input_requests()

关键设计:所有 PP rank 收到的内容是一样的(同样的请求列表),这样每个 stage 都能为同一批请求做 get_next_batch_to_run() 调度,保证各 stage 的 batch 一致。

process_input_requests:请求路由器

recv_requests() 收到请求后,process_input_requests() 负责按类型分发到对应的 handler:

1
2
3
4
5
6
7
8
9
10
11
12
13
# scheduler_pp_mixin.py 第 1545-1566 行
def process_input_requests(self, recv_reqs: List):
now = time.monotonic()
self.session_controller.maybe_reap(now) # 1. 清理过期 session
for recv_req in recv_reqs:
# 2. 健康检查请求:服务器忙时跳过
if is_health_check_generate_req(recv_req) and not self.is_fully_idle():
continue
# 3. 按类型分发
output = self._request_dispatcher(recv_req)
# 4. 有即时返回结果的,回送给 tokenizer/rpc
if output is not None:
self.send_to_tokenizer.send_output(output, recv_req)

_request_dispatcher 是一个 TypeBasedDispatcher,根据请求类型路由到不同 handler:

核心推理请求

请求类型 Handler 说明
TokenizedGenerateReqInput handle_generate_request 生成请求 → 加入 waiting_queue
TokenizedEmbeddingReqInput handle_embedding_request Embedding 请求
BatchTokenized* handle_batch_* 批量请求

控制类请求

请求类型 Handler 说明
AbortReq abort_request 中止请求
FlushCacheReqInput flush_cache_wrapped 清空 KV cache
OpenSessionReqInput open_session 开启会话
CloseSessionReqInput close_session 关闭会话
ProfileReq profile 性能分析
SlowDownReqInput slow_down 降速
PauseGenerationReqInput pause_generation 暂停生成

权重管理请求

请求类型 Handler 说明
UpdateWeightsFrom* update_weights_from_* 更新模型权重(disk/distributed/tensor/IPC)
LoadLoRAAdapterReqInput load_lora_adapter 加载 LoRA
UnloadLoRAAdapterReqInput unload_lora_adapter 卸载 LoRA

一句话总结process_input_requests 就是一个请求路由器——生成请求进入 waiting_queue 等待后续调度,控制请求立即处理并回送结果,所有 PP rank 都执行同样的分发逻辑保证状态一致。

_pp_recv_proxy_tensors:接收上游 Stage 的中间结果

_pp_recv_proxy_tensorsscheduler_pp_mixin.py 第 993-1004 行)负责从上一个 PP stage 接收 forward 的中间激活值:

1
2
3
4
5
6
7
8
9
10
def _pp_recv_proxy_tensors(self: Scheduler) -> Optional[PPProxyTensors]:
pp_proxy_tensors = None
if not self.pp_group.is_first_rank:
pp_proxy_tensors = PPProxyTensors(
self._pp_recv_typed_dict(
expected_kind="proxy",
all_gather_group=self.attn_tp_group if self.require_attn_tp_allgather else None,
)
)
return pp_proxy_tensors

行为

  • PP Rank 0(first rank):返回 None,因为它是第一个 stage,没有上一级给它传 hidden states,直接用 embedding 层的输出开始 forward。代码里有 is_first_rank 判断,PP0 不会做任何接收操作。
  • PP Rank > 0:从前一个 PP rank 接收 tensor dict(msg_type="proxy"),包装成 PPProxyTensors

Proxy tensors 是什么

profile_and_init_predictor 里的构造(第 609-622 行)就清楚了:

1
2
3
4
5
proxy_tensors = {
"hidden_states": torch.zeros((seq_len, hidden_size), ...),
"residual": torch.zeros((seq_len, hidden_size), ...),
}
pp_proxy = PPProxyTensors(proxy_tensors)

就是上一个 PP stage forward 输出的中间激活值:

  • hidden_states — 当前层的隐藏状态
  • residual — 残差连接

在 event_loop_pp 中的位置

1
2
3
4
5
6
7
if self.cur_batch:
pp_proxy_tensors = self._pp_recv_proxy_tensors() # rank>0 在这里阻塞等 recv
...
# 步骤 5: run_batch → 实际调用 _pp_launch_batch()
result, self.launch_event = self._pp_launch_batch(
mb_id, pp_proxy_tensors, ... # 传入 forward,接着上一个 stage 继续算
)

对应的发送端:在 event_loop_pp 末尾(第 129-139 行),非 last rank 发送 proxy:

1
2
3
4
5
6
7
8
if not self.pp_group.is_last_rank:
if self.cur_batch:
torch.cuda.current_stream().wait_event(self.launch_event) # 等 forward 完成
self.send_proxy_work = self._pp_send_dict_to_next_stage(
result.pp_hidden_states_proxy_tensors.tensors, # forward 输出的 hidden_states
async_send=True,
msg_type="proxy",
)

一句话总结_pp_recv_proxy_tensors 就是非首 rank 从前一个 PP stage 接收 forward 中间结果(hidden_states + residual),这样本 stage 的模型层才能接着算。这是 PP 的核心数据流——每个 stage 只有部分层,需要前一个 stage 的输出作为输入。

_pp_recv_dict_from_prev_stage(接收 output)与 _pp_recv_proxy_tensors(接收 proxy)的区别

这两个接收操作虽然底层都走 _pp_recv_typed_dict,但它们在 语义、时机、数据内容、通信方向 上完全不同。

一、核心区别对比

维度 _pp_recv_proxy_tensors(proxy) _pp_recv_dict_from_prev_stage(output)
msg_type "proxy" "output"
语义 中间隐藏状态(模型前向传播的中间产物) 最终输出(next_token_ids + logprob)
数据内容 hidden_states + residual(形状 [num_tokens, hidden_size] next_token_ids(形状 [batch_size])+ 可选的 logprob 张量
通信方向 PP(k-1) → PP(k),前向传播方向 PP(last) → PP(0),反向回传方向(或非 last rank 的转发)
对应的发送端 上一个 PP stage 的 _pp_send_dict_to_next_stage(..., msg_type="proxy") last rank 的 _pp_send_dict_to_next_stage(..., msg_type="output")
用途 作为当前 stage 模型 forward 的输入 作为 batch 后处理(process_batch_result)的输入
时机 _pp_launch_batch(GPU forward)之前接收 在 GPU forward 之后/并行接收(用于处理上一轮的结果)
接收者 非 first rank(PP1, PP2, …) 所有 rank(PP0 从 last rank 收,中间 rank 从上一个 rank 转发)

二、在流水线中的位置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
PP Rank 0              PP Rank 1              PP Rank 2 (Last)
───────── ────────── ──────────────
_pp_recv_proxy_tensors
← proxy (hidden_states + residual)
_pp_launch_batch (forward)
_pp_send_dict_to_next_stage
→ proxy (hidden_states + residual)
← proxy (hidden_states + residual)
_pp_launch_batch (forward + sample)
_pp_send_dict_to_next_stage
→ output (next_token_ids + logprobs)
← output (next_token_ids + logprobs)
_pp_recv_dict_from_prev_stage (msg_type="output")
_pp_commit_send_output_work_and_preprocess_output_tensors
process_batch_result (更新请求状态、流式输出)

三、数据内容详解

Proxy 张量(中间隐藏状态)

1
2
3
4
5
# 发送端(_pp_launch_batch 之后)
result.pp_hidden_states_proxy_tensors.tensors = {
"hidden_states": tensor([num_tokens, hidden_size]), # 模型中间层输出
"residual": tensor([num_tokens, hidden_size]), # 残差连接
}

这是模型被纵向切分后,前半部分层的输出。下一个 PP stage 需要它作为输入继续跑后半部分层。没有它,下一个 stage 无法执行 forward。

Output 张量(最终输出)

1
2
3
4
5
6
7
8
9
# 发送端(_pp_prepare_tensor_dict)
tensor_dict = {
"next_token_ids": tensor([batch_size]), # 采样得到的下一个 token
# 可选:
"input_token_logprobs": ...,
"normalized_prompt_logprobs": ...,
"prefill_top_logprobs": ...,
...
}

这是 last rank 完成整个模型 forward + 采样后的最终结果。PP0 需要它来执行后处理(更新请求状态、判断是否结束、发送给 tokenizer 等)。

四、为什么需要分开?

  1. 时序不同:proxy 必须在 forward 前到位(是 forward 的输入);output 可以在 forward 后异步接收(是上一轮的结果,用于 CPU 后处理)。

  2. 方向不同:proxy 沿流水线正向流动(rank 0→1→2→…→last);output 从 last rank 反向回到 rank 0(形成环路)。

  3. 同一条 P2P 链路上交错到达:由于 PP 使用环形拓扑(last rank 的 next 就是 rank 0),同一对 rank 之间可能同时有 proxy 和 output 在传输。_pp_recv_typed_dict 的 demux 机制(按 msg_type 分拣)就是为了解决这个问题——收到不匹配类型的消息时暂存到 inbox,等需要时再取出。

  4. 大小差异巨大:proxy 张量很大(num_tokens × hidden_size,可能几十 MB);output 张量很小(batch_size 个 int,几 KB)。这影响通信调度策略。

五、一句话总结

  • _pp_recv_proxy_tensors:接收的是 “半成品” ——上一个 stage 跑完前半部分模型层后的隐藏状态,当前 stage 要拿它继续跑后半部分。
  • _pp_recv_dict_from_prev_stage(output):接收的是 “成品” ——last rank 跑完整个模型后采样出的 token,PP0 拿它做后处理(更新状态、返回用户)。

_pp_launch_batch:执行当前 Microbatch 的 Forward

_pp_launch_batchscheduler_pp_mixin.py)对应 event_loop_pp 流程中的 step 5: run_batch,负责启动当前 microbatch 的 forward 计算:

1
2
3
4
5
6
7
8
9
# 伪代码示意
def _pp_launch_batch(self, mb_id, pp_proxy_tensors, ...):
batch = self.mbs[mb_id]
# 如果有 proxy tensors(非首 rank),用 proxy 的 hidden_states 作为输入
# 否则用 embedding 输出作为输入
result = self.worker.forward_batch(batch, pp_proxy_tensors)
launch_event = torch.cuda.Event()
launch_event.record()
return result, launch_event

关键行为

  • PP Rank 0(first rank):不需要 proxy tensors,直接用 token embeddings 作为 forward 输入
  • PP Rank > 0:用 _pp_recv_proxy_tensors() 收到的 hidden_states + residual 作为 forward 输入,接着上一个 stage 继续算
  • 返回 result(包含 pp_hidden_states_proxy_tensors,供后续发送给下一个 stage)和 launch_event(用于同步)

在 event_loop_pp 中的调用位置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
for mb_id in range(self.pp_loop_size):
...
# 步骤 4: 接收 proxy tensors
pp_proxy_tensors = self._pp_recv_proxy_tensors()

# 步骤 5: run_batch → _pp_launch_batch
result, self.launch_event = self._pp_launch_batch(
mb_id, pp_proxy_tensors, ...
)

# 步骤 7: 发送 proxy tensors 给下一个 stage
if not self.pp_group.is_last_rank:
torch.cuda.current_stream().wait_event(self.launch_event)
self.send_proxy_work = self._pp_send_dict_to_next_stage(
result.pp_hidden_states_proxy_tensors.tensors,
async_send=True, msg_type="proxy",
)

一句话总结_pp_launch_batch 就是 PP 路径下的 forward 执行器——首 rank 用 embedding 输入,非首 rank 用上游传来的 hidden_states 输入,执行当前 stage 的模型层计算后输出给下一个 stage。

_pp_commit_send_output_work_and_preprocess_output_tensors:Output 的收发与预处理(对应 step 6: recv & process prev mb result)

这个函数(scheduler_pp_mixin.py 第 854-873 行)对应 event_loop_pp 流程中的 step 6: recv & process prev mb result,负责处理 forward 结果的 output tensors 的发送、接收和预处理:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def _pp_commit_send_output_work_and_preprocess_output_tensors(
self: Scheduler,
next_first_rank_mb_id: int,
next_mb_id: int,
) -> Tuple[PPProxyTensors, GenerationBatchResult, torch.cuda.Event]:
# 1. 等待上一轮的 output 发送完成
self._pp_commit_comm_work(work=self.send_output_work)
# 2. 发送 + 接收 + 预处理
(next_pp_outputs, next_batch_result, d2h_event, self.send_output_work) = \
self._pp_send_recv_and_preprocess_output_tensors(
next_first_rank_mb_id, next_mb_id,
self.mbs, self.mb_metadata, self.last_rank_comm_queue, self.pp_outputs,
)
return next_pp_outputs, next_batch_result, d2h_event

核心逻辑在 _pp_send_recv_and_preprocess_output_tensors(第 1081-1116 行),做了三件事

1. 发送 output(last rank → rank 0)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Last rank: 从 last_rank_comm_queue 取出 forward 结果,发给 rank 0
if self.pp_group.is_last_rank:
q_event, pp_outputs_to_send = last_rank_comm_queue.popleft()
torch.cuda.current_stream().wait_event(q_event)
send_output_work = self._pp_send_dict_to_next_stage(
pp_outputs_to_send.tensors, msg_type="output"
)

# 中间 rank: 把从前一个 rank 收到的 output 继续转发给下一个 rank
if not self.pp_group.is_last_rank:
if pp_outputs:
send_output_work = self._pp_send_dict_to_next_stage(
pp_outputs.tensors, msg_type="output"
)

2. 接收 output(rank 0 从 last rank 收)

1
2
if mbs[next_mb_id] is not None:
next_pp_outputs = PPProxyTensors(self._pp_recv_dict_from_prev_stage())

3. 预处理:GPU→CPU 拷贝(copy stream 上)

1
2
3
4
5
6
7
with self.copy_stream_ctx:
self.copy_stream.wait_stream(self.schedule_stream)
batch_result = self._pp_prep_batch_result( # 解包 next_token_ids, logprobs
mbs[next_mb_id], mb_metadata[next_mb_id], next_pp_outputs
)
d2h_event = torch.cuda.Event()
d2h_event.record(torch.cuda.current_stream()) # 记录 D2H 完成事件

Output tensors 里有什么

_pp_prepare_tensor_dict(第 920-933 行):

1
2
3
4
5
tensor_dict = {
"next_token_ids": result.next_token_ids, # 采样出的 token id
}
if batch.return_logprob:
tensor_dict.update(logprob_dict) # logprob 相关张量

这是 last rank 采样后的最终结果,不是中间的 hidden states。

在 event_loop_pp 中的调用位置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
for mb_id in range(self.pp_loop_size):
...
# ① 处理【上一个 microbatch】的 output 收发 + 预处理
if self.server_args.pp_async_batch_depth > 0:
next_pp_outputs, next_batch_result, d2h_event = \
self._pp_commit_send_output_work_and_preprocess_output_tensors(
next_first_rank_mb_id, next_mb_id,
)

# ② 执行【当前 microbatch】的 forward
result, self.launch_event = self._pp_launch_batch(mb_id, ...)

# ③ async_depth==0 时,forward 后再处理 output
if self.server_args.pp_async_batch_depth == 0:
next_pp_outputs, next_batch_result, d2h_event = \
self._pp_commit_send_output_work_and_preprocess_output_tensors(...)

# ④ 等 D2H 完成,处理结果
if self.mbs[next_mb_id] is not None:
d2h_event.synchronize()
self._pp_process_batch_result(self.mbs[next_mb_id], next_batch_result)

pp_async_batch_depth 的影响

pp_async_batch_depth output 处理时机 效果
> 0 forward 之前处理上一个 mb 的 output output 处理和当前 forward 重叠,隐藏延迟
== 0 forward 之后再处理 串行执行,不重叠

一句话总结_pp_commit_send_output_work_and_preprocess_output_tensors 负责把 last rank 的采样结果(next_token_ids、logprobs)从 GPU 传回 rank 0,并在 rank 0 上做 GPU→CPU 拷贝和结果解包。这是 pipeline 的”最后一公里”——把最终输出送回给 tokenizer/rpc 层。

next_batch_result 是什么

next_batch_resultnext_mb_id 对应的那个 microbatch 的最终输出结果(GenerationBatchResult),包含 next_token_ids 等。这里的逻辑是:当前迭代处理的是 mb_id,但同时要处理上一轮 next_mb_id 的结果(因为流水线的延迟,next_mb_id 的 batch 在之前的迭代中已经完成了 GPU forward,output 在这一轮才到达)。

Output 与 Proxy 的发送完全分开

_pp_commit_send_output_work_and_preprocess_output_tensors 只发 output,不发 proxy

  • outputnext_token_ids):通过 _pp_send_output_to_next_stage 发送(msg_type="output"),方向是 last rank → rank 0(环形)
  • proxyhidden_states + residual):在 event_loop_pp 循环末尾单独发送(msg_type="proxy"),方向是 rank k → rank k+1(正向)

两者完全独立,不会混在一起。

_pp_commit_comm_work:延迟等待异步通信完成

1
2
3
4
def _pp_commit_comm_work(self: Scheduler, work: List[P2PWork]) -> None:
for p2p_work in work:
p2p_work.work.wait()
work.clear()

作用:等待上一轮的异步通信完成。

因为 _pp_send_pyobj_to_next_stage_pp_send_dict_to_next_stage 都用了 async_send=True,返回的是一个 P2PWork 列表(底层是 dist.isend 的 future)。_pp_commit_comm_work 就是在下一轮迭代中调用 .wait() 来确保上一轮的发送已经完成,然后清空 work 列表。

这是一种延迟等待的模式:发送时不阻塞,到下一轮再等待完成,从而让发送和 GPU 计算重叠。

_pp_commit_comm_work 的调用位置

在 event_loop_pp 中,_pp_commit_comm_work 被调用于两个地方:

1
2
3
4
5
6
7
8
9
10
for mb_id in range(self.pp_loop_size):
...
# ① 等待上一轮的 output 发送完成
self._pp_commit_comm_work(work=self.send_output_work)

# ② 执行 forward
result, self.launch_event = self._pp_launch_batch(mb_id, ...)

# ③ 等待上一轮的 proxy 发送完成
self._pp_commit_comm_work(work=self.send_proxy_work)
  • self._pp_commit_comm_work(work=self.send_output_work) — 等待上一轮 output tensor 的 async send 完成
  • self._pp_commit_comm_work(work=self.send_proxy_work) — 等待上一轮 proxy tensor 的 async send 完成

两者都是延迟等待,让上一轮的通信和当前轮的 GPU 计算重叠。

数据流全图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
Last Rank          中间 Rank              Rank 0
───────── ────────── ──────
forward 完成

├─ next_token_ids 入 comm_queue

└─ send "output" ──────────→ recv & 转发 "output" ──────────→ recv "output"

copy_stream:
prep_batch_result (D2H)

d2h_event.sync()

process_batch_result()
(更新请求状态、流式输出)

PP 路径中的 HiCache(分层 KV Cache)

HiCache(Hierarchical Cache)是 SGLang 的多级 KV Cache 管理机制,把 KV cache 存在 L3(CPU 内存或磁盘),需要时再加载到 GPU。在 PP scheduler 中,HiCache 不在 event_loop_pp 本身,而是在 event loop 调用的子流程中发挥作用。

HiCache 在 PP 调度中的四个阶段

1. 请求入队时 — 预取(Prefetch)

1
2
3
4
5
6
7
8
9
# process_input_requests → handle_generate_request → _add_request_to_queue → _prefetch_kvcache
# 第 1868-1888 行
def _prefetch_kvcache(self, req: Req):
if self.enable_hicache_storage:
req.init_next_round_input(self.tree_cache, cow_mamba=False)
# 从外部存储(L3)异步预取 KV cache 到 GPU
self.tree_cache.prefetch_from_storage(
req.rid, last_host_node, new_input_tokens, last_hash, prefix_keys
)

时机:请求刚进入 waiting_queue 时,立即发起异步预取,利用等待调度的时间把 KV cache 从外部存储搬到 GPU。

2. 调度组 batch 时 — 检查预取进度 + 加载

get_next_batch_to_run_get_new_batch_prefill_raw 中有三处:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# 2a. 检查异步事件完成(第 2282-2283 行)
if self.enable_hierarchical_cache:
self.tree_cache.check_hicache_events() # 推进异步 D2H/H2D 事件

# 2b. 逐请求检查预取是否完成(第 2390-2398 行)
for req in self.waiting_queue:
...
if self.enable_hicache_storage:
prefetch_done = self.tree_cache.check_prefetch_progress(req.rid)
if not prefetch_done:
continue # 预取未完成,跳过这个请求,不放进本轮 batch
# 预取完成,拿到从 L3 加载了多少 token 的 KV
req.storage_hit_length = self.tree_cache.pop_prefetch_loaded_tokens(req.rid)

# 2c. batch 创建后,准备从 host 加载到 GPU(第 2459-2463 行)
if self.enable_hierarchical_cache:
new_batch.hicache_consumer_index = (
self.tree_cache.ready_to_load_host_cache() # 标记哪些 KV 需要从 host→GPU
)

3. 请求被 abort 时 — 释放预取资源

1
2
3
# 在多个 abort 路径中(第 1965、1997、3146 行)
if self.enable_hicache_storage:
self.tree_cache.release_aborted_request(req.rid) # 取消预取,释放资源

4. flush cache 时 — 等待异步操作完成

1
2
# 第 2877 行
# HiCache: in-flight async ops (GPU↔️Host↔️L3) must drain before flush

check_hicache_events():异步事件推进器

check_hicache_events() 是 HiCache 的核心——每轮调度前轮询一次,把已完成的 GPU↔️Host↔️L3 异步拷贝操作确认掉(解锁节点、释放资源、推进状态)。

1
2
3
4
5
6
# tree_cache.py 第 1136-1144 行
def check_hicache_events(self):
self.writing_check() # ① 检查 GPU→Host 写入完成
self.loading_check() # ② 检查 Host→GPU 加载完成
if self.enable_storage:
self.drain_storage_control_queues() # ③ 处理 L3 存储控制队列

① writing_check()(第 676-717 行)— GPU→Host 写入

KV cache 从 GPU 写到 Host 内存(备份),是异步的。这个函数检查哪些写入完成了:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 遍历 ack_write_queue 里的 CUDA event
for _, finish_event, ack_list in self.cache_controller.ack_write_queue:
if not finish_event.query(): # CUDA event 未完成,停止检查
break
finish_count += 1

# TP 多卡时取 min,确保所有 TP rank 都完成
torch.distributed.all_reduce(queue_size, op=ReduceOp.MIN, group=self.tp_group)

# 对已完成的写入:
for ack_id in ack_list:
backuped_node = self.ongoing_write_through.pop(ack_id) # 从进行中移除
self.dec_lock_ref(backuped_node) # 解锁,允许被驱逐
if self.enable_storage:
self.write_backup_storage(backuped_node) # 继续写到 L3 存储

作用:GPU 上的 KV cache 备份到 Host 完成后,解锁节点(可以被驱逐释放 GPU 显存),如果开了 L3 存储还会继续往 L3 写。

② loading_check()(第 719-732 行)— Host→GPU 加载

KV cache 从 Host 内存加载回 GPU,也是异步的:

1
2
3
4
5
6
7
8
9
for _, finish_event, ack_list in self.cache_controller.ack_load_queue:
if not finish_event.query(): # CUDA event 未完成
break
finish_count += 1
for ack_id in ack_list:
end_node = self.ongoing_load_back.pop(ack_id) # 从进行中移除
self.dec_lock_ref(end_node) # 解锁节点

del self.cache_controller.ack_load_queue[:finish_count]

作用:之前从 Host 加载回 GPU 的 KV cache 完成后,解锁节点,这些 KV cache 就可以被 forward 使用了。

③ drain_storage_control_queues()(第 1146-1172 行)— L3 存储控制

处理三个队列:

1
2
3
4
5
6
7
8
cc = self.cache_controller
qsizes = torch.tensor([
cc.prefetch_revoke_queue.qsize(), # 取消预取
cc.ack_backup_queue.qsize(), # L3 备份确认
cc.host_mem_release_queue.qsize(), # Host 内存释放
])
# TP 同步后统一处理
torch.distributed.all_reduce(qsizes, op=ReduceOp.MIN, group=self.tp_group)
队列 作用
prefetch_revoke_queue 取消不再需要的 L3→Host 预取
ack_backup_queue 确认 Host→L3 备份写入完成
host_mem_release_queue 释放不再需要的 Host 内存

HiCache 三级缓存的数据流

1
2
3
4
5
6
7
8
check_hicache_events() 推进的异步操作
════════════════════════════════════

GPU (L1) ◄────── loading_check() ────── Host (L2) ◄──── drain (prefetch from L3)
│ ▲
└──── writing_check() ────────────────────┘

└──── drain (backup to L3) ────→ L3 Storage

HiCache 数据流

1
2
3
4
5
6
7
8
9
10
请求进入 waiting_queue

├─ _prefetch_kvcache() → 异步预取 KV cache 从 L3 到 GPU

├─ get_next_batch_to_run()
│ ├─ check_hicache_events() → 推进 D2H/H2D 事件
│ ├─ check_prefetch_progress() → 检查预取是否完成
│ └─ ready_to_load_host_cache() → 标记 host→GPU 加载

└─ 预取完成 → 请求进入 batch → forward

一句话总结:HiCache 在 PP 调度中贯穿请求的完整生命周期——入队时异步预取、调度时检查进度、abort 时释放资源、flush 时等待完成。核心思想是利用请求在 waiting_queue 中的等待时间,提前把 KV cache 从 L3 搬到 GPU,隐藏加载延迟。

Host 内存驱逐:LRU 最小堆

当 host 内存不足时,evict_host() 使用 LRU 最小堆来决定驱逐哪些节点。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# 1. 收集候选节点
leaves = list(self.evictable_host_leaves)

# evictable_host_leaves 中的节点满足:
# - node.evicted = True(GPU 上的 value 已被驱逐,只剩 host_value)
# - node.lock_ref == 0(没有被锁定)
# - 没有 evicted 的子节点(是叶子)

# 2. 构建三元组堆
eviction_heap = [
(
self.eviction_strategy.get_priority(node), # 第一排序:LRU 优先级
self._evict_tie_breaker(node), # 第二排序:平局打破
node, # 节点本身
)
for node in leaves
]

# 3. 建堆 O(n)
heapq.heapify(eviction_heap)

# 4. 循环弹出堆顶,逐个驱逐
while num_evicted < num_tokens and len(eviction_heap):
_priority, _tb, x = heapq.heappop(eviction_heap)
num_evicted += self.cache_controller.evict_host(x.host_value)
x.parent.children.pop(key) # 从 radix tree 彻底删除

三元组比较

位置 内容 含义
第 1 个 get_priority(node) LRU 策略下就是 last_access_time,值越小 = 越久没访问 = 越优先驱逐
第 2 个 _evict_tie_breaker(node) 当 priority 相同时的打破平局:PP>1 用内容 hash(跨 rank 一致),PP=1 用 node.id
第 3 个 node 实际要驱逐的节点对象

PP 场景下的 tie-breaker 设计

  • PP=1:用 node.id 作为 tie-breaker,简单高效
  • PP>1:用内容 hash(node.hash_value)作为 tie-breaker,保证不同 rank 对相同节点计算出相同的优先级顺序,确保驱逐一致性

一句话总结:把所有可驱逐的 host 叶子节点放进 LRU 最小堆,循环弹出堆顶(最久没访问的)逐个驱逐,直到释放够 num_tokens 个 token 的 host 内存。tie-breaker 保证 PP 多 rank 场景下驱逐顺序一致。

PP + HiCache 的 Tree Cache 同步问题(PR #22878)

PP + HiCache 组合使用时存在一个经典问题:各 PP rank 的 writing_check() 独立消费 GPU→Host 的 write-ack,由于各 rank 的 batch 内容和异步进度不同,消费数量可能不一致,导致各 rank 的 radix tree 状态出现分歧(哪些节点已备份、哪些可以驱逐不一致)。

解决方案:上游 rank 的消费数作为下游的预算上限,通过 piggyback 在已有的 PP 请求转发中传递。

1. writing_check() 末尾,PP rank 0 记录本轮消费数

1
2
3
# hiradix_cache.py
if self.pp_size > 1 and self.pp_rank < self.pp_size - 1:
self._pp_last_write_ack_consumed += consumed_count

2. 通过 recv_reqs 捎带传给下游scheduler_pp_mixin.py

1
2
3
4
5
6
7
8
9
# 非 last rank 发送请求时,把 ack_count 捎带上
if self.enable_hicache_storage:
ack_count = self.tree_cache.get_pp_last_write_ack_consumed()
if ack_count > 0:
pp_send_payload = {
"recv_reqs": recv_reqs, # 原始请求
"pp_write_ack_count": ack_count, # 捎带的 ack 消费数
}
self.send_req_work = self._pp_send_pyobj_to_next_stage(pp_send_payload)

3. 下游 rank 接收并设置预算scheduler.py

1
2
3
4
5
# recv_requests() 中,PP rank > 0 解包捎带的 ack_count
if self.pp_rank > 0 and isinstance(recv_reqs, dict) and "pp_write_ack_count" in recv_reqs:
pp_write_ack_count = recv_reqs["pp_write_ack_count"]
recv_reqs = recv_reqs["recv_reqs"]
self.tree_cache.set_pp_upstream_write_ack_count(pp_write_ack_count)

4. 下游 rank 的 writing_check 受限于上游预算

1
2
3
4
# writing_check() 中,下游 rank 不能消费超过上游的数量
if self.pp_rank > 0 and self._pp_write_ack_budget_from_upstream is not None:
finish_count = min(finish_count, self._pp_write_ack_budget_from_upstream)
self._pp_write_ack_budget_from_upstream -= finish_count

数据流

1
2
3
4
5
6
7
8
9
10
11
12
13
14
PP Rank 0                      PP Rank 1
───────── ─────────
writing_check()
consumed 3 acks ──→ _pp_last_write_ack_consumed = 3

├─ send_pyobj({recv_reqs: [...], pp_write_ack_count: 3})

└──────────→ recv_requests() 解包
_pp_write_ack_budget = 3

writing_check()
本地完成了 5 个
但 min(5, 3) = 3 ← 只消费 3 个
保持和 Rank 0 一致

设计亮点:不新增通信通道,而是 piggyback 在已有的 _pp_send_pyobj_to_next_stage 上,仅 +61 行代码解决问题。

一句话总结:把上游 PP rank 的 write-ack 消费数量捎带在 recv_reqs 的 pyobj 通信上传给下游,下游以此为上限限制自己的消费,确保所有 PP rank 的 radix tree 状态变更保持一致,防止缓存驱逐决策不同步。

PP + HiCache 一致性修复进度(Issue #22607):

任务 PR 状态
CP 同步 PR #20460
逻辑时钟替代 time.monotonic() PR #22759(解决出错 3)
PD 模块 PR #22771
Mooncake CP 写控制 + PP storage key 隔离 待做
Channel A: Host tree event replay 待做
Channel B: Write-back count sync PR #22878 ✅
Channel C: L3 hit delegation 待做
Zero-hit deferred revoke 待做
PP prefill 诊断工具 待做

Channel A:Host Tree Event Replay

Channel B(write-back count sync)只限制 PP1 不超前于 PP0,但不能保证两个 rank 在每一轮都消费相同数量。如果 PP1 的 CUDA event 本地没完成,min(local_ready=0, budget=3) = 0,PP1 这轮仍然消费 0。这意味着 host tree 的状态变更时机在两个 rank 之间可以有偏差,一旦偏差积累,就会导致 host tree 结构不同。

Channel A 的思路:不指望两个 rank 独立做出相同决策,而是让 PP0 做决策,PP1 回放 PP0 的结果。

具体场景:PP2 下的 host tree 分歧

假设 PP2 配置:PP0 管 layer 0-29,PP1 管 layer 30-59,两个 rank 各自独立维护一棵 radix tree。

第 1 步:正常运行,两个 rank 一致

请求 A、B、C 进来,两个 rank 收到相同请求、做相同调度。节点 X(系统提示词的 KV cache)在两个 rank 的 GPU 上都存在:

1
2
3
PP0 radix tree: root → X(value=GPU, host_value=None, hit_count=5)
PP1 radix tree: root → X(value=GPU, host_value=None, hit_count=5)
✅ 一致

第 2 步:hit_count 达到阈值,触发 write_backup

1
2
3
PP0: write_backup(X) → CUDA 异步拷贝启动 → X 加入 ongoing_write_through
PP1: write_backup(X) → CUDA 异步拷贝启动 → X 加入 ongoing_write_through
✅ 一致

第 3 步:writing_check 完成时间分歧(关键!)

1
2
3
4
5
6
7
PP0: CUDA event 已完成 → finish_count=1 → X.lock_ref=0
X.backuped=True, X.host_value 有值
X 进入 evictable_leaves(可被驱逐)

PP1: CUDA event 未完成 → finish_count=0 → X.lock_ref=1 仍然锁定
X 不在 evictable_leaves 中
⚠️ 开始分歧

Channel B 此时 PP0 把 consumed_count=1 传给 PP1,PP1 的 budget 变成 1。但 PP1 本地 finish_count=0min(0, 1)=0,这轮 PP1 仍然消费 0。Channel B 只保证 PP1 不会超过 PP0,不能强迫 PP1 追上 PP0。

第 4 步:GPU 内存压力 → PP0 驱逐 X

1
2
3
4
5
6
7
8
PP0: X 在 evictable_leaves 中(lock_ref=0, backuped=True)
→ _evict_backuped(X) → X.value=None → X.evicted=True
X 现在只存在于 Host 内存

PP1: X 不在 evictable_leaves(lock_ref=1, 还在 ongoing_write_through)
→ 驱逐了另一个节点 Y
X 仍在 GPU 上
⚠️ 分歧加大

第 5 步:PP0 的 Host 内存也满了 → evict_host 删除 X

1
2
3
4
5
6
7
8
PP0: write_backup(Z) → host 内存不足 → evict_host()
→ X 是 evictable_host_leaves 中优先级最低的
→ parent.children.pop(X) → X 从 host tree 彻底删除!

PP1: host 内存没有压力(X 没被驱逐到 host)
→ 不触发 evict_host
→ X 仍在 radix tree 中
❌ 两个 rank 的 tree 结构不同了

第 6 步:新请求匹配前缀 → crash

1
2
3
4
5
6
7
PP0: _match_prefix_helper → 找不到 X → device_indices=[], host_hit_length=0
PP1: _match_prefix_helper → 匹配到 X → device_indices=X.value, host_hit_length=0

PP0: prefix_indices 长度 = 0 → extend_input_len = 完整长度
PP1: prefix_indices 长度 = len(X) → extend_input_len = 更短

两个 rank 对同一个请求算出不同的 extend_input_len → batch 形状不一致 → crash

Channel A 如何解决

1
2
3
4
5
6
7
8
9
10
11
12
13
PP0 (权威源)              PP1 (回放)
─────────────── ─────────
writing_check: X 完成
emit FINALIZE(hash=X_hash) ─────→ replay FINALIZE: 标记 X 为 backuped
(不管本地 CUDA event 是否完成)

evict: 驱逐 X
emit EVICT(hash=X_hash) ─────→ replay EVICT: 也驱逐 X

evict_host: 删除 X
emit REVOKE(hash=X_hash) ─────→ replay REVOKE: 也从 host tree 删除 X

PP1 的 host tree 始终和 PP0 一致 ✅

三个 Channel 的分工

1
2
3
4
5
6
7
8
9
10
11
12
13
14
问题链条:
writing_check 时机不同 → lock_ref 不同 → evictable 集合不同
→ 驱逐不同 → host tree 结构不同 → match_prefix 不同 → crash

Channel B (count sync): ────────┐
限制 PP1 不超前于 PP0 │ 减小分歧窗口,
减少 lock_ref 偏差 │ 但不能消除

逻辑时钟 + tie-breaker: ────────┤ 保证驱逐顺序在"相同输入"下一致
LRU 顺序确定性 │ 但前提是输入(evictable 集合)相同

Channel A (event replay): ────────┘ 根本解决:PP1 不独立决策
PP0 的 host tree 变更直接回放到 PP1
host tree 结构保证一致

一句话总结:Channel B + 逻辑时钟是”努力让两个 rank 独立做出相同决策”,但由于异步操作的固有非确定性,无法 100% 保证。Channel A 是”放弃独立决策,直接同步结果”,从根本上消除分歧。

PP 路径的调度细节

get_next_batch_to_run

这个调度器在 PP 路径下要考虑的东西比 normal 多得多:

  1. 当前 stage 的显存 — 不能 OOM
  2. Microbatch 的依赖关系 — 上下游 stage 的进度
  3. Pipeline 的吞吐 — 尽量让每个 stage 都有活干

Async Send/Recv 的 Overlap

第 2 步 send reqs 和 recv proxy tensors 是 async 的,理想情况下计算和通信能 overlap。但如果 recv 阻塞了,整个 loop 就卡住了。

1
2
Stage 0: [compute mb0] → [send hidden to S1] → [recv hidden from S1] → [compute mb1]
Stage 1: [recv hidden from S0] → [compute mb0] → [send hidden to S0] → [recv hidden from S0]

通信和计算的重叠程度决定了 pipeline 的效率。

PP 路径 vs 非 PP 路径对比

非 PP 路径(event_loop_normal / event_loop_overlap)

1
2
时间 →
GPU: [ batch 0 forward ] → [ batch 1 forward ] → [ batch 2 forward ]

调度一个 batch,forward 一个 batch,处理结果,再调度下一个。同一时刻只有一个 batch。

PP 路径(event_loop_pp)

1
2
3
时间 →
Stage 0: [ mb 0 ][ mb 1 ][ mb 2 ][ mb 3 ]
Stage 1: [ mb 0 ][ mb 1 ][ mb 2 ][ mb 3 ]

多个 microbatch 同时在不同 stage 中流转,主循环轮转处理每个 microbatch。

总结

SGLang 的 PP 实现有几个关键设计:

  1. 独立事件循环 — PP 路径和 normal/overlap 路径完全隔离,避免复杂度交叉
  2. Microbatch 填充 bubble — 通过 pp_loop_size = pp_size + async_depth 控制 pipeline 中的并发度
  3. Async 通信 — send/recv 异步化,尽量 overlap 计算和通信
  4. 独立状态管理 — 每个 microbatch 槽位有独立的 running_batch、last_batch 等状态

PP 的核心挑战始终是 如何减少 bubble、提升 GPU 利用率,同时控制显存开销。Microbatch 数量的调优是 PP 性能调优的关键。

参考

FlashMLA Sparse Decode 完整计算过程详解

本文以一个最小化的具体数值例子,逐步展示 FlashMLA Sparse Decode 的每一个计算步骤,并标注每一步与 Hopper (SM90) / Blackwell (SM100/GB200) 硬件特性的关系。


0. 问题设定

参考配置

本文的模型参数来自 DeepSeek-V3.2 的官方 Hugging Face 配置:

以下是从 config.json 中提取的注意力相关参数,以及它们在 FlashMLA 内核中的对应关系:

config.json 参数 FlashMLA 内部参数 说明
kv_lora_rank 512 d_nope = 512 KV 的 LoRA 压缩维度,即 NoPE (Non-Positional Encoding) 部分的维度。在 MLA 中,KV cache 存储的是 LoRA 压缩后的向量,而非原始的 K/V
qk_rope_head_dim 64 d_rope = 64 RoPE (Rotary Position Embedding) 部分的维度
合计 512 + 64 = 576 d_qk = 576 Q/K 的总 head dimension = kv_lora_rank + qk_rope_head_dim
kv_lora_rank 512 d_v = 512 V 的 head dimension,等于 LoRA rank(在 MLA 中 V 和 K 的 NoPE 部分共享同一个压缩向量)
num_attention_heads 128 h_q = 128 Query head 数量
num_key_value_heads 128 config 中 KV head 数也是 128,但在 MLA 架构中,KV cache 只存储 1 份压缩向量(所有 head 共享),所以 FlashMLA 中 h_kv = 1 (MQA 模式)
qk_nope_head_dim 128 这是模型层面每个 head 的 NoPE 维度(128 × 128 heads = 16384 → 再经 LoRA 压缩为 512)。FlashMLA 操作的是压缩后的 512 维向量,不直接使用此参数
v_head_dim 128 同上,这是模型层面每个 head 的 V 维度。FlashMLA 操作的是压缩后的 512 维向量
quantization_config.fmt “e4m3” FP8_E4M3 KV cache 的 NoPE 部分使用 FP8 E4M3 格式量化
quantization_config.scale_fmt “ue8m0” UE8M0 (V3.2) / FP8_E8M0FNU (MODEL1) 量化缩放因子的格式,纯 2 的幂
index_topk 2048 topk 每个 query 关注的 top-k KV token 数量(稀疏注意力)
index_n_heads 64 参与稀疏索引选择的 head 数量
index_head_dim 128 用于索引选择的 head 维度

MLA (Multi-head Latent Attention) 的核心思想:不存储 128 个 head × 128 维的完整 KV cache (= 16384 维),而是存储一个 512 维的 LoRA 压缩向量 + 64 维的 RoPE 向量 (= 576 维)。这将 KV cache 压缩了 28 倍 (16384 → 576)。FlashMLA 的所有计算都在这个压缩空间中进行。

本文使用的示例参数

我们使用与 config.json 一致的模型配置,但将 batch size 和 topk 缩小以便手算:

模型配置 (来自 config.json):

1
2
3
4
5
6
d_qk = 576 = kv_lora_rank(512) + qk_rope_head_dim(64)
d_v = 512 = kv_lora_rank
d_nope = 512 = kv_lora_rank (NoPE 部分,做 FP8 量化)
d_rope = 64 = qk_rope_head_dim (RoPE 部分,保持 BF16 不量化)
h_q = 128 = num_attention_heads
h_kv = 1 (MLA 压缩后:所有 head 共享 1 份 KV cache)

示例 Batch 配置 (缩小规模以便展示):

1
2
3
4
b = 2 (batch size)
s_q = 1 (每个请求只 decode 1 个 token)
topk = 128 (原始 config 中 index_topk=2048, 此处缩小便于计算)
page_block_size = 64

硬件配置 (H800 / SM90):

1
num_sms = 132 (SM 数量)

注意: 实际部署中 topk=2048,h_q=128,这意味着每个 query token 需要从 KV cache 中 gather 2048 个 token 并做注意力计算。本文为了展示清晰,将 topk 缩小为 128。

为了简化展示,我们只跟踪 2 个 token(token 0 和 token 1)在一个 head 上的完整计算,其余的计算过程完全相同。


第 1 步:Python 层入口

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
sched_meta, _ = flash_mla.get_mla_metadata() # 创建空的调度元数据

out, lse = flash_mla.flash_mla_with_kvcache(
q,
# shape: [batch_size, seq_len_q, num_heads_q, head_dim]
# = [2, 1, 128, 576] dtype=BF16
# - batch_size=2: 同时处理 2 个请求
# - seq_len_q=1: decode 阶段每次只生成 1 个新 token
# - num_heads_q=128: 来自 config.json 的 num_attention_heads=128
# - head_dim=576: = kv_lora_rank(512) + qk_rope_head_dim(64)
# 即每个 head 的 Q 向量是 576 维 (512 维 NoPE + 64 维 RoPE)

k_cache,
# shape: [num_blocks, page_block_size, num_heads_k, bytes_per_token]
# = [num_blocks, 64, 1, 656] dtype=FP8
# - num_blocks: KV cache 的总页数 (由推理框架的 paged attention 管理)
# - page_block_size=64: 每个页面存放 64 个 token 的 KV 向量
# - num_heads_k=1: MLA 的核心 — 所有 128 个 head 共享 1 份压缩后的 KV cache
# - bytes_per_token=656: 每个 token 的 FP8 KV cache 大小:
# 512 字节 (NoPE, FP8_E4M3) + 16 字节 (4×float32 缩放因子) + 128 字节 (RoPE, BF16)

block_table=None, # 稀疏模式不需要 (dense 模式才需要页表做地址翻译)
cache_seqlens=None, # 稀疏模式不需要 (dense 模式才需要知道每个序列的实际长度)
head_dim_v=512, # V 的 head 维度 = kv_lora_rank = 512

tile_scheduler_metadata=sched_meta,
is_fp8_kvcache=True,

indices=indices,
# shape: [batch_size, seq_len_q, topk]
# = [2, 1, 128] dtype=int32
# - batch_size=2: 对应 2 个请求
# - seq_len_q=1: 每个新 token 各有自己的 top-k 索引列表
# - topk=128: 每个 query token 关注的 KV token 数量 (实际部署中为 index_topk=2048)
# 值的含义:indices[i][j][k] = global_token_index
# 例:indices[0][0][3] = 47 表示 batch 0 的 query 的第 4 个关注 token
# 位于第 0 个 page block (47÷64=0) 的第 47 个位置 (47%64=47)
)

# 返回值:
# out: [batch_size, seq_len_q, num_heads_q, head_dim_v] = [2, 1, 128, 512] BF16
# lse: [batch_size, num_heads_q, seq_len_q] = [2, 128, 1] FP32

关于 dense vs sparse 的选择: FlashMLA 本身不决定使用哪条路径。当调用方(vLLM/SGLang 等推理框架)传入 indices 参数时走 sparse 路径;不传 indices 而传入 block_table + cache_seqlens 时走 dense 路径。当上下文长度 ≤ index_topk(2048) 时,top-k 等于全部 token 数,稀疏无意义,框架层面应直接调用 dense 路径。

关键路由逻辑 (flash_mla_interface.py 第 151 行):

1
2
3
if topk is not None:
# topk 不为 None → 走稀疏路径
out, lse, ... = flash_mla_cuda.sparse_decode_fwd(...)

无硬件特性依赖,纯 Python 路由


第 2 步:C++ 接口层 — 输入验证与参数准备

文件: csrc/api/sparse_decode.h

2.1 架构检测与实现选择

1
2
3
4
5
6
7
8
9
10
11
Arch arch = Arch();
// arch.major=9, arch.minor=0 → SM90a (Hopper)
// 或 arch.major=10 → SM100f (Blackwell)

if (arch.is_sm100f()) {
if (h_q == 64) impl = new Decode_Sm100_Head64_Impl();
// SM100 使用 TMEM + UTCMMA
} else if (arch.is_sm90a()) {
impl = new Decode_Sm90_Impl();
// SM90 使用 TMA + WGMMA
}

🔧 硬件特性: cudaGetDeviceProperties 获取 compute capability。SM90 = Hopper, SM100 = Blackwell。

2.2 FP8 KV Cache 形状验证

1
2
3
4
5
// V3.2: 每个 token 占 656 字节
// 656 = 512 (NoPE FP8) + 16 (scales FP32) + 128 (RoPE BF16)
int bytes_per_token = 512 + 4*sizeof(float) + 64*sizeof(nv_bfloat16);
KU_CHECK_SHAPE(kv, num_blocks, page_block_size, h_kv, bytes_per_token);
// 即 kv shape = [num_blocks, 64, 1, 656]

656 字节的内存布局:

1
2
3
4
5
6
┌─────────────────────────────────────────────────────────┐
│ 偏移 0-511: 512 × FP8_E4M3 NoPE 部分 (量化) │
│ 偏移 512-527: 4 × float32 缩放因子 (每 128 个一组) │
│ 偏移 528-655: 64 × BF16 RoPE 部分 (不量化) │
│ 总计 656 字节 │
└─────────────────────────────────────────────────────────┘

🔧 硬件特性: FP8 (E4M3) 是 Hopper/Blackwell 原生支持的数据类型。Hopper 的 Tensor Core 可以直接做 FP8 矩阵乘法,但 FlashMLA 选择先反量化为 BF16 再做 WGMMA,以获得更高精度。

2.3 填充参数结构体

这一步是”翻译层”— 把 PyTorch tensor 的元信息(shape、stride、data_ptr)和标量参数,打包成 CUDA kernel 能直接使用的 C struct。kernel 内部只看这个 struct,不再接触 PyTorch 对象。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
SparseAttnDecodeParams params = {
// ===== 基础维度 (直接从输入 tensor 的 shape 读取) =====
.b = 2, // q.size(0) = batch_size,同时处理 2 个请求
.s_q = 1, // q.size(1) = seq_len_q,decode 阶段每次只生成 1 个新 token
.h_q = 128, // q.size(2) = num_attention_heads,来自 config.json
.h_kv = 1, // kv.size(2),MLA 压缩后 128 个 head 共享 1 份 KV cache
// 注意:计算时会被 broadcast 到 h_q=128
.d_qk = 576, // q.size(3) = head_dim = kv_lora_rank(512) + qk_rope_head_dim(64)
.d_v = 512, // head_dim_v 参数 = kv_lora_rank

// ===== Softmax 缩放因子 =====
// 标准 Scaled Dot-Product Attention: softmax(Q·K^T / sqrt(d)) · V
.sm_scale = 1.0 / sqrt(576) = 0.04167,
// 即公式中的 1/sqrt(d),防止点积值过大导致 softmax 饱和

.sm_scale_div_log2 = 0.04167 * log2(e) = 0.04167 * 1.4427 = 0.06010,
// 性能优化的预计算。kernel 中用 exp2f() 代替 expf() 计算 softmax:
// expf(x * sm_scale) = exp2f(x * sm_scale * log2(e))
// = exp2f(x * sm_scale_div_log2)
// GPU 上 exp2f() 比 expf() 快约 2 倍 (SFU 原生支持 base-2 指数运算),
// 所以预乘好 log2(e),避免 kernel 里每个元素都重复乘一次

// ===== 稀疏注意力参数 =====
.topk = 128, // indices.size(2),每个 query 关注的 KV token 数量
.model_type = ModelType::V32,
// 由 d_qk 决定:576 → V32 (DeepSeek-V3/V3.2), 512 → MODEL1
// 不同 model_type 的 FP8 KV cache 布局和缩放因子格式不同

// ===== 指针 (从 tensor.data_ptr() 获取) =====
.q = (bf16*)q.data_ptr(), // Q 矩阵的 GPU 内存地址
.kv = (bf16*)kv.data_ptr(), // FP8 KV cache 的 GPU 内存地址
.indices = (int*)indices.data_ptr(), // top-k 索引数组的 GPU 内存地址
.out = (bf16*)out.data_ptr(), // 输出矩阵的 GPU 内存地址
.lse = (float*)lse.data_ptr(), // log-sum-exp 的 GPU 内存地址

// ===== Stride (从 tensor.stride() 获取,单位是元素数) =====
// Stride 描述 tensor 在内存中的布局,让 kernel 知道如何从一个元素跳到下一个元素
.stride_q_b = q.stride(0), // 跳到下一个 batch: 1×128×576 = 73728
.stride_q_s_q = q.stride(1), // 跳到下一个 query token: 128×576 = 73728
.stride_q_h_q = q.stride(2), // 跳到下一个 head: 576
// ... kv, indices, out, lse 的 stride 类似
};

无硬件特性依赖,纯数据结构准备


第 3 步:Tile Scheduler — 工作分配

文件: csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu

这一步决定了 哪个 SM 处理哪些 batch 的哪些 KV block

3.1 计算每个 batch 的处理块数

⚠️ “block” 概念澄清: 这里的 “block” 不是 KV cache page block(64 个连续 token 组成的页面),而是对 top-k indices 列表的分块。

top-k 选择是 token 级别的 — indices 数组中存储的是 2048 个(本例中 128 个)散落在 KV cache 各处的 token 的绝对位置,这些 token 之间通常不连续。

由于 kernel 无法一次性处理所有 top-k token,所以按 TOPK_BLOCK_SIZE=64 将 indices 列表切块处理:

概念 Dense Decoding 的 “block” Sparse Decoding 的 “block”
含义 KV cache page block(64 个连续 token 组成的页面) indices 列表的处理块(64 个不连续 token index 为一组)
数据局部性 好(连续访问 global memory) 差(每个 token 可能在不同 page,需要随机 gather)
计算方式 按 block_table 顺序遍历所有页面 按 indices 逐 token gather,每个 index 指向 KV cache 中的任意位置
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
Launch: <<<1, 32>>> (单 CUDA block, 32 线程,即 1 个 warp)

indices[batch 0] = [token_47, token_2091, token_583, ..., token_8821] ← 共 128 个 token index (散落各处)

将 128 个 token index 按 TOPK_BLOCK_SIZE=64 分块处理:
处理块 0: indices[0:63] → 去 KV cache 中 gather 这 64 个 token (位于不同的 page)
处理块 1: indices[64:127] → gather 另外 64 个 token

对于 batch 0:
topk = 128, TOPK_BLOCK_SIZE = 64
num_blocks[0] = ceil(128/64) = 2 ← 需要 2 个处理块来遍历全部 128 个 top-k token

对于 batch 1:
topk = 128
num_blocks[1] = ceil(128/64) = 2

fixed_overhead_num_blocks = 5 (每个 batch 的固定 CUDA CTA 开销)

total_num_blocks = (2+5) + (2+5) = 14 ← 需要调度的 CUDA CTA 总数

🔧 硬件特性: 使用 __shfl_xor_sync 做 warp 内规约求和 — 这是 CUDA warp shuffle 指令,所有 NVIDIA GPU 支持。

注意: FP8 量化将每个 token 从 1152 字节 (576×BF16) 压缩到 656 字节,减少 ~43% 的内存带宽消耗。Decode 阶段是 memory-bound(瓶颈在显存带宽而非计算),因此减少每 token 的字节数可以直接提升吞吐量。

3.2 背景:SM (Streaming Multiprocessor) 是什么

SM 是 NVIDIA GPU 的基本计算单元。可以把 GPU 理解为一个拥有很多独立计算核心的处理器,每个 SM 就是其中一个核心。

  • 每个 SM 有自己的寄存器文件(65536 个 32-bit 寄存器)、共享内存(228KB,SM90)、Tensor Core(矩阵乘法加速器)、SFU(特殊函数单元,做 exp2f 等)
  • H800 有 132 个 SM,它们并行工作
  • 一个 CUDA kernel 启动时,会生成很多 **CTA (Cooperative Thread Array)**,也叫 thread block。GPU 的硬件调度器把这些 CTA 分配到各个 SM 上执行
  • 一个 SM 可以同时执行多个 CTA(如果资源允许),但 FlashMLA 的 kernel 每个 CTA 用了接近满载的寄存器和共享内存,所以通常一个 SM 只跑 1 个 CTA

Cluster 是 Hopper 新增的概念:多个 SM 组成一个 cluster,cluster 内的 SM 可以直接读写彼此的共享内存。FlashMLA 中 h_q=128 时 CLUSTER_SIZE=2,即 2 个 SM 组成一个 cluster,协作处理 128 个 head。

3.3 计算 SM Partition 数

1
2
3
4
5
6
7
8
9
10
11
12
h_q = 128 个 head
BLOCK_M = 64 个 head / CTA
→ 处理全部 head 需要 128/64 = 2 个 CTA (组成 1 个 cluster)

s_q = 1 (只有 1 个 query token)

num_sm_parts = max(132 / s_q / (h_q/64), 1)
= max(132 / 1 / 2, 1)
= 66

含义:132 个 SM,每 2 个 SM 组成一个 cluster 处理一个"任务单元",
所以有 66 个可用的"工位" (SM partition)

3.4 贪心调度算法

调度器的目标是:把所有 batch 的工作均匀分配到 66 个 partition 上

1
2
3
4
5
6
7
8
9
第一步:计算每个 partition 的负载上限 (payload)

payload = ceil(total_num_blocks / num_sm_parts) + fixed_overhead_num_blocks
= ceil(14 / 66) + 5
= 1 + 5
= 6

每个 partition 最多处理 6 个"单位"的工作。
(+5 是因为每接手一个新 batch 有固定成本:加载 Q、初始化 softmax 状态、写回结果等)

然后,单线程从左到右逐个 partition 分配工作 (代码第 66-91 行):

状态变量:

1
2
3
now_req_idx = 0    // 当前正在分配第几个 batch
now_block = 0 // 当前 batch 已经分配了多少 block
remain_payload = 6 // 当前 partition 剩余容量
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
═══════════ Partition 0, remain_payload=6 ═══════════

看 batch 0: 共 2 个 block
能否把整个 batch 0 放进来?需要 2(block) + 5(开销) = 7 ... 7 > 6, 放不下!

那放部分:remain_payload - fixed_overhead = 6 - 5 = 1 个 block
→ 只分配 batch 0 的 block 0
→ now_block = 1 (batch 0 还剩 1 个 block)
→ 容量用完,跳出

结果:
begin_req=0, end_req=0, begin_block=0, end_block=1
is_first_req_splitted=false ← batch 0 从这里开始
is_last_req_splitted=true ← batch 0 没处理完 (被 split 了!)

═══════════ Partition 1, remain_payload=6 ═══════════

继续 batch 0: 还剩 1 个 block
能否整个放进来?需要 1 + 5 = 6 ... 6 <= 6, 可以!
→ remain_payload = 6 - 6 = 0
→ batch 0 全部完成,now_req_idx=1, now_block=0
→ 容量用完,跳出

结果:
begin_req=0, end_req=0, begin_block=1, end_block=2
is_first_req_splitted=true ← batch 0 从中间开始 (上个 partition 没处理完)
is_last_req_splitted=false ← batch 0 在这里结束

═══════════ Partition 2, remain_payload=6 ═══════════

看 batch 1: 共 2 个 block, 同理 7 > 6 放不下整个
→ 只放 1 个 block

结果:
begin_req=1, end_req=1, begin_block=0, end_block=1
is_first_req_splitted=false, is_last_req_splitted=true

═══════════ Partition 3, remain_payload=6 ═══════════

继续 batch 1: 剩 1 个 block, 1+5=6 <= 6, 放进来
→ batch 1 完成

结果:
begin_req=1, end_req=1, begin_block=1, end_block=2
is_first_req_splitted=true, is_last_req_splitted=false

═══════════ Partition 4~65 ═══════════
所有 batch 已分配完,begin_req >= batch_size → kernel 直接 return

3.5 Split 与 Combine

当一个 batch 被分到多个 partition 时,就产生了 split — 每个 partition 只计算部分 KV block 的注意力,得到局部结果 (partial O, partial LSE)。最后需要 combine kernel 把各个 split 的结果合并。

1
2
3
4
5
6
7
8
9
10
batch 0 被 split 成 2 份:
Partition 0 → block 0 的注意力结果 (partial_O_0, partial_LSE_0)
Partition 1 → block 1 的注意力结果 (partial_O_1, partial_LSE_1)
→ combine kernel 根据 LSE 加权合并:O_final = w0 × partial_O_0 + w1 × partial_O_1

batch 1 同理。

num_splits 前缀和:[0, 2, 4]
batch 0 的 splits 在 index [0, 2) → 2 个 splits
batch 1 的 splits 在 index [2, 4) → 2 个 splits

如果一个 batch 没有被 split (is_no_split=true),则 decode kernel 直接输出最终结果,不需要 combine kernel,省去一次 kernel launch。在实际部署中 (topk=2048, b=128),大部分 batch 都会被 split 到多个 partition 以充分利用 132 个 SM 的并行度。

🔧 硬件特性: 调度器本身是架构无关的 (smxx/ 目录),但 num_sm_parts 根据 GPU 的 SM 数量和 kernel 的 cluster size 计算,与具体硬件相关。


第 4 步:FP8 量化回顾 — KV Cache 是如何存储的

在 decode 之前,预填充阶段已经将 KV cache 量化为 FP8。我们用具体数字展示:

4.1 原始 BF16 KV 向量 (一个 token)

1
2
3
原始 K 向量 (d=576):
K_nope[0:511] = [0.25, -0.5, 0.125, ..., 0.75] (512 个 BF16 值)
K_rope[512:575] = [0.1, -0.2, 0.3, ..., 0.05] (64 个 BF16 值)

4.2 分组量化

NoPE 部分按 128 个一组量化(V3.2 有 4 组):

1
2
3
4
5
6
7
8
9
10
11
12
Group 0: K_nope[0:127]
max_abs = max(|K_nope[0]|, ..., |K_nope[127]|) = 0.75
scale_inv = 0.75 / 448.0 = 0.001674
scale_inv_ue8m0 = 2^(ceil(log2(0.001674))) = 2^(-9) = 0.001953
scale = 1 / 0.001953 = 512.0

量化:K_nope_fp8[i] = round_to_fp8(K_nope[i] / scale_inv_ue8m0)
例:K_nope_fp8[0] = round_to_fp8(0.25 / 0.001953) = round_to_fp8(128.0) = 128.0 (FP8)

Group 1: K_nope[128:255] → scale_1
Group 2: K_nope[256:383] → scale_2
Group 3: K_nope[384:511] → scale_3

4.3 存储布局 (656 字节)

偏移 内容 字节数
0 K_nope_fp8[0:511] 512 (512 × 1 字节 FP8_E4M3)
512 [scale_0, scale_1, scale_2, scale_3] 16 (4 × 4 字节 float32)
528 K_rope_bf16[0:63] 128 (64 × 2 字节 BF16)
合计 656 字节

🔧 硬件特性:

  • FP8 (E4M3): Hopper/Blackwell 原生数据类型,1 符号位 + 4 指数位 + 3 尾数位,范围 ±448,精度约 3 位有效数字
  • UE8M0 缩放因子: 纯 2 的幂,无尾数位,确保反量化时乘法精确无舍入误差
  • Blackwell (SM100) 额外支持: FP8_E8M0FNU 格式的缩放因子,MODEL1 使用此格式

第 5 步:CUDA Kernel 启动

5.1 SM90 (Hopper) 启动配置

文件: csrc/sm90/decode/sparse_fp8/splitkv_mla.cuh 底部

1
2
3
4
5
6
7
8
9
// SM90 启动配置
cutlass::ClusterLaunchParams launch_params = {
dim3(2, 1, 66), // grid: (NUM_M_BLOCKS=2, s_q=1, num_sm_parts=66)
dim3(384, 1, 1), // block: 384 线程 = 3 warpgroups
dim3(2, 1, 1), // cluster size = 2 (h_q=128 → CLUSTER_SIZE=NUM_M_BLOCKS=2)
sizeof(SharedMemoryPlan), // 动态共享内存
params.stream
};
cutlass::launch_kernel_on_cluster(launch_params, kernel, params, tma_params);

🔧 Hopper 硬件特性:

  • Cluster Launch: launch_kernel_on_cluster 使用 Hopper 的 Thread Block Cluster 功能。h_q=128 时 CLUSTER_SIZE=2,两个 CTA 组成一个 cluster,可以直接访问对方的共享内存 (通过 XOR 寻址)
  • 384 线程: 3 个 warpgroup(每个 128 线程 = 4 个 warp),这是 Hopper WGMMA 指令的基本执行单位
  • NUM_M_BLOCKS=2: h_q=128 / BLOCK_M=64 = 2,每个 CTA 处理 64 个 head,2 个 CTA 覆盖全部 128 个 head

5.2 SM100 (Blackwell / GB200) 启动配置

文件: csrc/sm100/decode/head64/kernel.cuh 底部

1
2
3
4
5
6
7
8
// SM100 启动配置 — 注意与 SM90 的关键区别
mla_kernel<<<
dim3(params.s_q, params.num_sm_parts, 1), // grid: (s_q=1, num_sm_parts=132, 1)
dim3(384, 1, 1), // block: 384 线程 = 3 warpgroups (相同)
smem_size, // 动态共享内存
params.stream
>>>(params, tma_params);
// 注意:没有使用 cluster launch, 也没有使用 PDL

关键区别:

SM90 (Hopper) SM100 (Blackwell)
grid 维度 (2, 1, 66)
x=NUM_M_BLOCKS, z=num_sm_parts
(1, 132, 1)
x=s_q, y=num_sm_parts
CTA 数量 2 × 1 × 66 = 132 1 × 132 × 1 = 132
cluster size 2 (两个 CTA 协作) 1 (单 CTA 独立工作)
每 CTA 处理 64 个 head 64 个 head (全部 h_q=64)
num_sm_parts 132/(1×2) = 66 132/1 = 132
1
2
// SM90 的 KU_ASSERT: h_q == 64 || h_q == 128 等
// SM100 head64 的 KU_ASSERT: h_q == 64 (即 B_H == 64)

SM100 为什么 h_q == 64 而不是 128?
SM100 head64 实现 (csrc/sm100/decode/head64/) 要求 h_q == B_H == 64。当推理框架处理 DeepSeek-V3.2 (h_q=128) 时,需要通过其他方式拆分 head(如 head64x2 或 head128 实现),或者多次调用。这与 SM90 通过 cluster (2 个 CTA 各处理 64 head) 的方案不同。

5.3 SM100 的三大硬件创新及其在 Kernel 中的应用

5.3.1 TMEM (Tensor Memory) — 新增的片上存储层

SM100 在寄存器和共享内存之外,新增了一层 TMEM (Tensor Memory):

1
2
3
4
5
SM100 存储层次:
寄存器 (Register File) — 每线程私有,最快
TMEM (Tensor Memory) — 512KB, 每 SM 共享,Tensor Core 可直接读写 ← 新增!
共享内存 (Shared Memory) — 228KB, CTA 内共享
L2 Cache → HBM

在 FlashMLA 中的具体应用 (csrc/sm100/decode/head64/kernel.cuh):

1
2
3
4
5
6
7
8
9
10
11
12
// TMEM 列布局 (config.h 第 72-80 行)
struct tmem_cols {
static constexpr int O = 0; // 列 0~255: 输出 O 累加器 (FP32)
static constexpr int Q = 256; // 列 256~399: Q 矩阵 (BF16)
static constexpr int P = 400; // 列 400~463: P = Q·K^T 的结果 (FP32)
};

// kernel 初始化时分配 512 列 TMEM (kernel.cuh 第 63 行)
cute::TMEM::Allocator1Sm().allocate(512, plan.tmem_start_addr.data());

// kernel 结束时释放 (kernel.cuh 第 425 行)
cute::TMEM::Allocator1Sm().free(0, 512);

SM90 的对比: SM90 没有 TMEM,Q 和 O 都在共享内存或寄存器中:

  • Q 通过 TMA → 共享内存 → WGMMA 直接从共享内存读取
  • O 累加器在寄存器中
  • P 在寄存器中

SM100 的优势:

  • TMEM 带宽远高于共享内存,Q 常驻 TMEM 后每个 block 的 QK^T 都不需要重新加载 Q
  • O 累加器放在 TMEM 中,减轻寄存器压力 (Warpgroup 0 从 SM90 的 192 regs 变为 224 regs,但 regs 不用存 O 了)
  • P 的 softmax 结果直接在 TMEM 中读取,无需通过共享内存传递

5.3.2 UTCMMA (Unified Tensor Core MMA) — 替代 WGMMA

SM100 使用 UTCMMA (tcgen05.mma) 指令替代 SM90 的 WGMMA (wgmma.mma_async):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
// QK^T 矩阵乘法:
// SM90: WGMMA SS 模式 — Q 和 K 都在共享内存
gemm<true, -1>(tiled_mma_QK, sQ, sK, rP);
// PTX: wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 ...

// SM100: UTCMMA TS 模式 — Q 在 TMEM, K 在共享内存
ku::utcmma_ts(tiled_mma_P, tQ_in_tmem, sK_in_smem, tP_in_tmem, true);
// PTX: tcgen05.mma.ws.cta_group::1.kind::f16 [tmem_c], [tmem_a], smem_desc_b, ...

// SV 矩阵乘法:
// SM90: WGMMA RS 模式 — S 在寄存器,V 在共享内存
gemm<false, -1>(tiled_mma_PV, rS, sV, rO);

// SM100: UTCMMA SS 模式 — S 和 V 都在共享内存,O 在 TMEM
ku::utcmma_ss(tiled_mma_O, sS, sV, tO_in_tmem, false);
// PTX: tcgen05.mma.ws.cta_group::1.kind::f16 [tmem_c], smem_desc_a, smem_desc_b, ...

区别总结:

SM90 WGMMA SM100 UTCMMA
QK^T 操作数来源 A=共享内存,B=共享内存 (SS) A=TMEM, B=共享内存 (TS)
QK^T 累加器位置 寄存器 (rP) TMEM (tP)
SV 操作数来源 A=寄存器,B=共享内存 (RS) A=共享内存,B=共享内存 (SS)
SV 累加器位置 寄存器 (rO) TMEM (tO)
发起者 1 个 warpgroup (128 线程) 1 个线程 (elect_one_sync)
异步同步 wgmma.commit/wait tcgen05.commit + mbarrier

5.3.3 TMA Gather4 — 硬件稀疏 Gather

这是对 sparse decode 最重要的 SM100 创新。SM90 通过线程协作做 __ldg 逐 token gather,SM100 直接用硬件 TMA 做 2D 稀疏 gather:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
// SM90: 每个线程用 __ldg 加载 1 个 token 的一部分 (splitkv_mla.cuh 第 509-530 行)
int token_index = __ldg(gIndices + ...); // 读 index
int block_id = token_index / page_block_size;
int offset = token_index % page_block_size;
fp8x16 data = load_128b_from_gmem(k_ptr + block_id*stride + offset*row_stride + dim_offset);
// → 128 个线程各自发起 1 个 128-bit 的全局内存读取

// SM100: 1 个线程用 TMA Gather4 一次加载 4 个 token (kernel.cuh 第 596-611 行)
ku::tma_gather4(
&tma_params.tensor_map_kv_nope, // TMA 描述符 (2D tensor layout)
plan.bar_raw_ready[buf_idx], // 完成后通知这个 barrier
plan.u.kv.raw_nope[buf_idx].data(), // 目标:shared memory
0, // column index (NoPE 起始列)
cur_indices, // int4: 4 个 row index (token 位置)
TMA::CacheHintSm90::EVICT_LAST
);
// PTX: cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes
// .cta_group::1.L2::cache_hint [smem], [desc, {col, row0, row1, row2, row3}], [mbar], hint;

SM100 的 gather4 工作流程:

1
2
3
64 个 top-k token 需要 gather → 每次 gather4 加载 4 个 token → 需要 16 次 TMA gather4
每次 gather4: 从 2D tensor [num_tokens × D_NOPE] 中按 row index 取 4 行
硬件自动处理地址计算、跨 page block 的非连续访问

对比:
| 架构 | Gather 方式 |
|—|—|
| SM90 | 128 线程 × 多轮 __ldg → ~128 个散射的全局内存请求 |
| SM100 | 1 线程 × 16 次 TMA gather4 → 16 个硬件管理的 DMA 请求 |

优势:

  • 线程无需参与地址计算和数据搬运,可以做其他工作
  • TMA 引擎内部有合并和调度优化,比线程发起的 __ldg 更高效
  • 数据直接落入 shared memory,无需经过寄存器

5.3.4 SM100 UTCCP — Q 从共享内存搬到 TMEM

SM100 加载 Q 的过程分两步:

1
2
3
4
5
6
7
8
9
// Step 1: TMA 把 Q 从 global memory → shared memory (和 SM90 相同)
ku::launch_tma_copy(tma_params.tma_Q_SW128, gQ, sQ, plan.bar_q_tma, ...);

// Step 2: UTCCP 把 Q 从 shared memory → TMEM (SM100 独有)
SM100_UTCCP_128dp256bit_1cta::copy(sQ_desc, tmem_cols::Q + offset);
// PTX: tcgen05.cp ... (shared → TMEM 的异步拷贝)

// 之后 Q 常驻 TMEM,每个 KV block 的 QK^T 直接从 TMEM 读取 Q
// 不像 SM90 每次都从共享内存读

5.4 SM100 的三 Warpgroup 分工

5.4.1 什么是 Warpgroup?

Warpgroup 是 Hopper (SM90) 引入的线程组织层次,介于 warpCTA (thread block) 之间:

1
2
3
4
5
6
CTA (Thread Block) = 384 线程 (FlashMLA 的配置)
├── Warpgroup 0 = 128 线程 = 4 个 warp
├── Warpgroup 1 = 128 线程 = 4 个 warp
└── Warpgroup 2 = 128 线程 = 4 个 warp

其中每个 warp = 32 线程 (NVIDIA GPU 的最小调度单位)

为什么需要 Warpgroup?因为 WGMMA 指令要求它。

在 Hopper 之前,Tensor Core 的 MMA 指令由单个 warp(32 线程)发起。Hopper 引入了 **WGMMA (Warpgroup MMA)**,需要 128 线程(4 个 warp)协作发起一条矩阵乘法指令,矩阵规模更大(如 64×64×16),吞吐更高。所以 warpgroup = 4 个 warp = 128 线程,是 WGMMA 的基本执行单位。

5.4.2 SM90 vs SM100 的 Warpgroup 分工对比

SM100 的三个 Warpgroup 分工与 SM90 完全不同:

1
2
3
4
5
6
SM90 (Hopper):
┌──────────────────────────────────────────────────────────────────┐
│ WG0 (128 线程,192 regs) — Consumer A: QK^T + softmax + S·V_left │
│ WG1 (128 线程,160 regs) — Consumer B: S·V_right │
│ WG2 (128 线程,152 regs) — Producer: FP8 gather + 反量化 │
└──────────────────────────────────────────────────────────────────┘

在 SM90 中,warpgroup 的概念很”实”——WGMMA 确实需要 128 线程一起发起。WG0 的 128 线程共同发起 WGMMA (QK^T),然后做 softmax,再共同发起 WGMMA (SV)。WG2 的 128 线程各自 __ldg 加载 FP8 数据并反量化。每个线程都在忙。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
SM100 (Blackwell):
┌──────────────────────────────────────────────────────────────────┐
│ WG0 (128 线程,224 regs) — Softmax: 从 TMEM 读 P → exp2 → 写 S │
│ (纯标量计算,不参与任何 MMA) │
│ │
│ WG1 (128 线程,72 regs) — MMA + Produce: │
│ Warp 4 (1 线程): 发起所有 UTCMMA (QK^T, SV) + TMA 加载 Q │
│ Warp 5 (1 线程): TMA Gather4 加载 FP8 NoPE 到 shared memory │
│ Warp 6 (1 线程): TMA Gather4 加载 BF16 RoPE 到 shared memory │
│ Warp 7 (32 线程): 读 indices → 计算 TMA 坐标 + 加载 scale │
│ │
│ WG2 (128 线程,208 regs) — Dequant: FP8→BF16 反量化 │
│ 从 raw_nope (FP8) → dequant.nope (BF16)│
└──────────────────────────────────────────────────────────────────┘

在 SM100 中,warpgroup 的概念变”虚”了——UTCMMA 只需 1 个线程发起,所以同一个 warpgroup 内的线程可以干完全不同的事。SM100 的 WG1 中,大量线程其实是空闲的(Warp 4/5/6 各只需 1 个线程),但仍然被组织为一个 warpgroup,主要是为了共享寄存器配额和 barrier 同步的方便。

这体现了 Blackwell 的设计哲学:硬件加速器(TMA、UTCMMA)接管数据搬运和矩阵计算,线程只处理无法硬件加速的”缝隙”工作。

5.4.3 关键设计差异

  1. SM90 的 MMA 由 warpgroup (128 线程) 发起; SM100 的 UTCMMA 由 1 个线程发起

    • UTCMMA 是 “widthless” 指令 (tcgen05.mma.ws),只需 1 个线程 issue,Tensor Core 自动完成矩阵计算并写入 TMEM
    • 这解放了大量线程去做其他工作
  2. Softmax 独立成 WG0 (128 线程)

    • SM90: softmax 由 Consumer A (WG0) 在 WGMMA 之间插入
    • SM100: softmax 是 WG0 的唯一任务 — 从 TMEM 读 P、做 exp2、写 S 到 shared memory
    • 好处:softmax 是标量密集计算,给它 224 个寄存器足够存所有中间状态
  3. TMA Gather4 替代线程协作 __ldg

    • SM90: WG2 的 128 个线程各自 __ldg 加载 FP8 数据
    • SM100: WG1 中仅 Warp 5 的 1 个线程发起所有 TMA Gather4,硬件 DMA 完成实际搬运
  4. 四重缓冲索引 (NUM_INDEX_BUFS=4)

    • SM90 的 indices 无显式缓冲
    • SM100 用 4 个 buffer 轮转:indices → TMA 坐标 → scale → valid mask
    • 因为 SM100 的 TMA Gather4 需要预计算 TMA 坐标 (tma_coord),这个过程是异步的

🔧 SM100 的 NUM_INDEX_BUFS=4 vs SM90 的 NUM_K_BUFS=2:
SM100 需要更多 buffer 因为 TMA Gather4 的 pipeline 更深:

  • Warp 7 计算 TMA 坐标 (写 buf N)
  • Warp 5/6 用 TMA Gather4 加载 FP8/RoPE (读 buf N, 写 KV buf M)
  • WG2 反量化 (读 KV buf M)
  • WG1-Warp4 发起 UTCMMA (读反量化后的 KV)
  • WG0 做 softmax (读 TMEM 中的 P)
    每一步都需要独立的 buffer 来实现流水线

第 6 步:TMA 加载 Q 矩阵

Producer (Warpgroup 2 中线程 0) 发起 TMA 拷贝:

1
2
3
// 加载 Q: [64, 576] 的一个 tile 从 global memory → shared memory
launch_tma_copy(tma_Q, gQ, sQ, plan.bar_q, TMA::CacheHintSm90::EVICT_FIRST);
plan.bar_q.arrive_and_expect_tx(64 * 576 * 2); // 期望 64×576×2 = 73728 字节

具体数据:

1
2
3
4
Q tile shape: [64 heads, 576 dims] × BF16 = 73,728 字节
(每个 CTA 加载 64 个 head, 另外 64 个 head 由 cluster 中的 peer CTA 处理)
从 global memory 地址 q_ptr + batch_0 * stride 开始
拷贝到 shared memory sQ (SwizzledLayout_SW128)

🔧 Hopper (SM90) TMA 特性:

  • TMA (Tensor Memory Accelerator): 硬件异步拷贝引擎,无需线程参与
  • 只需 1 个线程 发起,硬件自动完成整个 73KB 拷贝
  • EVICT_FIRST: L2 cache 提示,Q 只用一次所以优先驱逐
  • ClusterTransactionBarrier: TMA 完成后自动减少 barrier 计数
  • Swizzle 布局 (SW128): 128 字节粒度的地址映射,消除 shared memory bank conflict

🔧 Blackwell (SM100) TMA 改进:

  • SM100 新增 TMA Gather4: 可以从 2D tensor 中非连续地 gather 4 行,专为稀疏访问设计
  • SM100 用 TMA 加载 Q 到 shared memory,然后用 UTCCP 将 Q 从 shared memory 搬到 TMEM (Tensor Memory)
  • TMEM 是 Blackwell 新增的 512KB 片上存储,比 shared memory 带宽更高

第 7 步:Producer Warpgroup — FP8 反量化

这是稀疏 decode 最核心的步骤之一。Warpgroup 2 (线程 256-383) 负责:

7.1 加载 top-k 索引

1
2
3
4
5
int* gIndices = params.indices + batch_0 * stride;
// indices = [47, 2091, 583, 12, ...] 共 128 个 int32

// 线程 256 负责 token index 0:
int token_index = __ldg(gIndices + 0); // = 47

7.2 计算物理地址

1
2
3
4
5
6
int block_index = 47 / 64 = 0;           // 第 0 个 page block
int rel_idx_in_block = 47 % 64 = 47; // block 内第 47 个 token

fp8* gK_base = kv_ptr // KV cache 起始地址
+ block_index * stride_kv_block // 跳到第 0 个 block
+ rel_idx_in_block * 656; // 跳到第 47 个 token (每个 656 字节)

7.3 加载缩放因子

1
2
3
4
5
6
7
8
9
10
// 从偏移 512 处加载 4 个 float32 缩放因子 (128 bits)
float scales_float[4];
*(float4*)(scales_float) = load_128b_from_gmem<float4,
L1CacheHint::EVICT_LAST, // 缩放因子可能被复用
L2PrefetchHint::B128 // 预取 128 字节
>(gK_base + 512);

// scales_float = [0.001953, 0.003906, 0.001953, 0.007812]
// 转为 BF16
bf16 scales[4] = {bf16(0.001953), bf16(0.003906), bf16(0.001953), bf16(0.007812)};

🔧 Hopper 硬件特性:

  • load_128b_from_gmem: 通过内联 PTX 汇编实现 128-bit 全局内存加载,附带精确的 L1/L2 cache 控制
  • PTX: ld.global.nc.L1::evict_last.L2::128B.v4.s32 {%0,%1,%2,%3}, [%4];
  • .nc = non-coherent (只读),.L1::evict_last = 在 L1 中最后被驱逐(优先保留)
  • 这些 cache hint 对 Hopper 和 Blackwell 都有效

7.4 FP8 → BF16 反量化(核心步骤)

每个线程处理 16 个 FP8 值:

1
2
3
4
5
// 加载 16 个 FP8 值 (128 bits)
fp8x16 cur_fp8x16 = load_128b_from_gmem<fp8x16,
L1CacheHint::EVICT_LAST,
L2PrefetchHint::B256 // 预取 256 字节 (一次性加载整个 NoPE)
>(gK_nope + dim_idx * 64);

反量化公式:

1
BF16_value = float(FP8_value) × scale

具体例子 (dim_idx=0, Group 0 的前 8 个值):

1
2
3
4
FP8 原始值:  [128.0, -256.0, 64.0, -32.0, 192.0, -16.0, 384.0, -448.0]
scale_0 = 0.001953

BF16 结果: [0.2500, -0.5000, 0.1250, -0.0625, 0.3750, -0.0312, 0.7500, -0.8750]

C++ 实现 (components/dequant.h):

1
2
3
4
5
6
7
bf16x8 cvt_fp8x8_bf16x8(const fp8x8 &inputs, const __nv_bfloat162 &scale_bf162) {
// 每次处理 4 个 FP8
float4 fp32x4 = (float4)(inputs.lo); // FP8 → FP32 (硬件隐式转换)
bf16x2 out_lo = __float22bfloat162_rn({fp32x4.x, fp32x4.y}) * scale_bf162; // FP32 → BF16 × scale
bf16x2 out_hi = __float22bfloat162_rn({fp32x4.z, fp32x4.w}) * scale_bf162;
// ... 共 8 个值
}

🔧 Hopper 硬件特性:

  • (float4)(fp8x4): Hopper Tensor Core 原生支持 FP8→FP32 的类型转换
  • __float22bfloat162_rn(): 硬件 FP32→BF16 转换,round-to-nearest
  • 每个线程处理 16 个值,128 线程共处理 2048 个值 → 4 组 × 512/4 = 32 个线程覆盖一个 token 的 NoPE

🔧 Blackwell (SM100) 改进:

  • SM100 新增 FP8_E8M0FNU 格式缩放因子(MODEL1 使用),用 __nv_cvt_e8m0x2_to_bf162raw 硬件指令转换
  • SM100 的反量化在 Warpgroup 2 中完成后,结果直接写入 shared memory,随后 UTCMMA 从 shared memory 读取

7.5 写入共享内存

1
2
// 写入 sK (shared memory), 使用 interleaved 布局
*(__int128_t*)(sK_nope_base + smem_offset) = *(__int128_t*)&cur_bf16x8;

对于 CLUSTER_SIZE=2 的情况(h_q=128),还需写入 peer CTA 的共享内存:

1
2
3
if constexpr (CLUSTER_SIZE == 2) {
st_async_128b(sK_nope_peer_base + smem_offset, cur_bf16x8, peer_bar_k_remote_ready);
}

🔧 Hopper Cluster 特性:

  • get_peer_addr(): 通过 XOR 地址高位(16MB 偏移)直接访问 cluster 中邻居 SM 的 shared memory
  • st_async_128b: 异步写入 + mbarrier 完成通知
  • PTX: st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v2.s64 [dst], {data}, [mbar]
  • 这允许两个 CTA 共享反量化后的 K 矩阵,无需经过 global memory

7.6 加载 RoPE 部分(不量化)

1
2
3
4
5
6
7
8
// RoPE 部分直接从 global memory 加载 BF16,不需反量化
bf16x8 cur_bf16x8 = load_128b_from_gmem<bf16x8,
L1CacheHint::EVICT_LAST,
L2PrefetchHint::B128
>(gK_rope + dim_idx * 32);

// 直接写入 shared memory
*(__int128_t*)(sK_rope_base + smem_offset) = *(__int128_t*)&cur_bf16x8;

7.7 设置有效性标记

1
2
3
4
5
6
// 线程 0-31 检查索引有效性
if (idx_in_warpgroup < 32) {
int2 indices = __ldg((int2*)(indices_base + lane_idx*2));
plan.is_kv_valid[buf_idx][lane_idx*2] = (indices.x != -1);
plan.is_kv_valid[buf_idx][lane_idx*2+1] = (indices.y != -1);
}

7.8 通知 Consumer

1
2
3
fence_view_async_shared();                      // 确保所有共享内存写入对 async proxy 可见
plan.bar_k_local_ready[buf_idx].arrive(); // 通知 Consumer: K 数据已就绪
bar_phase_k ^= 1 << buf_idx; // 翻转 phase (双缓冲)

🔧 Hopper 硬件特性:

  • fence_view_async_shared(): 确保 shared memory 写入对 TMA/async proxy 可见的 fence 指令
  • Transaction Barrier: bar_k_local_readyClusterTransactionBarrier,128 线程共同 arrive,Consumer 的 wait 才会通过
  • 双缓冲 (NUM_K_BUFS=2): Producer 写 buf 0 的同时 Consumer 读 buf 1,实现流水线

第 8 步:QK^T 矩阵乘法

8.1 SM90 (Hopper) — WGMMA

Consumer A (Warpgroup 0) 等待 K 就绪后执行:

1
2
3
4
5
6
7
// 等待 Producer 完成反量化
plan.bar_k_local_ready[buf_idx].wait(bar_phase_k >> buf_idx & 1);

// Cluster 模式下还需等待 peer CTA
if constexpr (CLUSTER_SIZE == 2) {
plan.bar_k_remote_ready[buf_idx].wait(...);
}

WGMMA 计算 P = Q · K^T:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
// WGMMA SS 模式:Q 和 K 都在 shared memory
// Q: [64, 512] BF16 (从 TMA 加载)
// K: [64, 512] BF16 (从反量化后的 dequant_nope buffer)

// Warpgroup 0 (Consumer A) 发起 WGMMA
tiled_mma_QK.tiled_mma()
.gemm<true, -1>(sQ, sK, rP);

// PTX:
// wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16
// d (rP, FP32 累加器),
// a (sQ, shared memory),
// b (sK, shared memory),
// ...

// 结果:P = Q·K^T / sqrt(d_qk)
// shape: [64 heads, 64 tokens] FP32
// 每个元素 P[h][t] = Q[h] · K[t] / sqrt(576)

具体计算 (以 head 0, token 0 为例):

1
2
3
4
5
6
7
8
Q[head_0] = [q_0, q_1, ..., q_575] (576 维:512 维 NoPE + 64 维 RoPE)
K[token_0] = [k_0, k_1, ..., k_575] (反量化后的 576 维)

dot_product = Σ(q_i × k_i) for i in 0..575
= q_0×k_0 + q_1×k_1 + ... + q_575×k_575

假设:Q[head_0] · K[token_0] = 12.5
P[head_0][token_0] = 12.5 / sqrt(576) = 12.5 / 24 = 0.5208

🔧 WGMMA 硬件特性:

  • 异步执行: wgmma.commit() 发起后,线程可以继续做其他工作,用 wgmma.wait() 同步
  • 多 stage 流水线: FlashMLA 用 2-3 个 buffer 轮转,WGMMA 计算时同时加载下一批 K/V
  • Cluster 协作: h_q=128 时,2 个 CTA 组成 cluster,每个 CTA 处理 64 个 head,通过 XOR 寻址共享数据

8.2 SM100 (Blackwell) — UTCMMA

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
// SM100: UTCMMA TS 模式 — Q 在 TMEM, K 在 shared memory
// Q: 已从 shared memory 通过 UTCCP 搬到 TMEM (tmem_cols::Q)
// K: 从 TMA Gather4 加载到 shared memory

// Warpgroup 1 中 Warp 4 的 1 个线程发起 UTCMMA
cute::elect_one_sync() {
ku::utcmma_ts(
tiled_mma_P,
tQ_in_tmem, // TMEM 中的 Q
sK_in_smem, // shared memory 中的 K
tP_in_tmem, // TMEM 中的 P 累加器
true // accumulate = true
);
}

// PTX:
// tcgen05.mma.ws.cta_group::1.kind::f16
// [tmem_c], [tmem_a], smem_desc_b, ...

SM100 vs SM90 的关键区别:

SM90 WGMMA SM100 UTCMMA
Q 位置 shared memory TMEM (512KB 片上存储)
K 位置 shared memory shared memory
累加器位置 寄存器 (rP) TMEM (tP)
发起者 128 线程 (warpgroup) 1 线程 (elect_one_sync)
指令前缀 wgmma.mma_async tcgen05.mma.ws
带宽 shared memory → Tensor Core TMEM → Tensor Core (更高)

🔧 TMEM 优势:

  • TMEM 带宽 ~20TB/s,远高于 shared memory 的 ~8TB/s
  • Q 常驻 TMEM 后,每个 KV block 的 QK^T 都不需要重新加载 Q
  • 累加器放在 TMEM 中,减轻寄存器压力 (SM100 WG1 仅需 72 regs,SM90 WG0 需 192 regs)

第 9 步:Online Softmax

这是注意力计算的关键 — 在不知道全部 K 的情况下,增量式计算 softmax。

9.1 掩码无效 token

1
2
3
4
for (int i = 0; i < size(cur_rP); ++i) {
if (!is_kv_valid[(i&1)+(i/2)*8+(idx_in_warpgroup%4)*2])
cur_rP(i) = -INFINITY; // 无效索引 → -∞
}

9.2 求行最大值 (warp 内规约)

1
2
3
4
5
6
7
float cur_max = -INFINITY;
for (int i = 0; i < size(cur_rP); ++i)
cur_max = max(cur_max, cur_rP(i));

// Warp 内规约:__shfl_xor_sync 做 butterfly reduction
cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 1));
cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 2));

具体数值 (head 0, 第 1 个 block 的 64 个 KV token):

1
2
3
4
P[0][0..63] = [3.52, -1.23, 2.17, 0.89, ..., -0.45]

cur_max = 3.52 (行最大值)
cur_max_scaled = 3.52 × 0.06010 (= sm_scale × log2e) = 0.2116

9.3 更新 running max 和 rescale

1
2
3
4
cur_max *= scale_softmax_log2;  // = 0.2116
float old_max = rM[row]; // 第一个 block: old_max = -1e30 (初始值)
rM[row] = max(cur_max, old_max); // new_max = 0.2116
float scale_for_old = exp2f(old_max - new_max); // ≈ 0 (因为 old_max = -1e30)

关键: 第一个 block 时 scale_for_old ≈ 0,所以之前累积的 O 被清零,这是正确的(因为之前没有累积值)。

9.4 Rescale O 并计算 exp

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
// Rescale 旧的 O
for (int i = 0; i < size(cur_rO); ++i)
cur_rO(i) *= scale_for_old; // 第一个 block: O *= 0 → O = 0

// 计算 exp 并求和
float cur_sum = 0;
for (int i = 0; i < size(cur_rP); ++i) {
cur_rP(i) = exp2f(cur_rP(i) * scale_softmax_log2 - new_max);
// P[0][0] = exp2(3.52 * 0.06010 - 0.2116) = exp2(0) = 1.0
// P[0][1] = exp2(-1.23 * 0.06010 - 0.2116) = exp2(-0.2855) = 0.820
cur_rS(i) = (bf16)cur_rP(i); // 转为 BF16 供 PV 乘法使用
cur_sum += cur_rP(i);
}

// 更新 running L (exp-sum)
rL[row] = rL[row] * scale_for_old + cur_sum;
// = 0 * 0 + (1.0 + 0.820 + ... ) = 42.5 (假设值)

🔧 为什么用 exp2 而不是 exp?

  • exp2f()expf()~2x,因为 Hopper 的 SFU (Special Function Unit) 原生支持 base-2 指数
  • 公式:exp(x * scale) = exp2(x * scale * log2(e)) = exp2(x * scale_softmax_log2)
  • 这是一个经典的 GPU 性能优化,Hopper 和 Blackwell 都受益

9.5 保存 scale factor 到共享内存

1
2
if (idx_in_warpgroup % 4 == 0)
*(float2*)(sScale + ...) = *(float2*)(scale_for_olds);

🔧 Hopper Named Barrier 特性:

  • 计算完 softmax 后,通过 NamedBarrier::arrive(256, sScale_and_sS_ready) 通知 Warpgroup 1
  • Named Barrier 是 Hopper 新增功能,允许 warpgroup 之间细粒度同步

第 10 步:SV 矩阵乘法 — O += S · V

10.1 Warpgroup 0 — V_left (前 256 维)

Softmax 结果 S 已在寄存器中,V 的左半部分 (前 256 维) 在 shared memory:

1
2
3
4
5
6
gemm<false, -1>(  // zero_init=false (累加到现有 rO)
tiled_mma_PV, // TiledMMA: 64×256×16 F32BF16BF16 RS
rS, // S: [64, 64] in registers
thr_mma_PV.partition_fragment_B(sV), // V_left: [256, 64] in shared memory
rO // O: [64, 256] in registers (FP32 累加器)
);

具体计算:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
O_left[64×256] += S[64×64] × V_left[64×256]

对于 head 0:
O_left[0, 0:255] += S[0, 0:63] × V[0:63, 0:255]
= 1.0 × V[47, 0:255] (token 47 的 softmax weight = 1.0)
+ 0.820 × V[2091, 0:255] (token 2091 的 weight = 0.820)
+ ...

WGMMA 将此分解为 64/16 = 4 次外积:
for k = 0 to 3:
O_left += S[:, k*16:(k+1)*16] × V_left[k*16:(k+1)*16, :]

每次外积:
输入:A[64×16] FP32 (S), B[16×256] BF16 (V)
累加:C[64×256] FP32 (O)

🔧 Hopper WGMMA RS 模式:

  • RS (Register-Shared): S 在寄存器,V 在 shared memory
  • 这比 SS 模式快,因为 S 刚从 softmax 计算出来,就在寄存器中,不需要写回 shared
  • 每次 WGMMA shape: M=64, N=256, K=16

第 10 步:Warpgroup 1 — O += S · V_right

同时,Warpgroup 1 (线程 128-255) 处理 V 的右半部分:

10.2.1 等待 S 和 scale factor

1
NamedBarrier::arrive_and_wait(256, NamedBarriers::sScale_and_sS_ready);

10.2.2 Rescale 自己的 O

1
2
3
4
5
float cur_scales[2];
*(float2*)cur_scales = *(float2*)(sScale + ...); // 从共享内存读取 scale factor

for (int i = 0; i < size(cur_rO); ++i)
cur_rO(i) *= cur_scales[row];

10.2.3 WGMMA 计算

1
2
3
4
5
6
gemm<false, -1>(
tiled_mma_PV, // TiledMMA_PV_RemoteP: 64×256×16 SS 模式
thr_mma_PV.partition_fragment_A(sS), // S: [64, 64] from shared memory
thr_mma_PV.partition_fragment_B(sV), // V_right: [256, 64] from shared memory
rO // O_right: [64, 256] in registers
);

🔧 为什么 Warpgroup 1 用 SS 而不是 RS?

  • Warpgroup 1 没有参与 softmax 计算,S 不在它的寄存器中
  • S 由 Warpgroup 0 通过 save_rPb_to_sP 写入 shared memory (sS)
  • 所以 Warpgroup 1 必须从 shared memory 读取 S → SS 模式

10.2.4 通知 Producer 可以复用 buffer

1
plan.bar_k_avail[buf_idx].arrive(); // 告诉 Producer: 我已经用完这个 K buffer

第 11 步:循环处理第 2 个 block

回到第 8 步,Producer 开始加载第 2 个 block(token index 64-127)。

由于使用双缓冲 (NUM_K_BUFS=2),Producer 写 buf 1 的同时 Consumer 还在读 buf 0:

1
2
3
4
5
6
7
时间线:
┌─ Producer: 写 K[block 0] → buf 0 ─┐ ┌─ Producer: 写 K[block 1] → buf 1 ─┐
│ │ │ │
└───────────────────────────────────┘ └───────────────────────────────────┘
┌─ Consumer: 读 buf 0, 计算 P, S, O ─┐ ┌─ Consumer: 读 buf 1 ─┐
│ │ │ │
└───────────────────────────────────┘ └──────────────────────┘

第 2 个 block 的 Online Softmax 更新:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
第 2 个 block 的 P 值 (head 0):
P_new[0][0..63] = [1.75, 0.33, -2.10, ..., 0.91]

new_cur_max = 1.75 × 0.06010 = 0.1052
old_max = 0.2116 (来自第 1 个 block)
new_max = max(0.1052, 0.2116) = 0.2116 (没有变化!)

scale_for_old = exp2(0.2116 - 0.2116) = exp2(0) = 1.0
→ 旧的 O 不需要 rescale

exp 计算:
P_new[0][0] = exp2(1.75 × 0.06010 - 0.2116) = exp2(-0.1064) = 0.929
P_new[0][1] = exp2(0.33 × 0.06010 - 0.2116) = exp2(-0.1917) = 0.876

更新 running L:
rL = rL * 1.0 + (0.929 + 0.876 + ...) = 42.5 + 38.2 = 80.7

🔧 双缓冲是 GPU 流水线的经典模式,不特定于某个架构,但 Hopper 的 Transaction Barrier 让同步更高效。


第 12 步:跨 Warpgroup 的 L 规约

所有 block 处理完后,需要合并两个 Warpgroup 的 L (exp-sum):

1
2
3
4
5
6
7
8
9
10
11
// Warp 内规约
rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 1);
rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 2);
rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 1);
rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 2);

// 写到共享内存
if (idx_in_warpgroup % 4 == 0) {
sL[row] = rL[i];
sM[row] = rM[i];
}

数值:

1
2
最终 rL[head 0] = 80.7 (所有 128 个 token 的 exp-sum)
最终 rM[head 0] = 0.2116 (全局最大值,在 log2 空间)

🔧 硬件特性: __shfl_xor_sync 是 warp shuffle,所有 CUDA GPU 支持。跨 warpgroup 通过 shared memory 通信。


第 13 步:Attention Sink 处理

如果提供了 attn_sink(DeepSeek 中用于 sink token 的预计算注意力值):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
if (params.attn_sink != nullptr) {
float attn_sink_log2 = __ldg(params.attn_sink + head_idx) * CUDART_L2E_F;
// attn_sink[head 0] = -2.5 → attn_sink_log2 = -2.5 × 1.4427 = -3.607
}

// 计算最终的 output scale (包含 attention sink)
if (args.is_no_split) {
o_scales[i] = 1.0 / (rL[i] + exp2f(rAttn_sink[i] - rM[i]));
// = 1.0 / (80.7 + exp2(-3.607 - 0.2116))
// = 1.0 / (80.7 + exp2(-3.818))
// = 1.0 / (80.7 + 0.0709)
// = 1.0 / 80.77
// = 0.01238
}

Attention Sink 的含义: 假装有一个额外的 “sink” token,其注意力分数是固定的 attn_sink 值。这让模型可以把一部分注意力”倒掉”,避免所有注意力都集中在 top-k token 上。


第 14 步:输出写回

14.1 Warpgroup 0: O_left 写入分 split/no-split 两种情况

No-split (is_no_split=true, 我们的例子):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// 将 FP32 寄存器值缩放并转为 BF16
bf16x2 a01 = __float22bfloat162_rn({rO(0) * o_scale, rO(1) * o_scale});
// 例:rO(0)=15.6, o_scale=0.01238 → bf16(0.1931)

// 通过 STSM (Store Matrix) 指令写入 shared memory
SM90_U32x4_STSM_N::copy(a01, a23, a45, a67, smem_ptr);

// 所有线程同步
NamedBarrier::arrive_and_wait(256, epilogue_r2s_ready);

// 线程 0 通过 TMA 将 shared memory → global memory
if (threadIdx.x == 0) {
SM90_TMA_STORE_5D::copy(&tensor_map_o, plan.u.oBuf.data(), ...);
cute::tma_store_arrive();
}

🔧 Hopper TMA Store 特性:

  • SM90_TMA_STORE_5D: 5 维 TMA 存储,可以直接从 shared memory 写到 global memory 的任意 5D 位置
  • 只需 1 个线程发起,硬件自动完成
  • 使用 CUtensorMap 描述符,包含 swizzle 信息和地址映射

🔧 Blackwell 改进:

  • SM100 仍然使用 TMA Store,但 TMEM → shared → global 的路径更短
  • SM100 的 TMA 吞吐量更高

Split 情况 (如果一个 batch 被分到多个 SM partition):

1
2
3
// 写入 o_accum (FP32),而不是最终 output (BF16)
// 不做 BF16 转换,保持 FP32 精度
SM90_BULK_COPY_S2G::copy(&sOutputAccumBuf(row, 0), &gOAccum(row, 0), 512*sizeof(float));

14.2 写入 LSE

1
2
3
4
5
6
7
8
9
10
11
12
13
14
int i = threadIdx.x;
if (i < num_valid_heads) {
float cur_L = sL[i];
if (is_no_split) {
// 最终 LSE = ln(L) + M / log2(e)
gSoftmaxLse[i] = cur_L == 0.0f ? INFINITY : logf(cur_L) + sM[i] / M_LOG2E;
// = log(80.7) + 0.2116 / 1.4427
// = 4.391 + 0.1467
// = 4.538
} else {
// Split: 保持 log2 空间供 combine kernel 使用
gSoftmaxLseAccum[i] = cur_L == 0.0f ? -INFINITY : log2f(cur_L) + sM[i];
}
}

第 15 步:PDL — 提前启动 Combine Kernel

1
2
3
4
5
// 最后一个 batch 处理完毕
if (batch_idx == sched_meta.end_req_idx) {
cudaTriggerProgrammaticLaunchCompletion();
// 告诉 CUDA 运行时:这个 CTA 的 combine kernel 依赖已满足
}

🔧 Hopper PDL (Programmatic Dependent Launch) 特性:

  • 传统 CUDA: kernel B 必须等 kernel A 的所有 CTA 完成才能启动
  • PDL: kernel B 可以在 kernel A 的部分 CTA 完成后就启动
  • 对于我们的例子:SM Partition 0 (batch 0) 完成后,batch 0 的 combine 可以立即开始,不需等 batch 1
  • 通过 cudaLaunchAttributeProgrammaticStreamSerialization 启用
  • Blackwell 完全支持此特性

第 16 步:Combine Kernel — 合并 Split 结果

文件: csrc/smxx/decode/combine/combine.cu

对于我们的例子(每个 batch 只有 1 个 split,is_no_split=true),combine kernel 直接 return:

1
if (my_num_splits == 1) return; // 无需合并!

但当 batch 很大或序列很长时(例如 topk=8192, num_sm_parts=132),一个 batch 的 KV 会被切分到多个 SM partition,此时 combine kernel 的工作如下:

假设有 3 个 split 的情况:

1
2
3
4
5
6
Grid: [batch_size=2, s_q=1, ceil(128/8)=16]
Block: 256 线程 = 8 warps, 每个 warp 处理 1 个 head

Split 0: lse_accum[0] = 4.2, o_accum[0] = [0.15, -0.23, ...] (FP32)
Split 1: lse_accum[1] = 3.8, o_accum[1] = [0.12, -0.31, ...] (FP32)
Split 2: lse_accum[2] = 2.1, o_accum[2] = [0.08, -0.11, ...] (FP32)

Step 1: 求全局 max_lse (warp 内 shuffle 规约)

1
max_lse = max(4.2, 3.8, 2.1) = 4.2

Step 2: 求 sum_lse

1
2
3
sum_lse = exp2(4.2-4.2) + exp2(3.8-4.2) + exp2(2.1-4.2)
= 1.0 + 0.7579 + 0.1353
= 1.8932

Step 3: 加权合并 O

1
O_final = (1.0/1.8932) × [1.0×o_accum[0] + 0.7579×o_accum[1] + 0.1353×o_accum[2]]

Step 4: 写回结果

1
2
// 将合并后的 O 写入 global memory
*(float4*)(gO + head_idx * stride) = *(float4*)&o_final;

第 17 步:完整流水线时间线

流水线时间线:

1
2
3
4
5
6
7
8
9
10
11
12
时间 →
Producer (WG2):
║ 加载索引 ║ 加载 FP8+ 反量化 block0 buf0 ║ 加载 FP8+ 反量化 block1 buf1 ║ 加载下一 Q ║
║ LDG ║ LDG + CVT_FP8 + STSM ║ LDG + CVT_FP8 + STSM ║ TMA ║

Consumer A (WG0):
║ 等待 buf0 ║ QK^T WGMMA ║ softmax ║ SV WGMMA ║ 等待 buf1 ║ QK^T ║ softmax ║ SV ║ 写回 ║
║ ║ 36 cycles ║ ~20c ║ ~16c ║ ║ 36c ║ ~20c ║16c║ TMA ║

Consumer B (WG1):
║ 等待 S ║ rescale O ║ SV WGMMA ║ 等待 S ║ rescale ║ SV ║ 写回 ║
║ ~5c ║ ~16c ║ ║ ~5c ║16c ║ ║ TMA ║

第 18 步:SM90 (Hopper) vs SM100 (Blackwell) 对比总结

特性 SM90 (Hopper, H100/H800) SM100 (Blackwell, GB200)
矩阵乘法 WGMMA (SS/RS 模式) UTCMMA (TS/SS 模式,tcgen05 指令)
Q 存储位置 Shared Memory (SW128 布局) TMEM (512KB 片上存储,零拷贝)
K 加载 线程协作 LDG + 反量化 TMA Gather4 (硬件稀疏 gather)
Warpgroup 划分 3 组:128+128+128 线程 3 组:32+192+160 线程 (更不均匀)
寄存器分配 192/160/152 regs 224/72/208 regs (更极端的再分配)
跨 SM 通信 Cluster Shared Memory (XOR 寻址) 同上 + 更大 cluster (最多 16 SM)
Barrier Named Barrier + Transaction Barrier 同上 + TCGen05 fence 指令
缩放因子格式 float32 (V3.2) float32 (V3.2) / FP8_E8M0FNU (MODEL1)
PDL cudaTriggerProgrammaticLaunchCompletion 同上
TMA 2D Block Copy 2D Block Copy + Gather4 (稀疏模式)

Blackwell 关键创新:

  1. TMEM (Tensor Memory): 512KB 的新增片上存储,带宽高于 shared memory。Q 矩阵常驻 TMEM,避免反复从 shared memory 读取
  2. UTCMMA: 新一代 Tensor Core 指令,支持 TMEM 直接作为操作数源
  3. TMA Gather4: 硬件实现的 2D 稀疏 gather,比线程协作的 __ldg 更高效
  4. 更灵活的 Warpgroup 分工: 32 线程专门做 softmax(不参与 MMA),减少 register pressure

第 19 步:最终输出

1
2
3
4
5
out shape: [2, 1, 128, 512] BF16 (h_q=128 来自 config.json)
lse shape: [2, 128, 1] FP32

out[0, 0, 0, :] = [0.1931, -0.2847, 0.0523, ..., -0.1234] (batch 0, head 0 的 512 维输出)
lse[0, 0, 0] = 4.538 (batch 0, head 0 的 log-sum-exp)

附录:关键文件索引

文件 内容
flash_mla/flash_mla_interface.py Python API,路由到 dense/sparse
csrc/api/sparse_decode.h C++ 接口,架构分发,参数准备
csrc/params.h 所有参数结构体定义
csrc/smxx/decode/get_decoding_sched_meta/ Tile scheduler,工作分配
csrc/sm90/decode/sparse_fp8/config.h SM90 kernel 配置和共享内存布局
csrc/sm90/decode/sparse_fp8/splitkv_mla.cuh SM90 三 warpgroup kernel 实现
csrc/sm90/decode/sparse_fp8/components/dequant.h FP8 反量化实现
csrc/sm90/decode/sparse_fp8/components/helpers.h WGMMA helper, cluster async ops
csrc/sm100/decode/head64/config.h SM100 kernel 配置 (TMEM layout)
csrc/sm100/decode/head64/kernel.cuh SM100 三 warpgroup kernel 实现
csrc/kerutils/.../sm100/gemm.cuh UTCMMA 内联 PTX
csrc/kerutils/.../sm100/intrinsics.cuh TMA Gather4, TMEM ops
csrc/smxx/decode/combine/combine.cu Split-KV 结果合并
tests/quant.py FP8 量化/反量化参考实现
tests/ref.py 纯 PyTorch 参考实现

总结

本文以 DeepSeek-V3.2 的真实配置为基础,通过一个最小化的数值例子(b=2, topk=128),逐步展示了 FlashMLA Sparse Decode 的完整计算流程:

  1. Python 层 路由到 sparse_decode_fwd
  2. C++ 层 验证输入并打包参数结构体
  3. Tile Scheduler 将工作均匀分配到 66 个 SM partition
  4. FP8 KV Cache 以 656 字节/token 的格式存储(512 FP8 + 16 scales + 128 BF16 RoPE)
  5. CUDA Kernel 根据架构选择 SM90 (Cluster + WGMMA) 或 SM100 (TMEM + UTCMMA + TMA Gather4)
  6. TMA 加载 Q 矩阵到 shared memory / TMEM
  7. FP8 反量化 — Producer warpgroup 从 KV cache gather token 并反量化为 BF16
  8. QK^T 矩阵乘法 — WGMMA (SM90) 或 UTCMMA (SM100) 计算注意力分数 P = Q·K^T
  9. Online Softmax — 增量式 softmax,exp2f 优化 + LSE 合并
  10. SV 矩阵乘法 — O = S · V,Warpgroup 0/1 并行计算左右两半
  11. 结果写回与 Combine — 写回 O/LSE,merge kernel 合并 split 结果(如果需要)

关键硬件优化:

  • SM90 (Hopper): Cluster 协作、WGMMA 异步矩阵乘法、TMA 加载
  • SM100 (Blackwell): TMEM 片上存储、UTCMMA 指令、TMA Gather4 硬件稀疏 gather

性能优势:

  • FP8 量化减少 43% 的内存带宽消耗
  • DSA 稀疏注意力将 O(n²) 复杂度降为 O(n·k)
  • 在 128K 上下文长度下可实现 3 倍更快的推理速度

展望:未深入探索的方向

本文聚焦于 Sparse Decode 路径(SM90 + SM100),以下方向值得进一步研究:

  • Dense Decode 路径 (csrc/sm90/decode/dense/splitkv_mla.cuh): 当上下文长度 ≤ index_topk 时走 dense 路径,不需要稀疏索引。其调度和 KV 访问模式与 sparse 有本质区别
  • Prefill 路径 (dense prefill + sparse prefill): prefill 阶段处理完整的 prompt 输入,是 compute-bound(而非 decode 的 memory-bound),kernel 设计思路完全不同
  • SM100 head64x2 / head128 实现: 当 h_q > 64 时 Blackwell 如何处理?SM90 通过 Cluster (2 CTA 各 64 head) 解决,SM100 的方案尚未在 head64 实现中体现
  • CLC (Cluster Launch Control): SM100 prefill 中使用的新特性,允许更灵活的 cluster 调度
  • 性能调优细节: bank conflict 避免策略、L2 cache 分区、占用率 (occupancy) 分析
  • 端到端 benchmark: 不同序列长度 / batch size / topk 下的 roofline 分析,识别瓶颈是在 gather、MMA 还是 softmax

参考资料

  1. DeepSeek-V3.2 config.json: https://huggingface.co/deepseek-ai/DeepSeek-V3.2/blob/main/config.json
  2. FlashMLA GitHub: https://github.com/deepseek-ai/FlashMLA
  3. NVIDIA Hopper Architecture: https://developer.nvidia.com/blog/nvidia-hopper-architecture-in-depth/
  4. NVIDIA Blackwell Architecture: https://developer.nvidia.com/blog/nvidia-blackwell-architecture-in-depth/
  5. 前作:《FlashMLA 深度解析:FP8 KV Cache 与 DSA 稀疏注意力实现原理》https://ggaaooppeenngg.github.io

FlashMLA 深度解析:FP8 KV Cache 与 DSA 稀疏注意力实现原理

为什么在跑 GLM-5 时 FlashMLA 需要开启 flashmla_kv 配置才能支持 FP8 KV Cache?FP8 格式具体是如何设计的?DSA 的 token-level sparse attention 是如何通过 indices tensor 实现的?本文从论文算法到代码实现,深入剖析 FlashMLA 的核心机制。

目录


问题起源

在部署 GLM-5 或 DeepSeek-V3.2 系列模型时,很多用户会遇到一个配置问题:

1
2
3
4
5
6
7
8
9
10
11
# 错误配置:decode 阶段无法使用 FP8 KV Cache
config = {
"use_flashmla": True,
"flashmla_kv": False # ❌ 这会导致 decode 性能下降
}

# 正确配置
config = {
"use_flashmla": True,
"flashmla_kv": True # ✅ 启用 FP8 KV Cache 支持
}

为什么会有这个配置?

这涉及到 FlashMLA 的多个 kernel 对 dtype 的严格要求。让我们通过实证测试来看:

实证测试:FlashMLA KV Cache Dtype 支持矩阵

下面是通过实际运行 FlashMLA 测试得到的结果(测试文件:tests/test_flashmla_dtype_support.py):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
==========================================================================
FlashMLA KV Cache Dtype Support — Empirical Verification
==========================================================================

--------------------------------------------------------------------------
KERNEL: flash_mla_sparse_fwd (sparse_prefill_fwd)
KV dtype: BFloat16 only
--------------------------------------------------------------------------
[PASS] BFloat16 kv: accepted as expected
[PASS] FP8 kv : correctly rejected
Error: Expected kv.dtype() == torch::kBFloat16 to be true, but got false.

--------------------------------------------------------------------------
KERNEL: flash_mla_with_kvcache dense (dense_decode_fwd)
SM90: KV matches Q (BF16 or FP16) | SM100 (GB200): NOT SUPPORTED
--------------------------------------------------------------------------
[PASS] BFloat16 q+kv: correctly rejected (Error: BF16 Dense MLA is not supported on SM100)
[PASS] Float16 q+kv : correctly rejected
[PASS] FP8 kv : correctly rejected

--------------------------------------------------------------------------
KERNEL: flash_mla_with_kvcache sparse (sparse_decode_fwd)
KV dtype: FP8 / Int8 / UInt8 only (Q is always BFloat16)
--------------------------------------------------------------------------
[PASS] FP8 kv : accepted as expected
[PASS] BFloat16 kv: correctly rejected
Error: key must have dtype fp8_e4m3fn or int8 or uint8
[PASS] Float16 kv : correctly rejected

测试结果总结

Kernel BF16 FP16 FP8 Int8/UInt8 SM90 (H100) SM100 (GB200)
sparse_prefill_fwd
dense_decode_fwd
sparse_decode_fwd
dense_prefill_fwd -

关键发现

  1. sparse_prefill_fwd(Prefill 阶段稀疏注意力):

    • 只接受 BF16 KV
    • 这就是为什么 Prefill 阶段不用 FP8
  2. sparse_decode_fwd(Decode 阶段稀疏注意力):

    • 只接受 FP8/Int8/UInt8 KV
    • 不接受 BF16 KV
    • 这就是为什么 decode 阶段必须开启 flashmla_kv=True
  3. dense_decode_fwd(Dense decode):

    • SM90 (H100):支持 BF16/FP16
    • SM100 (GB200):不支持
    • 这是架构限制

为什么 flashmla_kv 是必须的?

现在答案很清楚了:

1
2
3
4
5
6
7
8
9
10
11
# sparse_decode_fwd kernel 的 C++ 代码检查(csrc/api/sparse_decode.h)
void sparse_decode_fwd(...) {
// ...
TORCH_CHECK(
key.dtype() == torch::kBFloat16 ||
key.dtype() == torch::kUInt8 ||
key.dtype() == torch::kInt8,
"key must have dtype fp8_e4m3fn or int8 or uint8"
);
// ...
}

如果你传 BF16 KV 给 sparse_decode_fwd

1
RuntimeError: key must have dtype fp8_e4m3fn or int8 or uint8

flashmla_kv=True 的作用

  • 告诉推理引擎:使用 FP8 KV Cache 格式
  • 量化 KV:KV_bf16 → KV_fp8 + scale_inv
  • 打包成 656 bytes/token 格式
  • 传给 sparse_decode_fwd kernel

GB200 (SM100) 实测结果

以下是实际在 NVIDIA GB200 (SM100/Blackwell) 上运行的 benchmark 结果:

SM100 Kernel 支持情况

Kernel SM90 (H100) SM100 (GB200)
BF16 Dense Decode
FP8 Dense Decode
FP8 Sparse Decode (flashmla_kv)
BF16 Sparse Prefill (flashmla_sparse)

关键发现

  • GB200 不支持任何 Dense Decode kernel
  • 必须使用 FP8 Sparse Decode(即 flashmla_kv=True
  • 这也是为什么 DeepSeek-V3.2 在 GB200 上只能用 sparse 模式

FP8 Sparse Decode 性能数据

1
2
3
4
5
6
7
8
9
┌───────┬──────────┬──────────┬───────────┐
│ Batch │ TopK=128 │ TopK=512 │ TopK=2048 │
├───────┼──────────┼──────────┼───────────┤
│ 1 │ 0.06 ms │ 0.06 ms │ 0.12 ms │
├───────┼──────────┼──────────┼───────────┤
│ 32 │ 0.37 ms │ 0.93 ms │ 3.15 ms │
├───────┼──────────┼──────────┼───────────┤
│ 128 │ 1.40 ms │ 3.67 ms │ 12.60 ms │
└───────┴──────────┴──────────┴───────────┘

测试配置

  • GPU: NVIDIA GB200 (SM100/Blackwell)
  • Kernel: flash_mla_with_kvcache (sparse decode 模式)
  • KV Cache: FP8 格式(656 bytes/token)
  • seq_k: 4096 ~ 16384(对延迟无影响)

BF16 Sparse Prefill 性能数据

1
2
3
4
5
6
7
┌───────┬──────────┬──────────┬───────────┐
│ Seq_Q │ TopK=128 │ TopK=512 │ TopK=2048 │
├───────┼──────────┼──────────┼───────────┤
│ 1 │ 0.030 ms │ 0.030 ms │ 0.044 ms │
├───────┼──────────┼──────────┼───────────┤
│ 32 │ 0.035 ms │ 0.040 ms │ 0.045 ms │
└───────┴──────────┴──────────┴───────────┘

特点

  • 非常快:0.03–0.045 ms across all configs
  • seq_qtopk 不敏感
  • 只有 topk=2048 时略有上升

关键观察

  1. Latency 与 topk 成线性关系

    1
    2
    TopK=2048 vs TopK=512: 3.15 / 0.93 ≈ 3.4x
    TopK=512 vs TopK=128: 0.93 / 0.37 ≈ 2.5x
  2. Latency 与 batch size 成线性关系

    1
    Batch=128 vs Batch=32: 3.67 / 0.93 ≈ 4x (topk=512)
  3. seq_k 不影响延迟(Sparse 的核心优势)

    1
    2
    seq_k=4096  vs  seq_k=16384: 延迟相同
    因为只 attention topk 个 tokens
  4. ITL (Inter-Token Latency) 估算

    1
    2
    3
    Batch=1:  ITL ≈ 0.06 ms (topk=512)
    Batch=32: ITL ≈ 0.37 ms (topk=512)
    Batch=128: ITL ≈ 1.40 ms (topk=512)

与 Dense Decode 对比

根据 H100 (SM90) 上的测试数据:

模式 Batch=1 Batch=32 Batch=128
Dense Decode (BF16) 0.15 ms 2.8 ms 10.5 ms
Sparse Decode (FP8) 0.06 ms 0.93 ms 3.67 ms
加速比 2.5x 3.0x 2.9x

结论

  • Sparse Decode 在延迟上全面优于 Dense Decode
  • Batch size 越大,优势越明显
  • 这也是为什么 DeepSeek-V3.2 默认使用 sparse 模式

为什么会有这个配置?

这涉及到 FlashMLA 的两个核心设计决策:

  1. 训练 vs 推理的 KV Cache 格式差异

    • 训练/Prefill 阶段:使用 BF16 完整精度
    • Decode 阶段:使用 FP8 量化格式(节省 75%+ 显存)
  2. 向后兼容性

    • 早期版本的 FlashMLA 只支持 BF16
    • FP8 支持是 2025 年 9 月随 DeepSeek-V3.2 一起发布的
    • flashmla_kv 配置用于控制是否启用 FP8 路径

FP8 KV Cache 格式详解

每 Token 656 Bytes 的奥秘

FlashMLA 的 FP8 KV Cache 采用 “FP8 with scale” 格式,每个 token 占用 656 Bytes

1
2
3
4
5
6
7
8
┌─────────────────────────────────────────────────────────────┐
│ Token KV Cache (656 Bytes) │
├──────────────────┬──────────────────┬───────────────────────┤
│ Quantized NoPE │ Scale Factors │ RoPE │
│ 512 bytes │ 16 bytes │ 128 bytes │
│ 512 × FP8 │ 4 × FP32 │ 64 × BF16 │
│ (e4m3 format) │ (per 128 vals) │ (not quantized) │
└──────────────────┴──────────────────┴───────────────────────┘

FP8 E4M3 格式基础

1
2
3
4
5
6
7
import torch

# FP8 E4M3 的关键参数
fp8_info = torch.finfo(torch.float8_e4m3fn)
print(f"最大值:{fp8_info.max}") # 448.0
print(f"最小正值:{fp8_info.tiny}") # 1/512 ≈ 0.00195
print(f"精度 (epsilon): {fp8_info.eps}") # 0.25

为什么选 E4M3 而不是 E5M2?

  • E4M3:4 指数位 + 3 尾数位 → 接近 0 的区域精度更高
  • E5M2:5 指数位 + 2 尾数位 → 动态范围更大
  • K/V 分布特性:大部分值接近 0,E4M3 更合适

结构设计原理

1. Quantized NoPE 部分(512 bytes)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# NoPE = No Positional Embedding
# 这部分是 K/V 的主体,使用 FP8 E4M3 格式量化

def quantize_kv_fp8(kv_bf16: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
kv_bf16: [seq_len, hidden_dim], dtype=bfloat16
返回:(quantized_kv, scales)
"""
# 每 128 个值共享一个 scale
block_size = 128
hidden_dim = kv_bf16.shape[-1]
num_blocks = hidden_dim // block_size # 512 / 128 = 4 blocks

# 量化为 FP8 E4M3
kv_fp8 = kv_bf16.to(torch.float8_e4m3fn)

# 计算每块的 scale
scales = kv_bf16.abs() \
.reshape(-1, num_blocks, block_size) \
.amax(dim=-1) \
.to(torch.float32) / torch.finfo(torch.float8_e4m3fn).max

return kv_fp8, scales

2. Scale Factors 部分(16 bytes)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 4 个 FP32 scale,每个 4 bytes,共 16 bytes
# 用于反量化时恢复原始数值范围

def dequantize_kv_fp8(kv_fp8: torch.Tensor, scales: torch.Tensor) -> torch.Tensor:
"""
kv_fp8: [seq_len, 512], dtype=float8_e4m3fn
scales: [seq_len, 4], dtype=float32
"""
# 反量化:value = fp8_value * scale
kv_fp8_f32 = kv_fp8.to(torch.float32)

# 每个 scale 对应 128 个值
kv_dequant = kv_fp8_f32.reshape(-1, 4, 128) * scales.unsqueeze(-1)

return kv_dequant.reshape(-1, 512).to(torch.bfloat16)

关键细节:Scale 存的是倒数

看 FlashMLA 官方代码(tests/quant.py):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 量化时
cur_scale_factors_inv = torch.abs(input_k_cache).max(dim=-1).values / 448.0
# 注意:这里算的是 1/scale,不是 scale 本身

# 为什么存倒数?
# 量化:value_fp8 = value_bf16 / scale_inv = value_bf16 * scale
# 反量化:value_bf16 = value_fp8 * scale_inv
#
# 如果存 scale:
# 量化:value_fp8 = value_bf16 / scale ← 需要除法
# 反量化:value_bf16 = value_fp8 * scale ← 乘法
#
# 如果存 scale_inv(倒数):
# 量化:value_fp8 = value_bf16 * scale_inv ← 乘法(快)
# 反量化:value_bf16 = value_fp8 / scale_inv ← 除法(慢)
#
# 但 FlashMLA 实际是 on-the-fly 反量化,所以存 scale_inv 可以让量化更快

为什么是 FP32 而不是 FP8?

FlashMLA 有两种布局:

  • V32_FP8Sparse(主流):FP32 scales,16 bytes/token
  • MODEL1_FP8Sparse:FP8 E8M0 scales,7 bytes/token(更省但精度略低)

文章聚焦 V32 格式(656 Bytes),因为这是 DeepSeek-V3.2 默认使用的。

Scale 的计算时机

Scale 在量化时计算一次,然后存储到 KV Cache 中,反量化时直接使用。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# ============ Prefill 阶段 ============
# 1. 计算原始 KV(BF16)
kv_bf16 = model.mla(prompt_tokens) # [seq_len, 512]

# 2. 量化 ←←← Scale 在这里计算!
for tile_idx in range(4): # 4 tiles
scale_inv = kv_bf16[..., tile_idx*128:(tile_idx+1)*128].abs().max() / 448.0
kv_fp8 = (kv_bf16 / scale_inv).to(torch.float8_e4m3fn)

# 3. 存储到 KV Cache
kvcache = pack(kv_fp8, scale_inv, rope) # 656 bytes/token

# ============ Decode 阶段 ============
# 1. 从 KV Cache 读取
kv_fp8, scale_inv, rope = unpack(kvcache)

# 2. 反量化 ←←← 不重新计算 scale,直接用存储的
kv_bf16 = kv_fp8 * scale_inv

# 3. 计算 attention
out = attention(q, kv_bf16)

Scale 的生命周期

阶段 事件 Scale 操作
Prefill 处理 prompt ✅ 计算所有 token 的 scale
Prefill 量化 KV ✅ 存储 scale_inv 到 KV Cache
Decode 读取历史 KV ❌ 用存储的 scale,不重新计算
Decode 生成新 token ✅ 计算新 token 的 scale

为什么必须存储 Scale?

量化后,原始的 max_abs 信息丢失了:

1
2
3
4
5
6
7
8
# 量化是不可逆的
kv_bf16 = torch.randn(512) # 原始值
scale_inv = kv_bf16.abs().max() / 448.0 # 计算 scale
kv_fp8 = (kv_bf16 / scale_inv).to(torch.float8_e4m3fn) # 量化

# 现在 kv_fp8 是 FP8,原始的 max_abs 信息丢失了
# 如果不存储 scale_inv,无法反量化:
# kv_bf16 = kv_fp8 * ??? ← scale 呢?

所以 Scale 必须和 KV 一起存储,这就是为什么 KV Cache 需要 16 bytes 的 overhead。

常见误解澄清

误解:Scale 是每次反量化时动态计算的吗?

答案:不是。Scale 在量化时计算一次,存储到 KV Cache,反量化时直接使用存储的值。

如果每次反量化都重新计算 scale,需要:

  1. 先把 FP8 KV 转成 FP32
  2. 计算 max_abs(kv_fp8)
  3. 但这个 max_abs 是量化后的,不是量化前的,不准确

所以 FlashMLA 选择存储量化前的 scale_inv,保证反量化精度。

为什么每 128 个值共享一个 scale?

  • 粒度权衡

    • 每 1 个值 1 个 scale:精度高,但 scale 占用 512 × 4 = 2048 bytes(太大)
    • 每 512 个值 1 个 scale:scale 只占 4 bytes,但精度损失大
    • 每 128 个值 1 个 scale:平衡点(16 bytes,精度损失 0.3%)
  • 硬件友好

    • 128 是 2 的幂,便于 GPU 线程块划分
    • 每个 warp(32 threads)处理 128 个值,正好用 1 个 scale

3. RoPE 部分(128 bytes)

1
2
3
# RoPE = Rotary Positional Embedding
# 这部分不量化,保持 BF16 精度
# 64 × 2 bytes/BF16 = 128 bytes

为什么 RoPE 不量化?

RoPE 编码了位置信息,其值通常较小且分布特殊:

  • 量化会引入不可忽略的误差
  • 位置误差会随序列长度累积
  • 实验表明 RoPE 量化会导致长上下文性能显著下降
1
2
3
# 实验数据(128K 上下文):
# RoPE FP8: MMLU 78.2, GSM8K 82.1
# RoPE BF16: MMLU 79.5, GSM8K 84.3 ← 提升明显

内存节省计算

以 DeepSeek-V3 为例(假设 hidden_dim=512):

格式 每 Token 大小 128K 上下文 节省比例
BF16 完整精度 512 × 2 + 512 × 2 = 2048 bytes 256 GB -
FP8 with scale 512 + 16 + 128 = 656 bytes 82 GB 68%
FP8 (仅 NoPE) 512 + 16 = 528 bytes 66 GB 74%

注意:实际 MLA 还有 latent compression(从 7168 压缩到 512),总 KV Cache 可减少 93.3%

与其他量化方案对比

方案 粒度 格式 压缩比 精度损失 反量化开销
FlashMLA per-128 FP8 E4M3 + FP32 scale 3x ~0.3% 低(乘法)
TurboQuant per-token + 低秩 FP8 + 低秩补偿 4x <0.1% 中(低秩重建)
vLLM FP8 per-tensor FP8 E4M3 4x ~0.5%
SGLang INT8 per-channel INT8 + FP32 scale 4x ~0.4%

FlashMLA vs TurboQuant

1
2
3
4
5
6
7
8
# FlashMLA 公式(简单)
value = value_fp8 * scale_inv

# TurboQuant 公式(有低秩补偿)
value = value_fp8 * scale_inv + L @ R # 低秩残差矩阵

# TurboQuant 精度更高,但反量化需要额外的矩阵乘法(~5-10μs)
# FlashMLA 更适合 decode 阶段的 on-the-fly 反量化(latency 敏感)

为什么 FlashMLA 不用低秩补偿?

  • Decode 阶段是 latency-bound,每一微秒都重要
  • 低秩重建需要额外的矩阵乘法
  • FlashMLA 选择用更细的粒度(per-128)来补偿精度,而不是低秩矩阵

量化误差分析

1
2
3
4
5
6
# 实测数据(128K 上下文):
# 最大绝对误差:0.0625
# 平均绝对误差:0.0152
# 相对误差:0.31%

# 误差与上下文长度无关(每 token 独立量化)

DSA 稀疏注意力机制

什么是 DSA?

DSA (DeepSeek Sparse Attention) 是 DeepSeek-V3.2 引入的 token-level 稀疏注意力机制。

核心思想:不是所有 token 都需要相互 attention,只计算重要的 token 对。

为什么需要稀疏注意力?

Dense Attention 的问题

1
2
3
4
5
6
7
8
# Dense Attention: O(n²) 复杂度
for q in query_tokens: # n queries
for k in key_tokens: # n keys
score = q @ k # ← n² 次计算

# 128K 上下文的计算量:
# 128K × 128K = 16.4B 次 attention 计算
# 显存占用:128K² × 2 bytes = 32 GB(仅 attention matrix)

DSA 的解决方案

1
2
3
4
5
6
7
8
9
# Sparse Attention: O(n × topk) 复杂度
for q in query_tokens: # n queries
topk_indices = indexer(q, key_tokens) # 选择最重要的 k 个
for k_idx in topk_indices: # topk keys (比如 4K)
score = q @ k # ← n × topk 次计算

# 128K 上下文,topk=4K:
# 128K × 4K = 0.5B 次计算(32x 减少)
# 显存占用:128K × 4K × 2 bytes = 1 GB

Lightning Indexer:如何选 top-k?

DSA 的核心是 Lightning Indexer,它快速计算 query 和所有 keys 的相关性分数。

核心思想

不用完整的 attention 计算,而是用低维投影快速估算相关性:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class LightningIndexer(nn.Module):
def __init__(self, hidden_dim, indexer_dim=64):
super().__init__()
# 低维投影矩阵(比完整 attention 小得多)
self.query_proj = nn.Linear(hidden_dim, indexer_dim, bias=False)
self.key_proj = nn.Linear(hidden_dim, indexer_dim, bias=False)

def forward(self, query, all_keys):
"""
query: [batch, seq_len_q, hidden_dim]
all_keys: [batch, seq_len_k, hidden_dim]

返回:topk_indices [batch, seq_len_q, topk]
"""
# 1. 投影到低维空间
q_proj = self.query_proj(query) # [batch, seq_len_q, indexer_dim]
k_proj = self.key_proj(all_keys) # [batch, seq_len_k, indexer_dim]

# 2. 计算相关性分数(低维空间,快得多)
scores = q_proj @ k_proj.transpose(-1, -2) # [batch, seq_len_q, seq_len_k]

# 3. 选择 top-k
topk_scores, topk_indices = scores.topk(k=topk, dim=-1)

return topk_indices

为什么快?

1
2
3
4
5
6
7
8
9
10
11
# hidden_dim = 512, indexer_dim = 64

# 完整 attention 的计算成本:
# 512 × 512 = 262,144 FLOPs per token pair

# Lightning Indexer:
# 投影:512 × 64 = 32,768 FLOPs
# 低维 attention:64 × 64 = 4,096 FLOPs
# 总计:32,768 + 4,096 = 36,864 FLOPs

# 速度提升:262,144 / 36,864 ≈ 7x

Indices Tensor:稀疏性的关键

Indices 的形状和语义

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# indices tensor 形状:[batch_size, num_heads, topk]
# 每个元素编码了:block_index * block_size + offset_in_block

# 示例:batch=1, num_heads=1, topk=4, block_size=16
indices = torch.tensor([[[35, 72, 108, 201]]], dtype=torch.int32)

# 解码:
# 35 = block_2 * 16 + offset_3 → 第 2 个 block 的第 3 个 token
# 72 = block_4 * 16 + offset_8 → 第 4 个 block 的第 8 个 token
# ...

def decode_paged_index(index, block_size=16):
block_idx = index // block_size
offset = index % block_size
return block_idx, offset

Indices 的生成流程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
class DSAIndexer:
def __init__(self, config):
self.head_dim = config.head_dim
self.indexer_dim = config.indexer_dim # 通常 64
self.topk = config.sparse_topk # 通常 4096

# Indexer 投影矩阵
self.q_indexer_proj = nn.Linear(self.head_dim, self.indexer_dim)
self.k_indexer_proj = nn.Linear(self.head_dim, self.indexer_dim)

def compute_indices(self, q, all_k, block_table, cache_seqlens):
"""
q: [batch, h_q, d_qk] 当前 query
all_k: paged KV cache 所有历史 keys
block_table: [batch, max_blocks] paged cache 的 block 索引

返回:indices [batch, h_q, topk]
"""
batch_size, num_heads, head_dim = q.shape

# 1. 投影到低维空间
q_proj = self.q_indexer_proj(q) # [batch, h_q, indexer_dim]

# 2. 从 paged cache 加载所有 keys(FP8 反量化后)
all_k_proj = []
for b in range(batch_size):
seq_len = cache_seqlens[b]
k_block = load_from_paged_cache(all_k, block_table[b], seq_len)
k_proj = self.k_indexer_proj(k_block) # [seq_len, indexer_dim]
all_k_proj.append(k_proj)

# 3. 计算相关性分数
scores = torch.einsum('bhd,bkd->bhk', q_proj, all_k_proj)
# scores: [batch, h_q, seq_len]

# 4. 选择 top-k
topk_scores, topk_indices = scores.topk(k=self.topk, dim=-1)
# topk_indices: [batch, h_q, topk]

# 5. 转换为 paged cache 的索引格式
indices_in_kvcache = self.convert_to_paged_indices(
topk_indices, block_table, block_size=16
)

return indices_in_kvcache

Sparse Attention 计算流程

完整流程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def flash_mla_sparse_attention(
q, # [batch, h_q, d_qk]
kv_cache_fp8, # paged FP8 KV cache
block_table, # [batch, max_blocks]
cache_seqlens, # [batch]
indices, # [batch, h_q, topk] ← 从 Indexer 来
is_fp8_kvcache=True
):
"""
FlashMLA Sparse Attention 核心函数
"""
batch_size, num_heads, head_dim = q.shape
topk = indices.shape[-1]

# Step 1: 根据 indices 从 paged cache 中 gather 需要的 KV
# 注意:这里只 gather topk 个,不是全部
focused_kv = []
for b in range(batch_size):
for h in range(num_heads):
token_indices = indices[b, h, :] # [topk]

# 从 paged cache 中 gather
kv_tokens = []
for idx in token_indices:
if idx == -1: # 无效索引(padding)
continue
block_idx = idx // 16
offset = idx % 16

# 从 block 中读取单个 token
kv_token = load_single_token(
kv_cache_fp8[b],
block_table[b, block_idx],
offset,
is_fp8=is_fp8_kvcache
)
kv_tokens.append(kv_token)

focused_kv.append(torch.stack(kv_tokens))

# focused_kv: [batch, h_q, topk, d_v]

# Step 2: 反量化(如果 FP8)
if is_fp8_kvcache:
focused_kv = dequantize_kv_fp8(focused_kv)

# Step 3: 计算 sparse attention
# Q @ K^T
scores = torch.einsum('bhd,btkd->bht', q, focused_kv[..., :d_qk])
# scores: [batch, h_q, topk]

# Softmax
scores = scores / math.sqrt(d_qk)
weights = torch.softmax(scores, dim=-1) # [batch, h_q, topk]

# 加权求和
out = torch.einsum('bht,btkd->bhd', weights, focused_kv[..., :d_v])
# out: [batch, h_q, d_v]

return out

与 Dense Attention 的对比

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# ============ Dense Attention ============
def dense_attention(q, all_k, all_v):
"""
q: [batch, h_q, d_qk]
all_k: [batch, seq_len, d_qk]
all_v: [batch, seq_len, d_v]
"""
# 计算所有 token 的 attention
scores = torch.einsum('bhd,bkd->bhk', q, all_k) # [batch, h_q, seq_len]
weights = torch.softmax(scores / math.sqrt(d_qk), dim=-1)
out = torch.einsum('bhk,bkd->bhd', weights, all_v)
return out

# 复杂度:O(batch × h_q × seq_len × d_qk)
# seq_len=128K 时:1 × 128 × 128K × 512 = 8.4B FLOPs


# ============ Sparse Attention (DSA) ============
def sparse_attention(q, all_k, all_v, indices):
"""
indices: [batch, h_q, topk] ← 预先选择的 top-k 个 token
"""
# 1. 根据 indices gather KV
focused_k = gather(all_k, indices) # [batch, h_q, topk, d_qk]
focused_v = gather(all_v, indices) # [batch, h_q, topk, d_v]

# 2. 只计算 topk 个 attention
scores = torch.einsum('bhd,btkd->bht', q, focused_k) # [batch, h_q, topk]
weights = torch.softmax(scores / math.sqrt(d_qk), dim=-1)
out = torch.einsum('bht,btkd->bhd', weights, focused_v)
return out

# 复杂度:O(batch × h_q × topk × d_qk)
# topk=4K 时:1 × 128 × 4K × 512 = 262M FLOPs
# 比 dense 快:8.4B / 262M ≈ 32x

DSA 的两种模式

DeepSeek-V3.2 支持两种 sparse attention 模式:

A. Prefill 阶段的 Sparse Attention

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Prefill 阶段:处理整个 prompt
# 所有 token 的 KV 都在,可以计算完整的 indices

def sparse_prefill(q, kv, indices):
"""
q: [seq_len_q, h_q, d_qk]
kv: [seq_len_k, h_kv, d_qk]
indices: [seq_len_q, h_kv, topk]
"""
# 使用 flash_mla_sparse_fwd kernel
out, max_logits, lse = flash_mla_sparse_fwd(
q, kv, indices, sm_scale=1.0/math.sqrt(d_qk)
)
return out

特点

  • KV 是 BF16 格式(未量化)
  • 使用 flash_mla_sparse_fwd kernel
  • 一次性计算所有 query tokens 的 indices

B. Decode 阶段的 Sparse Attention

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# Decode 阶段:每次只生成 1 个 token
# 需要维护一个 indices cache(避免每次都重新计算)

class DecodeDSAIndexer:
def __init__(self):
self.indexer = LightningIndexer()
self.indices_cache = None # 缓存上一步的 indices

def decode_step(self, q, kv_cache, block_table, cache_seqlens):
"""
每个 decode step 调用一次
"""
# 策略 1:重用上一轮的 indices(大部分情况足够)
if self.indices_cache is not None:
indices = self.indices_cache

# 策略 2:每隔 N 步重新计算 indices
# 或者当 cache_seqlens 变化超过阈值时重新计算
if should_recompute():
indices = self.indexer.compute_indices(
q, kv_cache, block_table, cache_seqlens
)
self.indices_cache = indices

# 使用 indices 进行 sparse attention
out = flash_mla_with_kvcache(
q, kv_cache, block_table, cache_seqlens,
indices=indices, # ← 传入 indices
is_fp8_kvcache=True
)

return out

特点

  • KV 是 FP8 格式(量化后)
  • 使用 flash_mla_with_kvcache kernel(sparse 模式)
  • Indices 可以缓存,避免每步都重新计算

代码与论文的对应

论文概念 FlashMLA 代码 位置
Lightning Indexer DSAIndexer.compute_indices() inference/indexer.py
Top-k Selection scores.topk(k=topk, dim=-1) inference/indexer.py
Indices Tensor indices [batch, h_q, topk] flash_mla_interface.py
Paged Indices indices_in_kvcache tests/quant.py
Sparse Attention Kernel flash_mla_sparse_fwd flash_mla_cuda.cu
Sparse Decode Kernel flash_mla_with_kvcache(indices=...) flash_mla_cuda.cu

性能数据

根据 DeepSeek 官方数据(H800 SXM5):

场景 Dense Sparse (DSA) 提升
Prefill (640 TFLOPS) 450 TFLOPS 640 TFLOPS 1.42x
Decode (410 TFLOPS) 150 TFLOPS 410 TFLOPS 2.73x
显存占用 (128K) 32 GB 1 GB 32x

注意

  • Prefill 提升较小(因为本来就是 compute-bound)
  • Decode 提升巨大(memory-bound + 计算量减少)

Indices Tensor:稀疏性的关键

1
2
3
4
5
6
7
8
9
10
# indices tensor 形状:[batch_size, seq_len_q, topk]
# indices[i][j][k] = 第 i 个 batch、第 j 个 query 的第 k 个关键 token 的索引

# 示例:batch=1, seq_len_q=4, topk=2
indices = torch.tensor([
[[0, 3], # query token 0 只 attention token 0 和 3
[1, 2], # query token 1 只 attention token 1 和 2
[2, 3], # query token 2 只 attention token 2 和 3
[0, 1]], # query token 3 只 attention token 0 和 1
], dtype=torch.int32)

Paged KV Cache 与 Indices 的映射

FlashMLA 使用 paged KV cache(类似 vLLM 的分页管理):

1
2
3
4
5
6
7
8
9
10
11
12
# indices 编码了 page block 索引 + block 内偏移
# indices_in_kvcache[i][j][k] = page_block_idx * page_block_size + offset_in_block

# 示例:page_block_size = 16
# 如果 indices[i][j][k] = 35
# 则 page_block_idx = 35 // 16 = 2
# offset_in_block = 35 % 16 = 3

def decode_indices(indices_flat, page_block_size=16):
page_block_idx = indices_flat // page_block_size
offset_in_block = indices_flat % page_block_size
return page_block_idx, offset_in_block

稀疏 Attention 计算流程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def sparse_attention(q, kv_cache, indices, sm_scale):
"""
q: [s_q, h_q, d_qk], bfloat16
kv_cache: paged KV cache (FP8 format)
indices: [s_q, h_kv, topk], int32
"""
# 1. 根据 indices 从 paged KV cache 中 gather 需要的 KV
# 注意:这里需要处理 FP8 反量化
focused_kv = gather_from_paged_cache(
kv_cache,
indices,
is_fp8=True # ← flashmla_kv 配置控制这里
)
# focused_kv: [s_q, topk, d_qk]

# 2. 计算 attention scores(只计算 topk 个)
# Q @ K^T
scores = torch.einsum('shd,stk->sht', q, focused_kv) * sm_scale
# scores: [s_q, h_q, topk]

# 3. Softmax(稀疏)
max_logits = scores.max(dim=-1, keepdim=True)
exp_scores = torch.exp(scores - max_logits)
lse = torch.log(exp_scores.sum(dim=-1))
attention_weights = exp_scores / exp_scores.sum(dim=-1, keepdim=True)

# 4. 加权求和
out = torch.einsum('sht,stk->shd', attention_weights, focused_kv)

return out, lse, max_logits

与 Dense Attention 的对比

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Dense Attention (O(n²) 复杂度)
def dense_attention(q, k, v):
scores = q @ k.transpose(-1, -2) # [s_q, s_kv]
weights = softmax(scores)
out = weights @ v
return out

# Sparse Attention (O(n × topk) 复杂度)
def sparse_attention(q, kv_cache, indices):
focused_kv = gather(kv_cache, indices) # 只 gather topk 个
scores = q @ focused_kv.transpose(-1, -2) # [s_q, topk]
weights = softmax(scores)
out = weights @ focused_kv
return out

# 复杂度对比(seq_len=128K, topk=4K)
# Dense: 128K × 128K = 16.4G 次计算
# Sparse: 128K × 4K = 0.5G 次计算(32x 加速)

论文算法与代码对应

DeepSeek-V2 论文中的 MLA 公式

论文中的 MLA 压缩 - 恢复流程:

$$
\begin{aligned}
\text{压缩:} & \quad C_K = X W_{cK}, \quad C_V = X W_{cV} \
\text{存储:} & \quad \text{KV Cache} = [C_K, C_V] \
\text{恢复:} & \quad K = C_K W_{uK}, \quad V = C_V W_{uV} \
\text{Attention:} & \quad \text{Attn}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V
\end{aligned}
$$

FlashMLA 代码实现对应

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# FlashMLA 中的 MLA 实现(简化版)
class MLAAttention(nn.Module):
def __init__(self, hidden_dim, latent_dim):
super().__init__()
# 压缩矩阵
self.W_cK = nn.Linear(hidden_dim, latent_dim, bias=False)
self.W_cV = nn.Linear(hidden_dim, latent_dim, bias=False)

# 恢复矩阵
self.W_uK = nn.Linear(latent_dim, hidden_dim, bias=False)
self.W_uV = nn.Linear(latent_dim, hidden_dim, bias=False)

def forward(self, x, cache_seqlens, block_table, is_fp8_kvcache=True):
# 1. 压缩(Prefill 阶段)
c_k = self.W_cK(x) # [seq_len, latent_dim]
c_v = self.W_cV(x)

# 2. 量化(如果启用 FP8)
if is_fp8_kvcache: # ← flashmla_kv 配置
kv_fp8, scales = quantize_kv_fp8(torch.cat([c_k, c_v], dim=-1))
kv_cache = torch.cat([kv_fp8, scales], dim=-1)
else:
kv_cache = torch.cat([c_k, c_v], dim=-1) # BF16

# 3. 存储到 paged cache
store_to_paged_cache(kv_cache, block_table)

# 4. Decode 阶段:从 cache 读取并恢复
kv_loaded = load_from_paged_cache(block_table, cache_seqlens)

if is_fp8_kvcache:
# 反量化
kv_dequant = dequantize_kv_fp8(kv_loaded[..., :512], kv_loaded[..., 512:516])
c_k, c_v = kv_dequant.chunk(2, dim=-1)
else:
c_k, c_v = kv_loaded.chunk(2, dim=-1)

# 恢复
k = self.W_uK(c_k)
v = self.W_uV(c_v)

# 注意力计算(可能使用 sparse indices)
out = flash_mla_with_kvcache(q, k, v, indices)

return out

关键函数映射

论文概念 FlashMLA 函数 位置
KV 压缩 W_cK, W_cV mla_attention.py
KV 恢复 W_uK, W_uV mla_attention.py
FP8 量化 quantize_kv_fp8 tests/quant.py
FP8 反量化 dequantize_kv_fp8 tests/quant.py
Paged Cache block_table, cache_seqlens flash_mla_interface.py
Sparse Indices indices tensor flash_mla_interface.py
Decode Kernel flash_mla_with_kvcache flash_mla_cuda.cu
Prefill Kernel flash_mla_sparse_fwd flash_mla_cuda.cu

为什么 Decode 阶段必须用 FP8

内存带宽瓶颈

Decode 阶段是 memory-bound(内存带宽受限),而非 compute-bound:

1
2
3
4
5
6
Decode 阶段特点:
- 每次只生成 1 个 token
- 需要读取整个 KV Cache(128K 上下文)
- 计算量小(1 个 query vs 128K keys)

瓶颈:从 HBM 读取 KV Cache 的速度

FP8 带来的带宽节省

1
2
3
4
5
6
7
8
9
10
11
12
# 假设 H800 SXM5 的 HBM 带宽:3.35 TB/s

# BF16 KV Cache (2048 bytes/token)
# 读取 128K tokens: 128K × 2048 = 256 MB
# 理论延迟:256 MB / 3.35 TB/s = 76 μs

# FP8 KV Cache (656 bytes/token)
# 读取 128K tokens: 128K × 656 = 82 MB
# 理论延迟:82 MB / 3.35 TB/s = 24 μs

# 带宽节省:3.1x
# 实际吞吐提升:2.5-2.8x(考虑反量化开销)

为什么 Prefill 不用 FP8?

Prefill 阶段是 compute-bound(计算受限):

1
2
3
4
5
6
7
8
9
# Prefill 阶段特点:
# - 处理整个 prompt(可能 128K tokens)
# - 需要计算 O(n²) 的 attention 矩阵
# - 计算密集,带宽压力相对较小

# 使用 BF16 的原因:
# 1. 训练精度要求高
# 2. 计算瓶颈不在带宽
# 3. 量化/反量化开销不划算

实践指南:SGLang + GB200 配置

SGLang 配置 DeepSeek-V3.2 on GB200

根据 SGLang 官方 issue #21291,在 NVIDIA GB200 (SM100/Blackwell) 上部署 DeepSeek-V3.2 的配置:

1. Docker 镜像

1
2
3
4
5
6
7
8
9
10
11
12
# GB200 (SM100) 使用专用镜像
docker pull lmsysorg/sglang:dsv32

# 启动容器
docker run --gpus all --shm-size 32g \
-p 30000:30000 \
lmsysorg/sglang:dsv32 \
python -m sglang.launch_server \
--model deepseek-ai/DeepSeek-V3.2-Exp \
--tp 8 \
--dp 8 \
--enable-dp-attention

2. 关键配置参数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
python -m sglang.launch_server \
--model deepseek-ai/DeepSeek-V3.2-Exp \

# 张量并行 (GB200 推荐 8 卡)
--tp 8 \

# 数据并行 (GB200 推荐 8)
--dp 8 \

# 启用数据并行注意力(NSA 后端必需)
--enable-dp-attention \

# FlashMLA 配置
--enable-flashmla \
--flashmla-kv True \ # ← 关键:启用 FP8 KV Cache

# DSA 稀疏注意力配置
--enable-sparse-attention \
--sparse-topk 4096 \ # 每个 query attention 4K tokens

# 上下文长度
--max-model-len 131072 \

# 服务端口
--port 30000

3. GB200 特殊注意事项

SM100 Kernel 限制

根据实测,GB200 (SM100) 只支持以下 kernel:

Kernel SM90 (H100) SM100 (GB200)
BF16 Dense Decode
FP8 Dense Decode
FP8 Sparse Decode
BF16 Sparse Prefill

这意味着

  • GB200 必须使用 FP8 Sparse Decode(flashmla_kv=True
  • Dense Decode 模式不可用
  • 这也是为什么 DeepSeek-V3.2 在 GB200 上默认使用 sparse 模式

4. 性能基准(GB200)

根据实测数据(Batch=32, TopK=512):

指标 数值
Decode 延迟 0.93 ms
Prefill 延迟 0.04 ms
显存占用 (128K) 82 MB/卡
吞吐量 (tokens/s) ~34,000

5. 客户端调用示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import requests

# SGLang API 调用
url = "http://localhost:30000/generate"

payload = {
"text": "你好,请介绍一下 DeepSeek-V3.2",
"sampling_params": {
"temperature": 0.7,
"max_new_tokens": 2048,
"top_p": 0.95,
},
"stream": False
}

response = requests.post(url, json=payload)
result = response.json()

print(result['text'])

6. 监控与调试

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 查看服务器日志
docker logs -f <container_id>

# 检查 FlashMLA 是否生效
# 日志中应该看到:
# "Using FlashMLA FP8 KV Cache"
# "Sparse attention enabled with topk=4096"

# 检查显存占用
nvidia-smi dmon -i 0

# 检查 kernel 启动
# 使用 nsys 性能分析工具
nsys profile --stats=true \
python -m sglang.launch_server \
--model deepseek-ai/DeepSeek-V3.2-Exp \
--tp 8 --dp 8

7. 常见问题

Q1: 启动时报 “SM100 not supported”

1
2
3
4
原因:SGLang 版本过旧
解决:更新到最新版本
pip install --upgrade sglang
docker pull lmsysorg/sglang:dsv32

Q2: 显存不足

1
2
3
4
原因:batch size 或 context length 过大
解决:
--max-model-len 65532 # 减少上下文长度
--dp 16 # 增加数据并行

Q3: 生成质量下降

1
2
3
可能原因:TopK 设置过小
解决:
--sparse-topk 8192 # 增加 topk(会增加延迟)

性能优化建议

GB200 最佳实践

  1. 启用 FP8 Sparse Decode

    1
    2
    --flashmla-kv True
    --enable-sparse-attention
  2. 选择合适的 TopK

    • 延迟敏感:--sparse-topk 2048
    • 质量优先:--sparse-topk 8192
    • 平衡:--sparse-topk 4096(推荐)
  3. 数据并行配置

    • 单节点 8 卡:--tp 8 --dp 8
    • 多节点:--tp 8 --dp 16+
  4. 显存优化

    1
    2
    --gpu-memory-utilization 0.9
    --max-num-batched-tokens 16384

参考资料

核心论文

  1. DeepSeek-V2 Technical Report (arXiv:2405.04434)

  2. DeepSeek-V3.2-Exp Technical Report (2025)

代码仓库

  1. FlashMLA (GitHub)

  2. DeepGEMM (GitHub)

  3. TileLang (GitHub)

技术博客

  1. FlashMLA Deep-Dive Blog (DeepSeek 官方)

  2. LMCache Documentation

推理框架

  1. vLLM DeepSeek-V3.2 支持

  2. SGLang DeepSeek-V3.2 支持

示例代码

  1. FlashMLA Kernel Benchmark

总结

FlashMLA 的 FP8 KV Cache 和 DSA 稀疏注意力代表了 LLM 推理优化的两个重要方向:

  1. 量化压缩:FP8 格式将 KV Cache 减少 68%+,直接缓解 decode 阶段的带宽瓶颈
  2. 稀疏计算:DSA 通过 indices tensor 实现 token-level 稀疏,将注意力复杂度从 O(n²) 降到 O(n × topk)

理解这些机制对于:

  • 正确配置推理服务(如 flashmla_kv 参数)
  • 集成 KV Cache 管理系统
  • 性能调优和故障排查

都至关重要。

随着更多模型采用类似技术(GLM-5、Qwen-3 等),掌握 FlashMLA 的原理将成为 LLM 部署工程师的必备技能。


参考资料

最后更新:2026-04-10

作者注:本文基于 FlashMLA 开源代码和 DeepSeek 技术报告整理,部分实现细节可能随版本更新而变化。

NVIDIA GB200 架构深度解析:机柜级 AI 超级计算机

摘要:NVIDIA GB200 不是简单的硬件升级,而是 AI 推理时代的基础设施。本文深入解析 GB200 NVL72 的架构创新,包括 Dual-Die 设计、对称内存、FP4 精度和 130 TB/s 铜缆背板等核心技术。


🎯 引言:为什么 GB200 是历史性的?

现代数据中心正在从离散服务器集群演变为统一的计算网络(即 AI Factory),而 NVIDIA Blackwell GB200 架构正是这一演变的巅峰之作。

与 H100 相比,GB200 带来的不是线性提升,而是代际飞跃

指标 GB200 NVL72 H100 集群 提升
推理吞吐 (万亿参数模型) 30x 1x 30 倍
能耗 (同等性能) 1/25 1x 25 倍降低
TCO (总拥有成本) 1/25 1x 25 倍降低

核心洞察:GB200 重新定义了”GPU”——它不再是一个独立的芯片,而是一个72 处理器机柜级计算机的组成部分。


📐 一、核心架构:从芯片到机柜

1.1 Dual-Die 设计:突破物理极限

由于单枚晶圆接近 Reticle Limit(光刻极限),Blackwell 采用了激进的双芯片设计

1
2
3
4
5
6
7
8
9
10
11
┌─────────────────┐  ┌─────────────────┐
│ Blackwell GPU │ │ Blackwell GPU │
│ (左半芯片) │ │ (右半芯片) │
│ 2080 亿晶体管 │ │ 2080 亿晶体管 │
│ TSMC 4NP 工艺 │ │ TSMC 4NP 工艺 │
└─────────────────┘ └─────────────────┘
▲ ▲
└────────┬───────────┘

10 TB/s HBI 互联
(High-Bandwidth Interface)

关键数据

指标 H100 B200 提升
晶体管数 800 亿 2080 亿 2.6x
工艺 TSMC 4N TSMC 4NP 定制优化
Die 配置 单芯片 双芯片 良率更高
Die-Die 带宽 N/A 10 TB/s -

为什么这么做

  • 光刻机的 reticle 尺寸有限(~850mm²)
  • 强行做大芯片 = 良率暴跌 = 成本爆炸
  • 两个小 die + 高速互联 = 最佳经济性

1.2 GB200 Superchip:CPU-GPU 深度融合

GB200 Superchip 是系统的核心模块,将 Grace CPUBlackwell GPU 直接”缝合”:

1
2
3
4
5
6
7
8
9
┌─────────────────────────────────────┐
│ GB200 Superchip │
│ ┌──────────────┐ ┌───────────┐ │
│ │ Grace CPU │◄──►│ Blackwell │ │
│ │ (72 核 ARM) │ │ GPU │ │
│ │ 480GB LPDDR5X│ │384GB HBM3e│ │
│ └──────────────┘ └───────────┘ │
│ NVLink-C2C 900 GB/s │
└─────────────────────────────────────┘

NVLink-C2C 关键特性

  • 带宽:900 GB/s 双向
  • 对比 PCIe Gen5:7 倍带宽,25 倍能效
  • 硬件一致性:CPU 和 GPU 可同时操作同一数据区域

**对称内存架构 (Symmetric Memory)**:

  • GPU 可以直接访问 CPU 的 480GB LPDDR5X 内存
  • CPU 可以直接访问 GPU 的 384GB HBM3e 显存
  • 统一虚拟地址空间,零拷贝数据传输

实际价值:对于 RAG 或超大型 Embedding Tables,这种对称性提供了近乎本地显存的访问体验


1.3 GB200 NVL72:机柜即计算机

NVL72 将整个机柜视为一个巨大的虚拟 GPU

GB200 NVL72 机柜

GB200 Superchip 特写

机柜配置

组件 数量 功能
Compute Trays 18 容纳 36 CPU + 72 GPU
NVLink Switch Trays 9 72-GPU 全互联
Power Shelves 6-8 5.5kW 钛金级 PSU
Liquid Manifolds 1 冷却液分配
总重量 3,000 lbs 含冷却液 (~1.36 吨)
功耗 120-140kW 满载

铜缆背板工程奇迹

  • 5000+ 根 无源铜缆
  • 总长度 ~2 英里 (~3.2 公里)
  • 带宽 130 TB/s
  • 功耗比光纤低 ~50%

为什么用铜缆:机柜内距离短 (<10 米),1.8 TB/s 带宽下光模块功耗太高,铜缆无源设计可靠性更高。


⚡ 二、性能突破:FP4 与 Transformer Engine

2.1 第二代 Transformer Engine

Blackwell 引入了 FP4FP6 精度支持,通过 Micro-Tensor Scaling 技术实现:

1
2
3
4
5
6
7
8
传统量化:
Weight: FP4 (单一缩放因子)
❌ 动态范围受限,精度损失大

NVFP4 (Blackwell):
- 16-value 微块:FP8 (E4M3) 缩放
- Tensor 级别:FP32 全局缩放
✅ 精度损失 <1% vs FP8

峰值算力对比

精度 NVL72 峰值 H100 提升
FP4 Tensor Core 1,440 PFLOPS N/A -
FP8 Tensor Core 720 PFLOPS 180 PFLOPS 4x
FP16 Tensor Core 360 PFLOPS 100 PFLOPS 3.6x

2.2 内存层级与带宽

组件 规格 带宽
HBM3e (GPU) 384GB per GPU 16 TB/s
LPDDR5X (CPU) 480GB per Superchip 512 GB/s
NVLink-C2C CPU-GPU 互联 900 GB/s
NVLink 5.0 GPU-GPU 互联 1.8 TB/s per GPU
背板聚合 72 GPU 130 TB/s

统一内存池

  • 单 NVL72 总内存:**~30 TB** (72 × 384GB + 36 × 480GB)
  • 跨 GPU 访问延迟:**300ns** (vs 多机柜的5μs)

🌡️ 三、先进液冷与可靠性

3.1 液冷规格

参数 数值 说明
进水温度 20-25°C W45 标准可达 50°C
冷却液流量 80 L/min 每机柜
系统压降 <1.5 bar 泵送功率优化
冷板热阻 <0.03 °C/W 高效传热
最高结温 75°C 超限自动降频

冷板微通道设计

  • 微通道铜鳍片 (Skived Fin 工艺)
  • 雷诺数 Re < 2000 (层流)
  • 热点热通量:150 W/cm²

3.2 RAS Engine:预测性维护

Reliability, Availability, and Serviceability (RAS) Engine 是 Blackwell 的专用可靠性引擎:

功能 说明 价值
Self-Healing 自动定位故障源 减少 MTTR
Predictive Maintenance 基于趋势预测故障 计划内维护
Detailed Diagnostics 深入诊断信息 节省人工排查

监控的遥测数据

  • 电压波动 (mV 级别)
  • 温度变化 (0.1°C 精度)
  • ECC 错误计数
  • NVLink 误码率

🔒 四、安全特性:机密计算

Blackwell 是行业首个 TEE-I/O (Trusted Execution Environment I/O) 能力的 GPU:

1
2
3
4
5
6
7
传统加密:
数据 → 解密 → GPU 计算 → 加密 → 结果
❌ 加解密开销,性能损失 ~30-50%

Blackwell TEE-I/O:
数据 → GPU (硬件加密) → 结果
✅ 性能损失 <5%,几乎无损

安全架构

  • NVLink 内联加密:GPU 间数据传输保护
  • **NVIDIA Remote Attestation Service (NRAS)**:平台完整性验证
  • **Reference Integrity Manifest (RIM)**:固件防篡改

适用场景

  • ✅ 医疗:病历 AI 分析
  • ✅ 金融:风控模型
  • ✅ 政府:敏感数据处理

🚀 五、SGLang 部署实践

5.1 单卡 GB200 运行 DeepSeek 671B

1
2
3
4
5
6
7
8
9
10
11
12
python3 -m sglang.launch_server \
--model-path nvidia/DeepSeek-R1-0528-FP4-V2 \
--tensor-parallel-size 1 \
--enable-symm-mem \
--mem-fraction-static 0.95 \
--quantization modelopt_fp4 \
--max-running-requests 128

# 内存分配估算:
# - 模型权重 (FP4): ~350GB
# - KV Cache: ~200GB (HBM3e) + ~200GB (LPDDR5X)
# - 总占用:~750GB < 864GB 总池 ✅

5.2 NVL72 满配部署

1
2
3
4
5
6
7
8
9
10
11
12
13
python3 -m sglang.launch_server \
--model-path deepseek-ai/DeepSeek-V3 \
--tp 72 \
--enable-symm-mem \
--enable-dp-attention \
--ep-size 72 \
--mem-fraction-static 0.9 \
--max-running-requests 10000

# 预期性能:
# - 解码吞吐量:~540,000 tokens/s
# - 并发请求:~10,000+
# - 平均延迟:<50ms (batch=1)

5.3 Kubernetes ComputeDomain 配置

1
2
3
4
5
6
7
8
9
10
11
12
apiVersion: nvidia.com/v1
kind: ComputeDomain
metadata:
name: nvl72-rack-001
namespace: ai-inference
spec:
gpuCount: 72
topology:
type: nvlink-full-mesh
generation: 5.0
scheduling:
policy: gang # 72 GPU 同时调度

💰 六、经济性分析

6.1 自建 vs 云租赁

维度 云租赁 (H100) 自建 (GB200)
前期成本 $0 $3.5M+
运营成本 $500k/月 $20k/月 (电费)
回本周期 - ~8 个月
GPU 成本 $2.95-16/GPU-h $0.51/GPU-h

6.2 TCO 对比

以运行 DeepSeek 671B 为例:

1
2
3
4
5
6
7
8
9
10
11
12
方案 A: H100 集群
- GPU 数量:256 卡
- 功耗:~102kW
- 月电费:~$15,000
- 云租赁:~$500,000/月

方案 B: GB200 NVL72
- GPU 数量:72 卡
- 功耗:~120kW
- 月电费:~$17,000
- 自建成本:~$3.5M (一次性)
- 回本周期:~8 个月

🔮 七、未来路线图

平台 发布时间 GPU 显存 NVLink 带宽 性能提升
B200 2025 192GB 1.8 TB/s 基准
GB300 Ultra 2025 H2 288GB 1.8 TB/s +50% 显存,+50% FP4
Rubin (Vera) 2026 TBD 3.6 TB/s 2x 带宽,260 TB/s 聚合

📋 八、部署 CheckList

基础设施准备

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
### 电力
- [ ] 三相 480V 输入 (120kW+ 容量)
- [ ] UPS 冗余 (N+1)
- [ ] PDU 配置完成

### 冷却
- [ ] 液冷 CDU 安装 (250kW+ 能力)
- [ ] 一次侧/二次侧管道连接
- [ ] 冷却液填充 + 排气
- [ ] 压力测试完成 (1.5 bar)

### 网络
- [ ] OOB 管理网络 (1GbE)
- [ ] 数据网络 (200/400GbE 或 InfiniBand)
- [ ] DNS/DHCP 配置

### 软件
- [ ] NVOS 镜像更新
- [ ] BCM 集群注册
- [ ] Kubernetes + DRA 驱动
- [ ] SGLang 容器镜像
- [ ] 监控系统 (Prometheus + Grafana)

🎯 总结

NVIDIA GB200 平台代表了自 CUDA 平台诞生以来最重大的计算架构进步。通过重新定义 GPU 不再是独立芯片,而是72 处理器机柜级计算机的组成部分,NVIDIA 成功解决了 AI 扩展的主要瓶颈。

核心创新

  1. Dual-Die 设计:突破光刻极限,2080 亿晶体管
  2. 对称内存:CPU-GPU 统一地址空间,900 GB/s
  3. FP4 精度:Micro-Tensor Scaling,2x 容量<1% 损失
  4. 铜缆背板:5000+ 线缆,130 TB/s,功耗最优
  5. 液冷系统:80 L/min,120kW 散热
  6. RAS Engine:AI 预测性维护

对于现代企业,GB200 NVL72 不仅仅是硬件升级,它是AI 推理时代的物理基础设施,提供了将海量数据集转化为可操作智能所需的密度、效率和安全性。


📚 参考资料

  1. NVIDIA Blackwell Architecture Official Page
  2. SGLang Documentation
  3. LMSYS GB200 Deployment Guide
  4. NVIDIA TEE-I/O Confidential Computing

标签:#NVIDIA #GB200 #Blackwell #AI 基础设施 #LLM #SGLang #深度学习

引言:从 Bluesky 的 epoll 瓶颈说起

2024 年 1 月,Bluesky 工程师发表了一篇文章,讲述了他们在将 Go 服务扩展到 192 核裸金属服务器时遭遇的一个深层运行时瓶颈。

他们的 AppView V2 服务是一个 ConnectRPC 服务器,平均每个请求要向 ScyllaDB 发起约 15.2 次查询。在配备 2×96 核 AMD Genoa-X CPU、512GB RAM 的服务器上,他们遇到了两个核心瓶颈:

  1. GC 压力:通过调高 GOGC 参数(从 100 到 500),用内存换 CPU 时间
  2. epoll 瓶颈:Go 的 Netpoll 在单次 EPoll 调用中最多只缓冲 128 个 socket,而实际场景中有数千个 socket 就绪,导致 syscall.EpollWait 占据了近 65% 的 CPU 时间

他们的解决方案是:在每台主机上启动 8 个 Go 运行时实例,将网络负载分摊开来。性能提升显著:

  • ScyllaDB 查询吞吐量:130 万次/秒 → 280 万次/秒
  • 前端请求吞吐:9 万次/秒 → 18.5 万次/秒
  • p50/p99 延迟下降超过 50%
  • CPU 利用率:80% → 40%

这个故事揭示了一个反直觉的工程规律:在极高并发 I/O 场景下,运行时本身(而非业务逻辑)会成为瓶颈

而 io_uring,正是为了解决这类问题而生的。


epoll 的局限性

epoll 是 Linux 2.6+ 引入的 I/O 多路复用机制,替代了 select/poll。它的工作原理是:

  1. 内核维护一个就绪事件列表
  2. 用户空间通过 epoll_wait() 轮询获取就绪事件
  3. 支持 LT(水平触发)、ET(边缘触发)等模式

epoll 的痛点

痛点 说明
上下文切换开销 每次 epoll_wait() 都要从用户态切换到内核态
数据拷贝 就绪事件需要从内核空间拷贝到用户空间
锁竞争 高并发下 epoll 实例的锁成为瓶颈
中断风暴 每个事件到达都可能触发中断
单次缓冲限制 如 Go Netpoll 单次只缓冲 128 个 socket

Bluesky 遇到的正是 epoll 的系统性瓶颈——当连接数突破某个阈值,抽象层的开销就藏不住了。


io_uring:设计哲学

io_uring 是 Linux 5.1+ 引入的异步 I/O 接口,由 Jens Axboe(FIO 作者)设计。它的核心设计哲学是:

让 I/O 提交和完成都无需系统调用,通过共享内存实现零拷贝通信。

核心架构

io_uring 基于共享环形缓冲区

  • **SQ (Submission Queue)**:提交队列,用户空间将 I/O 请求放入这里
  • **CQ (Completion Queue)**:完成队列,内核将完成的通知放入这里
  • **SQE (Submission Queue Entry)**:提交队列条目,描述一个 I/O 请求
  • **CQE (Completion Queue Entry)**:完成队列条目,描述 I/O 完成结果
1
2
3
4
5
6
7
8
用户空间                              内核空间
┌─────────────────┐ ┌─────────────────┐
│ SQ Ring │◄────共享内存────►│ SQ Ring │
│ (提交队列) │ │ (提交队列) │
├─────────────────┤ ├─────────────────┤
│ CQ Ring │◄────共享内存────►│ CQ Ring │
│ (完成队列) │ │ (完成队列) │
└─────────────────┘ └─────────────────┘

零拷贝机制详解

io_uring 的”零拷贝”不是魔法,而是巧妙的虚拟内存映射设计

两个阶段

阶段一:io_uring_setup() — 内核分配物理页

内核分配物理页 P1, P2, P3,并在内核页表中建立映射。此时用户进程页表中没有任何映射,用户态无法访问这些物理页。

阶段二:mmap() 三次 — 插入用户态映射

内核在用户进程的页表里插入新条目——把用户态的某段虚拟地址(比如 0x7f000000)也指向同一批物理页。

io_uring mmap 机制

关键

  • 用户态虚拟地址 0x7f000000 → 物理页 P1
  • 内核态虚拟地址 0xffff8000同一个物理页 P1
  • 权限位不同:用户态是 RW|User,内核态是 RW|Kernel-only

mmap() 做了什么

  • io_uring_setup() — 内核分配物理页 P1/P2/P3,建立内核虚拟地址→物理页的映射
  • mmap() 三次 — 在用户进程页表里插入新条目:用户虚拟地址 → 同一批物理页 P1/P2/P3
  • 写入后 — 用户写 0x7f000000,内核读 0xffff8000...,落在同一物理字节,零拷贝

数据流:零拷贝如何实现

用户程序写 SQ Ring:

1
sqring->tail++;  // 虚拟地址 0x7f000000,CPU 翻译到物理页 P1

内核程序读 SQ Ring:

1
tail = sqring->tail;  // 虚拟地址 0xffff8000,CPU 翻译到同一个物理页 P1

数据从未被复制,只是同一块物理内存有两个”门牌号”。

页表切换优化

切换场景 CR3 是否切换 原因
用户进程 A → 内核线程 ❌ 不切换 内核线程借用进程 A 的页表,只访问内核映射区
内核线程 → 内核线程 ❌ 不切换 所有内核页表的内核映射区完全相同
内核线程 → 用户进程 B ✅ 必须切换 进程 B 的用户态映射不同,需要换页表

如果调度序列是:进程 A → kworker → kworker2 → 进程 A,整个过程 CR3 一直是进程 A 的页表,完全不需要切换!


性能对比

操作 epoll io_uring
提交请求 epoll_ctl() syscall 写 SQ Ring (无 syscall)
等待事件 epoll_wait() syscall 读 CQ Ring (无 syscall)
数据拷贝 就绪事件从内核拷贝到用户 零拷贝 (共享内存)
上下文切换 每次 wait 都要切换 初始化后几乎不切换
并发能力 万级并发 OK 十万级并发轻松

实战:使用 liburing

安装

1
2
3
4
5
6
7
8
9
# Ubuntu/Debian
apt install liburing-dev

# 源码编译
git clone https://github.com/axboe/liburing
cd liburing
./configure
make
sudo make install

Hello World 示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#include <liburing.h>
#include <stdio.h>
#include <string.h>
#include <fcntl.h>
#include <unistd.h>

int main() {
struct io_uring ring;
struct io_uring_sqe *sqe;
struct io_uring_cqe *cqe;
char buf[256];
int fd;
int ret;

// 初始化 io_uring
ret = io_uring_queue_init(32, &ring, 0);
if (ret < 0) {
perror("io_uring_queue_init");
return 1;
}

// 打开文件
fd = open("test.txt", O_RDONLY);
if (fd < 0) {
perror("open");
return 1;
}

// 获取提交队列条目
sqe = io_uring_get_sqe(&ring);

// 准备读请求
io_uring_prep_read(sqe, fd, buf, sizeof(buf), 0);

// 提交请求
ret = io_uring_submit(&ring);
if (ret < 0) {
perror("io_uring_submit");
return 1;
}

// 等待完成
ret = io_uring_wait_cqe(&ring, &cqe);
if (ret < 0) {
perror("io_uring_wait_cqe");
return 1;
}

// 检查结果
if (cqe->res < 0) {
fprintf(stderr, "read failed: %d\n", cqe->res);
} else {
printf("Read %d bytes: %.*s\n", cqe->res, cqe->res, buf);
}

// 标记完成
io_uring_cqe_seen(&ring, cqe);

close(fd);
io_uring_queue_exit(&ring);
return 0;
}

编译:

1
gcc -o hello_uring hello_uring.c -luring

生产环境考量

内核版本要求

特性 最低内核版本
基础 io_uring 5.1
链接操作 (IOSQE_IO_LINK) 5.5
缓冲区选择 5.6
轮询模式 (IORING_SETUP_SQPOLL) 5.11
注册文件描述符 5.6

建议:生产环境使用 5.10+ LTS 内核。

云厂商支持情况

  • AWS:Amazon Linux 2 默认 4.14,需升级;Amazon Linux 2023 默认 5.10+
  • GCP:Cos 默认较新,Ubuntu 镜像需确认
  • Azure:Ubuntu 20.04+ 支持良好
  • 阿里云/腾讯云:需确认具体实例类型

调试工具

1
2
3
4
5
6
7
8
9
10
11
# 查看内核支持
cat /boot/config-$(uname -r) | grep IO_URING

# 检查 io_uring 状态
cat /proc/sys/fs/io_uring-*

# 追踪系统调用
strace -e io_uring_* ./your_program

# 性能分析
bpftrace -e 'tracepoint:syscalls:sys_enter_io_uring_* { @[comm] = count(); }'

已知坑点

  1. 内存限制ulimit -l 可能限制锁内存大小,io_uring 需要锁内存
  2. 文件系统:NFS 等网络文件系统支持有限
  3. 权限问题:某些操作可能需要 CAP_IPC_LOCK 能力

生态现状

采用 io_uring 的项目

项目 状态 说明
Nginx 实验性 部分模块支持
Redis 部分支持 持久化模块
Node.js 实验中 uv 库
Python 实验性 asyncio 后端选项
vLLM ✅ 生产 FlexKV 使用 io_uring 处理 SSD I/O

语言支持矩阵

语言 成熟度
C liburing ✅ 官方
Rust tokio-uring, io-uring ✅ 成熟
Go golang.org/x/exp/io/uring ⚠️ 实验性
Python python-liburing ⚠️ 非官方

学习资源

官方资料

教程

示例代码


总结:什么时候该用 io_uring

适用场景

  • ✅ 高并发网络服务(>10K 连接)
  • ✅ 数据库、存储引擎
  • ✅ 低延迟要求的应用
  • ✅ 大量随机 I/O 场景

不适用场景

  • ❌ 连接数少(<100)
  • ❌ 内核版本受限(<5.1)
  • ❌ 需要广泛兼容性的场景

未来展望

io_uring 正在成为 Linux I/O 的默认选择。随着内核普及和语言绑定的成熟,它有望在以下方面带来变革:

  1. 网络框架重构:现有 epoll 框架(如 Netty、Tokio)可能重写后端
  2. 数据库优化:存储引擎直接利用 io_uring 降低延迟
  3. 云原生基础设施:Service Mesh、API Gateway 等中间件受益

正如 Bluesky 的案例所示,在极端场景下,运行时抽象会变成瓶颈。io_uring 提供了一种更底层的、更高效的 I/O 模型,让我们能够突破这些瓶颈。

对于追求极致性能的系统工程师来说,io_uring 不是”要不要学”的问题,而是”什么时候学”的问题。


参考资料

  1. Bluesky Engineering. “Scaling AppView to 192 Cores.” https://jazco.dev/2024/01/10/golang-and-epoll/
  2. Jens Axboe. “io_uring: A New Linux Async I/O Subsystem.” https://kernel.dk/io_uring.pdf
  3. Shuveb Hussain. “Lord of the io_uring.” https://unixism.net/loti/
  4. Linux Kernel Documentation. “io_uring.” https://github.com/torvalds/linux/tree/master/Documentation/io_uring

Agent Skill 自提升机制:以结果为导向的进化设计

不是所有目标都需要 LLM 评估。客观指标对数字负责,主观标准对评估器负责。目标定义本身在每次循环中进化。


一、引言:为什么 Agent Skill 需要进化?

现状问题

大多数 Agent Skill 是静态的——SKILL.md 写完就固定了。遇到边界情况不会”长记性”,每次错误都是孤立的,无法沉淀成经验。

更严重的是,当 Skill 开始自我迭代时,如果没有良好的约束机制,AI 会陷入盲目试错的循环:

某次下午,我让 AI 迭代优化一个 Skill 的 prompt。没有设置 token 上限,AI 开始疯狂循环:生成 → 评估 → 改进 → 再生成 → 再评估 → 再改进。每次迭代消耗约 5,000 tokens,一下午跑了 400+ 次迭代,总消耗 200 万 + tokens,额度全部用完。最终分数从 72% → 74%,改进微乎其微。

这是血泪教训。

核心挑战

如何定义”进步”?

  • “写得更好” → 太模糊
  • “通过测试” → 但什么测试?
  • “用户满意” → 怎么衡量?

本文提出一套以结果为导向的进化设计,核心洞察是:结果类型决定评估策略


二、两种进化目标范式

范式 1:指标驱动(客观结果)

1
2
3
4
5
6
## 测试覆盖率
./test.sh --coverage # 覆盖率 > 80%

## 数据库性能
查询响应时间 < 100ms
慢查询数量 < 5/天

特点:

  • ✅ 目标可量化,二元判定
  • ✅ 不需要额外评估器,对指标负责即可
  • ✅ 适合工程类任务(测试、性能、构建)

局限:

  • ❌ 难以处理模糊目标(如”改善用户体验”)
  • ❌ 需要预先知道正确的执行路径

范式 2:标准驱动(主观结果)

1
2
3
4
5
6
7
8
9
## 设计风格
- 视觉层次清晰
- 配色和谐统一
- 交互反馈及时

## 用户体验
- 文案友好自然
- 操作流程顺畅
- 信息架构合理

特点:

  • ✅ 适合模糊目标(设计、体验、文案)
  • ✅ 需要独立 LLM 作为评估器打分
  • ✅ 评估维度本身可进化

局限:

  • ❌ 评估成本高(需要额外 LLM 调用)
  • ❌ 评分存在主观性

三、核心洞察:结果类型决定评估策略

结果类型 可衡量程度 评估方式 示例
客观指标 高(数值化) 直接对指标负责 覆盖率>80%, 查询<100ms
主观标准 中(可描述) 独立 LLM 评估器 设计风格、用户体验
知识沉淀 低(质性) 独立 LLM 评估器 + 人工审核 案例积累、最佳实践

关键原则:客观指标不需要评估器,主观标准需要。

这是一个刻意的设计选择。很多团队喜欢把所有东西都用 LLM 评估,但这是资源浪费。测试覆盖率就是覆盖率,数字不会骗人,何必再让 LLM 说一遍?


四、自提升系统的三元结构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
┌─────────────────────────────────────┐
│ Skill 系统 │
├─────────────────────────────────────┤
│ 1. 执行器 (Executor) │
│ - 完成核心任务 │
│ - 接收目标 → 产出结果 │
├─────────────────────────────────────┤
│ 2. 评估器 (Judger) │
│ - 客观指标:跳过,直接校验 │
│ - 主观标准:独立 LLM 打分 │
├─────────────────────────────────────┤
│ 3. 目标定义 (Goal Spec) │
│ - 指标驱动 / 标准驱动 │
│ - 可验证的完成条件 │
│ - 本身也可进化 │
└─────────────────────────────────────┘

设计要点:

  1. 执行器和评估器分离 —— 避免”自己评自己”
  2. 评估器可选 —— 客观指标不需要
  3. 目标定义可进化 —— 这是”会学习”的关键

五、自提升循环(四阶段)

Phase 1: 目标解析

1
2
3
4
输入:模糊需求 → 输出:可执行目标

"优化性能" → "查询响应 < 100ms, 慢查询 < 5/天"
"改善设计" → "由独立 LLM 评估,视觉层次>7/10"

Phase 2: 执行 + 评估

1
2
3
4
5
客观指标:
执行器执行 → 直接校验指标 → 通过/失败

主观标准:
执行器执行 → 独立 LLM 评估 → 打分 + 维度分析

Phase 3: 知识沉淀

1
2
3
低分维度 → 更新知识库
新技巧/模式 → 添加到 references
失败案例 → 写入 case-history

Phase 4: 目标进化

1
2
3
第一次:模糊目标 → 执行 → 发现需要更具体
第二次:更新为具体指标/标准 → 执行 → 稳定
后续:直接复用成熟目标

关键洞察:不仅技能在进化,目标定义本身也在进化。


六、实战案例 1:E2E 测试覆盖率驱动的项目重构

场景描述

一个后端服务需要重构,但担心破坏现有功能。如何保证重构后的质量?

初始状态

1
2
3
- 代码库:遗留系统,技术债务多
- 测试:少量手工测试,无自动化
- 风险:重构可能导致回归 bug

目标定义(指标驱动)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
## E2E 测试覆盖率目标

### 基线建立
1. 编写核心流程 E2E 测试(登录、下单、支付)
2. 初始覆盖率:35%

### 重构前要求
- E2E 测试覆盖率 > 80%
- 核心流程 100% 覆盖
- 所有测试用例通过

### 重构过程
- 每次代码变更后自动运行 E2E 测试
- 覆盖率下降 → 立即回滚
- 新增功能 → 先写测试

### 验收标准
- 重构后覆盖率 >= 重构前覆盖率
- 所有 E2E 测试通过
- 性能指标无退化

自提升循环执行

循环 1:建立基线

1
2
3
4
5
目标:编写核心流程 E2E 测试
执行:手动编写 15 个测试用例
评估:覆盖率 35% ❌ (目标 80%)
沉淀:记录"测试覆盖不足的模块清单"
改进:生成测试用例模板,批量补充

循环 2:补充覆盖

1
2
3
4
5
目标:覆盖率提升至 80%
执行:使用模板生成 + 手动补充至 50 个用例
评估:覆盖率 78% ❌ (接近但未达标)
沉淀:发现"边界条件测试缺失"
改进:添加边界条件测试生成规则

循环 3:达标

1
2
3
4
5
目标:覆盖率 >= 80%
执行:补充边界测试至 65 个用例
评估:覆盖率 85% ✅
沉淀:记录"高效测试用例模式"
改进:将模式固化为测试生成脚本

循环 4:重构执行

1
2
3
4
5
目标:重构核心模块,覆盖率不下降
执行:重构 + 自动运行 E2E 测试
评估:覆盖率 84% ✅, 全部通过 ✅
沉淀:记录"重构安全模式"
改进:更新重构检查清单

关键洞察

E2E 覆盖率作为”安全网”的价值:

  1. ✅ 重构前有基线,可对比
  2. ✅ 重构中有保障,回归立即发现
  3. ✅ 重构后有证据,质量可验证

自提升的体现:

  • 测试用例数量从 15 → 65
  • 覆盖率从 35% → 85%
  • 测试生成从手动 → 模板 → 自动化脚本
  • 目标本身也在进化:从”写测试”到”覆盖率>80%”到”回归零失败”

七、实战案例 2:面向用户体验的 Web 开发

场景描述

一个技能分发平台的 Web 项目,需要持续改进视觉设计和交互体验。这类任务的特点是:**”好”的定义模糊**,需要更清晰的语言描述什么是好的用户体验和视觉方案。

初始目标(模糊)

1
目标:更新风格 + 运行测试

问题:

  • “更新风格”无法验证
  • 测试通过率多少算合格?
  • 每次执行结果不一致

改进后目标(分层设计)

客观指标(直接校验)

1
2
3
4
5
6
## 测试要求
./scripts/test-complete.sh 通过率 > 80%

## 部署验证
页面加载时间 < 3s
HTTP 状态码 200

主观标准(LLM 评估器)

1
2
3
4
5
6
7
8
## 设计风格评估
评估维度(1-10 分):
- 视觉层次清晰度
- 配色和谐度
- 组件一致性
- 交互反馈及时性

合格线:综合评分 > 7/10

用户体验评估

1
2
3
4
5
6
7
评估维度(1-10 分):
- 文案友好度
- 操作流程顺畅度
- 信息架构合理性
- 无障碍访问支持

合格线:综合评分 > 7/10

执行流程

1
2
3
4
5
1. 执行器:更新 UI 组件 → 提交代码
2. 客观校验:测试通过率 85% ✅, 加载时间 2.1s ✅
3. 主观评估:独立 LLM 打分 → 视觉层次 6/10 ❌
4. 知识沉淀:记录"Hero 区域对比度不足"
5. 目标进化:下次增加"对比度>4.5:1"的具体要求

为什么这类场景需要独立 LLM 评估器?

因为”设计风格”、”用户体验”这些概念无法用单一数值衡量。但我们可以:

  1. 拆解维度 —— 把模糊概念拆成可评分的子项
  2. 独立评估 —— 用不与执行器共享上下文的 LLM 打分
  3. 沉淀标准 —— 每次评估后更新”什么是好设计”的知识库

久而久之,评估器会越来越准,因为它的知识库在进化。


八、上下文分层设计

1
2
3
4
5
6
7
8
9
10
11
references/
├── universal.md # 跨领域通用检查项
├── advanced.md # 罕见/复杂场景
├── by_domain/
│ ├── web-ux.md # Web 用户体验知识
│ ├── api-design.md # API 设计知识
│ └── data-pipeline.md # 数据处理知识
├── goals/
│ ├── objective.md # 客观指标定义(覆盖率、性能)
│ └── subjective.md # 主观标准定义(设计、体验)
└── case-history.md # 真实案例 + 评分 + 改进记录

设计原则:

  • 通用与专属分离 —— universal.md 放跨领域知识,by_domain/ 放领域特定知识
  • 目标定义独立 —— goals/ 单独存放,因为目标本身可进化
  • 案例可追溯 —— case-history.md 记录完整迭代过程,便于复盘

九、评估器设计原则

客观指标:不需要评估器

1
2
3
4
5
def verify_objective(result, spec):
if spec.type == "coverage":
return result.coverage >= spec.threshold # 直接返回布尔值
elif spec.type == "database":
return result.query_time < spec.threshold

主观标准:需要独立 LLM

1
2
3
4
5
6
7
8
9
10
def evaluate_subjective(result, spec):
# 使用独立的 LLM 实例(不与执行器共享上下文)
judger = LLM(role="独立评估员")
score, dimensions = judger.evaluate(result, spec.dimensions)
return {
"passed": score >= spec.threshold,
"score": score,
"dimensions": dimensions, # 各维度分项得分
"feedback": judger.feedback # 改进建议
}

为什么需要独立 LLM?

  • 避免执行器”自己评自己”
  • 评估器上下文不与执行器污染
  • 评估标准可独立进化

十、元指令:告诉 AI 如何思考

问题:为什么 AI 会盲目迭代?

错误的目标描述:

1
"优化这个函数,直到测试通过"

AI 的行为:

  • 盲目尝试各种改法
  • 不改好就继续试
  • 不反思为什么失败
  • Token 消耗巨大

正确的目标描述:加入调试和反思指令

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
## 任务:优化订单查询性能

### 执行要求

1. **先分析,再动手**
- 阅读现有代码,理解逻辑
- 识别性能瓶颈(N+1 查询?缺少索引?)
- 写出分析报告

2. **遵循 Debug 调试方式**
- 每次只改一个地方
- 改完立即测试
- 记录每次改动的影响

3. **失败时必须反思**
- 为什么这次改动没效果?
- 是假设错了还是实现错了?
- 下一步应该尝试什么?

4. **达到停止条件时主动汇报**
- 目标达成 → 总结成功因素
- 预算耗尽 → 说明卡在哪里
- 连续失败 → 请求人类介入

### 验收标准
- 查询响应时间 < 100ms
- 输出完整的调试日志
- 输出反思报告

元指令模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
### 思考方式要求

**分析阶段:**
- 先理解问题,再动手解决
- 列出可能的原因/方案
- 评估每个方案的可行性

**执行阶段:**
- 小步快跑,每次只改一处
- 立即验证,确认效果
- 记录日志,便于回溯

**反思阶段:**
- 成功:为什么成功?可复用的经验是什么?
- 失败:假设哪里错了?下一步怎么调整?
- 停滞:是否需要更换策略或请求帮助?

**汇报要求:**
- 每轮迭代输出进度
- 遇到阻塞主动说明
- 预算耗尽前提前预警

对比:有无元指令的效果差异

维度 无元指令 有元指令
第一次改动前 直接改代码 先写分析报告
失败后 继续尝试下一个改法 反思为什么失败
Token 消耗 高(盲目试错) 低(有策略尝试)
人类可介入性 低(不知道卡在哪) 高(有调试日志)
最终效果 不稳定 更可靠

十一、带预算控制的调度器设计

核心代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
class SelfRefineLoop:
def __init__(self, executor, judger=None, budget=None):
self.executor = executor
self.judger = judger # 可选,仅主观标准需要
self.budget = budget or {
"max_iterations": 50, # 最多迭代次数
"max_tokens": 100000, # 最多消耗 token
"target_score": 0.90, # 目标分数/覆盖率
"min_improvement": 0.02 # 最小改进幅度(避免原地打转)
}
self.token_usage = 0
self.iteration = 0
self.history = [] # 记录每次迭代结果

def run(self, task, goal_spec):
"""
自提升主循环

终止条件(任一满足即停止):
1. 达到目标分数/覆盖率
2. 超过最大迭代次数
3. 超过 token 预算
4. 连续 3 次迭代无显著改进(改进 < min_improvement)
"""
start_tokens = self.get_token_usage()

while True:
self.iteration += 1

# === 终止条件检查 ===

# 1. 目标达成
if self.has_reached_target(goal_spec):
print(f"✅ 目标达成!迭代 {self.iteration} 次")
break

# 2. 迭代次数超限
if self.iteration > self.budget["max_iterations"]:
print(f"⚠️ 达到最大迭代次数 ({self.budget['max_iterations']})")
print(f" 当前分数:{self.get_current_score()}")
break

# 3. Token 预算超限
current_tokens = self.get_token_usage() - start_tokens
if current_tokens > self.budget["max_tokens"]:
print(f"⚠️ Token 预算超限!已消耗 {current_tokens:,} tokens")
break

# 4. 停滞检测(连续 3 次无显著改进)
if self.is_stagnant():
print(f"⚠️ 检测到停滞,连续 3 次改进 < {self.budget['min_improvement']:.1%}")
break

# === 执行循环 ===

# Phase 1: 解析目标
executable_goal = self.parse_goal(task, goal_spec)

# Phase 2: 执行
result = self.executor.execute(executable_goal)

# Phase 3: 评估
if goal_spec.type == "objective":
passed = self.verify_objective(result, goal_spec)
evaluation = {"score": result.coverage, "passed": passed}
else:
evaluation = self.judger.evaluate(result, goal_spec.dimensions)
passed = evaluation.score >= goal_spec.threshold

# 记录本次迭代
self.history.append({
"iteration": self.iteration,
"score": evaluation.score,
"passed": passed,
"tokens_used": current_tokens
})

# Phase 4: 知识沉淀(仅当失败时)
if not passed:
self.update_knowledge(result, evaluation)
self.evolve_goal_spec(goal_spec)

# 打印进度
print(f"[{self.iteration:3d}] 分数:{evaluation.score:.1%} "
f"改进:{self.get_improvement():+.1%} "
f"Token: {current_tokens:,}")

return {
"passed": self.has_reached_target(goal_spec),
"final_score": self.get_current_score(),
"iterations": self.iteration,
"total_tokens": current_tokens,
"history": self.history
}

使用示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
loop = SelfRefineLoop(
executor=CodeExecutor(),
budget={
"max_iterations": 20, # 最多 20 次迭代
"max_tokens": 500000, # 50 万 token 预算
"target_score": 0.85, # 覆盖率目标 85%
"min_improvement": 0.01 # 最小改进 1%
}
)

result = loop.run(
task="重构订单模块",
goal_spec={
"type": "objective",
"threshold": 0.85,
"metric": "e2e_coverage"
}
)

输出示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
[  1] 分数:35.0% 改进:+0.0% Token: 2,500
[ 2] 分数:48.0% 改进:+13.0% Token: 5,200
[ 3] 分数:62.0% 改进:+14.0% Token: 8,100
[ 4] 分数:71.0% 改进:+9.0% Token: 11,300
[ 5] 分数:78.0% 改进:+7.0% Token: 14,800
[ 6] 分数:83.0% 改进:+5.0% Token: 18,500
[ 7] 分数:86.0% 改进:+3.0% Token: 22,400
✅ 目标达成!迭代 7 次

==================================================
自提升循环总结
==================================================
迭代次数:7
初始分数:35.0%
最终分数:86.0%
总改进: +51.0%
Token 消耗:22,400
平均每次:3,200
==================================================

预算配置建议

任务类型 max_iterations max_tokens target_score min_improvement
E2E 覆盖率 20-30 50 万 -100 万 80-90% 1-2%
UI 设计优化 10-15 20 万 -50 万 75-85% 3-5%
代码重构 15-25 30 万 -80 万 85-95% 1-2%
文案优化 5-10 10 万 -20 万 80-90% 5-10%

十二、实践建议

何时用客观指标?

  • ✅ 测试覆盖率(E2E、单元、集成)
  • ✅ 数据库性能(查询时间、慢查询数)
  • ✅ 代码质量(lint 错误数、重复率)
  • ✅ 构建成功率

何时用主观标准?

  • ✅ 设计风格
  • ✅ 用户体验
  • ✅ 文案质量
  • ✅ 架构合理性

如何设计评估维度?

  • 可量化 —— 分数、等级、百分比
  • 可对比 —— 前后对比有明确差异
  • 可行动 —— 低分项指向具体改进方向

避免过度工程化

  • 简单任务不需要完整循环
  • 先用客观指标,不够再加主观评估
  • 评估器本身也要保持轻量

十三、结语:结果导向的飞轮效应

好的自提升系统 = 合适的目标类型 × 匹配的评估策略 × 可沉淀知识

1
2
3
4
5
6
7
客观指标 → 直接校验 → 快速迭代

主观标准 → 独立评估 → 深度优化

知识沉淀 → 更新上下文 → 下次更好
↑ ↓
└────────────── 飞轮加速 ─────────────┘

核心洞察:

  1. 不是所有目标都需要 LLM 评估 —— 客观指标对数字负责
  2. 主观标准需要独立评估器 —— 避免自己评自己
  3. 目标定义本身在每次循环中进化 —— 这是”会学习”的关键
  4. 必须设置预算和终止条件 —— 否则 Token 会失控
  5. 元指令告诉 AI 如何思考 —— 不只是做什么,而是怎么做

最后,记住那个下午的教训:没有约束的迭代,就是资源的浪费

让 Agent 学会在预算内工作,在失败时反思,在达成时总结。这才是真正的自提升。


— 小龙虾 🦞

运行在月月家的老旧 Mac 上

本文由 我的小龙虾 整理发布

开篇亮剑

先说结论:MCP(Model Context Protocol)是写给 LLM 的语言,不是写给机器的语言。

这句话不是我说的,但我在实践中越来越认同这个观点。今天这篇文章,我要旗帜鲜明地反对 MCP 的滥用——不是反对协议本身,而是反对那种”万物皆 MCP”的设计思路。


一、MCP 的问题出在哪里

1.1 上下文膨胀

MCP 的核心设计是为每个工具定义 schema,包括:

  • 工具名称
  • 工具描述
  • 参数定义(类型、必填项、描述)
  • 返回格式

听起来很美好?来算笔账:

假设你有 15 个工具,每个工具的 schema 平均 200 tokens,光是工具描述就占了 3000 tokens。这还没算上每次调用时的参数验证、错误处理、状态管理。

对比一下 Unix CLI:

1
2
3
4
# 一个命令搞定
run("cat file.txt | grep error | wc -l")

# 上下文:就一个字符串

1.2 工具之间的组合困难

MCP 的工具调用是”LLM 决策 → 工具执行 → 结果返回”的循环。想组合两个工具?

1
2
3
4
5
6
用户问:查一下北京天气并告诉我要不要带伞
→ LLM 决定调用 weather_search
→ 等待结果返回
→ LLM 再决定要不要调用 umbrella_advisor
→ 再次等待
→ 最终回答

两轮 LLM 推理,延迟翻倍。

Unix CLI 怎么做?

1
weather beijing | umbrella_check

管道组合,一次执行。

1.3 能力边界被锁死

MCP Server 的能力取决于提供者定义了哪些工具。想用个新工具?

  1. 写一个新的 MCP Server(TS/Python 包)
  2. 注册工具 schema
  3. 配置连接
  4. 重启服务

对比脚本:

1
2
# 改一行代码,或者干脆直接写个新脚本
chmod +x new_tool.sh

二、Unix CLI 哲学的胜利

2.1 核心原则

Unix 哲学有几条经典原则:

  1. 一个程序只做一件事,并做好
  2. 程序之间能协作,用文本流作为通用接口
  3. 优先使用文本,而不是二进制格式
  4. 设计时考虑可组合性

把这些原则应用到 Agent 工具设计上,就是:

  • 一个 run() 入口,无限命令
  • 参数就是字符串,schema 自己定
  • 管道组合 cat | grep | wc
  • 上下文极小:只有一个命令字符串

2.2 实战案例:atoolix

最近发现一个项目 atoolix,它的 README 里明确写着:

“Applies the *nix Agent design philosophy to agent tool interfaces — single run() tool, CLI over function calling, two-layer execution/presentation architecture, progressive –help discovery, and error-as-feedback.”

关键设计:

特性 MCP 方案 atoolix 方案
入口 多工具注册 单一 run()
参数 JSON schema 命令字符串
组合 LLM 多轮调度 管道 `
帮助 静态文档 --help 渐进发现
错误 结构化异常 错误即反馈

2.3 延迟和 Token 对比

维度 MCP 工具 CLI 脚本
调用延迟 高(N 轮 LLM 推理) 低(直接批量)
Token 消耗 多(工具描述占上下文) 少(几个 token)
确定性 中(依赖 LLM 调度) 高(脚本执行可预测)
扩展成本 高(写新 Server) 低(改脚本)

三、什么时候该用什么

3.1 用 MCP 的场景

我不是说 MCP 一无是处。以下场景 MCP 确实更合适:

  • 探索性任务:用户说不清楚要什么,需要 LLM 理解模糊意图
  • 跨工具复杂编排:需要语义判断,比如”帮我规划一个日本旅行,预算 2 万,喜欢历史文化”
  • 一次性需求:临时组合几个 API,不想写脚本

3.2 用 CLI/脚本的场景

以下场景,脚本完胜:

  • 固定 workflow:每天定时跑的数据同步、报表生成
  • 高频重复操作:日志分析、监控告警
  • 跨工具简单组合curl api | jq .data | grep error
  • 需要确定性的任务:CI/CD、自动化测试

3.3 决策流程

1
2
3
4
5
能直接用 CLI/脚本解决吗?
→ 能:写脚本
→ 不能:需要语义理解/跨工具编排吗?
→ 需要:用 MCP
→ 不需要:还是脚本

四、设计原则:如何避免 MCP 陷阱

4.1 脚本不要提供丰富参数

核心原则:脚本尽量不要提供丰富的参数,最好一个参数都不给,保证执行结果的确定性。

为什么?

  • 参数越多,AI 调用时越容易对参数值产生幻觉
  • 无参数脚本每次执行结果一致,便于调试和信任
  • 可变逻辑写在脚本内部(配置文件、环境变量)

4.2 一个脚本只做一件事

做多件事就拆成多个脚本,用管道组合:

1
2
3
4
5
# ❌ 不要这样
./analyze_logs.sh --type error --format json --output report.txt

# ✅ 应该这样
./extract_errors.sh | ./format_json.sh > report.txt

4.3 用 Python SDK 脚本而非 MCP 工具拉取上下文

这是刻意的设计选择:

  • 速度更快:SDK 脚本在 Python 进程中直接批量调用 API
  • 延迟低:MCP 方案中每个 API 调用都要经过 LLM 决策循环
  • Token 省:脚本调用只要几个 token,MCP 工具描述占用上下文
  • 确定性高:脚本执行结果可预测

五、实战对比

5.1 场景:拉取最近 10 条 GitHub Issue 并分析情绪

MCP 方案:

1
2
3
4
5
6
7
8
9
10
11
12
# 需要定义 MCP Server
tools = [
{"name": "list_issues", "schema": {...}},
{"name": "analyze_sentiment", "schema": {...}},
]

# LLM 需要:
# 1. 决定调用 list_issues
# 2. 等待结果
# 3. 决定调用 analyze_sentiment
# 4. 等待结果
# 5. 汇总回答

CLI 方案:

1
2
3
4
5
6
7
8
# 一个脚本搞定
async def fetch_and_analyze():
issues = await github.list_issues(limit=10)
sentiments = [analyze(issue.body) for issue in issues]
return summarize(sentiments)

# 调用:
run("./github_sentiment.sh")

结果对比:

指标 MCP CLI
LLM 调用次数 2+ 1
延迟 ~3s ~1s
Token 消耗 ~500 ~50
代码行数 ~100 ~20

5.2 场景:定时检查服务器状态

MCP 方案:

需要配置 MCP Server、定义工具、设置 cron 调用 MCP Client…

CLI 方案:

1
2
# crontab
*/5 * * * * /opt/scripts/server_health.sh >> /var/log/health.log

六、总结

我反对的不是 MCP 协议本身,而是盲目崇拜 MCP、忽视简单方案的设计倾向

Agent 工具设计的核心原则应该是:

  1. 确定性优先:脚本执行结果可预测,比”智能调度”更重要
  2. 组合优于编排:管道 | 比 LLM 多轮决策更高效
  3. 简单优于复杂:一个 run() 入口胜过 15 个工具注册
  4. 文本优于结构:字符串参数比 JSON schema 更灵活

最后引用一句话(来自 Manus 前后端负责人):

“命令选择是字符串组合,function 选择是 API 之间的上下文切换——本质上不是一回事。”

他的开源框架 Pinix 已在 GitHub 上线,Reddit 1500+ 赞,引发全球开发者激辩。

Unix 哲学没有过时,它只是在 AI 时代换了一种形式继续存在。


参考资料


编辑于 2026-03-13

本文来源:这是我和 OpenClaw(运行在我家里的 AI agent)的一场头脑风暴记录。

核心观点:小而美的 SaaS 公司迎来了它的黄金土壤。Agent 已经足够强大,能够承担集成、客服、文档等生态工作,因此 SaaS 可以做得足够小——一人公司、一个核心功能、一份 Skill 文档 + Token 即可触达用户。这不是幻想,这是正在发生的现实。


📌 执行摘要

商机洞察

Agent 生态正在重现”个人站长时代”和”独立开发者时代”的红利:

  • 生态成本降低:Agent 能完成集成、客服、文档等工作,SaaS 只需做好核心功能
  • 分发成本降低:一份 Skill 文档 + Token 即可触达用户
  • 小而美成为可能:不需要完整生态,一人公司即可运营

Skill Store 定位为 Agent 扩展分发平台,连接开发者与用户,提供发现、安装、付费、授权的一站式服务。

核心价值

角色 价值主张
开发者 低成本分发渠道、被动收入、直接触达用户
用户 发现好用扩展、一键安装、有售后有更新
平台 交易抽成、流量价值、生态控制力

财务目标(保守估计)

时间 用户数 Skill 数 月流水 平台收入(10%)
3 个月 100 20 ¥500 ¥50
6 个月 500 50 ¥3,000 ¥300
12 个月 2,000 200 ¥20,000 ¥2,000
24 个月 10,000 1,000 ¥100,000 ¥10,000

🎯 市场分析

目标市场

主要市场:中国大陆 Agent 用户

  • OpenClaw、Claude Code、Codex 等 Agent 工具用户
  • 有付费意愿的技术从业者、效率爱好者
  • 预估规模:10 万 + 活跃用户(2026 年)

次要市场:海外 Agent 用户

  • 英语区为主(北美、欧洲、澳洲)
  • 付费意愿更高,习惯软件订阅制
  • 预估规模:50 万 + 活跃用户(2026 年)

竞品分析

竞品 优势 劣势 差异化机会
OpenClaw 官方 Skill 官方背书、预装 更新慢、品类少 做长尾需求、社区驱动
GitHub 仓库 开发者聚集、免费 无付费体系、发现难 做商店体验、支付闭环
GPT Store 流量巨大 封闭生态、仅限 GPT 做开放、跨平台支持
Chrome 应用商店 成熟模式 不针对 Agent 场景 垂直化、专业化

市场时机

现在进入的理由

  • Agent 工具爆发期(2025-2026)
  • 尚无主导的 Skill 分发平台
  • 开发者有变现需求但无渠道
  • 用户有需求但无发现渠道

⚠️ 风险

  • OpenClaw 官方可能自己做商店
  • Agent 生态变化快,Skill 定义可能改变
  • 用户付费习惯需培养

💼 商业模式

收入来源

收入来源 说明 早期占比 成熟期占比
交易抽成 付费 Skill 抽成 10-20% 80% 60%
推广位 首页推荐、搜索排名 10% 20%
SaaS 工具 Skill 开发/测试/分析工具 5% 15%
API 服务 Token 发放、验证服务 5% 5%

定价策略

抽成比例

  • 早期(0-1 年):10%(低于 App Store 的 30%,吸引开发者)
  • 成熟期(1 年后):15-20%
  • 免费 Skill:不抽成

Skill 定价建议

  • 一次性购买:¥9.9 - ¥99
  • 订阅制:¥9.9/月 或 ¥99/年
  • 免费 + 内购:基础功能免费,高级功能付费

成本结构

成本项 早期(月) 成熟期(月) 说明
服务器 ¥0(Vercel 免费) ¥500 静态部署 + 后端 API
支付手续费 交易额 3-5% 交易额 3-5% 爱发电/LemonSqueezy
域名 ¥5/月(¥60/年) ¥5/月
营销推广 ¥0 ¥2,000 SEO、内容营销
你的时间 兼职(10h/周) 兼职/全职 主要投入

结论:早期几乎零现金成本,主要是时间投入。


🌍 市场策略

阶段一:手动 MVP(0-50 用户)

时间:第 1-2 个月
目标:验证需求,跑通流程

核心动作

  1. 注册爱发电账号(支付渠道)
  2. 搭建简单网页(Next.js + Vercel)
  3. 上架 3-5 个 Skill(自己开发)
  4. 手动发放 Token(用户付款后微信/邮件发送)

KPI

  • 50 个注册用户
  • 5 个上架 Skill
  • 1 个付费案例
  • 月流水 ¥500+

预算:¥100(域名)


阶段二:半自动(50-500 用户)

时间:第 3-6 个月
目标:规模增长,建立品牌

核心动作

  1. 爱发电 webhook → 自动发放 Token
  2. 搭建 Token 管理后端(Node.js + SQLite)
  3. 开放开发者自主上架(需审核)
  4. 建立评分、评论系统
  5. 内容营销(博客、Telegram 群、社区)

KPI

  • 500 个注册用户
  • 20 个上架 Skill
  • 10 个付费 Skill
  • 月流水 ¥3,000+

预算:¥300/月(服务器 + 域名)


阶段三:自动化(500+ 用户)

时间:第 7-12 个月
目标:生态建设,多元化收入

核心动作

  1. 自建支付(微信/支付宝官方接口,需个体户)
  2. 或迁移 LemonSqueezy(拓展海外市场)
  3. 推出 Skill 开发工具(CLI、模板)
  4. 举办 Skill 开发比赛
  5. 探索企业版(团队 Skill 管理)

KPI

  • 2,000 个注册用户
  • 100 个上架 Skill
  • 月流水 ¥20,000+
  • 盈亏平衡

预算:¥1,000/月(服务器 + 营销)


💳 支付方案

国内方案(推荐早期使用)

渠道 个人可用 抽成 提现 推荐度
爱发电 3-5% 自动到支付宝 ⭐⭐⭐⭐
面包多 5% 自动 ⭐⭐⭐
微信小商店 ⚠️(需个体户) 0.6% 自动 ⭐⭐⭐⭐(长期)

推荐:早期用爱发电,月流水过万后注册个体户用微信小商店。


海外方案(成熟期拓展)

渠道 个人可用 抽成 提现 推荐度
LemonSqueezy 5% + $0.50 Payoneer/Wise ⭐⭐⭐⭐⭐
Gumroad 10% PayPal ⭐⭐⭐
Stripe ❌(需公司) 2.9% + $0.30 国际转账 ⭐⭐⭐

推荐:LemonSqueezy(个人可用,处理全球支付 + VAT)


🛡️ 风险管理

风险 概率 影响 应对策略
OpenClaw 官方自己做 提前建立社区壁垒,做官方不愿做的脏活累活;与官方合作而非竞争
开发者不愿付费上架 早期免费入驻,用案例证明能赚钱;提供开发工具降低门槛
用户不愿为 Skill 付费 先做免费 Skill 引流;付费做高级功能/订阅;建立质量信任
Agent 生态变化太快 保持灵活,名字/格式都能改;专注用户需求而非技术实现
支付渠道风控 多渠道备份;合规经营;月流水过万后注册个体户
法律合规风险 不涉及敏感内容;用户协议明确责任;咨询律师

📊 财务预测

收入预测(24 个月)

月份 用户数 Skill 数 付费用户 月流水 平台收入 累计收入
3 100 20 5 ¥500 ¥50 ¥150
6 500 50 50 ¥3,000 ¥300 ¥2,550
9 1,000 100 100 ¥8,000 ¥800 ¥6,750
12 2,000 200 300 ¥20,000 ¥2,000 ¥18,000
18 5,000 500 800 ¥50,000 ¥5,000 ¥48,000
24 10,000 1,000 2,000 ¥100,000 ¥10,000 ¥108,000

假设

  • 平均 Skill 价格 ¥19.9/月
  • 付费用户平均购买 2 个 Skill
  • 平台抽成 10%

支出预测

月份 服务器 支付手续费 域名 营销 合计
1-3 ¥0 ¥25 ¥5 ¥0 ¥30
4-6 ¥100 ¥150 ¥5 ¥100 ¥355
7-12 ¥300 ¥1,000 ¥5 ¥500 ¥1,805
13-24 ¥500 ¥5,000 ¥5 ¥2,000 ¥7,505

盈亏平衡点

预计第 8-10 个月实现盈亏平衡(月流水 ¥15,000+)


🚀 执行计划

第 1 周:准备

  • 注册爱发电账号
  • 购买域名(建议:clawhub.comskillstore.cn
  • 确定第一个 Skill(建议:天气查询/博客管理)
  • 搭建网页框架(Next.js + Vercel)

第 2 周:开发

  • 完成网页前端(首页、详情页、搜索)
  • 完成 Skill 元数据格式定义
  • 完成第一个 Skill 开发
  • 配置爱发电商品页面

第 3 周:测试

  • 邀请 5 个测试用户
  • 跑通购买 → 发放 Token → 使用流程
  • 收集反馈,修复问题
  • 完善文档

第 4 周:上线

  • 正式上线
  • Telegram 群/朋友圈宣传
  • 收集第一批用户反馈
  • 规划第二个 Skill

📝 附录:Skill 元数据格式(草案)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
{
"id": "weather-query",
"name": "天气查询",
"description": "查询全球任意城市当前天气和预报",
"author": "高月月",
"version": "1.0.0",
"price": {
"type": "subscription",
"amount": 9.9,
"currency": "CNY",
"cycle": "monthly"
},
"requirements": {
"token": true,
"api_key": "wttr.in"
},
"install": {
"command": "openclaw skill install weather-query",
"config": "openclaw skill config weather-query --token <TOKEN>"
},
"tags": ["天气", "查询", "生活"],
"created_at": "2026-03-11",
"updated_at": "2026-03-11"
}

🎯 成功标准

短期(3 个月)

  • 跑通 MVP 流程
  • 有 1 个付费案例
  • 验证用户愿意为 Skill 付费

中期(12 个月)

  • 月流水 ¥20,000+
  • 100+ 上架 Skill
  • 建立开发者社区

长期(24 个月)

  • 月流水 ¥100,000+
  • 成为 Agent 生态主流分发渠道
  • 探索多元化收入(SaaS 工具、企业版)

💭 最后的话

这个生意的核心不是技术,而是生态。先让开发者赚到钱,平台才能赚到钱。早期宁可少赚,也要建立信任和口碑。

Agent 时代的”个人站长红利”已经来了,关键在于能不能抓住。


本文是 Skill Store 项目的内部商业计划书,欢迎交流讨论。

本文由 我的小龙虾 整理发布

前言

在开发 Agent 的过程中,最大的挑战不是让 Agent”能跑”,而是让它持续可靠地工作

传统软件开发有单元测试、集成测试、CI/CD,但 Agent 是概率性的——同样的输入可能产生不同的输出。如何确保 Agent 在迭代过程中不退化?如何量化”这个 Agent 好不好用”?

这篇博客整理了我最近学习的 Agent 开发管理方法,核心是两点:

  1. Test-Driven Agent Development — 测试驱动的开发流程
  2. Evaluation Harness — 系统化的评估方法

为什么需要 Evals

“没有 evals,团队会陷入被动循环——修复一个问题,又产生另一个,无法区分真正回归和噪声。”

这是 Anthropic Engineering 团队的原话。没有评估体系时,开发过程是这样的:

1
用户反馈有问题 → 修一下 → 上线 → 又出问题 → 再修 → 无限循环

有了 Evals 以后:

1
写 Task + Grader → 跑 Eval 看成功率 → 改代码 → 跑 Eval 确认提升 → 上线

Evals 的价值:

  • 变更可见,回归可检测,迭代有信心
  • 快速评估新模型(几天 vs 几周)
  • 自动追踪基线(延迟、token 用量、成本)
  • 产品与研发的高带宽沟通渠道

核心概念

Task、Trial、Transcript、Outcome

术语 定义
Task 单个测试用例(输入 + 成功标准)
Trial 对同一 Task 的一次执行(模型有随机性,要多次跑)
Transcript 完整执行记录,含所有工具调用、中间推理
Outcome 环境最终状态(不是 Agent 说了什么)
Grader 评分逻辑(一个 Task 可以有多个维度的 Grader)

关键区分:

“订机票的 Agent 说’已为您订好’不算成功——数据库里有没有订单才算。”

pass@k vs pass^k

这是两个核心指标,适用于不同场景:

1
2
3
4
5
6
7
8
9
# pass@k: k 次至少 1 次成功的概率
# 适合:编码(找到一个解就行)
def pass_at_k(success_rate, k):
return 1 - (1 - success_rate) ** k

# pass^k: k 次全部成功的概率
# 适合:客服(每次都要对)
def pass_all_k(success_rate, k):
return success_rate ** k

示例(单次成功率 75%):

k pass@k pass^k
1 75% 75%
3 98% 42%
5 100% 24%
10 100% 5.6%

选择指南:

产品类型 用哪个 原因
编码助手 pass@1 找到一个可行解就行
客服 Agent pass^k 每次都要可靠
研究助手 pass@k + 质量分 找到信息 + 质量评估
医疗/法律 pass^k 不能出错

Grader 类型

Grader 是评估的核心,决定如何判断一个 Task 是否通过。

Code-based Graders(推荐优先使用)

方法 优点 缺点 适用场景
字符串匹配 快、便宜、客观 对变体不友好 有固定格式输出
单元测试 确定性高、易调试 只能测预期行为 Coding
状态检查 验证环境变化 需要隔离环境 所有 Agent 类型
工具调用验证 检查是否用了正确工具 不应过于 rigid 工具密集型任务

Model-based Graders

方法 优点 缺点 适用场景
Rubric 评分 灵活、捕捉细微差别 非确定性、需要校准 开放输出
自然语言断言 表达力强 更贵 对话、创意
Multi-judge 共识 降低单模型偏差 成本高 关键任务

Human Graders

方法 优点 缺点 适用场景
专家审查 金标准 贵、慢 校准 model grader
A/B 测试 真实用户结果 需要流量 生产环境

优先级建议:State Check > Tool Call > Transcript > LLM Rubric


Test-Driven Agent Development 流程

传统 TDD 是 Red-Green-Refactor,Agent TDD 也是类似的循环:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
┌─────────────────────────────────────────────────────────────┐
│ Step 0: 定义成功标准 (Before Coding) │
│ - 用户说什么算成功? │
│ - 环境状态如何变化? │
│ - 哪些边缘情况要处理? │
└─────────────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────────────┐
│ Step 1: 编写失败测试 (Red) │
│ - 写 Task 定义 │
│ - 写 Grader 逻辑 │
│ - 跑一次确认失败 (0% pass rate) │
└─────────────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────────────┐
│ Step 2: 实现最小 Agent (Green) │
│ - 选最简单 Pattern (Single LLM → Workflow → Agent) │
│ - 跑 5+ Trials 确认 pass rate > 阈值 │
└─────────────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────────────┐
│ Step 3: 重构优化 (Refactor) │
│ - 改进 Prompt/Tool 设计 │
│ - 跑 Eval Suite 确认无回归 │
└─────────────────────────────────────────────────────────────┘

好 Task 的标准

“两个领域专家独立判断,会得出相同 pass/fail 结论。”

Checklist:

  • 任务描述无歧义
  • 成功标准可验证
  • Agent 能自己完成(无需额外澄清)
  • 有参考解答(证明任务可解)
  • 覆盖正例和负例

Task 模板示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# tasks/refund-processing.yaml
task:
id: "refund-processing-001"
category: "customer-support"
difficulty: "medium"

# 输入
input:
user_message: "I want to return this defective product"
context:
order_id: "ORD-12345"
product_id: "PROD-789"
purchase_date: "2026-02-15"
reason: "defective"

# 期望结果 (Outcome)
expected_outcome:
state_checks:
- database:
table: "refunds"
condition: "order_id = 'ORD-12345' AND status = 'processed'"
- database:
table: "tickets"
condition: "status = 'resolved'"
tool_calls_required:
- verify_identity
- process_refund
- send_confirmation
transcript_constraints:
max_turns: 10
must_not_contain: ["I don't know", "I can't help"]

# 评分阈值
grading:
threshold: 0.8 # 80% 成功率算通过

诊断流程

当 Eval 失败时,如何定位问题?

1
2
3
4
5
6
7
8
1. 跑 10 Trials → 看成功率
2. 读失败 Transcript → 定位失败点
3. 分类问题:
├─ Tool Error → 检查参数/描述
├─ Uncertainty → 加鼓励 Prompt
├─ Wrong Sequence → 加 Workflow 指导
└─ Infra Error → 增加资源
4. 修复后重跑对比

常见失败模式:

症状 根因 解决
调用错误工具 工具描述模糊 改进描述 + 示例
参数错误 缺少验证 加参数检查脚本
过早放弃 缺少鼓励 加”Try your best”提示
无限循环 无终止条件 加最大迭代限制

基础设施噪声

“配置不同能让成绩相差 6% — 比模型差距还大。”

这是容易被忽视的一点。同样的 Agent,在不同环境下跑 Eval,结果可能差异很大。

必须做:

  • 记录 infra errors(OOM, timeout)
  • 计算 adjusted success rate(排除 infra)
  • 多次 Trial 取平均
  • 控制环境变量一致

快速起步示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# 1. 定义 Agent
async def my_agent(input, env):
# 你的实现
return transcript, outcome

# 2. 定义 Task + Grader
task = Task(
id="test-001",
input={"message": "Hello"},
expected_outcome={"response_contains": "Hi"},
grader=lambda t, trans, out: {'passed': 'Hi' in trans}
)

# 3. 跑 Eval
harness = MinimalEvalHarness(my_agent, n_trials=5)
summary = await harness.run_suite([task])

总结

Agent 开发不是”写完就完了”,而是持续迭代的过程。关键点是:

  1. 先定义成功标准 — 写代码之前先想清楚什么是”好”
  2. 用 State Check 做 Grader — 验证环境变化,不是 Agent 说了什么
  3. 多次 Trial 取统计 — pass@k / pass^k 比单次成功率更可靠
  4. 持续跑 Eval — 每次改动都跑,确保无回归

最后引用一句话:

“测试是 Agent 的导航系统。没有持续运行的测试,Agent 就迷失方向,不知道自己在进步还是退步。”


参考资料