FlashMLA Sparse Decode 完整计算过程详解

FlashMLA Sparse Decode 完整计算过程详解

本文以一个最小化的具体数值例子,逐步展示 FlashMLA Sparse Decode 的每一个计算步骤,并标注每一步与 Hopper (SM90) / Blackwell (SM100/GB200) 硬件特性的关系。


0. 问题设定

参考配置

本文的模型参数来自 DeepSeek-V3.2 的官方 Hugging Face 配置:

以下是从 config.json 中提取的注意力相关参数,以及它们在 FlashMLA 内核中的对应关系:

config.json 参数 FlashMLA 内部参数 说明
kv_lora_rank 512 d_nope = 512 KV 的 LoRA 压缩维度,即 NoPE (Non-Positional Encoding) 部分的维度。在 MLA 中,KV cache 存储的是 LoRA 压缩后的向量,而非原始的 K/V
qk_rope_head_dim 64 d_rope = 64 RoPE (Rotary Position Embedding) 部分的维度
合计 512 + 64 = 576 d_qk = 576 Q/K 的总 head dimension = kv_lora_rank + qk_rope_head_dim
kv_lora_rank 512 d_v = 512 V 的 head dimension,等于 LoRA rank(在 MLA 中 V 和 K 的 NoPE 部分共享同一个压缩向量)
num_attention_heads 128 h_q = 128 Query head 数量
num_key_value_heads 128 config 中 KV head 数也是 128,但在 MLA 架构中,KV cache 只存储 1 份压缩向量(所有 head 共享),所以 FlashMLA 中 h_kv = 1 (MQA 模式)
qk_nope_head_dim 128 这是模型层面每个 head 的 NoPE 维度(128 × 128 heads = 16384 → 再经 LoRA 压缩为 512)。FlashMLA 操作的是压缩后的 512 维向量,不直接使用此参数
v_head_dim 128 同上,这是模型层面每个 head 的 V 维度。FlashMLA 操作的是压缩后的 512 维向量
quantization_config.fmt “e4m3” FP8_E4M3 KV cache 的 NoPE 部分使用 FP8 E4M3 格式量化
quantization_config.scale_fmt “ue8m0” UE8M0 (V3.2) / FP8_E8M0FNU (MODEL1) 量化缩放因子的格式,纯 2 的幂
index_topk 2048 topk 每个 query 关注的 top-k KV token 数量(稀疏注意力)
index_n_heads 64 参与稀疏索引选择的 head 数量
index_head_dim 128 用于索引选择的 head 维度

MLA (Multi-head Latent Attention) 的核心思想:不存储 128 个 head × 128 维的完整 KV cache (= 16384 维),而是存储一个 512 维的 LoRA 压缩向量 + 64 维的 RoPE 向量 (= 576 维)。这将 KV cache 压缩了 28 倍 (16384 → 576)。FlashMLA 的所有计算都在这个压缩空间中进行。

本文使用的示例参数

我们使用与 config.json 一致的模型配置,但将 batch size 和 topk 缩小以便手算:

模型配置 (来自 config.json):

1
2
3
4
5
6
d_qk = 576 = kv_lora_rank(512) + qk_rope_head_dim(64)
d_v = 512 = kv_lora_rank
d_nope = 512 = kv_lora_rank (NoPE 部分,做 FP8 量化)
d_rope = 64 = qk_rope_head_dim (RoPE 部分,保持 BF16 不量化)
h_q = 128 = num_attention_heads
h_kv = 1 (MLA 压缩后:所有 head 共享 1 份 KV cache)

示例 Batch 配置 (缩小规模以便展示):

1
2
3
4
b = 2 (batch size)
s_q = 1 (每个请求只 decode 1 个 token)
topk = 128 (原始 config 中 index_topk=2048, 此处缩小便于计算)
page_block_size = 64

硬件配置 (H800 / SM90):

1
num_sms = 132 (SM 数量)

注意: 实际部署中 topk=2048,h_q=128,这意味着每个 query token 需要从 KV cache 中 gather 2048 个 token 并做注意力计算。本文为了展示清晰,将 topk 缩小为 128。

为了简化展示,我们只跟踪 2 个 token(token 0 和 token 1)在一个 head 上的完整计算,其余的计算过程完全相同。


第 1 步:Python 层入口

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
sched_meta, _ = flash_mla.get_mla_metadata() # 创建空的调度元数据

out, lse = flash_mla.flash_mla_with_kvcache(
q,
# shape: [batch_size, seq_len_q, num_heads_q, head_dim]
# = [2, 1, 128, 576] dtype=BF16
# - batch_size=2: 同时处理 2 个请求
# - seq_len_q=1: decode 阶段每次只生成 1 个新 token
# - num_heads_q=128: 来自 config.json 的 num_attention_heads=128
# - head_dim=576: = kv_lora_rank(512) + qk_rope_head_dim(64)
# 即每个 head 的 Q 向量是 576 维 (512 维 NoPE + 64 维 RoPE)

k_cache,
# shape: [num_blocks, page_block_size, num_heads_k, bytes_per_token]
# = [num_blocks, 64, 1, 656] dtype=FP8
# - num_blocks: KV cache 的总页数 (由推理框架的 paged attention 管理)
# - page_block_size=64: 每个页面存放 64 个 token 的 KV 向量
# - num_heads_k=1: MLA 的核心 — 所有 128 个 head 共享 1 份压缩后的 KV cache
# - bytes_per_token=656: 每个 token 的 FP8 KV cache 大小:
# 512 字节 (NoPE, FP8_E4M3) + 16 字节 (4×float32 缩放因子) + 128 字节 (RoPE, BF16)

block_table=None, # 稀疏模式不需要 (dense 模式才需要页表做地址翻译)
cache_seqlens=None, # 稀疏模式不需要 (dense 模式才需要知道每个序列的实际长度)
head_dim_v=512, # V 的 head 维度 = kv_lora_rank = 512

tile_scheduler_metadata=sched_meta,
is_fp8_kvcache=True,

indices=indices,
# shape: [batch_size, seq_len_q, topk]
# = [2, 1, 128] dtype=int32
# - batch_size=2: 对应 2 个请求
# - seq_len_q=1: 每个新 token 各有自己的 top-k 索引列表
# - topk=128: 每个 query token 关注的 KV token 数量 (实际部署中为 index_topk=2048)
# 值的含义:indices[i][j][k] = global_token_index
# 例:indices[0][0][3] = 47 表示 batch 0 的 query 的第 4 个关注 token
# 位于第 0 个 page block (47÷64=0) 的第 47 个位置 (47%64=47)
)

# 返回值:
# out: [batch_size, seq_len_q, num_heads_q, head_dim_v] = [2, 1, 128, 512] BF16
# lse: [batch_size, num_heads_q, seq_len_q] = [2, 128, 1] FP32

关于 dense vs sparse 的选择: FlashMLA 本身不决定使用哪条路径。当调用方(vLLM/SGLang 等推理框架)传入 indices 参数时走 sparse 路径;不传 indices 而传入 block_table + cache_seqlens 时走 dense 路径。当上下文长度 ≤ index_topk(2048) 时,top-k 等于全部 token 数,稀疏无意义,框架层面应直接调用 dense 路径。

关键路由逻辑 (flash_mla_interface.py 第 151 行):

1
2
3
if topk is not None:
# topk 不为 None → 走稀疏路径
out, lse, ... = flash_mla_cuda.sparse_decode_fwd(...)

无硬件特性依赖,纯 Python 路由


第 2 步:C++ 接口层 — 输入验证与参数准备

文件: csrc/api/sparse_decode.h

2.1 架构检测与实现选择

1
2
3
4
5
6
7
8
9
10
11
Arch arch = Arch();
// arch.major=9, arch.minor=0 → SM90a (Hopper)
// 或 arch.major=10 → SM100f (Blackwell)

if (arch.is_sm100f()) {
if (h_q == 64) impl = new Decode_Sm100_Head64_Impl();
// SM100 使用 TMEM + UTCMMA
} else if (arch.is_sm90a()) {
impl = new Decode_Sm90_Impl();
// SM90 使用 TMA + WGMMA
}

🔧 硬件特性: cudaGetDeviceProperties 获取 compute capability。SM90 = Hopper, SM100 = Blackwell。

2.2 FP8 KV Cache 形状验证

1
2
3
4
5
// V3.2: 每个 token 占 656 字节
// 656 = 512 (NoPE FP8) + 16 (scales FP32) + 128 (RoPE BF16)
int bytes_per_token = 512 + 4*sizeof(float) + 64*sizeof(nv_bfloat16);
KU_CHECK_SHAPE(kv, num_blocks, page_block_size, h_kv, bytes_per_token);
// 即 kv shape = [num_blocks, 64, 1, 656]

656 字节的内存布局:

1
2
3
4
5
6
┌─────────────────────────────────────────────────────────┐
│ 偏移 0-511: 512 × FP8_E4M3 NoPE 部分 (量化) │
│ 偏移 512-527: 4 × float32 缩放因子 (每 128 个一组) │
│ 偏移 528-655: 64 × BF16 RoPE 部分 (不量化) │
│ 总计 656 字节 │
└─────────────────────────────────────────────────────────┘

🔧 硬件特性: FP8 (E4M3) 是 Hopper/Blackwell 原生支持的数据类型。Hopper 的 Tensor Core 可以直接做 FP8 矩阵乘法,但 FlashMLA 选择先反量化为 BF16 再做 WGMMA,以获得更高精度。

