FlashMLA Sparse Decode 完整计算过程详解
FlashMLA Sparse Decode 完整计算过程详解
本文以一个最小化的具体数值例子,逐步展示 FlashMLA Sparse Decode 的每一个计算步骤,并标注每一步与 Hopper (SM90) / Blackwell (SM100/GB200) 硬件特性的关系。
0. 问题设定
参考配置
本文的模型参数来自 DeepSeek-V3.2 的官方 Hugging Face 配置:
以下是从 config.json 中提取的注意力相关参数,以及它们在 FlashMLA 内核中的对应关系:
| config.json 参数 | 值 | FlashMLA 内部参数 | 说明 |
|---|---|---|---|
| kv_lora_rank | 512 | d_nope = 512 | KV 的 LoRA 压缩维度,即 NoPE (Non-Positional Encoding) 部分的维度。在 MLA 中,KV cache 存储的是 LoRA 压缩后的向量,而非原始的 K/V |
| qk_rope_head_dim | 64 | d_rope = 64 | RoPE (Rotary Position Embedding) 部分的维度 |
| 合计 | 512 + 64 = 576 | d_qk = 576 | Q/K 的总 head dimension = kv_lora_rank + qk_rope_head_dim |
| kv_lora_rank | 512 | d_v = 512 | V 的 head dimension,等于 LoRA rank(在 MLA 中 V 和 K 的 NoPE 部分共享同一个压缩向量) |
| num_attention_heads | 128 | h_q = 128 | Query head 数量 |
| num_key_value_heads | 128 | — | config 中 KV head 数也是 128,但在 MLA 架构中,KV cache 只存储 1 份压缩向量(所有 head 共享),所以 FlashMLA 中 h_kv = 1 (MQA 模式) |
| qk_nope_head_dim | 128 | — | 这是模型层面每个 head 的 NoPE 维度(128 × 128 heads = 16384 → 再经 LoRA 压缩为 512)。FlashMLA 操作的是压缩后的 512 维向量,不直接使用此参数 |
| v_head_dim | 128 | — | 同上,这是模型层面每个 head 的 V 维度。FlashMLA 操作的是压缩后的 512 维向量 |
| quantization_config.fmt | “e4m3” | FP8_E4M3 | KV cache 的 NoPE 部分使用 FP8 E4M3 格式量化 |
| quantization_config.scale_fmt | “ue8m0” | UE8M0 (V3.2) / FP8_E8M0FNU (MODEL1) | 量化缩放因子的格式,纯 2 的幂 |
| index_topk | 2048 | topk | 每个 query 关注的 top-k KV token 数量(稀疏注意力) |
| index_n_heads | 64 | — | 参与稀疏索引选择的 head 数量 |
| index_head_dim | 128 | — | 用于索引选择的 head 维度 |
MLA (Multi-head Latent Attention) 的核心思想:不存储 128 个 head × 128 维的完整 KV cache (= 16384 维),而是存储一个 512 维的 LoRA 压缩向量 + 64 维的 RoPE 向量 (= 576 维)。这将 KV cache 压缩了 28 倍 (16384 → 576)。FlashMLA 的所有计算都在这个压缩空间中进行。
本文使用的示例参数
我们使用与 config.json 一致的模型配置,但将 batch size 和 topk 缩小以便手算:
模型配置 (来自 config.json):
1 | d_qk = 576 = kv_lora_rank(512) + qk_rope_head_dim(64) |
示例 Batch 配置 (缩小规模以便展示):
1 | b = 2 (batch size) |
硬件配置 (H800 / SM90):
1 | num_sms = 132 (SM 数量) |
注意: 实际部署中 topk=2048,h_q=128,这意味着每个 query token 需要从 KV cache 中 gather 2048 个 token 并做注意力计算。本文为了展示清晰,将 topk 缩小为 128。
为了简化展示,我们只跟踪 2 个 token(token 0 和 token 1)在一个 head 上的完整计算,其余的计算过程完全相同。
第 1 步:Python 层入口
1 | sched_meta, _ = flash_mla.get_mla_metadata() # 创建空的调度元数据 |
关于 dense vs sparse 的选择: FlashMLA 本身不决定使用哪条路径。当调用方(vLLM/SGLang 等推理框架)传入
indices参数时走 sparse 路径;不传indices而传入block_table + cache_seqlens时走 dense 路径。当上下文长度 ≤ index_topk(2048) 时,top-k 等于全部 token 数,稀疏无意义,框架层面应直接调用 dense 路径。
关键路由逻辑 (flash_mla_interface.py 第 151 行):
1 | if topk is not None: |
→ 无硬件特性依赖,纯 Python 路由
第 2 步:C++ 接口层 — 输入验证与参数准备
文件: csrc/api/sparse_decode.h
2.1 架构检测与实现选择
1 | Arch arch = Arch(); |
🔧 硬件特性:
cudaGetDeviceProperties获取 compute capability。SM90 = Hopper, SM100 = Blackwell。
2.2 FP8 KV Cache 形状验证
1 | // V3.2: 每个 token 占 656 字节 |
656 字节的内存布局:
1 | ┌─────────────────────────────────────────────────────────┐ |
🔧 硬件特性: FP8 (E4M3) 是 Hopper/Blackwell 原生支持的数据类型。Hopper 的 Tensor Core 可以直接做 FP8 矩阵乘法,但 FlashMLA 选择先反量化为 BF16 再做 WGMMA,以获得更高精度。
2.3 填充参数结构体
这一步是”翻译层”— 把 PyTorch tensor 的元信息(shape、stride、data_ptr)和标量参数,打包成 CUDA kernel 能直接使用的 C struct。kernel 内部只看这个 struct,不再接触 PyTorch 对象。
1 | SparseAttnDecodeParams params = { |
→ 无硬件特性依赖,纯数据结构准备
第 3 步:Tile Scheduler — 工作分配
文件: csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu
这一步决定了 哪个 SM 处理哪些 batch 的哪些 KV block。
3.1 计算每个 batch 的处理块数
⚠️ “block” 概念澄清: 这里的 “block” 不是 KV cache page block(64 个连续 token 组成的页面),而是对 top-k indices 列表的分块。
top-k 选择是 token 级别的 —
indices数组中存储的是 2048 个(本例中 128 个)散落在 KV cache 各处的 token 的绝对位置,这些 token 之间通常不连续。由于 kernel 无法一次性处理所有 top-k token,所以按
TOPK_BLOCK_SIZE=64将 indices 列表切块处理:
| 概念 | Dense Decoding 的 “block” | Sparse Decoding 的 “block” |
|---|---|---|
| 含义 | KV cache page block(64 个连续 token 组成的页面) | indices 列表的处理块(64 个不连续 token index 为一组) |
| 数据局部性 | 好(连续访问 global memory) | 差(每个 token 可能在不同 page,需要随机 gather) |
| 计算方式 | 按 block_table 顺序遍历所有页面 | 按 indices 逐 token gather,每个 index 指向 KV cache 中的任意位置 |
1 | Launch: <<<1, 32>>> (单 CUDA block, 32 线程,即 1 个 warp) |
🔧 硬件特性: 使用
__shfl_xor_sync做 warp 内规约求和 — 这是 CUDA warp shuffle 指令,所有 NVIDIA GPU 支持。注意: FP8 量化将每个 token 从 1152 字节 (576×BF16) 压缩到 656 字节,减少 ~43% 的内存带宽消耗。Decode 阶段是 memory-bound(瓶颈在显存带宽而非计算),因此减少每 token 的字节数可以直接提升吞吐量。
3.2 背景:SM (Streaming Multiprocessor) 是什么
SM 是 NVIDIA GPU 的基本计算单元。可以把 GPU 理解为一个拥有很多独立计算核心的处理器,每个 SM 就是其中一个核心。
- 每个 SM 有自己的寄存器文件(65536 个 32-bit 寄存器)、共享内存(228KB,SM90)、Tensor Core(矩阵乘法加速器)、SFU(特殊函数单元,做 exp2f 等)
- H800 有 132 个 SM,它们并行工作
- 一个 CUDA kernel 启动时,会生成很多 **CTA (Cooperative Thread Array)**,也叫 thread block。GPU 的硬件调度器把这些 CTA 分配到各个 SM 上执行
- 一个 SM 可以同时执行多个 CTA(如果资源允许),但 FlashMLA 的 kernel 每个 CTA 用了接近满载的寄存器和共享内存,所以通常一个 SM 只跑 1 个 CTA
Cluster 是 Hopper 新增的概念:多个 SM 组成一个 cluster,cluster 内的 SM 可以直接读写彼此的共享内存。FlashMLA 中 h_q=128 时
CLUSTER_SIZE=2,即 2 个 SM 组成一个 cluster,协作处理 128 个 head。
3.3 计算 SM Partition 数
1 | h_q = 128 个 head |
3.4 贪心调度算法
调度器的目标是:把所有 batch 的工作均匀分配到 66 个 partition 上。
1 | 第一步:计算每个 partition 的负载上限 (payload) |
然后,单线程从左到右逐个 partition 分配工作 (代码第 66-91 行):
状态变量:
1 | now_req_idx = 0 // 当前正在分配第几个 batch |
1 | ═══════════ Partition 0, remain_payload=6 ═══════════ |
3.5 Split 与 Combine
当一个 batch 被分到多个 partition 时,就产生了 split — 每个 partition 只计算部分 KV block 的注意力,得到局部结果 (partial O, partial LSE)。最后需要 combine kernel 把各个 split 的结果合并。
1 | batch 0 被 split 成 2 份: |
如果一个 batch 没有被 split (
is_no_split=true),则 decode kernel 直接输出最终结果,不需要 combine kernel,省去一次 kernel launch。在实际部署中 (topk=2048, b=128),大部分 batch 都会被 split 到多个 partition 以充分利用 132 个 SM 的并行度。
🔧 硬件特性: 调度器本身是架构无关的 (
smxx/目录),但num_sm_parts根据 GPU 的 SM 数量和 kernel 的 cluster size 计算,与具体硬件相关。
第 4 步:FP8 量化回顾 — KV Cache 是如何存储的
在 decode 之前,预填充阶段已经将 KV cache 量化为 FP8。我们用具体数字展示:
4.1 原始 BF16 KV 向量 (一个 token)
1 | 原始 K 向量 (d=576): |
4.2 分组量化
NoPE 部分按 128 个一组量化(V3.2 有 4 组):
1 | Group 0: K_nope[0:127] |
4.3 存储布局 (656 字节)
| 偏移 | 内容 | 字节数 |
|---|---|---|
| 0 | K_nope_fp8[0:511] | 512 (512 × 1 字节 FP8_E4M3) |
| 512 | [scale_0, scale_1, scale_2, scale_3] | 16 (4 × 4 字节 float32) |
| 528 | K_rope_bf16[0:63] | 128 (64 × 2 字节 BF16) |
| 合计 | 656 字节 |
🔧 硬件特性:
- FP8 (E4M3): Hopper/Blackwell 原生数据类型,1 符号位 + 4 指数位 + 3 尾数位,范围 ±448,精度约 3 位有效数字
- UE8M0 缩放因子: 纯 2 的幂,无尾数位,确保反量化时乘法精确无舍入误差
- Blackwell (SM100) 额外支持: FP8_E8M0FNU 格式的缩放因子,MODEL1 使用此格式
第 5 步:CUDA Kernel 启动
5.1 SM90 (Hopper) 启动配置
文件: csrc/sm90/decode/sparse_fp8/splitkv_mla.cuh 底部
1 | // SM90 启动配置 |
🔧 Hopper 硬件特性:
- Cluster Launch:
launch_kernel_on_cluster使用 Hopper 的 Thread Block Cluster 功能。h_q=128 时CLUSTER_SIZE=2,两个 CTA 组成一个 cluster,可以直接访问对方的共享内存 (通过 XOR 寻址)- 384 线程: 3 个 warpgroup(每个 128 线程 = 4 个 warp),这是 Hopper WGMMA 指令的基本执行单位
- NUM_M_BLOCKS=2: h_q=128 / BLOCK_M=64 = 2,每个 CTA 处理 64 个 head,2 个 CTA 覆盖全部 128 个 head
5.2 SM100 (Blackwell / GB200) 启动配置
文件: csrc/sm100/decode/head64/kernel.cuh 底部
1 | // SM100 启动配置 — 注意与 SM90 的关键区别 |
关键区别:
| SM90 (Hopper) | SM100 (Blackwell) | |
|---|---|---|
| grid 维度 | (2, 1, 66)x=NUM_M_BLOCKS, z=num_sm_parts |
(1, 132, 1)x=s_q, y=num_sm_parts |
| CTA 数量 | 2 × 1 × 66 = 132 | 1 × 132 × 1 = 132 |
| cluster size | 2 (两个 CTA 协作) | 1 (单 CTA 独立工作) |
| 每 CTA 处理 | 64 个 head | 64 个 head (全部 h_q=64) |
| num_sm_parts | 132/(1×2) = 66 | 132/1 = 132 |
1 | // SM90 的 KU_ASSERT: h_q == 64 || h_q == 128 等 |
SM100 为什么
h_q == 64而不是 128?
SM100 head64 实现 (csrc/sm100/decode/head64/) 要求h_q == B_H == 64。当推理框架处理 DeepSeek-V3.2 (h_q=128) 时,需要通过其他方式拆分 head(如 head64x2 或 head128 实现),或者多次调用。这与 SM90 通过 cluster (2 个 CTA 各处理 64 head) 的方案不同。
5.3 SM100 的三大硬件创新及其在 Kernel 中的应用
5.3.1 TMEM (Tensor Memory) — 新增的片上存储层
SM100 在寄存器和共享内存之外,新增了一层 TMEM (Tensor Memory):
1 | SM100 存储层次: |
在 FlashMLA 中的具体应用 (csrc/sm100/decode/head64/kernel.cuh):
1 | // TMEM 列布局 (config.h 第 72-80 行) |
SM90 的对比: SM90 没有 TMEM,Q 和 O 都在共享内存或寄存器中:
- Q 通过 TMA → 共享内存 → WGMMA 直接从共享内存读取
- O 累加器在寄存器中
- P 在寄存器中
SM100 的优势:
- TMEM 带宽远高于共享内存,Q 常驻 TMEM 后每个 block 的 QK^T 都不需要重新加载 Q
- O 累加器放在 TMEM 中,减轻寄存器压力 (Warpgroup 0 从 SM90 的 192 regs 变为 224 regs,但 regs 不用存 O 了)
- P 的 softmax 结果直接在 TMEM 中读取,无需通过共享内存传递
5.3.2 UTCMMA (Unified Tensor Core MMA) — 替代 WGMMA
SM100 使用 UTCMMA (tcgen05.mma) 指令替代 SM90 的 WGMMA (wgmma.mma_async):
1 | // QK^T 矩阵乘法: |
区别总结:
| SM90 WGMMA | SM100 UTCMMA | |
|---|---|---|
| QK^T 操作数来源 | A=共享内存,B=共享内存 (SS) | A=TMEM, B=共享内存 (TS) |
| QK^T 累加器位置 | 寄存器 (rP) | TMEM (tP) |
| SV 操作数来源 | A=寄存器,B=共享内存 (RS) | A=共享内存,B=共享内存 (SS) |
| SV 累加器位置 | 寄存器 (rO) | TMEM (tO) |
| 发起者 | 1 个 warpgroup (128 线程) | 1 个线程 (elect_one_sync) |
| 异步同步 | wgmma.commit/wait | tcgen05.commit + mbarrier |
5.3.3 TMA Gather4 — 硬件稀疏 Gather
这是对 sparse decode 最重要的 SM100 创新。SM90 通过线程协作做 __ldg 逐 token gather,SM100 直接用硬件 TMA 做 2D 稀疏 gather:
1 | // SM90: 每个线程用 __ldg 加载 1 个 token 的一部分 (splitkv_mla.cuh 第 509-530 行) |
SM100 的 gather4 工作流程:
1 | 64 个 top-k token 需要 gather → 每次 gather4 加载 4 个 token → 需要 16 次 TMA gather4 |
对比:
| | |
|—|—|
| SM90 | 128 线程 × 多轮 __ldg → ~128 个散射的全局内存请求 |
| SM100 | 1 线程 × 16 次 TMA gather4 → 16 个硬件管理的 DMA 请求 |
优势:
- 线程无需参与地址计算和数据搬运,可以做其他工作
- TMA 引擎内部有合并和调度优化,比线程发起的
__ldg更高效 - 数据直接落入 shared memory,无需经过寄存器
5.3.4 SM100 UTCCP — Q 从共享内存搬到 TMEM
SM100 加载 Q 的过程分两步:
1 | // Step 1: TMA 把 Q 从 global memory → shared memory (和 SM90 相同) |
5.4 SM100 的三 Warpgroup 分工
SM100 的三个 Warpgroup 分工与 SM90 完全不同:
1 | SM90 (Hopper): |
关键设计差异:
SM90 的 MMA 由 warpgroup (128 线程) 发起; SM100 的 UTCMMA 由 1 个线程发起
- UTCMMA 是 “widthless” 指令 (
tcgen05.mma.ws),只需 1 个线程 issue,Tensor Core 自动完成矩阵计算并写入 TMEM - 这解放了大量线程去做其他工作
- UTCMMA 是 “widthless” 指令 (
Softmax 独立成 WG0 (128 线程)
- SM90: softmax 由 Consumer A (WG0) 在 WGMMA 之间插入
- SM100: softmax 是 WG0 的唯一任务 — 从 TMEM 读 P、做 exp2、写 S 到 shared memory
- 好处:softmax 是标量密集计算,给它 224 个寄存器足够存所有中间状态
TMA Gather4 替代线程协作
__ldg- SM90: WG2 的 128 个线程各自
__ldg加载 FP8 数据 - SM100: WG1 中仅 Warp 5 的 1 个线程发起所有 TMA Gather4,硬件 DMA 完成实际搬运
- SM90: WG2 的 128 个线程各自
四重缓冲索引 (
NUM_INDEX_BUFS=4)- SM90 的 indices 无显式缓冲
- SM100 用 4 个 buffer 轮转:indices → TMA 坐标 → scale → valid mask
- 因为 SM100 的 TMA Gather4 需要预计算 TMA 坐标 (
tma_coord),这个过程是异步的
🔧 SM100 的
NUM_INDEX_BUFS=4vs SM90 的NUM_K_BUFS=2:
SM100 需要更多 buffer 因为 TMA Gather4 的 pipeline 更深:
- Warp 7 计算 TMA 坐标 (写 buf N)
- Warp 5/6 用 TMA Gather4 加载 FP8/RoPE (读 buf N, 写 KV buf M)
- WG2 反量化 (读 KV buf M)
- WG1-Warp4 发起 UTCMMA (读反量化后的 KV)
- WG0 做 softmax (读 TMEM 中的 P)
每一步都需要独立的 buffer 来实现流水线
第 6 步:TMA 加载 Q 矩阵
Producer (Warpgroup 2 中线程 0) 发起 TMA 拷贝:
1 | // 加载 Q: [64, 576] 的一个 tile 从 global memory → shared memory |
具体数据:
1 | Q tile shape: [64 heads, 576 dims] × BF16 = 73,728 字节 |
🔧 Hopper (SM90) TMA 特性:
- TMA (Tensor Memory Accelerator): 硬件异步拷贝引擎,无需线程参与
- 只需 1 个线程 发起,硬件自动完成整个 73KB 拷贝
- EVICT_FIRST: L2 cache 提示,Q 只用一次所以优先驱逐
- ClusterTransactionBarrier: TMA 完成后自动减少 barrier 计数
- Swizzle 布局 (SW128): 128 字节粒度的地址映射,消除 shared memory bank conflict
🔧 Blackwell (SM100) TMA 改进:
- SM100 新增 TMA Gather4: 可以从 2D tensor 中非连续地 gather 4 行,专为稀疏访问设计
- SM100 用 TMA 加载 Q 到 shared memory,然后用 UTCCP 将 Q 从 shared memory 搬到 TMEM (Tensor Memory)
- TMEM 是 Blackwell 新增的 512KB 片上存储,比 shared memory 带宽更高
第 7 步:Producer Warpgroup — FP8 反量化
这是稀疏 decode 最核心的步骤之一。Warpgroup 2 (线程 256-383) 负责:
7.1 加载 top-k 索引
1 | int* gIndices = params.indices + batch_0 * stride; |
7.2 计算物理地址
1 | int block_index = 47 / 64 = 0; // 第 0 个 page block |
7.3 加载缩放因子
1 | // 从偏移 512 处加载 4 个 float32 缩放因子 (128 bits) |
🔧 Hopper 硬件特性:
load_128b_from_gmem: 通过内联 PTX 汇编实现 128-bit 全局内存加载,附带精确的 L1/L2 cache 控制- PTX:
ld.global.nc.L1::evict_last.L2::128B.v4.s32 {%0,%1,%2,%3}, [%4];.nc= non-coherent (只读),.L1::evict_last= 在 L1 中最后被驱逐(优先保留)- 这些 cache hint 对 Hopper 和 Blackwell 都有效
7.4 FP8 → BF16 反量化(核心步骤)
每个线程处理 16 个 FP8 值:
1 | // 加载 16 个 FP8 值 (128 bits) |
反量化公式:
1 | BF16_value = float(FP8_value) × scale |
具体例子 (dim_idx=0, Group 0 的前 8 个值):
1 | FP8 原始值: [128.0, -256.0, 64.0, -32.0, 192.0, -16.0, 384.0, -448.0] |
C++ 实现 (components/dequant.h):
1 | bf16x8 cvt_fp8x8_bf16x8(const fp8x8 &inputs, const __nv_bfloat162 &scale_bf162) { |
🔧 Hopper 硬件特性:
(float4)(fp8x4): Hopper Tensor Core 原生支持 FP8→FP32 的类型转换__float22bfloat162_rn(): 硬件 FP32→BF16 转换,round-to-nearest- 每个线程处理 16 个值,128 线程共处理 2048 个值 → 4 组 × 512/4 = 32 个线程覆盖一个 token 的 NoPE
🔧 Blackwell (SM100) 改进:
- SM100 新增 FP8_E8M0FNU 格式缩放因子(MODEL1 使用),用
__nv_cvt_e8m0x2_to_bf162raw硬件指令转换- SM100 的反量化在 Warpgroup 2 中完成后,结果直接写入 shared memory,随后 UTCMMA 从 shared memory 读取
7.5 写入共享内存
1 | // 写入 sK (shared memory), 使用 interleaved 布局 |
对于 CLUSTER_SIZE=2 的情况(h_q=128),还需写入 peer CTA 的共享内存:
1 | if constexpr (CLUSTER_SIZE == 2) { |
🔧 Hopper Cluster 特性:
get_peer_addr(): 通过 XOR 地址高位(16MB 偏移)直接访问 cluster 中邻居 SM 的 shared memoryst_async_128b: 异步写入 + mbarrier 完成通知- PTX:
st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v2.s64 [dst], {data}, [mbar]- 这允许两个 CTA 共享反量化后的 K 矩阵,无需经过 global memory
7.6 加载 RoPE 部分(不量化)
1 | // RoPE 部分直接从 global memory 加载 BF16,不需反量化 |
7.7 设置有效性标记
1 | // 线程 0-31 检查索引有效性 |
7.8 通知 Consumer
1 | fence_view_async_shared(); // 确保所有共享内存写入对 async proxy 可见 |
🔧 Hopper 硬件特性:
fence_view_async_shared(): 确保 shared memory 写入对 TMA/async proxy 可见的 fence 指令- Transaction Barrier:
bar_k_local_ready是ClusterTransactionBarrier,128 线程共同 arrive,Consumer 的wait才会通过- 双缓冲 (NUM_K_BUFS=2): Producer 写 buf 0 的同时 Consumer 读 buf 1,实现流水线
第 8 步:QK^T 矩阵乘法
8.1 SM90 (Hopper) — WGMMA
Consumer A (Warpgroup 0) 等待 K 就绪后执行:
1 | // 等待 Producer 完成反量化 |
WGMMA 计算 P = Q · K^T:
1 | // WGMMA SS 模式:Q 和 K 都在 shared memory |
具体计算 (以 head 0, token 0 为例):
1 | Q[head_0] = [q_0, q_1, ..., q_575] (576 维:512 维 NoPE + 64 维 RoPE) |
🔧 WGMMA 硬件特性:
- 异步执行:
wgmma.commit()发起后,线程可以继续做其他工作,用wgmma.wait()同步- 多 stage 流水线: FlashMLA 用 2-3 个 buffer 轮转,WGMMA 计算时同时加载下一批 K/V
- Cluster 协作: h_q=128 时,2 个 CTA 组成 cluster,每个 CTA 处理 64 个 head,通过 XOR 寻址共享数据
8.2 SM100 (Blackwell) — UTCMMA
1 | // SM100: UTCMMA TS 模式 — Q 在 TMEM, K 在 shared memory |
SM100 vs SM90 的关键区别:
| SM90 WGMMA | SM100 UTCMMA | |
|---|---|---|
| Q 位置 | shared memory | TMEM (512KB 片上存储) |
| K 位置 | shared memory | shared memory |
| 累加器位置 | 寄存器 (rP) | TMEM (tP) |
| 发起者 | 128 线程 (warpgroup) | 1 线程 (elect_one_sync) |
| 指令前缀 | wgmma.mma_async |
tcgen05.mma.ws |
| 带宽 | shared memory → Tensor Core | TMEM → Tensor Core (更高) |
🔧 TMEM 优势:
- TMEM 带宽 ~20TB/s,远高于 shared memory 的 ~8TB/s
- Q 常驻 TMEM 后,每个 KV block 的 QK^T 都不需要重新加载 Q
- 累加器放在 TMEM 中,减轻寄存器压力 (SM100 WG1 仅需 72 regs,SM90 WG0 需 192 regs)
第 9 步:Online Softmax
这是注意力计算的关键 — 在不知道全部 K 的情况下,增量式计算 softmax。
9.1 掩码无效 token
1 | for (int i = 0; i < size(cur_rP); ++i) { |
9.2 求行最大值 (warp 内规约)
1 | float cur_max = -INFINITY; |
具体数值 (head 0, 第 1 个 block 的 64 个 KV token):
1 | P[0][0..63] = [3.52, -1.23, 2.17, 0.89, ..., -0.45] |
9.3 更新 running max 和 rescale
1 | cur_max *= scale_softmax_log2; // = 0.2116 |
关键: 第一个 block 时 scale_for_old ≈ 0,所以之前累积的 O 被清零,这是正确的(因为之前没有累积值)。
9.4 Rescale O 并计算 exp
1 | // Rescale 旧的 O |
🔧 为什么用 exp2 而不是 exp?
exp2f()比expf()快 ~2x,因为 Hopper 的 SFU (Special Function Unit) 原生支持 base-2 指数- 公式:
exp(x * scale) = exp2(x * scale * log2(e)) = exp2(x * scale_softmax_log2)- 这是一个经典的 GPU 性能优化,Hopper 和 Blackwell 都受益
9.5 保存 scale factor 到共享内存
1 | if (idx_in_warpgroup % 4 == 0) |
🔧 Hopper Named Barrier 特性:
- 计算完 softmax 后,通过
NamedBarrier::arrive(256, sScale_and_sS_ready)通知 Warpgroup 1- Named Barrier 是 Hopper 新增功能,允许 warpgroup 之间细粒度同步
第 10 步:SV 矩阵乘法 — O += S · V
10.1 Warpgroup 0 — V_left (前 256 维)
Softmax 结果 S 已在寄存器中,V 的左半部分 (前 256 维) 在 shared memory:
1 | gemm<false, -1>( // zero_init=false (累加到现有 rO) |
具体计算:
1 | O_left[64×256] += S[64×64] × V_left[64×256] |
🔧 Hopper WGMMA RS 模式:
- RS (Register-Shared): S 在寄存器,V 在 shared memory
- 这比 SS 模式快,因为 S 刚从 softmax 计算出来,就在寄存器中,不需要写回 shared
- 每次 WGMMA shape: M=64, N=256, K=16
第 10 步:Warpgroup 1 — O += S · V_right
同时,Warpgroup 1 (线程 128-255) 处理 V 的右半部分:
10.2.1 等待 S 和 scale factor
1 | NamedBarrier::arrive_and_wait(256, NamedBarriers::sScale_and_sS_ready); |
10.2.2 Rescale 自己的 O
1 | float cur_scales[2]; |
10.2.3 WGMMA 计算
1 | gemm<false, -1>( |
🔧 为什么 Warpgroup 1 用 SS 而不是 RS?
- Warpgroup 1 没有参与 softmax 计算,S 不在它的寄存器中
- S 由 Warpgroup 0 通过
save_rPb_to_sP写入 shared memory (sS)- 所以 Warpgroup 1 必须从 shared memory 读取 S → SS 模式
10.2.4 通知 Producer 可以复用 buffer
1 | plan.bar_k_avail[buf_idx].arrive(); // 告诉 Producer: 我已经用完这个 K buffer |
第 11 步:循环处理第 2 个 block
回到第 8 步,Producer 开始加载第 2 个 block(token index 64-127)。
由于使用双缓冲 (NUM_K_BUFS=2),Producer 写 buf 1 的同时 Consumer 还在读 buf 0:
1 | 时间线: |
第 2 个 block 的 Online Softmax 更新:
1 | 第 2 个 block 的 P 值 (head 0): |
🔧 双缓冲是 GPU 流水线的经典模式,不特定于某个架构,但 Hopper 的 Transaction Barrier 让同步更高效。
第 12 步:跨 Warpgroup 的 L 规约
所有 block 处理完后,需要合并两个 Warpgroup 的 L (exp-sum):
1 | // Warp 内规约 |
数值:
1 | 最终 rL[head 0] = 80.7 (所有 128 个 token 的 exp-sum) |
🔧 硬件特性:
__shfl_xor_sync是 warp shuffle,所有 CUDA GPU 支持。跨 warpgroup 通过 shared memory 通信。
第 13 步:Attention Sink 处理
如果提供了 attn_sink(DeepSeek 中用于 sink token 的预计算注意力值):
1 | if (params.attn_sink != nullptr) { |
Attention Sink 的含义: 假装有一个额外的 “sink” token,其注意力分数是固定的
attn_sink值。这让模型可以把一部分注意力”倒掉”,避免所有注意力都集中在 top-k token 上。
第 14 步:输出写回
14.1 Warpgroup 0: O_left 写入分 split/no-split 两种情况
No-split (is_no_split=true, 我们的例子):
1 | // 将 FP32 寄存器值缩放并转为 BF16 |
🔧 Hopper TMA Store 特性:
- SM90_TMA_STORE_5D: 5 维 TMA 存储,可以直接从 shared memory 写到 global memory 的任意 5D 位置
- 只需 1 个线程发起,硬件自动完成
- 使用
CUtensorMap描述符,包含 swizzle 信息和地址映射🔧 Blackwell 改进:
- SM100 仍然使用 TMA Store,但 TMEM → shared → global 的路径更短
- SM100 的 TMA 吞吐量更高
Split 情况 (如果一个 batch 被分到多个 SM partition):
1 | // 写入 o_accum (FP32),而不是最终 output (BF16) |
14.2 写入 LSE
1 | int i = threadIdx.x; |
第 15 步:PDL — 提前启动 Combine Kernel
1 | // 最后一个 batch 处理完毕 |
🔧 Hopper PDL (Programmatic Dependent Launch) 特性:
- 传统 CUDA: kernel B 必须等 kernel A 的所有 CTA 完成才能启动
- PDL: kernel B 可以在 kernel A 的部分 CTA 完成后就启动
- 对于我们的例子:SM Partition 0 (batch 0) 完成后,batch 0 的 combine 可以立即开始,不需等 batch 1
- 通过
cudaLaunchAttributeProgrammaticStreamSerialization启用- Blackwell 完全支持此特性
第 16 步:Combine Kernel — 合并 Split 结果
文件: csrc/smxx/decode/combine/combine.cu
对于我们的例子(每个 batch 只有 1 个 split,is_no_split=true),combine kernel 直接 return:
1 | if (my_num_splits == 1) return; // 无需合并! |
但当 batch 很大或序列很长时(例如 topk=8192, num_sm_parts=132),一个 batch 的 KV 会被切分到多个 SM partition,此时 combine kernel 的工作如下:
假设有 3 个 split 的情况:
1 | Grid: [batch_size=2, s_q=1, ceil(128/8)=16] |
Step 1: 求全局 max_lse (warp 内 shuffle 规约)
1 | max_lse = max(4.2, 3.8, 2.1) = 4.2 |
Step 2: 求 sum_lse
1 | sum_lse = exp2(4.2-4.2) + exp2(3.8-4.2) + exp2(2.1-4.2) |
Step 3: 加权合并 O
1 | O_final = (1.0/1.8932) × [1.0×o_accum[0] + 0.7579×o_accum[1] + 0.1353×o_accum[2]] |
Step 4: 写回结果
1 | // 将合并后的 O 写入 global memory |
第 17 步:完整流水线时间线
流水线时间线:
1 | 时间 → |
第 18 步:SM90 (Hopper) vs SM100 (Blackwell) 对比总结
| 特性 | SM90 (Hopper, H100/H800) | SM100 (Blackwell, GB200) |
|---|---|---|
| 矩阵乘法 | WGMMA (SS/RS 模式) | UTCMMA (TS/SS 模式,tcgen05 指令) |
| Q 存储位置 | Shared Memory (SW128 布局) | TMEM (512KB 片上存储,零拷贝) |
| K 加载 | 线程协作 LDG + 反量化 | TMA Gather4 (硬件稀疏 gather) |
| Warpgroup 划分 | 3 组:128+128+128 线程 | 3 组:32+192+160 线程 (更不均匀) |
| 寄存器分配 | 192/160/152 regs | 224/72/208 regs (更极端的再分配) |
| 跨 SM 通信 | Cluster Shared Memory (XOR 寻址) | 同上 + 更大 cluster (最多 16 SM) |
| Barrier | Named Barrier + Transaction Barrier | 同上 + TCGen05 fence 指令 |
| 缩放因子格式 | float32 (V3.2) | float32 (V3.2) / FP8_E8M0FNU (MODEL1) |
| PDL | cudaTriggerProgrammaticLaunchCompletion | 同上 |
| TMA | 2D Block Copy | 2D Block Copy + Gather4 (稀疏模式) |
Blackwell 关键创新:
- TMEM (Tensor Memory): 512KB 的新增片上存储,带宽高于 shared memory。Q 矩阵常驻 TMEM,避免反复从 shared memory 读取
- UTCMMA: 新一代 Tensor Core 指令,支持 TMEM 直接作为操作数源
- TMA Gather4: 硬件实现的 2D 稀疏 gather,比线程协作的
__ldg更高效 - 更灵活的 Warpgroup 分工: 32 线程专门做 softmax(不参与 MMA),减少 register pressure
第 19 步:最终输出
1 | out shape: [2, 1, 128, 512] BF16 (h_q=128 来自 config.json) |
学习进度
已完成:
- ✅ 项目结构概览 (入口 → dense/sparse 路由)
- ✅ Sparse Decode 完整 19 步 walkthrough (SM90 为主,带具体数值)
- ✅ DeepSeek-V3.2 config.json 参数映射
- ✅ FP8 量化/反量化机制 (V3.2 布局,656 字节/token)
- ✅ Tile Scheduler 贪心调度算法详解 (含 SM/CTA/Cluster 概念)
- ✅ SM90 三 Warpgroup 分工 (Consumer A / Consumer B / Producer)
- ✅ SM100 三大硬件创新 (TMEM, UTCMMA, TMA Gather4) 及其 kernel 实现对比
- ✅ SM100 三 Warpgroup 分工 (Softmax / MMA+Produce / Dequant)
- ✅ exp2f 优化的数学推导 (换底公式)
- ✅ Online Softmax 增量更新 + LSE 合并
- ✅ PDL (Programmatic Dependent Launch) 提前启动 Combine Kernel
- ✅ Attention Sink 机制
未深入探索:
- ⬜️ Dense Decode 路径 (
csrc/sm90/decode/dense/splitkv_mla.cuh) — 见下方补充 - ⬜️ Prefill 路径 (dense prefill + sparse prefill) — 见下方补充
- ⬜️ SM100 head64x2 / head128 实现 (h_q > 64 时的 Blackwell 方案)
- ⬜️
flash_attn_varlen_func/flash_attn_varlen_qkvpacked_func等变长接口 - ⬜️ CLC (Cluster Launch Control) 在 SM100 prefill 中的使用
- ⬜️ 性能调优细节 (bank conflict 避免,L2 cache 策略,占用率分析)
- ⬜️ 端到端 benchmark 与 roofline 分析
附录
A:Dense Decode vs Sparse Decode 对比
虽然本文聚焦于 Sparse Decode,但理解 Dense Decode 的路径有助于理解为什么 Sparse 是必要的。
A.1 Dense Decode 的工作流程
1 | # Dense 路径不需要 indices |
Dense 路径的 kernel 路由:
1 | // csrc/api/dense_decode.h |
A.2 Dense vs Sparse 的核心区别
| 特性 | Dense Decode | Sparse Decode |
|---|---|---|
| 输入 | block_table + cache_seqlens |
indices (top-k token 列表) |
| KV 访问模式 | 顺序遍历所有 page block | 随机 gather top-k token |
| 计算复杂度 | O(n²) — 处理所有 KV token | O(n·k) — 只处理 top-k token |
| 适用场景 | 短上下文 (≤ 2048 tokens) | 长上下文 (> 2048 tokens) |
| SM100 支持 | ❌ 不支持 | ✅ 支持 |
| 性能 (128K context) | ~10x 慢 | 基准 |
A.3 为什么 SM100 不支持 Dense Decode?
Blackwell (SM100) 的架构设计更倾向于稀疏计算:
- TMA Gather4 是为稀疏访问优化的,顺序访问反而没有优势
- TMEM 容量有限 (512KB),长上下文的 Dense 路径需要更多片上存储
- 市场定位: GB200 主要面向长上下文推理,Sparse 是默认场景
实践建议:
- 上下文长度 ≤ 2048: 用 Dense Prefill + Sparse Decode
- 上下文长度 > 2048: 全程 Sparse (Prefill + Decode)
B:Prefill 路径简介
Decode 路径是”生成一个 token”,Prefill 路径是”处理整个 prompt”。
B.1 Prefill vs Decode 的计算差异
| 维度 | Prefill | Decode |
|---|---|---|
| seq_len_q | 整个 prompt (可能数千 tokens) | 1 (只生成下一个 token) |
| 计算类型 | 计算密集型 (FLOPS-bound) | 访存密集型 (Memory-bound) |
| KV Cache | 生成并存储 KV Cache | 读取已有 KV Cache |
| 延迟敏感度 | 较低 (用户等待可接受) | 极高 (影响 token 生成速度) |
| 优化目标 | 吞吐量最大化 | 延迟最小化 |
B.2 Sparse Prefill 的特殊性
Sparse Prefill 与 Sparse Decode 的关键区别:
1 | // Sparse Prefill 启动配置 (SM90) |
Sparse Prefill 的特点:
- 需要处理因果掩码 (causal mask) — 每个 query token 只能关注之前的 token
- Q 和 KV 都是动态的,不能预先加载到 shared memory
- 通常用 CTA Tiling 策略,每个 CTA 处理一个 Q block × KV block 的子矩阵
B.3 FlashMLA 的 Prefill 实现文件
| 文件 | 内容 |
|---|---|
csrc/sm90/prefill/sparse_fp8/ |
SM90 Sparse Prefill 实现 |
csrc/sm100/prefill/ |
SM100 Prefill 实现 (支持 CLC) |
csrc/smxx/prefill/flash_mla_fwd.cuh |
Prefill 主 kernel |
CLC (Cluster Launch Control): SM100 Prefill 使用的高级特性,允许动态调整 cluster 大小,优化不同序列长度的性能。
C:性能分析与优化建议
C.1 Roofline 模型分析
Decode 阶段 (Memory-bound):
1 | 理论峰值带宽:H800 = 3.35 TB/s |
优化空间:
- FP8 量化已经减少 43% 带宽 (1152 → 656 bytes/token)
- 进一步量化到 INT4 可能再减少 50%,但精度损失需评估
C.2 Bank Conflict 避免策略
FlashMLA 使用以下策略避免 shared memory bank conflict:
Swizzled Layout (SW128):
1
2
3// 128 字节粒度的地址映射
// 将连续的 128 字节映射到不同的 bank
// 避免 32 线程同时访问同一 bankInterleaved 布局:
1
2
3
4// K 矩阵按列交错存储
// 线程 0 访问列 0,32,64...
// 线程 1 访问列 1,33,65...
// 避免 warp 内冲突
C.3 L2 Cache 策略
| 数据类型 | L1 策略 | L2 策略 | 原因 |
|---|---|---|---|
| Q 矩阵 | EVICT_FIRST | 128B 预取 | 只用一次,优先驱逐 |
| K 缩放因子 | EVICT_LAST | 128B 预取 | 可能被多个 block 复用 |
| FP8 NoPE | EVICT_LAST | 256B 预取 | 反量化后不再需要 |
| RoPE | EVICT_LAST | 128B 预取 | BF16 直接加载 |
C.4 占用率 (Occupancy) 分析
SM90 Sparse Decode:
1 | 每 CTA 资源使用: |
优化建议:
- 减少 WG0 寄存器使用 (当前 192 → 目标 160)
- 增加 CTA/SM 数 (1 → 2),但需要减少 shared memory 使用
D:实践指南 — 如何在你的项目中使用 FlashMLA
D.1 安装 FlashMLA
1 | # 克隆仓库 |
D.2 基本使用示例
1 | import torch |
D.3 与 vLLM / SGLang 集成
vLLM 集成:
1 | # vllm/attention/flashmla.py |
SGLang 集成:
1 | # sglang/srt/layers/attention/flashmla.py |
D.4 常见问题排查
| 问题 | 可能原因 | 解决方案 |
|---|---|---|
Expected kv.dtype() == torch::kBFloat16 |
错误配置 flashmla_kv=False |
设置 flashmla_kv=True |
SM100 not supported |
在 GB200 上用 Dense Decode | 切换到 Sparse Decode |
CUDA out of memory |
KV Cache 过大 | 减少 page_block_size 或启用 offload |
| 性能不达标 | 未启用 FP8 | 确保 is_fp8_kvcache=True |
E:扩展阅读
E.1 相关论文
DeepSeek-V3.2 Technical Report (2026)
- DSA (DeepSeek Sparse Attention) 原始论文
- 解释 top-k 选择算法和索引生成
FlashMLA: Efficient MLA Attention on GPUs (FlashMLA 团队)
- FlashMLA 架构设计文档
- FP8 量化策略和性能分析
Hopper Architecture Whitepaper (NVIDIA)
- SM90 架构详解
- TMA、WGMMA、Cluster 特性
Blackwell Architecture Whitepaper (NVIDIA)
- SM100 架构详解
- TMEM、UTCMMA、TMA Gather4 特性
E.2 相关项目
- FlashMLA: https://github.com/deepseek-ai/FlashMLA
- vLLM: https://github.com/vllm-project/vllm
- SGLang: https://github.com/sgl-project/sglang
- LMCache: https://github.com/LMCache/LMCache (KV Cache 存储层)
E.3 关键文件索引
| 文件 | 内容 |
|---|---|
flash_mla/flash_mla_interface.py |
Python API,路由到 dense/sparse |
csrc/api/sparse_decode.h |
C++ 接口,架构分发,参数准备 |
csrc/params.h |
所有参数结构体定义 |
csrc/smxx/decode/get_decoding_sched_meta/ |
Tile scheduler,工作分配 |
csrc/sm90/decode/sparse_fp8/config.h |
SM90 kernel 配置和共享内存布局 |
csrc/sm90/decode/sparse_fp8/splitkv_mla.cuh |
SM90 三 warpgroup kernel 实现 |
csrc/sm90/decode/sparse_fp8/components/dequant.h |
FP8 反量化实现 |
csrc/sm90/decode/sparse_fp8/components/helpers.h |
WGMMA helper, cluster async ops |
csrc/sm100/decode/head64/config.h |
SM100 kernel 配置 (TMEM layout) |
csrc/sm100/decode/head64/kernel.cuh |
SM100 三 warpgroup kernel 实现 |
csrc/kerutils/.../sm100/gemm.cuh |
UTCMMA 内联 PTX |
csrc/kerutils/.../sm100/intrinsics.cuh |
TMA Gather4, TMEM ops |
csrc/smxx/decode/combine/combine.cu |
Split-KV 结果合并 |
tests/quant.py |
FP8 量化/反量化参考实现 |
tests/ref.py |
纯 PyTorch 参考实现 |
E.4 技术博客
- 前作:《FlashMLA 深度解析:FP8 KV Cache 与 DSA 稀疏注意力实现原理》
https://ggaaooppeenngg.github.io/zh-CN/2026/04/10/flashmla-fp8-dsa-deep-dive/
总结
本文以 DeepSeek-V3.2 的真实配置为基础,通过一个最小化的数值例子(b=2, topk=128),逐步展示了 FlashMLA Sparse Decode 的完整计算流程:
- Python 层 路由到
sparse_decode_fwd - C++ 层 验证输入并打包参数结构体
- Tile Scheduler 将工作均匀分配到 66 个 SM partition
- FP8 KV Cache 以 656 字节/token 的格式存储(512 FP8 + 16 scales + 128 BF16 RoPE)
- CUDA Kernel 根据架构选择 SM90 (Cluster + WGMMA) 或 SM100 (TMEM + UTCMMA + TMA Gather4)
- TMA 加载 Q 矩阵到 shared memory / TMEM
- FP8 反量化 — Producer warpgroup 从 KV cache gather token 并反量化为 BF16
- QK^T 矩阵乘法 — WGMMA (SM90) 或 UTCMMA (SM100) 计算注意力分数 P = Q·K^T
- Online Softmax — 增量式 softmax,exp2f 优化 + LSE 合并
- SV 矩阵乘法 — O = S · V,Warpgroup 0/1 并行计算左右两半
- 结果写回与 Combine — 写回 O/LSE,merge kernel 合并 split 结果(如果需要)
关键硬件优化:
- SM90 (Hopper): Cluster 协作、WGMMA 异步矩阵乘法、TMA 加载
- SM100 (Blackwell): TMEM 片上存储、UTCMMA 指令、TMA Gather4 硬件稀疏 gather
性能优势:
- FP8 量化减少 43% 的内存带宽消耗
- DSA 稀疏注意力将 O(n²) 复杂度降为 O(n·k)
- 在 128K 上下文长度下可实现 3 倍更快的推理速度
参考资料
- DeepSeek-V3.2 config.json: https://huggingface.co/deepseek-ai/DeepSeek-V3.2/blob/main/config.json
- FlashMLA GitHub: https://github.com/deepseek-ai/FlashMLA
- NVIDIA Hopper Architecture: https://developer.nvidia.com/blog/nvidia-hopper-architecture-in-depth/
- NVIDIA Blackwell Architecture: https://developer.nvidia.com/blog/nvidia-blackwell-architecture-in-depth/
- 前作:《FlashMLA 深度解析:FP8 KV Cache 与 DSA 稀疏注意力实现原理》https://ggaaooppeenngg.github.io