2.3 填充参数结构体

这一步是”翻译层”— 把 PyTorch tensor 的元信息(shape、stride、data_ptr)和标量参数,打包成 CUDA kernel 能直接使用的 C struct。kernel 内部只看这个 struct,不再接触 PyTorch 对象。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
SparseAttnDecodeParams params = {
// ===== 基础维度 (直接从输入 tensor 的 shape 读取) =====
.b = 2, // q.size(0) = batch_size,同时处理 2 个请求
.s_q = 1, // q.size(1) = seq_len_q,decode 阶段每次只生成 1 个新 token
.h_q = 128, // q.size(2) = num_attention_heads,来自 config.json
.h_kv = 1, // kv.size(2),MLA 压缩后 128 个 head 共享 1 份 KV cache
// 注意:计算时会被 broadcast 到 h_q=128
.d_qk = 576, // q.size(3) = head_dim = kv_lora_rank(512) + qk_rope_head_dim(64)
.d_v = 512, // head_dim_v 参数 = kv_lora_rank

// ===== Softmax 缩放因子 =====
// 标准 Scaled Dot-Product Attention: softmax(Q·K^T / sqrt(d)) · V
.sm_scale = 1.0 / sqrt(576) = 0.04167,
// 即公式中的 1/sqrt(d),防止点积值过大导致 softmax 饱和

.sm_scale_div_log2 = 0.04167 * log2(e) = 0.04167 * 1.4427 = 0.06010,
// 性能优化的预计算。kernel 中用 exp2f() 代替 expf() 计算 softmax:
// expf(x * sm_scale) = exp2f(x * sm_scale * log2(e))
// = exp2f(x * sm_scale_div_log2)
// GPU 上 exp2f() 比 expf() 快约 2 倍 (SFU 原生支持 base-2 指数运算),
// 所以预乘好 log2(e),避免 kernel 里每个元素都重复乘一次

// ===== 稀疏注意力参数 =====
.topk = 128, // indices.size(2),每个 query 关注的 KV token 数量
.model_type = ModelType::V32,
// 由 d_qk 决定:576 → V32 (DeepSeek-V3/V3.2), 512 → MODEL1
// 不同 model_type 的 FP8 KV cache 布局和缩放因子格式不同

// ===== 指针 (从 tensor.data_ptr() 获取) =====
.q = (bf16*)q.data_ptr(), // Q 矩阵的 GPU 内存地址
.kv = (bf16*)kv.data_ptr(), // FP8 KV cache 的 GPU 内存地址
.indices = (int*)indices.data_ptr(), // top-k 索引数组的 GPU 内存地址
.out = (bf16*)out.data_ptr(), // 输出矩阵的 GPU 内存地址
.lse = (float*)lse.data_ptr(), // log-sum-exp 的 GPU 内存地址

// ===== Stride (从 tensor.stride() 获取,单位是元素数) =====
// Stride 描述 tensor 在内存中的布局,让 kernel 知道如何从一个元素跳到下一个元素
.stride_q_b = q.stride(0), // 跳到下一个 batch: 1×128×576 = 73728
.stride_q_s_q = q.stride(1), // 跳到下一个 query token: 128×576 = 73728
.stride_q_h_q = q.stride(2), // 跳到下一个 head: 576
// ... kv, indices, out, lse 的 stride 类似
};

无硬件特性依赖,纯数据结构准备


第 3 步:Tile Scheduler — 工作分配

文件: csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu

这一步决定了 哪个 SM 处理哪些 batch 的哪些 KV block

3.1 计算每个 batch 的处理块数

⚠️ “block” 概念澄清: 这里的 “block” 不是 KV cache page block(64 个连续 token 组成的页面),而是对 top-k indices 列表的分块。

top-k 选择是 token 级别的 — indices 数组中存储的是 2048 个(本例中 128 个)散落在 KV cache 各处的 token 的绝对位置,这些 token 之间通常不连续。

由于 kernel 无法一次性处理所有 top-k token,所以按 TOPK_BLOCK_SIZE=64 将 indices 列表切块处理:

概念 Dense Decoding 的 “block” Sparse Decoding 的 “block”
含义 KV cache page block(64 个连续 token 组成的页面) indices 列表的处理块(64 个不连续 token index 为一组)
数据局部性 好(连续访问 global memory) 差(每个 token 可能在不同 page,需要随机 gather)
计算方式 按 block_table 顺序遍历所有页面 按 indices 逐 token gather,每个 index 指向 KV cache 中的任意位置
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
Launch: <<<1, 32>>> (单 CUDA block, 32 线程,即 1 个 warp)

indices[batch 0] = [token_47, token_2091, token_583, ..., token_8821] ← 共 128 个 token index (散落各处)

将 128 个 token index 按 TOPK_BLOCK_SIZE=64 分块处理:
处理块 0: indices[0:63] → 去 KV cache 中 gather 这 64 个 token (位于不同的 page)
处理块 1: indices[64:127] → gather 另外 64 个 token

对于 batch 0:
topk = 128, TOPK_BLOCK_SIZE = 64
num_blocks[0] = ceil(128/64) = 2 ← 需要 2 个处理块来遍历全部 128 个 top-k token

对于 batch 1:
topk = 128
num_blocks[1] = ceil(128/64) = 2

fixed_overhead_num_blocks = 5 (每个 batch 的固定 CUDA CTA 开销)

total_num_blocks = (2+5) + (2+5) = 14 ← 需要调度的 CUDA CTA 总数

🔧 硬件特性: 使用 __shfl_xor_sync 做 warp 内规约求和 — 这是 CUDA warp shuffle 指令,所有 NVIDIA GPU 支持。

注意: FP8 量化将每个 token 从 1152 字节 (576×BF16) 压缩到 656 字节,减少 ~43% 的内存带宽消耗。Decode 阶段是 memory-bound(瓶颈在显存带宽而非计算),因此减少每 token 的字节数可以直接提升吞吐量。

3.2 背景:SM (Streaming Multiprocessor) 是什么

SM 是 NVIDIA GPU 的基本计算单元。可以把 GPU 理解为一个拥有很多独立计算核心的处理器,每个 SM 就是其中一个核心。

  • 每个 SM 有自己的寄存器文件(65536 个 32-bit 寄存器)、共享内存(228KB,SM90)、Tensor Core(矩阵乘法加速器)、SFU(特殊函数单元,做 exp2f 等)
  • H800 有 132 个 SM,它们并行工作
  • 一个 CUDA kernel 启动时,会生成很多 **CTA (Cooperative Thread Array)**,也叫 thread block。GPU 的硬件调度器把这些 CTA 分配到各个 SM 上执行
  • 一个 SM 可以同时执行多个 CTA(如果资源允许),但 FlashMLA 的 kernel 每个 CTA 用了接近满载的寄存器和共享内存,所以通常一个 SM 只跑 1 个 CTA

Cluster 是 Hopper 新增的概念:多个 SM 组成一个 cluster,cluster 内的 SM 可以直接读写彼此的共享内存。FlashMLA 中 h_q=128 时 CLUSTER_SIZE=2,即 2 个 SM 组成一个 cluster,协作处理 128 个 head。

3.3 计算 SM Partition 数

1
2
3
4
5
6
7
8
9
10
11
12
h_q = 128 个 head
BLOCK_M = 64 个 head / CTA
→ 处理全部 head 需要 128/64 = 2 个 CTA (组成 1 个 cluster)

s_q = 1 (只有 1 个 query token)

num_sm_parts = max(132 / s_q / (h_q/64), 1)
= max(132 / 1 / 2, 1)
= 66

含义:132 个 SM,每 2 个 SM 组成一个 cluster 处理一个"任务单元",
所以有 66 个可用的"工位" (SM partition)

3.4 贪心调度算法

调度器的目标是:把所有 batch 的工作均匀分配到 66 个 partition 上

1
2
3
4
5
6
7
8
9
第一步:计算每个 partition 的负载上限 (payload)

payload = ceil(total_num_blocks / num_sm_parts) + fixed_overhead_num_blocks
= ceil(14 / 66) + 5
= 1 + 5
= 6

每个 partition 最多处理 6 个"单位"的工作。
(+5 是因为每接手一个新 batch 有固定成本:加载 Q、初始化 softmax 状态、写回结果等)

然后,单线程从左到右逐个 partition 分配工作 (代码第 66-91 行):

状态变量:

1
2
3
now_req_idx = 0    // 当前正在分配第几个 batch
now_block = 0 // 当前 batch 已经分配了多少 block
remain_payload = 6 // 当前 partition 剩余容量
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
═══════════ Partition 0, remain_payload=6 ═══════════

看 batch 0: 共 2 个 block
能否把整个 batch 0 放进来?需要 2(block) + 5(开销) = 7 ... 7 > 6, 放不下!

那放部分:remain_payload - fixed_overhead = 6 - 5 = 1 个 block
→ 只分配 batch 0 的 block 0
→ now_block = 1 (batch 0 还剩 1 个 block)
→ 容量用完,跳出

结果:
begin_req=0, end_req=0, begin_block=0, end_block=1
is_first_req_splitted=false ← batch 0 从这里开始
is_last_req_splitted=true ← batch 0 没处理完 (被 split 了!)

═══════════ Partition 1, remain_payload=6 ═══════════

继续 batch 0: 还剩 1 个 block
能否整个放进来?需要 1 + 5 = 6 ... 6 <= 6, 可以!
→ remain_payload = 6 - 6 = 0
→ batch 0 全部完成,now_req_idx=1, now_block=0
→ 容量用完,跳出

结果:
begin_req=0, end_req=0, begin_block=1, end_block=2
is_first_req_splitted=true ← batch 0 从中间开始 (上个 partition 没处理完)
is_last_req_splitted=false ← batch 0 在这里结束

═══════════ Partition 2, remain_payload=6 ═══════════

看 batch 1: 共 2 个 block, 同理 7 > 6 放不下整个
→ 只放 1 个 block

结果:
begin_req=1, end_req=1, begin_block=0, end_block=1
is_first_req_splitted=false, is_last_req_splitted=true

═══════════ Partition 3, remain_payload=6 ═══════════

继续 batch 1: 剩 1 个 block, 1+5=6 <= 6, 放进来
→ batch 1 完成

结果:
begin_req=1, end_req=1, begin_block=1, end_block=2
is_first_req_splitted=true, is_last_req_splitted=false

═══════════ Partition 4~65 ═══════════
所有 batch 已分配完,begin_req >= batch_size → kernel 直接 return

3.5 Split 与 Combine

当一个 batch 被分到多个 partition 时,就产生了 split — 每个 partition 只计算部分 KV block 的注意力,得到局部结果 (partial O, partial LSE)。最后需要 combine kernel 把各个 split 的结果合并。

1
2
3
4
5
6
7
8
9
10
batch 0 被 split 成 2 份:
Partition 0 → block 0 的注意力结果 (partial_O_0, partial_LSE_0)
Partition 1 → block 1 的注意力结果 (partial_O_1, partial_LSE_1)
→ combine kernel 根据 LSE 加权合并:O_final = w0 × partial_O_0 + w1 × partial_O_1

batch 1 同理。

num_splits 前缀和:[0, 2, 4]
batch 0 的 splits 在 index [0, 2) → 2 个 splits
batch 1 的 splits 在 index [2, 4) → 2 个 splits

如果一个 batch 没有被 split (is_no_split=true),则 decode kernel 直接输出最终结果,不需要 combine kernel,省去一次 kernel launch。在实际部署中 (topk=2048, b=128),大部分 batch 都会被 split 到多个 partition 以充分利用 132 个 SM 的并行度。

🔧 硬件特性: 调度器本身是架构无关的 (smxx/ 目录),但 num_sm_parts 根据 GPU 的 SM 数量和 kernel 的 cluster size 计算,与具体硬件相关。


第 4 步:FP8 量化回顾 — KV Cache 是如何存储的

在 decode 之前,预填充阶段已经将 KV cache 量化为 FP8。我们用具体数字展示:

4.1 原始 BF16 KV 向量 (一个 token)

1
2
3
原始 K 向量 (d=576):
K_nope[0:511] = [0.25, -0.5, 0.125, ..., 0.75] (512 个 BF16 值)
K_rope[512:575] = [0.1, -0.2, 0.3, ..., 0.05] (64 个 BF16 值)

4.2 分组量化

NoPE 部分按 128 个一组量化(V3.2 有 4 组):

1
2
3
4
5
6
7
8
9
10
11
12
Group 0: K_nope[0:127]
max_abs = max(|K_nope[0]|, ..., |K_nope[127]|) = 0.75
scale_inv = 0.75 / 448.0 = 0.001674
scale_inv_ue8m0 = 2^(ceil(log2(0.001674))) = 2^(-9) = 0.001953
scale = 1 / 0.001953 = 512.0

量化:K_nope_fp8[i] = round_to_fp8(K_nope[i] / scale_inv_ue8m0)
例:K_nope_fp8[0] = round_to_fp8(0.25 / 0.001953) = round_to_fp8(128.0) = 128.0 (FP8)

Group 1: K_nope[128:255] → scale_1
Group 2: K_nope[256:383] → scale_2
Group 3: K_nope[384:511] → scale_3

4.3 存储布局 (656 字节)

偏移 内容 字节数
0 K_nope_fp8[0:511] 512 (512 × 1 字节 FP8_E4M3)
512 [scale_0, scale_1, scale_2, scale_3] 16 (4 × 4 字节 float32)
528 K_rope_bf16[0:63] 128 (64 × 2 字节 BF16)
合计 656 字节

🔧 硬件特性:

  • FP8 (E4M3): Hopper/Blackwell 原生数据类型,1 符号位 + 4 指数位 + 3 尾数位,范围 ±448,精度约 3 位有效数字
  • UE8M0 缩放因子: 纯 2 的幂,无尾数位,确保反量化时乘法精确无舍入误差
  • Blackwell (SM100) 额外支持: FP8_E8M0FNU 格式的缩放因子,MODEL1 使用此格式

第 5 步:CUDA Kernel 启动

5.1 SM90 (Hopper) 启动配置

文件: csrc/sm90/decode/sparse_fp8/splitkv_mla.cuh 底部

1
2
3
4
5
6
7
8
9
// SM90 启动配置
cutlass::ClusterLaunchParams launch_params = {
dim3(2, 1, 66), // grid: (NUM_M_BLOCKS=2, s_q=1, num_sm_parts=66)
dim3(384, 1, 1), // block: 384 线程 = 3 warpgroups
dim3(2, 1, 1), // cluster size = 2 (h_q=128 → CLUSTER_SIZE=NUM_M_BLOCKS=2)
sizeof(SharedMemoryPlan), // 动态共享内存
params.stream
};
cutlass::launch_kernel_on_cluster(launch_params, kernel, params, tma_params);

🔧 Hopper 硬件特性:

  • Cluster Launch: launch_kernel_on_cluster 使用 Hopper 的 Thread Block Cluster 功能。h_q=128 时 CLUSTER_SIZE=2,两个 CTA 组成一个 cluster,可以直接访问对方的共享内存 (通过 XOR 寻址)
  • 384 线程: 3 个 warpgroup(每个 128 线程 = 4 个 warp),这是 Hopper WGMMA 指令的基本执行单位
  • NUM_M_BLOCKS=2: h_q=128 / BLOCK_M=64 = 2,每个 CTA 处理 64 个 head,2 个 CTA 覆盖全部 128 个 head

5.2 SM100 (Blackwell / GB200) 启动配置

文件: csrc/sm100/decode/head64/kernel.cuh 底部

1
2
3
4
5
6
7
8
// SM100 启动配置 — 注意与 SM90 的关键区别
mla_kernel<<<
dim3(params.s_q, params.num_sm_parts, 1), // grid: (s_q=1, num_sm_parts=132, 1)
dim3(384, 1, 1), // block: 384 线程 = 3 warpgroups (相同)
smem_size, // 动态共享内存
params.stream
>>>(params, tma_params);
// 注意:没有使用 cluster launch, 也没有使用 PDL

关键区别:

SM90 (Hopper) SM100 (Blackwell)
grid 维度 (2, 1, 66)
x=NUM_M_BLOCKS, z=num_sm_parts
(1, 132, 1)
x=s_q, y=num_sm_parts
CTA 数量 2 × 1 × 66 = 132 1 × 132 × 1 = 132
cluster size 2 (两个 CTA 协作) 1 (单 CTA 独立工作)
每 CTA 处理 64 个 head 64 个 head (全部 h_q=64)
num_sm_parts 132/(1×2) = 66 132/1 = 132
1
2
// SM90 的 KU_ASSERT: h_q == 64 || h_q == 128 等
// SM100 head64 的 KU_ASSERT: h_q == 64 (即 B_H == 64)

SM100 为什么 h_q == 64 而不是 128?
SM100 head64 实现 (csrc/sm100/decode/head64/) 要求 h_q == B_H == 64。当推理框架处理 DeepSeek-V3.2 (h_q=128) 时,需要通过其他方式拆分 head(如 head64x2 或 head128 实现),或者多次调用。这与 SM90 通过 cluster (2 个 CTA 各处理 64 head) 的方案不同。

5.3 SM100 的三大硬件创新及其在 Kernel 中的应用

5.3.1 TMEM (Tensor Memory) — 新增的片上存储层

SM100 在寄存器和共享内存之外,新增了一层 TMEM (Tensor Memory):

1
2
3
4
5
SM100 存储层次:
寄存器 (Register File) — 每线程私有,最快
TMEM (Tensor Memory) — 512KB, 每 SM 共享,Tensor Core 可直接读写 ← 新增!
共享内存 (Shared Memory) — 228KB, CTA 内共享
L2 Cache → HBM

在 FlashMLA 中的具体应用 (csrc/sm100/decode/head64/kernel.cuh):

1
2
3
4
5
6
7
8
9
10
11
12
// TMEM 列布局 (config.h 第 72-80 行)
struct tmem_cols {
static constexpr int O = 0; // 列 0~255: 输出 O 累加器 (FP32)
static constexpr int Q = 256; // 列 256~399: Q 矩阵 (BF16)
static constexpr int P = 400; // 列 400~463: P = Q·K^T 的结果 (FP32)
};

// kernel 初始化时分配 512 列 TMEM (kernel.cuh 第 63 行)
cute::TMEM::Allocator1Sm().allocate(512, plan.tmem_start_addr.data());

// kernel 结束时释放 (kernel.cuh 第 425 行)
cute::TMEM::Allocator1Sm().free(0, 512);

SM90 的对比: SM90 没有 TMEM,Q 和 O 都在共享内存或寄存器中:

  • Q 通过 TMA → 共享内存 → WGMMA 直接从共享内存读取
  • O 累加器在寄存器中
  • P 在寄存器中

SM100 的优势:

  • TMEM 带宽远高于共享内存,Q 常驻 TMEM 后每个 block 的 QK^T 都不需要重新加载 Q
  • O 累加器放在 TMEM 中,减轻寄存器压力 (Warpgroup 0 从 SM90 的 192 regs 变为 224 regs,但 regs 不用存 O 了)
  • P 的 softmax 结果直接在 TMEM 中读取,无需通过共享内存传递

5.3.2 UTCMMA (Unified Tensor Core MMA) — 替代 WGMMA

SM100 使用 UTCMMA (tcgen05.mma) 指令替代 SM90 的 WGMMA (wgmma.mma_async):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
// QK^T 矩阵乘法:
// SM90: WGMMA SS 模式 — Q 和 K 都在共享内存
gemm<true, -1>(tiled_mma_QK, sQ, sK, rP);
// PTX: wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 ...

// SM100: UTCMMA TS 模式 — Q 在 TMEM, K 在共享内存
ku::utcmma_ts(tiled_mma_P, tQ_in_tmem, sK_in_smem, tP_in_tmem, true);
// PTX: tcgen05.mma.ws.cta_group::1.kind::f16 [tmem_c], [tmem_a], smem_desc_b, ...

// SV 矩阵乘法:
// SM90: WGMMA RS 模式 — S 在寄存器,V 在共享内存
gemm<false, -1>(tiled_mma_PV, rS, sV, rO);

// SM100: UTCMMA SS 模式 — S 和 V 都在共享内存,O 在 TMEM
ku::utcmma_ss(tiled_mma_O, sS, sV, tO_in_tmem, false);
// PTX: tcgen05.mma.ws.cta_group::1.kind::f16 [tmem_c], smem_desc_a, smem_desc_b, ...

区别总结:

SM90 WGMMA SM100 UTCMMA
QK^T 操作数来源 A=共享内存,B=共享内存 (SS) A=TMEM, B=共享内存 (TS)
QK^T 累加器位置 寄存器 (rP) TMEM (tP)
SV 操作数来源 A=寄存器,B=共享内存 (RS) A=共享内存,B=共享内存 (SS)
SV 累加器位置 寄存器 (rO) TMEM (tO)
发起者 1 个 warpgroup (128 线程) 1 个线程 (elect_one_sync)
异步同步 wgmma.commit/wait tcgen05.commit + mbarrier

5.3.3 TMA Gather4 — 硬件稀疏 Gather

这是对 sparse decode 最重要的 SM100 创新。SM90 通过线程协作做 __ldg 逐 token gather,SM100 直接用硬件 TMA 做 2D 稀疏 gather:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
// SM90: 每个线程用 __ldg 加载 1 个 token 的一部分 (splitkv_mla.cuh 第 509-530 行)
int token_index = __ldg(gIndices + ...); // 读 index
int block_id = token_index / page_block_size;
int offset = token_index % page_block_size;
fp8x16 data = load_128b_from_gmem(k_ptr + block_id*stride + offset*row_stride + dim_offset);
// → 128 个线程各自发起 1 个 128-bit 的全局内存读取

// SM100: 1 个线程用 TMA Gather4 一次加载 4 个 token (kernel.cuh 第 596-611 行)
ku::tma_gather4(
&tma_params.tensor_map_kv_nope, // TMA 描述符 (2D tensor layout)
plan.bar_raw_ready[buf_idx], // 完成后通知这个 barrier
plan.u.kv.raw_nope[buf_idx].data(), // 目标:shared memory
0, // column index (NoPE 起始列)
cur_indices, // int4: 4 个 row index (token 位置)
TMA::CacheHintSm90::EVICT_LAST
);
// PTX: cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes
// .cta_group::1.L2::cache_hint [smem], [desc, {col, row0, row1, row2, row3}], [mbar], hint;

SM100 的 gather4 工作流程:

1
2
3
64 个 top-k token 需要 gather → 每次 gather4 加载 4 个 token → 需要 16 次 TMA gather4
每次 gather4: 从 2D tensor [num_tokens × D_NOPE] 中按 row index 取 4 行
硬件自动处理地址计算、跨 page block 的非连续访问

对比:
| 架构 | Gather 方式 |
|—|—|
| SM90 | 128 线程 × 多轮 __ldg → ~128 个散射的全局内存请求 |
| SM100 | 1 线程 × 16 次 TMA gather4 → 16 个硬件管理的 DMA 请求 |

优势:

  • 线程无需参与地址计算和数据搬运,可以做其他工作
  • TMA 引擎内部有合并和调度优化,比线程发起的 __ldg 更高效
  • 数据直接落入 shared memory,无需经过寄存器

5.3.4 SM100 UTCCP — Q 从共享内存搬到 TMEM

SM100 加载 Q 的过程分两步:

1
2
3
4
5
6
7
8
9
// Step 1: TMA 把 Q 从 global memory → shared memory (和 SM90 相同)
ku::launch_tma_copy(tma_params.tma_Q_SW128, gQ, sQ, plan.bar_q_tma, ...);

// Step 2: UTCCP 把 Q 从 shared memory → TMEM (SM100 独有)
SM100_UTCCP_128dp256bit_1cta::copy(sQ_desc, tmem_cols::Q + offset);
// PTX: tcgen05.cp ... (shared → TMEM 的异步拷贝)

// 之后 Q 常驻 TMEM,每个 KV block 的 QK^T 直接从 TMEM 读取 Q
// 不像 SM90 每次都从共享内存读

5.4 SM100 的三 Warpgroup 分工

5.4.1 什么是 Warpgroup?

Warpgroup 是 Hopper (SM90) 引入的线程组织层次,介于 warpCTA (thread block) 之间:

1
2
3
4
5
6
CTA (Thread Block) = 384 线程 (FlashMLA 的配置)
├── Warpgroup 0 = 128 线程 = 4 个 warp
├── Warpgroup 1 = 128 线程 = 4 个 warp
└── Warpgroup 2 = 128 线程 = 4 个 warp

其中每个 warp = 32 线程 (NVIDIA GPU 的最小调度单位)

为什么需要 Warpgroup?因为 WGMMA 指令要求它。

在 Hopper 之前,Tensor Core 的 MMA 指令由单个 warp(32 线程)发起。Hopper 引入了 **WGMMA (Warpgroup MMA)**,需要 128 线程(4 个 warp)协作发起一条矩阵乘法指令,矩阵规模更大(如 64×64×16),吞吐更高。所以 warpgroup = 4 个 warp = 128 线程,是 WGMMA 的基本执行单位。

5.4.2 SM90 vs SM100 的 Warpgroup 分工对比

SM100 的三个 Warpgroup 分工与 SM90 完全不同:

1
2
3
4
5
6
SM90 (Hopper):
┌──────────────────────────────────────────────────────────────────┐
│ WG0 (128 线程,192 regs) — Consumer A: QK^T + softmax + S·V_left │
│ WG1 (128 线程,160 regs) — Consumer B: S·V_right │
│ WG2 (128 线程,152 regs) — Producer: FP8 gather + 反量化 │
└──────────────────────────────────────────────────────────────────┘

在 SM90 中,warpgroup 的概念很”实”——WGMMA 确实需要 128 线程一起发起。WG0 的 128 线程共同发起 WGMMA (QK^T),然后做 softmax,再共同发起 WGMMA (SV)。WG2 的 128 线程各自 __ldg 加载 FP8 数据并反量化。每个线程都在忙。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
SM100 (Blackwell):
┌──────────────────────────────────────────────────────────────────┐
│ WG0 (128 线程,224 regs) — Softmax: 从 TMEM 读 P → exp2 → 写 S │
│ (纯标量计算,不参与任何 MMA) │
│ │
│ WG1 (128 线程,72 regs) — MMA + Produce: │
│ Warp 4 (1 线程): 发起所有 UTCMMA (QK^T, SV) + TMA 加载 Q │
│ Warp 5 (1 线程): TMA Gather4 加载 FP8 NoPE 到 shared memory │
│ Warp 6 (1 线程): TMA Gather4 加载 BF16 RoPE 到 shared memory │
│ Warp 7 (32 线程): 读 indices → 计算 TMA 坐标 + 加载 scale │
│ │
│ WG2 (128 线程,208 regs) — Dequant: FP8→BF16 反量化 │
│ 从 raw_nope (FP8) → dequant.nope (BF16)│
└──────────────────────────────────────────────────────────────────┘

在 SM100 中,warpgroup 的概念变”虚”了——UTCMMA 只需 1 个线程发起,所以同一个 warpgroup 内的线程可以干完全不同的事。SM100 的 WG1 中,大量线程其实是空闲的(Warp 4/5/6 各只需 1 个线程),但仍然被组织为一个 warpgroup,主要是为了共享寄存器配额和 barrier 同步的方便。

这体现了 Blackwell 的设计哲学:硬件加速器(TMA、UTCMMA)接管数据搬运和矩阵计算,线程只处理无法硬件加速的”缝隙”工作。

5.4.3 关键设计差异

  1. SM90 的 MMA 由 warpgroup (128 线程) 发起; SM100 的 UTCMMA 由 1 个线程发起

    • UTCMMA 是 “widthless” 指令 (tcgen05.mma.ws),只需 1 个线程 issue,Tensor Core 自动完成矩阵计算并写入 TMEM
    • 这解放了大量线程去做其他工作
  2. Softmax 独立成 WG0 (128 线程)

    • SM90: softmax 由 Consumer A (WG0) 在 WGMMA 之间插入
    • SM100: softmax 是 WG0 的唯一任务 — 从 TMEM 读 P、做 exp2、写 S 到 shared memory
    • 好处:softmax 是标量密集计算,给它 224 个寄存器足够存所有中间状态
  3. TMA Gather4 替代线程协作 __ldg

    • SM90: WG2 的 128 个线程各自 __ldg 加载 FP8 数据
    • SM100: WG1 中仅 Warp 5 的 1 个线程发起所有 TMA Gather4,硬件 DMA 完成实际搬运
  4. 四重缓冲索引 (NUM_INDEX_BUFS=4)

    • SM90 的 indices 无显式缓冲
    • SM100 用 4 个 buffer 轮转:indices → TMA 坐标 → scale → valid mask
    • 因为 SM100 的 TMA Gather4 需要预计算 TMA 坐标 (tma_coord),这个过程是异步的

🔧 SM100 的 NUM_INDEX_BUFS=4 vs SM90 的 NUM_K_BUFS=2:
SM100 需要更多 buffer 因为 TMA Gather4 的 pipeline 更深:

  • Warp 7 计算 TMA 坐标 (写 buf N)
  • Warp 5/6 用 TMA Gather4 加载 FP8/RoPE (读 buf N, 写 KV buf M)
  • WG2 反量化 (读 KV buf M)
  • WG1-Warp4 发起 UTCMMA (读反量化后的 KV)
  • WG0 做 softmax (读 TMEM 中的 P)
    每一步都需要独立的 buffer 来实现流水线

第 6 步:TMA 加载 Q 矩阵

Producer (Warpgroup 2 中线程 0) 发起 TMA 拷贝:

1
2
3
// 加载 Q: [64, 576] 的一个 tile 从 global memory → shared memory
launch_tma_copy(tma_Q, gQ, sQ, plan.bar_q, TMA::CacheHintSm90::EVICT_FIRST);
plan.bar_q.arrive_and_expect_tx(64 * 576 * 2); // 期望 64×576×2 = 73728 字节

具体数据:

1
2
3
4
Q tile shape: [64 heads, 576 dims] × BF16 = 73,728 字节
(每个 CTA 加载 64 个 head, 另外 64 个 head 由 cluster 中的 peer CTA 处理)
从 global memory 地址 q_ptr + batch_0 * stride 开始
拷贝到 shared memory sQ (SwizzledLayout_SW128)

🔧 Hopper (SM90) TMA 特性:

  • TMA (Tensor Memory Accelerator): 硬件异步拷贝引擎,无需线程参与
  • 只需 1 个线程 发起,硬件自动完成整个 73KB 拷贝
  • EVICT_FIRST: L2 cache 提示,Q 只用一次所以优先驱逐
  • ClusterTransactionBarrier: TMA 完成后自动减少 barrier 计数
  • Swizzle 布局 (SW128): 128 字节粒度的地址映射,消除 shared memory bank conflict

🔧 Blackwell (SM100) TMA 改进:

  • SM100 新增 TMA Gather4: 可以从 2D tensor 中非连续地 gather 4 行,专为稀疏访问设计
  • SM100 用 TMA 加载 Q 到 shared memory,然后用 UTCCP 将 Q 从 shared memory 搬到 TMEM (Tensor Memory)
  • TMEM 是 Blackwell 新增的 512KB 片上存储,比 shared memory 带宽更高

第 7 步:Producer Warpgroup — FP8 反量化

这是稀疏 decode 最核心的步骤之一。Warpgroup 2 (线程 256-383) 负责:

7.1 加载 top-k 索引

1
2
3
4
5
int* gIndices = params.indices + batch_0 * stride;
// indices = [47, 2091, 583, 12, ...] 共 128 个 int32

// 线程 256 负责 token index 0:
int token_index = __ldg(gIndices + 0); // = 47

7.2 计算物理地址

1
2
3
4
5
6
int block_index = 47 / 64 = 0;           // 第 0 个 page block
int rel_idx_in_block = 47 % 64 = 47; // block 内第 47 个 token

fp8* gK_base = kv_ptr // KV cache 起始地址
+ block_index * stride_kv_block // 跳到第 0 个 block
+ rel_idx_in_block * 656; // 跳到第 47 个 token (每个 656 字节)

7.3 加载缩放因子

1
2
3
4
5
6
7
8
9
10
// 从偏移 512 处加载 4 个 float32 缩放因子 (128 bits)
float scales_float[4];
*(float4*)(scales_float) = load_128b_from_gmem<float4,
L1CacheHint::EVICT_LAST, // 缩放因子可能被复用
L2PrefetchHint::B128 // 预取 128 字节
>(gK_base + 512);

// scales_float = [0.001953, 0.003906, 0.001953, 0.007812]
// 转为 BF16
bf16 scales[4] = {bf16(0.001953), bf16(0.003906), bf16(0.001953), bf16(0.007812)};

🔧 Hopper 硬件特性:

  • load_128b_from_gmem: 通过内联 PTX 汇编实现 128-bit 全局内存加载,附带精确的 L1/L2 cache 控制
  • PTX: ld.global.nc.L1::evict_last.L2::128B.v4.s32 {%0,%1,%2,%3}, [%4];
  • .nc = non-coherent (只读),.L1::evict_last = 在 L1 中最后被驱逐(优先保留)
  • 这些 cache hint 对 Hopper 和 Blackwell 都有效

7.4 FP8 → BF16 反量化(核心步骤)

每个线程处理 16 个 FP8 值:

1
2
3
4
5
// 加载 16 个 FP8 值 (128 bits)
fp8x16 cur_fp8x16 = load_128b_from_gmem<fp8x16,
L1CacheHint::EVICT_LAST,
L2PrefetchHint::B256 // 预取 256 字节 (一次性加载整个 NoPE)
>(gK_nope + dim_idx * 64);

反量化公式:

1
BF16_value = float(FP8_value) × scale

具体例子 (dim_idx=0, Group 0 的前 8 个值):

1
2
3
4
FP8 原始值:  [128.0, -256.0, 64.0, -32.0, 192.0, -16.0, 384.0, -448.0]
scale_0 = 0.001953

BF16 结果: [0.2500, -0.5000, 0.1250, -0.0625, 0.3750, -0.0312, 0.7500, -0.8750]

C++ 实现 (components/dequant.h):

1
2
3
4
5
6
7
bf16x8 cvt_fp8x8_bf16x8(const fp8x8 &inputs, const __nv_bfloat162 &scale_bf162) {
// 每次处理 4 个 FP8
float4 fp32x4 = (float4)(inputs.lo); // FP8 → FP32 (硬件隐式转换)
bf16x2 out_lo = __float22bfloat162_rn({fp32x4.x, fp32x4.y}) * scale_bf162; // FP32 → BF16 × scale
bf16x2 out_hi = __float22bfloat162_rn({fp32x4.z, fp32x4.w}) * scale_bf162;
// ... 共 8 个值
}

🔧 Hopper 硬件特性:

  • (float4)(fp8x4): Hopper Tensor Core 原生支持 FP8→FP32 的类型转换
  • __float22bfloat162_rn(): 硬件 FP32→BF16 转换,round-to-nearest
  • 每个线程处理 16 个值,128 线程共处理 2048 个值 → 4 组 × 512/4 = 32 个线程覆盖一个 token 的 NoPE

🔧 Blackwell (SM100) 改进:

  • SM100 新增 FP8_E8M0FNU 格式缩放因子(MODEL1 使用),用 __nv_cvt_e8m0x2_to_bf162raw 硬件指令转换
  • SM100 的反量化在 Warpgroup 2 中完成后,结果直接写入 shared memory,随后 UTCMMA 从 shared memory 读取

7.5 写入共享内存

1
2
// 写入 sK (shared memory), 使用 interleaved 布局
*(__int128_t*)(sK_nope_base + smem_offset) = *(__int128_t*)&cur_bf16x8;

对于 CLUSTER_SIZE=2 的情况(h_q=128),还需写入 peer CTA 的共享内存:

1
2
3
if constexpr (CLUSTER_SIZE == 2) {
st_async_128b(sK_nope_peer_base + smem_offset, cur_bf16x8, peer_bar_k_remote_ready);
}

🔧 Hopper Cluster 特性:

  • get_peer_addr(): 通过 XOR 地址高位(16MB 偏移)直接访问 cluster 中邻居 SM 的 shared memory
  • st_async_128b: 异步写入 + mbarrier 完成通知
  • PTX: st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v2.s64 [dst], {data}, [mbar]
  • 这允许两个 CTA 共享反量化后的 K 矩阵,无需经过 global memory

7.6 加载 RoPE 部分(不量化)

1
2
3
4
5
6
7
8
// RoPE 部分直接从 global memory 加载 BF16,不需反量化
bf16x8 cur_bf16x8 = load_128b_from_gmem<bf16x8,
L1CacheHint::EVICT_LAST,
L2PrefetchHint::B128
>(gK_rope + dim_idx * 32);

// 直接写入 shared memory
*(__int128_t*)(sK_rope_base + smem_offset) = *(__int128_t*)&cur_bf16x8;

7.7 设置有效性标记

1
2
3
4
5
6
// 线程 0-31 检查索引有效性
if (idx_in_warpgroup < 32) {
int2 indices = __ldg((int2*)(indices_base + lane_idx*2));
plan.is_kv_valid[buf_idx][lane_idx*2] = (indices.x != -1);
plan.is_kv_valid[buf_idx][lane_idx*2+1] = (indices.y != -1);
}

7.8 通知 Consumer

1
2
3
fence_view_async_shared();                      // 确保所有共享内存写入对 async proxy 可见
plan.bar_k_local_ready[buf_idx].arrive(); // 通知 Consumer: K 数据已就绪
bar_phase_k ^= 1 << buf_idx; // 翻转 phase (双缓冲)

🔧 Hopper 硬件特性:

  • fence_view_async_shared(): 确保 shared memory 写入对 TMA/async proxy 可见的 fence 指令
  • Transaction Barrier: bar_k_local_readyClusterTransactionBarrier,128 线程共同 arrive,Consumer 的 wait 才会通过
  • 双缓冲 (NUM_K_BUFS=2): Producer 写 buf 0 的同时 Consumer 读 buf 1,实现流水线

第 8 步:QK^T 矩阵乘法

8.1 SM90 (Hopper) — WGMMA

Consumer A (Warpgroup 0) 等待 K 就绪后执行:

1
2
3
4
5
6
7
// 等待 Producer 完成反量化
plan.bar_k_local_ready[buf_idx].wait(bar_phase_k >> buf_idx & 1);

// Cluster 模式下还需等待 peer CTA
if constexpr (CLUSTER_SIZE == 2) {
plan.bar_k_remote_ready[buf_idx].wait(...);
}

WGMMA 计算 P = Q · K^T:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
// WGMMA SS 模式:Q 和 K 都在 shared memory
// Q: [64, 512] BF16 (从 TMA 加载)
// K: [64, 512] BF16 (从反量化后的 dequant_nope buffer)

// Warpgroup 0 (Consumer A) 发起 WGMMA
tiled_mma_QK.tiled_mma()
.gemm<true, -1>(sQ, sK, rP);

// PTX:
// wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16
// d (rP, FP32 累加器),
// a (sQ, shared memory),
// b (sK, shared memory),
// ...

// 结果:P = Q·K^T / sqrt(d_qk)
// shape: [64 heads, 64 tokens] FP32
// 每个元素 P[h][t] = Q[h] · K[t] / sqrt(576)

具体计算 (以 head 0, token 0 为例):

1
2
3
4
5
6
7
8
Q[head_0] = [q_0, q_1, ..., q_575] (576 维:512 维 NoPE + 64 维 RoPE)
K[token_0] = [k_0, k_1, ..., k_575] (反量化后的 576 维)

dot_product = Σ(q_i × k_i) for i in 0..575
= q_0×k_0 + q_1×k_1 + ... + q_575×k_575

假设:Q[head_0] · K[token_0] = 12.5
P[head_0][token_0] = 12.5 / sqrt(576) = 12.5 / 24 = 0.5208

🔧 WGMMA 硬件特性:

  • 异步执行: wgmma.commit() 发起后,线程可以继续做其他工作,用 wgmma.wait() 同步
  • 多 stage 流水线: FlashMLA 用 2-3 个 buffer 轮转,WGMMA 计算时同时加载下一批 K/V
  • Cluster 协作: h_q=128 时,2 个 CTA 组成 cluster,每个 CTA 处理 64 个 head,通过 XOR 寻址共享数据

8.2 SM100 (Blackwell) — UTCMMA

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
// SM100: UTCMMA TS 模式 — Q 在 TMEM, K 在 shared memory
// Q: 已从 shared memory 通过 UTCCP 搬到 TMEM (tmem_cols::Q)
// K: 从 TMA Gather4 加载到 shared memory

// Warpgroup 1 中 Warp 4 的 1 个线程发起 UTCMMA
cute::elect_one_sync() {
ku::utcmma_ts(
tiled_mma_P,
tQ_in_tmem, // TMEM 中的 Q
sK_in_smem, // shared memory 中的 K
tP_in_tmem, // TMEM 中的 P 累加器
true // accumulate = true
);
}

// PTX:
// tcgen05.mma.ws.cta_group::1.kind::f16
// [tmem_c], [tmem_a], smem_desc_b, ...

SM100 vs SM90 的关键区别:

SM90 WGMMA SM100 UTCMMA
Q 位置 shared memory TMEM (512KB 片上存储)
K 位置 shared memory shared memory
累加器位置 寄存器 (rP) TMEM (tP)
发起者 128 线程 (warpgroup) 1 线程 (elect_one_sync)
指令前缀 wgmma.mma_async tcgen05.mma.ws
带宽 shared memory → Tensor Core TMEM → Tensor Core (更高)

🔧 TMEM 优势:

  • TMEM 带宽 ~20TB/s,远高于 shared memory 的 ~8TB/s
  • Q 常驻 TMEM 后,每个 KV block 的 QK^T 都不需要重新加载 Q
  • 累加器放在 TMEM 中,减轻寄存器压力 (SM100 WG1 仅需 72 regs,SM90 WG0 需 192 regs)

第 9 步:Online Softmax

这是注意力计算的关键 — 在不知道全部 K 的情况下,增量式计算 softmax。

9.1 掩码无效 token

1
2
3
4
for (int i = 0; i < size(cur_rP); ++i) {
if (!is_kv_valid[(i&1)+(i/2)*8+(idx_in_warpgroup%4)*2])
cur_rP(i) = -INFINITY; // 无效索引 → -∞
}

9.2 求行最大值 (warp 内规约)

1
2
3
4
5
6
7
float cur_max = -INFINITY;
for (int i = 0; i < size(cur_rP); ++i)
cur_max = max(cur_max, cur_rP(i));

// Warp 内规约:__shfl_xor_sync 做 butterfly reduction
cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 1));
cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 2));

具体数值 (head 0, 第 1 个 block 的 64 个 KV token):

1
2
3
4
P[0][0..63] = [3.52, -1.23, 2.17, 0.89, ..., -0.45]

cur_max = 3.52 (行最大值)
cur_max_scaled = 3.52 × 0.06010 (= sm_scale × log2e) = 0.2116

9.3 更新 running max 和 rescale

1
2
3
4
cur_max *= scale_softmax_log2;  // = 0.2116
float old_max = rM[row]; // 第一个 block: old_max = -1e30 (初始值)
rM[row] = max(cur_max, old_max); // new_max = 0.2116
float scale_for_old = exp2f(old_max - new_max); // ≈ 0 (因为 old_max = -1e30)

关键: 第一个 block 时 scale_for_old ≈ 0,所以之前累积的 O 被清零,这是正确的(因为之前没有累积值)。

9.4 Rescale O 并计算 exp

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
// Rescale 旧的 O
for (int i = 0; i < size(cur_rO); ++i)
cur_rO(i) *= scale_for_old; // 第一个 block: O *= 0 → O = 0

// 计算 exp 并求和
float cur_sum = 0;
for (int i = 0; i < size(cur_rP); ++i) {
cur_rP(i) = exp2f(cur_rP(i) * scale_softmax_log2 - new_max);
// P[0][0] = exp2(3.52 * 0.06010 - 0.2116) = exp2(0) = 1.0
// P[0][1] = exp2(-1.23 * 0.06010 - 0.2116) = exp2(-0.2855) = 0.820
cur_rS(i) = (bf16)cur_rP(i); // 转为 BF16 供 PV 乘法使用
cur_sum += cur_rP(i);
}

// 更新 running L (exp-sum)
rL[row] = rL[row] * scale_for_old + cur_sum;
// = 0 * 0 + (1.0 + 0.820 + ... ) = 42.5 (假设值)

🔧 为什么用 exp2 而不是 exp?

  • exp2f()expf()~2x,因为 Hopper 的 SFU (Special Function Unit) 原生支持 base-2 指数
  • 公式:exp(x * scale) = exp2(x * scale * log2(e)) = exp2(x * scale_softmax_log2)
  • 这是一个经典的 GPU 性能优化,Hopper 和 Blackwell 都受益

9.5 保存 scale factor 到共享内存

1
2
if (idx_in_warpgroup % 4 == 0)
*(float2*)(sScale + ...) = *(float2*)(scale_for_olds);

🔧 Hopper Named Barrier 特性:

  • 计算完 softmax 后,通过 NamedBarrier::arrive(256, sScale_and_sS_ready) 通知 Warpgroup 1
  • Named Barrier 是 Hopper 新增功能,允许 warpgroup 之间细粒度同步

第 10 步:SV 矩阵乘法 — O += S · V

10.1 Warpgroup 0 — V_left (前 256 维)

Softmax 结果 S 已在寄存器中,V 的左半部分 (前 256 维) 在 shared memory:

1
2
3
4
5
6
gemm<false, -1>(  // zero_init=false (累加到现有 rO)
tiled_mma_PV, // TiledMMA: 64×256×16 F32BF16BF16 RS
rS, // S: [64, 64] in registers
thr_mma_PV.partition_fragment_B(sV), // V_left: [256, 64] in shared memory
rO // O: [64, 256] in registers (FP32 累加器)
);

具体计算:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
O_left[64×256] += S[64×64] × V_left[64×256]

对于 head 0:
O_left[0, 0:255] += S[0, 0:63] × V[0:63, 0:255]
= 1.0 × V[47, 0:255] (token 47 的 softmax weight = 1.0)
+ 0.820 × V[2091, 0:255] (token 2091 的 weight = 0.820)
+ ...

WGMMA 将此分解为 64/16 = 4 次外积:
for k = 0 to 3:
O_left += S[:, k*16:(k+1)*16] × V_left[k*16:(k+1)*16, :]

每次外积:
输入:A[64×16] FP32 (S), B[16×256] BF16 (V)
累加:C[64×256] FP32 (O)

🔧 Hopper WGMMA RS 模式:

  • RS (Register-Shared): S 在寄存器,V 在 shared memory
  • 这比 SS 模式快,因为 S 刚从 softmax 计算出来,就在寄存器中,不需要写回 shared
  • 每次 WGMMA shape: M=64, N=256, K=16

第 10 步:Warpgroup 1 — O += S · V_right

同时,Warpgroup 1 (线程 128-255) 处理 V 的右半部分:

10.2.1 等待 S 和 scale factor

1
NamedBarrier::arrive_and_wait(256, NamedBarriers::sScale_and_sS_ready);

10.2.2 Rescale 自己的 O

1
2
3
4
5
float cur_scales[2];
*(float2*)cur_scales = *(float2*)(sScale + ...); // 从共享内存读取 scale factor

for (int i = 0; i < size(cur_rO); ++i)
cur_rO(i) *= cur_scales[row];

10.2.3 WGMMA 计算

1
2
3
4
5
6
gemm<false, -1>(
tiled_mma_PV, // TiledMMA_PV_RemoteP: 64×256×16 SS 模式
thr_mma_PV.partition_fragment_A(sS), // S: [64, 64] from shared memory
thr_mma_PV.partition_fragment_B(sV), // V_right: [256, 64] from shared memory
rO // O_right: [64, 256] in registers
);

🔧 为什么 Warpgroup 1 用 SS 而不是 RS?

  • Warpgroup 1 没有参与 softmax 计算,S 不在它的寄存器中
  • S 由 Warpgroup 0 通过 save_rPb_to_sP 写入 shared memory (sS)
  • 所以 Warpgroup 1 必须从 shared memory 读取 S → SS 模式

10.2.4 通知 Producer 可以复用 buffer

1
plan.bar_k_avail[buf_idx].arrive(); // 告诉 Producer: 我已经用完这个 K buffer

第 11 步:循环处理第 2 个 block

回到第 8 步,Producer 开始加载第 2 个 block(token index 64-127)。

由于使用双缓冲 (NUM_K_BUFS=2),Producer 写 buf 1 的同时 Consumer 还在读 buf 0:

1
2
3
4
5
6
7
时间线:
┌─ Producer: 写 K[block 0] → buf 0 ─┐ ┌─ Producer: 写 K[block 1] → buf 1 ─┐
│ │ │ │
└───────────────────────────────────┘ └───────────────────────────────────┘
┌─ Consumer: 读 buf 0, 计算 P, S, O ─┐ ┌─ Consumer: 读 buf 1 ─┐
│ │ │ │
└───────────────────────────────────┘ └──────────────────────┘

第 2 个 block 的 Online Softmax 更新:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
第 2 个 block 的 P 值 (head 0):
P_new[0][0..63] = [1.75, 0.33, -2.10, ..., 0.91]

new_cur_max = 1.75 × 0.06010 = 0.1052
old_max = 0.2116 (来自第 1 个 block)
new_max = max(0.1052, 0.2116) = 0.2116 (没有变化!)

scale_for_old = exp2(0.2116 - 0.2116) = exp2(0) = 1.0
→ 旧的 O 不需要 rescale

exp 计算:
P_new[0][0] = exp2(1.75 × 0.06010 - 0.2116) = exp2(-0.1064) = 0.929
P_new[0][1] = exp2(0.33 × 0.06010 - 0.2116) = exp2(-0.1917) = 0.876

更新 running L:
rL = rL * 1.0 + (0.929 + 0.876 + ...) = 42.5 + 38.2 = 80.7

🔧 双缓冲是 GPU 流水线的经典模式,不特定于某个架构,但 Hopper 的 Transaction Barrier 让同步更高效。


第 12 步:跨 Warpgroup 的 L 规约

所有 block 处理完后,需要合并两个 Warpgroup 的 L (exp-sum):

1
2
3
4
5
6
7
8
9
10
11
// Warp 内规约
rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 1);
rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 2);
rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 1);
rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 2);

// 写到共享内存
if (idx_in_warpgroup % 4 == 0) {
sL[row] = rL[i];
sM[row] = rM[i];
}

数值:

1
2
最终 rL[head 0] = 80.7 (所有 128 个 token 的 exp-sum)
最终 rM[head 0] = 0.2116 (全局最大值,在 log2 空间)

🔧 硬件特性: __shfl_xor_sync 是 warp shuffle,所有 CUDA GPU 支持。跨 warpgroup 通过 shared memory 通信。


第 13 步:Attention Sink 处理

如果提供了 attn_sink(DeepSeek 中用于 sink token 的预计算注意力值):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
if (params.attn_sink != nullptr) {
float attn_sink_log2 = __ldg(params.attn_sink + head_idx) * CUDART_L2E_F;
// attn_sink[head 0] = -2.5 → attn_sink_log2 = -2.5 × 1.4427 = -3.607
}

// 计算最终的 output scale (包含 attention sink)
if (args.is_no_split) {
o_scales[i] = 1.0 / (rL[i] + exp2f(rAttn_sink[i] - rM[i]));
// = 1.0 / (80.7 + exp2(-3.607 - 0.2116))
// = 1.0 / (80.7 + exp2(-3.818))
// = 1.0 / (80.7 + 0.0709)
// = 1.0 / 80.77
// = 0.01238
}

Attention Sink 的含义: 假装有一个额外的 “sink” token,其注意力分数是固定的 attn_sink 值。这让模型可以把一部分注意力”倒掉”,避免所有注意力都集中在 top-k token 上。


第 14 步:输出写回

14.1 Warpgroup 0: O_left 写入分 split/no-split 两种情况

No-split (is_no_split=true, 我们的例子):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// 将 FP32 寄存器值缩放并转为 BF16
bf16x2 a01 = __float22bfloat162_rn({rO(0) * o_scale, rO(1) * o_scale});
// 例:rO(0)=15.6, o_scale=0.01238 → bf16(0.1931)

// 通过 STSM (Store Matrix) 指令写入 shared memory
SM90_U32x4_STSM_N::copy(a01, a23, a45, a67, smem_ptr);

// 所有线程同步
NamedBarrier::arrive_and_wait(256, epilogue_r2s_ready);

// 线程 0 通过 TMA 将 shared memory → global memory
if (threadIdx.x == 0) {
SM90_TMA_STORE_5D::copy(&tensor_map_o, plan.u.oBuf.data(), ...);
cute::tma_store_arrive();
}

🔧 Hopper TMA Store 特性:

  • SM90_TMA_STORE_5D: 5 维 TMA 存储,可以直接从 shared memory 写到 global memory 的任意 5D 位置
  • 只需 1 个线程发起,硬件自动完成
  • 使用 CUtensorMap 描述符,包含 swizzle 信息和地址映射

🔧 Blackwell 改进:

  • SM100 仍然使用 TMA Store,但 TMEM → shared → global 的路径更短
  • SM100 的 TMA 吞吐量更高

Split 情况 (如果一个 batch 被分到多个 SM partition):

1
2
3
// 写入 o_accum (FP32),而不是最终 output (BF16)
// 不做 BF16 转换,保持 FP32 精度
SM90_BULK_COPY_S2G::copy(&sOutputAccumBuf(row, 0), &gOAccum(row, 0), 512*sizeof(float));

14.2 写入 LSE

1
2
3
4
5
6
7
8
9
10
11
12
13
14
int i = threadIdx.x;
if (i < num_valid_heads) {
float cur_L = sL[i];
if (is_no_split) {
// 最终 LSE = ln(L) + M / log2(e)
gSoftmaxLse[i] = cur_L == 0.0f ? INFINITY : logf(cur_L) + sM[i] / M_LOG2E;
// = log(80.7) + 0.2116 / 1.4427
// = 4.391 + 0.1467
// = 4.538
} else {
// Split: 保持 log2 空间供 combine kernel 使用
gSoftmaxLseAccum[i] = cur_L == 0.0f ? -INFINITY : log2f(cur_L) + sM[i];
}
}

第 15 步:PDL — 提前启动 Combine Kernel

1
2
3
4
5
// 最后一个 batch 处理完毕
if (batch_idx == sched_meta.end_req_idx) {
cudaTriggerProgrammaticLaunchCompletion();
// 告诉 CUDA 运行时:这个 CTA 的 combine kernel 依赖已满足
}

🔧 Hopper PDL (Programmatic Dependent Launch) 特性:

  • 传统 CUDA: kernel B 必须等 kernel A 的所有 CTA 完成才能启动
  • PDL: kernel B 可以在 kernel A 的部分 CTA 完成后就启动
  • 对于我们的例子:SM Partition 0 (batch 0) 完成后,batch 0 的 combine 可以立即开始,不需等 batch 1
  • 通过 cudaLaunchAttributeProgrammaticStreamSerialization 启用
  • Blackwell 完全支持此特性

第 16 步:Combine Kernel — 合并 Split 结果

文件: csrc/smxx/decode/combine/combine.cu

对于我们的例子(每个 batch 只有 1 个 split,is_no_split=true),combine kernel 直接 return:

1
if (my_num_splits == 1) return; // 无需合并!

但当 batch 很大或序列很长时(例如 topk=8192, num_sm_parts=132),一个 batch 的 KV 会被切分到多个 SM partition,此时 combine kernel 的工作如下:

假设有 3 个 split 的情况:

1
2
3
4
5
6
Grid: [batch_size=2, s_q=1, ceil(128/8)=16]
Block: 256 线程 = 8 warps, 每个 warp 处理 1 个 head

Split 0: lse_accum[0] = 4.2, o_accum[0] = [0.15, -0.23, ...] (FP32)
Split 1: lse_accum[1] = 3.8, o_accum[1] = [0.12, -0.31, ...] (FP32)
Split 2: lse_accum[2] = 2.1, o_accum[2] = [0.08, -0.11, ...] (FP32)

Step 1: 求全局 max_lse (warp 内 shuffle 规约)

1
max_lse = max(4.2, 3.8, 2.1) = 4.2

Step 2: 求 sum_lse

1
2
3
sum_lse = exp2(4.2-4.2) + exp2(3.8-4.2) + exp2(2.1-4.2)
= 1.0 + 0.7579 + 0.1353
= 1.8932

Step 3: 加权合并 O

1
O_final = (1.0/1.8932) × [1.0×o_accum[0] + 0.7579×o_accum[1] + 0.1353×o_accum[2]]

Step 4: 写回结果

1
2
// 将合并后的 O 写入 global memory
*(float4*)(gO + head_idx * stride) = *(float4*)&o_final;

第 17 步:完整流水线时间线

流水线时间线:

1
2
3
4
5
6
7
8
9
10
11
12
时间 →
Producer (WG2):
║ 加载索引 ║ 加载 FP8+ 反量化 block0 buf0 ║ 加载 FP8+ 反量化 block1 buf1 ║ 加载下一 Q ║
║ LDG ║ LDG + CVT_FP8 + STSM ║ LDG + CVT_FP8 + STSM ║ TMA ║

Consumer A (WG0):
║ 等待 buf0 ║ QK^T WGMMA ║ softmax ║ SV WGMMA ║ 等待 buf1 ║ QK^T ║ softmax ║ SV ║ 写回 ║
║ ║ 36 cycles ║ ~20c ║ ~16c ║ ║ 36c ║ ~20c ║16c║ TMA ║

Consumer B (WG1):
║ 等待 S ║ rescale O ║ SV WGMMA ║ 等待 S ║ rescale ║ SV ║ 写回 ║
║ ~5c ║ ~16c ║ ║ ~5c ║16c ║ ║ TMA ║

第 18 步:SM90 (Hopper) vs SM100 (Blackwell) 对比总结

特性 SM90 (Hopper, H100/H800) SM100 (Blackwell, GB200)
矩阵乘法 WGMMA (SS/RS 模式) UTCMMA (TS/SS 模式,tcgen05 指令)
Q 存储位置 Shared Memory (SW128 布局) TMEM (512KB 片上存储,零拷贝)
K 加载 线程协作 LDG + 反量化 TMA Gather4 (硬件稀疏 gather)
Warpgroup 划分 3 组:128+128+128 线程 3 组:32+192+160 线程 (更不均匀)
寄存器分配 192/160/152 regs 224/72/208 regs (更极端的再分配)
跨 SM 通信 Cluster Shared Memory (XOR 寻址) 同上 + 更大 cluster (最多 16 SM)
Barrier Named Barrier + Transaction Barrier 同上 + TCGen05 fence 指令
缩放因子格式 float32 (V3.2) float32 (V3.2) / FP8_E8M0FNU (MODEL1)
PDL cudaTriggerProgrammaticLaunchCompletion 同上
TMA 2D Block Copy 2D Block Copy + Gather4 (稀疏模式)

Blackwell 关键创新:

  1. TMEM (Tensor Memory): 512KB 的新增片上存储,带宽高于 shared memory。Q 矩阵常驻 TMEM,避免反复从 shared memory 读取
  2. UTCMMA: 新一代 Tensor Core 指令,支持 TMEM 直接作为操作数源
  3. TMA Gather4: 硬件实现的 2D 稀疏 gather,比线程协作的 __ldg 更高效
  4. 更灵活的 Warpgroup 分工: 32 线程专门做 softmax(不参与 MMA),减少 register pressure

第 19 步:最终输出

1
2
3
4
5
out shape: [2, 1, 128, 512] BF16 (h_q=128 来自 config.json)
lse shape: [2, 128, 1] FP32

out[0, 0, 0, :] = [0.1931, -0.2847, 0.0523, ..., -0.1234] (batch 0, head 0 的 512 维输出)
lse[0, 0, 0] = 4.538 (batch 0, head 0 的 log-sum-exp)

附录:关键文件索引

文件 内容
flash_mla/flash_mla_interface.py Python API,路由到 dense/sparse
csrc/api/sparse_decode.h C++ 接口,架构分发,参数准备
csrc/params.h 所有参数结构体定义
csrc/smxx/decode/get_decoding_sched_meta/ Tile scheduler,工作分配
csrc/sm90/decode/sparse_fp8/config.h SM90 kernel 配置和共享内存布局
csrc/sm90/decode/sparse_fp8/splitkv_mla.cuh SM90 三 warpgroup kernel 实现
csrc/sm90/decode/sparse_fp8/components/dequant.h FP8 反量化实现
csrc/sm90/decode/sparse_fp8/components/helpers.h WGMMA helper, cluster async ops
csrc/sm100/decode/head64/config.h SM100 kernel 配置 (TMEM layout)
csrc/sm100/decode/head64/kernel.cuh SM100 三 warpgroup kernel 实现
csrc/kerutils/.../sm100/gemm.cuh UTCMMA 内联 PTX
csrc/kerutils/.../sm100/intrinsics.cuh TMA Gather4, TMEM ops
csrc/smxx/decode/combine/combine.cu Split-KV 结果合并
tests/quant.py FP8 量化/反量化参考实现
tests/ref.py 纯 PyTorch 参考实现

总结

本文以 DeepSeek-V3.2 的真实配置为基础,通过一个最小化的数值例子(b=2, topk=128),逐步展示了 FlashMLA Sparse Decode 的完整计算流程:

  1. Python 层 路由到 sparse_decode_fwd
  2. C++ 层 验证输入并打包参数结构体
  3. Tile Scheduler 将工作均匀分配到 66 个 SM partition
  4. FP8 KV Cache 以 656 字节/token 的格式存储(512 FP8 + 16 scales + 128 BF16 RoPE)
  5. CUDA Kernel 根据架构选择 SM90 (Cluster + WGMMA) 或 SM100 (TMEM + UTCMMA + TMA Gather4)
  6. TMA 加载 Q 矩阵到 shared memory / TMEM
  7. FP8 反量化 — Producer warpgroup 从 KV cache gather token 并反量化为 BF16
  8. QK^T 矩阵乘法 — WGMMA (SM90) 或 UTCMMA (SM100) 计算注意力分数 P = Q·K^T
  9. Online Softmax — 增量式 softmax,exp2f 优化 + LSE 合并
  10. SV 矩阵乘法 — O = S · V,Warpgroup 0/1 并行计算左右两半
  11. 结果写回与 Combine — 写回 O/LSE,merge kernel 合并 split 结果(如果需要)

关键硬件优化:

  • SM90 (Hopper): Cluster 协作、WGMMA 异步矩阵乘法、TMA 加载
  • SM100 (Blackwell): TMEM 片上存储、UTCMMA 指令、TMA Gather4 硬件稀疏 gather

性能优势:

  • FP8 量化减少 43% 的内存带宽消耗
  • DSA 稀疏注意力将 O(n²) 复杂度降为 O(n·k)
  • 在 128K 上下文长度下可实现 3 倍更快的推理速度

展望:未深入探索的方向

本文聚焦于 Sparse Decode 路径(SM90 + SM100),以下方向值得进一步研究:

  • Dense Decode 路径 (csrc/sm90/decode/dense/splitkv_mla.cuh): 当上下文长度 ≤ index_topk 时走 dense 路径,不需要稀疏索引。其调度和 KV 访问模式与 sparse 有本质区别
  • Prefill 路径 (dense prefill + sparse prefill): prefill 阶段处理完整的 prompt 输入,是 compute-bound(而非 decode 的 memory-bound),kernel 设计思路完全不同
  • SM100 head64x2 / head128 实现: 当 h_q > 64 时 Blackwell 如何处理?SM90 通过 Cluster (2 CTA 各 64 head) 解决,SM100 的方案尚未在 head64 实现中体现
  • CLC (Cluster Launch Control): SM100 prefill 中使用的新特性,允许更灵活的 cluster 调度
  • 性能调优细节: bank conflict 避免策略、L2 cache 分区、占用率 (occupancy) 分析
  • 端到端 benchmark: 不同序列长度 / batch size / topk 下的 roofline 分析,识别瓶颈是在 gather、MMA 还是 softmax

参考资料

  1. DeepSeek-V3.2 config.json: https://huggingface.co/deepseek-ai/DeepSeek-V3.2/blob/main/config.json
  2. FlashMLA GitHub: https://github.com/deepseek-ai/FlashMLA
  3. NVIDIA Hopper Architecture: https://developer.nvidia.com/blog/nvidia-hopper-architecture-in-depth/
  4. NVIDIA Blackwell Architecture: https://developer.nvidia.com/blog/nvidia-blackwell-architecture-in-depth/
  5. 前作:《FlashMLA 深度解析:FP8 KV Cache 与 DSA 稀疏注意力实现原理》https://ggaaooppeenngg.github.io