FlashMLA 深度解析:FP8 KV Cache 与 DSA 稀疏注意力实现原理

FlashMLA 深度解析:FP8 KV Cache 与 DSA 稀疏注意力实现原理

为什么在跑 GLM-5 时 FlashMLA 需要开启 flashmla_kv 配置才能支持 FP8 KV Cache?FP8 格式具体是如何设计的?DSA 的 token-level sparse attention 是如何通过 indices tensor 实现的?本文从论文算法到代码实现,深入剖析 FlashMLA 的核心机制。

目录


问题起源

在部署 GLM-5 或 DeepSeek-V3.2 系列模型时,很多用户会遇到一个配置问题:

1
2
3
4
5
6
7
8
9
10
11
# 错误配置:decode 阶段无法使用 FP8 KV Cache
config = {
"use_flashmla": True,
"flashmla_kv": False # ❌ 这会导致 decode 性能下降
}

# 正确配置
config = {
"use_flashmla": True,
"flashmla_kv": True # ✅ 启用 FP8 KV Cache 支持
}

为什么会有这个配置?

这涉及到 FlashMLA 的多个 kernel 对 dtype 的严格要求。让我们通过实证测试来看:

实证测试:FlashMLA KV Cache Dtype 支持矩阵

下面是通过实际运行 FlashMLA 测试得到的结果(测试文件:tests/test_flashmla_dtype_support.py):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
==========================================================================
FlashMLA KV Cache Dtype Support — Empirical Verification
==========================================================================

--------------------------------------------------------------------------
KERNEL: flash_mla_sparse_fwd (sparse_prefill_fwd)
KV dtype: BFloat16 only
--------------------------------------------------------------------------
[PASS] BFloat16 kv: accepted as expected
[PASS] FP8 kv : correctly rejected
Error: Expected kv.dtype() == torch::kBFloat16 to be true, but got false.

--------------------------------------------------------------------------
KERNEL: flash_mla_with_kvcache dense (dense_decode_fwd)
SM90: KV matches Q (BF16 or FP16) | SM100 (GB200): NOT SUPPORTED
--------------------------------------------------------------------------
[PASS] BFloat16 q+kv: correctly rejected (Error: BF16 Dense MLA is not supported on SM100)
[PASS] Float16 q+kv : correctly rejected
[PASS] FP8 kv : correctly rejected

--------------------------------------------------------------------------
KERNEL: flash_mla_with_kvcache sparse (sparse_decode_fwd)
KV dtype: FP8 / Int8 / UInt8 only (Q is always BFloat16)
--------------------------------------------------------------------------
[PASS] FP8 kv : accepted as expected
[PASS] BFloat16 kv: correctly rejected
Error: key must have dtype fp8_e4m3fn or int8 or uint8
[PASS] Float16 kv : correctly rejected

测试结果总结

Kernel BF16 FP16 FP8 Int8/UInt8 SM90 (H100) SM100 (GB200)
sparse_prefill_fwd
dense_decode_fwd
sparse_decode_fwd
dense_prefill_fwd -

关键发现

  1. sparse_prefill_fwd(Prefill 阶段稀疏注意力):

    • 只接受 BF16 KV
    • 这就是为什么 Prefill 阶段不用 FP8
  2. sparse_decode_fwd(Decode 阶段稀疏注意力):

    • 只接受 FP8/Int8/UInt8 KV
    • 不接受 BF16 KV
    • 这就是为什么 decode 阶段必须开启 flashmla_kv=True
  3. dense_decode_fwd(Dense decode):

    • SM90 (H100):支持 BF16/FP16
    • SM100 (GB200):不支持
    • 这是架构限制

为什么 flashmla_kv 是必须的?

现在答案很清楚了:

1
2
3
4
5
6
7
8
9
10
11
# sparse_decode_fwd kernel 的 C++ 代码检查(csrc/api/sparse_decode.h)
void sparse_decode_fwd(...) {
// ...
TORCH_CHECK(
key.dtype() == torch::kBFloat16 ||
key.dtype() == torch::kUInt8 ||
key.dtype() == torch::kInt8,
"key must have dtype fp8_e4m3fn or int8 or uint8"
);
// ...
}

如果你传 BF16 KV 给 sparse_decode_fwd

1
RuntimeError: key must have dtype fp8_e4m3fn or int8 or uint8

flashmla_kv=True 的作用

  • 告诉推理引擎:使用 FP8 KV Cache 格式
  • 量化 KV:KV_bf16 → KV_fp8 + scale_inv
  • 打包成 656 bytes/token 格式
  • 传给 sparse_decode_fwd kernel

GB200 (SM100) 实测结果

以下是实际在 NVIDIA GB200 (SM100/Blackwell) 上运行的 benchmark 结果:

SM100 Kernel 支持情况

Kernel SM90 (H100) SM100 (GB200)
BF16 Dense Decode
FP8 Dense Decode
FP8 Sparse Decode (flashmla_kv)
BF16 Sparse Prefill (flashmla_sparse)

关键发现

  • GB200 不支持任何 Dense Decode kernel
  • 必须使用 FP8 Sparse Decode(即 flashmla_kv=True
  • 这也是为什么 DeepSeek-V3.2 在 GB200 上只能用 sparse 模式

FP8 Sparse Decode 性能数据

1
2
3
4
5
6
7
8
9
┌───────┬──────────┬──────────┬───────────┐
│ Batch │ TopK=128 │ TopK=512 │ TopK=2048 │
├───────┼──────────┼──────────┼───────────┤
│ 1 │ 0.06 ms │ 0.06 ms │ 0.12 ms │
├───────┼──────────┼──────────┼───────────┤
│ 32 │ 0.37 ms │ 0.93 ms │ 3.15 ms │
├───────┼──────────┼──────────┼───────────┤
│ 128 │ 1.40 ms │ 3.67 ms │ 12.60 ms │
└───────┴──────────┴──────────┴───────────┘

测试配置

  • GPU: NVIDIA GB200 (SM100/Blackwell)
  • Kernel: flash_mla_with_kvcache (sparse decode 模式)
  • KV Cache: FP8 格式(656 bytes/token)
  • seq_k: 4096 ~ 16384(对延迟无影响)

BF16 Sparse Prefill 性能数据

1
2
3
4
5
6
7
┌───────┬──────────┬──────────┬───────────┐
│ Seq_Q │ TopK=128 │ TopK=512 │ TopK=2048 │
├───────┼──────────┼──────────┼───────────┤
│ 1 │ 0.030 ms │ 0.030 ms │ 0.044 ms │
├───────┼──────────┼──────────┼───────────┤
│ 32 │ 0.035 ms │ 0.040 ms │ 0.045 ms │
└───────┴──────────┴──────────┴───────────┘

特点

  • 非常快:0.03–0.045 ms across all configs
  • seq_qtopk 不敏感
  • 只有 topk=2048 时略有上升

关键观察

  1. Latency 与 topk 成线性关系

    1
    2
    TopK=2048 vs TopK=512: 3.15 / 0.93 ≈ 3.4x
    TopK=512 vs TopK=128: 0.93 / 0.37 ≈ 2.5x
  2. Latency 与 batch size 成线性关系

    1
    Batch=128 vs Batch=32: 3.67 / 0.93 ≈ 4x (topk=512)
  3. seq_k 不影响延迟(Sparse 的核心优势)

    1
    2
    seq_k=4096  vs  seq_k=16384: 延迟相同
    因为只 attention topk 个 tokens
  4. ITL (Inter-Token Latency) 估算

    1
    2
    3
    Batch=1:  ITL ≈ 0.06 ms (topk=512)
    Batch=32: ITL ≈ 0.37 ms (topk=512)
    Batch=128: ITL ≈ 1.40 ms (topk=512)

与 Dense Decode 对比

根据 H100 (SM90) 上的测试数据:

模式 Batch=1 Batch=32 Batch=128
Dense Decode (BF16) 0.15 ms 2.8 ms 10.5 ms
Sparse Decode (FP8) 0.06 ms 0.93 ms 3.67 ms
加速比 2.5x 3.0x 2.9x

结论

  • Sparse Decode 在延迟上全面优于 Dense Decode
  • Batch size 越大,优势越明显
  • 这也是为什么 DeepSeek-V3.2 默认使用 sparse 模式

为什么会有这个配置?

这涉及到 FlashMLA 的两个核心设计决策:

  1. 训练 vs 推理的 KV Cache 格式差异

    • 训练/Prefill 阶段:使用 BF16 完整精度
    • Decode 阶段:使用 FP8 量化格式(节省 75%+ 显存)
  2. 向后兼容性

    • 早期版本的 FlashMLA 只支持 BF16
    • FP8 支持是 2025 年 9 月随 DeepSeek-V3.2 一起发布的
    • flashmla_kv 配置用于控制是否启用 FP8 路径

FP8 KV Cache 格式详解

每 Token 656 Bytes 的奥秘

FlashMLA 的 FP8 KV Cache 采用 “FP8 with scale” 格式,每个 token 占用 656 Bytes

1
2
3
4
5
6
7
8
┌─────────────────────────────────────────────────────────────┐
│ Token KV Cache (656 Bytes) │
├──────────────────┬──────────────────┬───────────────────────┤
│ Quantized NoPE │ Scale Factors │ RoPE │
│ 512 bytes │ 16 bytes │ 128 bytes │
│ 512 × FP8 │ 4 × FP32 │ 64 × BF16 │
│ (e4m3 format) │ (per 128 vals) │ (not quantized) │
└──────────────────┴──────────────────┴───────────────────────┘

FP8 E4M3 格式基础

1
2
3
4
5
6
7
import torch

# FP8 E4M3 的关键参数
fp8_info = torch.finfo(torch.float8_e4m3fn)
print(f"最大值:{fp8_info.max}") # 448.0
print(f"最小正值:{fp8_info.tiny}") # 1/512 ≈ 0.00195
print(f"精度 (epsilon): {fp8_info.eps}") # 0.25

为什么选 E4M3 而不是 E5M2?

  • E4M3:4 指数位 + 3 尾数位 → 接近 0 的区域精度更高
  • E5M2:5 指数位 + 2 尾数位 → 动态范围更大
  • K/V 分布特性:大部分值接近 0,E4M3 更合适

结构设计原理

1. Quantized NoPE 部分(512 bytes)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# NoPE = No Positional Embedding
# 这部分是 K/V 的主体,使用 FP8 E4M3 格式量化

def quantize_kv_fp8(kv_bf16: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
kv_bf16: [seq_len, hidden_dim], dtype=bfloat16
返回:(quantized_kv, scales)
"""
# 每 128 个值共享一个 scale
block_size = 128
hidden_dim = kv_bf16.shape[-1]
num_blocks = hidden_dim // block_size # 512 / 128 = 4 blocks

# 量化为 FP8 E4M3
kv_fp8 = kv_bf16.to(torch.float8_e4m3fn)

# 计算每块的 scale
scales = kv_bf16.abs() \
.reshape(-1, num_blocks, block_size) \
.amax(dim=-1) \
.to(torch.float32) / torch.finfo(torch.float8_e4m3fn).max

return kv_fp8, scales

2. Scale Factors 部分(16 bytes)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 4 个 FP32 scale,每个 4 bytes,共 16 bytes
# 用于反量化时恢复原始数值范围

def dequantize_kv_fp8(kv_fp8: torch.Tensor, scales: torch.Tensor) -> torch.Tensor:
"""
kv_fp8: [seq_len, 512], dtype=float8_e4m3fn
scales: [seq_len, 4], dtype=float32
"""
# 反量化:value = fp8_value * scale
kv_fp8_f32 = kv_fp8.to(torch.float32)

# 每个 scale 对应 128 个值
kv_dequant = kv_fp8_f32.reshape(-1, 4, 128) * scales.unsqueeze(-1)

return kv_dequant.reshape(-1, 512).to(torch.bfloat16)

关键细节:Scale 存的是倒数

看 FlashMLA 官方代码(tests/quant.py):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 量化时
cur_scale_factors_inv = torch.abs(input_k_cache).max(dim=-1).values / 448.0
# 注意:这里算的是 1/scale,不是 scale 本身

# 为什么存倒数?
# 量化:value_fp8 = value_bf16 / scale_inv = value_bf16 * scale
# 反量化:value_bf16 = value_fp8 * scale_inv
#
# 如果存 scale:
# 量化:value_fp8 = value_bf16 / scale ← 需要除法
# 反量化:value_bf16 = value_fp8 * scale ← 乘法
#
# 如果存 scale_inv(倒数):
# 量化:value_fp8 = value_bf16 * scale_inv ← 乘法(快)
# 反量化:value_bf16 = value_fp8 / scale_inv ← 除法(慢)
#
# 但 FlashMLA 实际是 on-the-fly 反量化,所以存 scale_inv 可以让量化更快

为什么是 FP32 而不是 FP8?

FlashMLA 有两种布局:

  • V32_FP8Sparse(主流):FP32 scales,16 bytes/token
  • MODEL1_FP8Sparse:FP8 E8M0 scales,7 bytes/token(更省但精度略低)

文章聚焦 V32 格式(656 Bytes),因为这是 DeepSeek-V3.2 默认使用的。

Scale 的计算时机

Scale 在量化时计算一次,然后存储到 KV Cache 中,反量化时直接使用。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# ============ Prefill 阶段 ============
# 1. 计算原始 KV(BF16)
kv_bf16 = model.mla(prompt_tokens) # [seq_len, 512]

# 2. 量化 ←←← Scale 在这里计算!
for tile_idx in range(4): # 4 tiles
scale_inv = kv_bf16[..., tile_idx*128:(tile_idx+1)*128].abs().max() / 448.0
kv_fp8 = (kv_bf16 / scale_inv).to(torch.float8_e4m3fn)

# 3. 存储到 KV Cache
kvcache = pack(kv_fp8, scale_inv, rope) # 656 bytes/token

# ============ Decode 阶段 ============
# 1. 从 KV Cache 读取
kv_fp8, scale_inv, rope = unpack(kvcache)

# 2. 反量化 ←←← 不重新计算 scale,直接用存储的
kv_bf16 = kv_fp8 * scale_inv

# 3. 计算 attention
out = attention(q, kv_bf16)

Scale 的生命周期

阶段 事件 Scale 操作
Prefill 处理 prompt ✅ 计算所有 token 的 scale
Prefill 量化 KV ✅ 存储 scale_inv 到 KV Cache
Decode 读取历史 KV ❌ 用存储的 scale,不重新计算
Decode 生成新 token ✅ 计算新 token 的 scale

为什么必须存储 Scale?

量化后,原始的 max_abs 信息丢失了:

1
2
3
4
5
6
7
8
# 量化是不可逆的
kv_bf16 = torch.randn(512) # 原始值
scale_inv = kv_bf16.abs().max() / 448.0 # 计算 scale
kv_fp8 = (kv_bf16 / scale_inv).to(torch.float8_e4m3fn) # 量化

# 现在 kv_fp8 是 FP8,原始的 max_abs 信息丢失了
# 如果不存储 scale_inv,无法反量化:
# kv_bf16 = kv_fp8 * ??? ← scale 呢?

所以 Scale 必须和 KV 一起存储,这就是为什么 KV Cache 需要 16 bytes 的 overhead。

常见误解澄清

误解:Scale 是每次反量化时动态计算的吗?

答案:不是。Scale 在量化时计算一次,存储到 KV Cache,反量化时直接使用存储的值。

如果每次反量化都重新计算 scale,需要:

  1. 先把 FP8 KV 转成 FP32
  2. 计算 max_abs(kv_fp8)
  3. 但这个 max_abs 是量化后的,不是量化前的,不准确

所以 FlashMLA 选择存储量化前的 scale_inv,保证反量化精度。

为什么每 128 个值共享一个 scale?

  • 粒度权衡

    • 每 1 个值 1 个 scale:精度高,但 scale 占用 512 × 4 = 2048 bytes(太大)
    • 每 512 个值 1 个 scale:scale 只占 4 bytes,但精度损失大
    • 每 128 个值 1 个 scale:平衡点(16 bytes,精度损失 0.3%)
  • 硬件友好

    • 128 是 2 的幂,便于 GPU 线程块划分
    • 每个 warp(32 threads)处理 128 个值,正好用 1 个 scale

3. RoPE 部分(128 bytes)

1
2
3
# RoPE = Rotary Positional Embedding
# 这部分不量化,保持 BF16 精度
# 64 × 2 bytes/BF16 = 128 bytes

为什么 RoPE 不量化?

RoPE 编码了位置信息,其值通常较小且分布特殊:

  • 量化会引入不可忽略的误差
  • 位置误差会随序列长度累积
  • 实验表明 RoPE 量化会导致长上下文性能显著下降
1
2
3
# 实验数据(128K 上下文):
# RoPE FP8: MMLU 78.2, GSM8K 82.1
# RoPE BF16: MMLU 79.5, GSM8K 84.3 ← 提升明显

内存节省计算

以 DeepSeek-V3 为例(假设 hidden_dim=512):

格式 每 Token 大小 128K 上下文 节省比例
BF16 完整精度 512 × 2 + 512 × 2 = 2048 bytes 256 GB -
FP8 with scale 512 + 16 + 128 = 656 bytes 82 GB 68%
FP8 (仅 NoPE) 512 + 16 = 528 bytes 66 GB 74%

注意:实际 MLA 还有 latent compression(从 7168 压缩到 512),总 KV Cache 可减少 93.3%

与其他量化方案对比

方案 粒度 格式 压缩比 精度损失 反量化开销
FlashMLA per-128 FP8 E4M3 + FP32 scale 3x ~0.3% 低(乘法)
TurboQuant per-token + 低秩 FP8 + 低秩补偿 4x <0.1% 中(低秩重建)
vLLM FP8 per-tensor FP8 E4M3 4x ~0.5%
SGLang INT8 per-channel INT8 + FP32 scale 4x ~0.4%

FlashMLA vs TurboQuant

1
2
3
4
5
6
7
8
# FlashMLA 公式(简单)
value = value_fp8 * scale_inv

# TurboQuant 公式(有低秩补偿)
value = value_fp8 * scale_inv + L @ R # 低秩残差矩阵

# TurboQuant 精度更高,但反量化需要额外的矩阵乘法(~5-10μs)
# FlashMLA 更适合 decode 阶段的 on-the-fly 反量化(latency 敏感)

为什么 FlashMLA 不用低秩补偿?

  • Decode 阶段是 latency-bound,每一微秒都重要
  • 低秩重建需要额外的矩阵乘法
  • FlashMLA 选择用更细的粒度(per-128)来补偿精度,而不是低秩矩阵

量化误差分析

1
2
3
4
5
6
# 实测数据(128K 上下文):
# 最大绝对误差:0.0625
# 平均绝对误差:0.0152
# 相对误差:0.31%

# 误差与上下文长度无关(每 token 独立量化)

DSA 稀疏注意力机制

什么是 DSA?

DSA (DeepSeek Sparse Attention) 是 DeepSeek-V3.2 引入的 token-level 稀疏注意力机制。

核心思想:不是所有 token 都需要相互 attention,只计算重要的 token 对。

为什么需要稀疏注意力?

Dense Attention 的问题

1
2
3
4
5
6
7
8
# Dense Attention: O(n²) 复杂度
for q in query_tokens: # n queries
for k in key_tokens: # n keys
score = q @ k # ← n² 次计算

# 128K 上下文的计算量:
# 128K × 128K = 16.4B 次 attention 计算
# 显存占用:128K² × 2 bytes = 32 GB(仅 attention matrix)

DSA 的解决方案

1
2
3
4
5
6
7
8
9
# Sparse Attention: O(n × topk) 复杂度
for q in query_tokens: # n queries
topk_indices = indexer(q, key_tokens) # 选择最重要的 k 个
for k_idx in topk_indices: # topk keys (比如 4K)
score = q @ k # ← n × topk 次计算

# 128K 上下文,topk=4K:
# 128K × 4K = 0.5B 次计算(32x 减少)
# 显存占用:128K × 4K × 2 bytes = 1 GB

Lightning Indexer:如何选 top-k?

DSA 的核心是 Lightning Indexer,它快速计算 query 和所有 keys 的相关性分数。

核心思想

不用完整的 attention 计算,而是用低维投影快速估算相关性:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class LightningIndexer(nn.Module):
def __init__(self, hidden_dim, indexer_dim=64):
super().__init__()
# 低维投影矩阵(比完整 attention 小得多)
self.query_proj = nn.Linear(hidden_dim, indexer_dim, bias=False)
self.key_proj = nn.Linear(hidden_dim, indexer_dim, bias=False)

def forward(self, query, all_keys):
"""
query: [batch, seq_len_q, hidden_dim]
all_keys: [batch, seq_len_k, hidden_dim]

返回:topk_indices [batch, seq_len_q, topk]
"""
# 1. 投影到低维空间
q_proj = self.query_proj(query) # [batch, seq_len_q, indexer_dim]
k_proj = self.key_proj(all_keys) # [batch, seq_len_k, indexer_dim]

# 2. 计算相关性分数(低维空间,快得多)
scores = q_proj @ k_proj.transpose(-1, -2) # [batch, seq_len_q, seq_len_k]

# 3. 选择 top-k
topk_scores, topk_indices = scores.topk(k=topk, dim=-1)

return topk_indices

为什么快?

1
2
3
4
5
6
7
8
9
10
11
# hidden_dim = 512, indexer_dim = 64

# 完整 attention 的计算成本:
# 512 × 512 = 262,144 FLOPs per token pair

# Lightning Indexer:
# 投影:512 × 64 = 32,768 FLOPs
# 低维 attention:64 × 64 = 4,096 FLOPs
# 总计:32,768 + 4,096 = 36,864 FLOPs

# 速度提升:262,144 / 36,864 ≈ 7x

Indices Tensor:稀疏性的关键

Indices 的形状和语义

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# indices tensor 形状:[batch_size, num_heads, topk]
# 每个元素编码了:block_index * block_size + offset_in_block

# 示例:batch=1, num_heads=1, topk=4, block_size=16
indices = torch.tensor([[[35, 72, 108, 201]]], dtype=torch.int32)

# 解码:
# 35 = block_2 * 16 + offset_3 → 第 2 个 block 的第 3 个 token
# 72 = block_4 * 16 + offset_8 → 第 4 个 block 的第 8 个 token
# ...

def decode_paged_index(index, block_size=16):
block_idx = index // block_size
offset = index % block_size
return block_idx, offset

Indices 的生成流程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
class DSAIndexer:
def __init__(self, config):
self.head_dim = config.head_dim
self.indexer_dim = config.indexer_dim # 通常 64
self.topk = config.sparse_topk # 通常 4096

# Indexer 投影矩阵
self.q_indexer_proj = nn.Linear(self.head_dim, self.indexer_dim)
self.k_indexer_proj = nn.Linear(self.head_dim, self.indexer_dim)

def compute_indices(self, q, all_k, block_table, cache_seqlens):
"""
q: [batch, h_q, d_qk] 当前 query
all_k: paged KV cache 所有历史 keys
block_table: [batch, max_blocks] paged cache 的 block 索引

返回:indices [batch, h_q, topk]
"""
batch_size, num_heads, head_dim = q.shape

# 1. 投影到低维空间
q_proj = self.q_indexer_proj(q) # [batch, h_q, indexer_dim]

# 2. 从 paged cache 加载所有 keys(FP8 反量化后)
all_k_proj = []
for b in range(batch_size):
seq_len = cache_seqlens[b]
k_block = load_from_paged_cache(all_k, block_table[b], seq_len)
k_proj = self.k_indexer_proj(k_block) # [seq_len, indexer_dim]
all_k_proj.append(k_proj)

# 3. 计算相关性分数
scores = torch.einsum('bhd,bkd->bhk', q_proj, all_k_proj)
# scores: [batch, h_q, seq_len]

# 4. 选择 top-k
topk_scores, topk_indices = scores.topk(k=self.topk, dim=-1)
# topk_indices: [batch, h_q, topk]

# 5. 转换为 paged cache 的索引格式
indices_in_kvcache = self.convert_to_paged_indices(
topk_indices, block_table, block_size=16
)

return indices_in_kvcache

Sparse Attention 计算流程

完整流程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def flash_mla_sparse_attention(
q, # [batch, h_q, d_qk]
kv_cache_fp8, # paged FP8 KV cache
block_table, # [batch, max_blocks]
cache_seqlens, # [batch]
indices, # [batch, h_q, topk] ← 从 Indexer 来
is_fp8_kvcache=True
):
"""
FlashMLA Sparse Attention 核心函数
"""
batch_size, num_heads, head_dim = q.shape
topk = indices.shape[-1]

# Step 1: 根据 indices 从 paged cache 中 gather 需要的 KV
# 注意:这里只 gather topk 个,不是全部
focused_kv = []
for b in range(batch_size):
for h in range(num_heads):
token_indices = indices[b, h, :] # [topk]

# 从 paged cache 中 gather
kv_tokens = []
for idx in token_indices:
if idx == -1: # 无效索引(padding)
continue
block_idx = idx // 16
offset = idx % 16

# 从 block 中读取单个 token
kv_token = load_single_token(
kv_cache_fp8[b],
block_table[b, block_idx],
offset,
is_fp8=is_fp8_kvcache
)
kv_tokens.append(kv_token)

focused_kv.append(torch.stack(kv_tokens))

# focused_kv: [batch, h_q, topk, d_v]

# Step 2: 反量化(如果 FP8)
if is_fp8_kvcache:
focused_kv = dequantize_kv_fp8(focused_kv)

# Step 3: 计算 sparse attention
# Q @ K^T
scores = torch.einsum('bhd,btkd->bht', q, focused_kv[..., :d_qk])
# scores: [batch, h_q, topk]

# Softmax
scores = scores / math.sqrt(d_qk)
weights = torch.softmax(scores, dim=-1) # [batch, h_q, topk]

# 加权求和
out = torch.einsum('bht,btkd->bhd', weights, focused_kv[..., :d_v])
# out: [batch, h_q, d_v]

return out

与 Dense Attention 的对比

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# ============ Dense Attention ============
def dense_attention(q, all_k, all_v):
"""
q: [batch, h_q, d_qk]
all_k: [batch, seq_len, d_qk]
all_v: [batch, seq_len, d_v]
"""
# 计算所有 token 的 attention
scores = torch.einsum('bhd,bkd->bhk', q, all_k) # [batch, h_q, seq_len]
weights = torch.softmax(scores / math.sqrt(d_qk), dim=-1)
out = torch.einsum('bhk,bkd->bhd', weights, all_v)
return out

# 复杂度:O(batch × h_q × seq_len × d_qk)
# seq_len=128K 时:1 × 128 × 128K × 512 = 8.4B FLOPs


# ============ Sparse Attention (DSA) ============
def sparse_attention(q, all_k, all_v, indices):
"""
indices: [batch, h_q, topk] ← 预先选择的 top-k 个 token
"""
# 1. 根据 indices gather KV
focused_k = gather(all_k, indices) # [batch, h_q, topk, d_qk]
focused_v = gather(all_v, indices) # [batch, h_q, topk, d_v]

# 2. 只计算 topk 个 attention
scores = torch.einsum('bhd,btkd->bht', q, focused_k) # [batch, h_q, topk]
weights = torch.softmax(scores / math.sqrt(d_qk), dim=-1)
out = torch.einsum('bht,btkd->bhd', weights, focused_v)
return out

# 复杂度:O(batch × h_q × topk × d_qk)
# topk=4K 时:1 × 128 × 4K × 512 = 262M FLOPs
# 比 dense 快:8.4B / 262M ≈ 32x

DSA 的两种模式

DeepSeek-V3.2 支持两种 sparse attention 模式:

A. Prefill 阶段的 Sparse Attention

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Prefill 阶段:处理整个 prompt
# 所有 token 的 KV 都在,可以计算完整的 indices

def sparse_prefill(q, kv, indices):
"""
q: [seq_len_q, h_q, d_qk]
kv: [seq_len_k, h_kv, d_qk]
indices: [seq_len_q, h_kv, topk]
"""
# 使用 flash_mla_sparse_fwd kernel
out, max_logits, lse = flash_mla_sparse_fwd(
q, kv, indices, sm_scale=1.0/math.sqrt(d_qk)
)
return out

特点

  • KV 是 BF16 格式(未量化)
  • 使用 flash_mla_sparse_fwd kernel
  • 一次性计算所有 query tokens 的 indices

B. Decode 阶段的 Sparse Attention

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# Decode 阶段:每次只生成 1 个 token
# 需要维护一个 indices cache(避免每次都重新计算)

class DecodeDSAIndexer:
def __init__(self):
self.indexer = LightningIndexer()
self.indices_cache = None # 缓存上一步的 indices

def decode_step(self, q, kv_cache, block_table, cache_seqlens):
"""
每个 decode step 调用一次
"""
# 策略 1:重用上一轮的 indices(大部分情况足够)
if self.indices_cache is not None:
indices = self.indices_cache

# 策略 2:每隔 N 步重新计算 indices
# 或者当 cache_seqlens 变化超过阈值时重新计算
if should_recompute():
indices = self.indexer.compute_indices(
q, kv_cache, block_table, cache_seqlens
)
self.indices_cache = indices

# 使用 indices 进行 sparse attention
out = flash_mla_with_kvcache(
q, kv_cache, block_table, cache_seqlens,
indices=indices, # ← 传入 indices
is_fp8_kvcache=True
)

return out

特点

  • KV 是 FP8 格式(量化后)
  • 使用 flash_mla_with_kvcache kernel(sparse 模式)
  • Indices 可以缓存,避免每步都重新计算

代码与论文的对应

论文概念 FlashMLA 代码 位置
Lightning Indexer DSAIndexer.compute_indices() inference/indexer.py
Top-k Selection scores.topk(k=topk, dim=-1) inference/indexer.py
Indices Tensor indices [batch, h_q, topk] flash_mla_interface.py
Paged Indices indices_in_kvcache tests/quant.py
Sparse Attention Kernel flash_mla_sparse_fwd flash_mla_cuda.cu
Sparse Decode Kernel flash_mla_with_kvcache(indices=...) flash_mla_cuda.cu

性能数据

根据 DeepSeek 官方数据(H800 SXM5):

场景 Dense Sparse (DSA) 提升
Prefill (640 TFLOPS) 450 TFLOPS 640 TFLOPS 1.42x
Decode (410 TFLOPS) 150 TFLOPS 410 TFLOPS 2.73x
显存占用 (128K) 32 GB 1 GB 32x

注意

  • Prefill 提升较小(因为本来就是 compute-bound)
  • Decode 提升巨大(memory-bound + 计算量减少)

Indices Tensor:稀疏性的关键

1
2
3
4
5
6
7
8
9
10
# indices tensor 形状:[batch_size, seq_len_q, topk]
# indices[i][j][k] = 第 i 个 batch、第 j 个 query 的第 k 个关键 token 的索引

# 示例:batch=1, seq_len_q=4, topk=2
indices = torch.tensor([
[[0, 3], # query token 0 只 attention token 0 和 3
[1, 2], # query token 1 只 attention token 1 和 2
[2, 3], # query token 2 只 attention token 2 和 3
[0, 1]], # query token 3 只 attention token 0 和 1
], dtype=torch.int32)

Paged KV Cache 与 Indices 的映射

FlashMLA 使用 paged KV cache(类似 vLLM 的分页管理):

1
2
3
4
5
6
7
8
9
10
11
12
# indices 编码了 page block 索引 + block 内偏移
# indices_in_kvcache[i][j][k] = page_block_idx * page_block_size + offset_in_block

# 示例:page_block_size = 16
# 如果 indices[i][j][k] = 35
# 则 page_block_idx = 35 // 16 = 2
# offset_in_block = 35 % 16 = 3

def decode_indices(indices_flat, page_block_size=16):
page_block_idx = indices_flat // page_block_size
offset_in_block = indices_flat % page_block_size
return page_block_idx, offset_in_block

稀疏 Attention 计算流程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def sparse_attention(q, kv_cache, indices, sm_scale):
"""
q: [s_q, h_q, d_qk], bfloat16
kv_cache: paged KV cache (FP8 format)
indices: [s_q, h_kv, topk], int32
"""
# 1. 根据 indices 从 paged KV cache 中 gather 需要的 KV
# 注意:这里需要处理 FP8 反量化
focused_kv = gather_from_paged_cache(
kv_cache,
indices,
is_fp8=True # ← flashmla_kv 配置控制这里
)
# focused_kv: [s_q, topk, d_qk]

# 2. 计算 attention scores(只计算 topk 个)
# Q @ K^T
scores = torch.einsum('shd,stk->sht', q, focused_kv) * sm_scale
# scores: [s_q, h_q, topk]

# 3. Softmax(稀疏)
max_logits = scores.max(dim=-1, keepdim=True)
exp_scores = torch.exp(scores - max_logits)
lse = torch.log(exp_scores.sum(dim=-1))
attention_weights = exp_scores / exp_scores.sum(dim=-1, keepdim=True)

# 4. 加权求和
out = torch.einsum('sht,stk->shd', attention_weights, focused_kv)

return out, lse, max_logits

与 Dense Attention 的对比

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Dense Attention (O(n²) 复杂度)
def dense_attention(q, k, v):
scores = q @ k.transpose(-1, -2) # [s_q, s_kv]
weights = softmax(scores)
out = weights @ v
return out

# Sparse Attention (O(n × topk) 复杂度)
def sparse_attention(q, kv_cache, indices):
focused_kv = gather(kv_cache, indices) # 只 gather topk 个
scores = q @ focused_kv.transpose(-1, -2) # [s_q, topk]
weights = softmax(scores)
out = weights @ focused_kv
return out

# 复杂度对比(seq_len=128K, topk=4K)
# Dense: 128K × 128K = 16.4G 次计算
# Sparse: 128K × 4K = 0.5G 次计算(32x 加速)

论文算法与代码对应

DeepSeek-V2 论文中的 MLA 公式

论文中的 MLA 压缩 - 恢复流程:

$$
\begin{aligned}
\text{压缩:} & \quad C_K = X W_{cK}, \quad C_V = X W_{cV} \
\text{存储:} & \quad \text{KV Cache} = [C_K, C_V] \
\text{恢复:} & \quad K = C_K W_{uK}, \quad V = C_V W_{uV} \
\text{Attention:} & \quad \text{Attn}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V
\end{aligned}
$$

FlashMLA 代码实现对应

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# FlashMLA 中的 MLA 实现(简化版)
class MLAAttention(nn.Module):
def __init__(self, hidden_dim, latent_dim):
super().__init__()
# 压缩矩阵
self.W_cK = nn.Linear(hidden_dim, latent_dim, bias=False)
self.W_cV = nn.Linear(hidden_dim, latent_dim, bias=False)

# 恢复矩阵
self.W_uK = nn.Linear(latent_dim, hidden_dim, bias=False)
self.W_uV = nn.Linear(latent_dim, hidden_dim, bias=False)

def forward(self, x, cache_seqlens, block_table, is_fp8_kvcache=True):
# 1. 压缩(Prefill 阶段)
c_k = self.W_cK(x) # [seq_len, latent_dim]
c_v = self.W_cV(x)

# 2. 量化(如果启用 FP8)
if is_fp8_kvcache: # ← flashmla_kv 配置
kv_fp8, scales = quantize_kv_fp8(torch.cat([c_k, c_v], dim=-1))
kv_cache = torch.cat([kv_fp8, scales], dim=-1)
else:
kv_cache = torch.cat([c_k, c_v], dim=-1) # BF16

# 3. 存储到 paged cache
store_to_paged_cache(kv_cache, block_table)

# 4. Decode 阶段:从 cache 读取并恢复
kv_loaded = load_from_paged_cache(block_table, cache_seqlens)

if is_fp8_kvcache:
# 反量化
kv_dequant = dequantize_kv_fp8(kv_loaded[..., :512], kv_loaded[..., 512:516])
c_k, c_v = kv_dequant.chunk(2, dim=-1)
else:
c_k, c_v = kv_loaded.chunk(2, dim=-1)

# 恢复
k = self.W_uK(c_k)
v = self.W_uV(c_v)

# 注意力计算(可能使用 sparse indices)
out = flash_mla_with_kvcache(q, k, v, indices)

return out

关键函数映射

论文概念 FlashMLA 函数 位置
KV 压缩 W_cK, W_cV mla_attention.py
KV 恢复 W_uK, W_uV mla_attention.py
FP8 量化 quantize_kv_fp8 tests/quant.py
FP8 反量化 dequantize_kv_fp8 tests/quant.py
Paged Cache block_table, cache_seqlens flash_mla_interface.py
Sparse Indices indices tensor flash_mla_interface.py
Decode Kernel flash_mla_with_kvcache flash_mla_cuda.cu
Prefill Kernel flash_mla_sparse_fwd flash_mla_cuda.cu

为什么 Decode 阶段必须用 FP8

内存带宽瓶颈

Decode 阶段是 memory-bound(内存带宽受限),而非 compute-bound:

1
2
3
4
5
6
Decode 阶段特点:
- 每次只生成 1 个 token
- 需要读取整个 KV Cache(128K 上下文)
- 计算量小(1 个 query vs 128K keys)

瓶颈:从 HBM 读取 KV Cache 的速度

FP8 带来的带宽节省

1
2
3
4
5
6
7
8
9
10
11
12
# 假设 H800 SXM5 的 HBM 带宽:3.35 TB/s

# BF16 KV Cache (2048 bytes/token)
# 读取 128K tokens: 128K × 2048 = 256 MB
# 理论延迟:256 MB / 3.35 TB/s = 76 μs

# FP8 KV Cache (656 bytes/token)
# 读取 128K tokens: 128K × 656 = 82 MB
# 理论延迟:82 MB / 3.35 TB/s = 24 μs

# 带宽节省:3.1x
# 实际吞吐提升:2.5-2.8x(考虑反量化开销)

为什么 Prefill 不用 FP8?

Prefill 阶段是 compute-bound(计算受限):

1
2
3
4
5
6
7
8
9
# Prefill 阶段特点:
# - 处理整个 prompt(可能 128K tokens)
# - 需要计算 O(n²) 的 attention 矩阵
# - 计算密集,带宽压力相对较小

# 使用 BF16 的原因:
# 1. 训练精度要求高
# 2. 计算瓶颈不在带宽
# 3. 量化/反量化开销不划算

实践指南:SGLang + GB200 配置

SGLang 配置 DeepSeek-V3.2 on GB200

根据 SGLang 官方 issue #21291,在 NVIDIA GB200 (SM100/Blackwell) 上部署 DeepSeek-V3.2 的配置:

1. Docker 镜像

1
2
3
4
5
6
7
8
9
10
11
12
# GB200 (SM100) 使用专用镜像
docker pull lmsysorg/sglang:dsv32

# 启动容器
docker run --gpus all --shm-size 32g \
-p 30000:30000 \
lmsysorg/sglang:dsv32 \
python -m sglang.launch_server \
--model deepseek-ai/DeepSeek-V3.2-Exp \
--tp 8 \
--dp 8 \
--enable-dp-attention

2. 关键配置参数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
python -m sglang.launch_server \
--model deepseek-ai/DeepSeek-V3.2-Exp \

# 张量并行 (GB200 推荐 8 卡)
--tp 8 \

# 数据并行 (GB200 推荐 8)
--dp 8 \

# 启用数据并行注意力(NSA 后端必需)
--enable-dp-attention \

# FlashMLA 配置
--enable-flashmla \
--flashmla-kv True \ # ← 关键:启用 FP8 KV Cache

# DSA 稀疏注意力配置
--enable-sparse-attention \
--sparse-topk 4096 \ # 每个 query attention 4K tokens

# 上下文长度
--max-model-len 131072 \

# 服务端口
--port 30000

3. GB200 特殊注意事项

SM100 Kernel 限制

根据实测,GB200 (SM100) 只支持以下 kernel:

Kernel SM90 (H100) SM100 (GB200)
BF16 Dense Decode
FP8 Dense Decode
FP8 Sparse Decode
BF16 Sparse Prefill

这意味着

  • GB200 必须使用 FP8 Sparse Decode(flashmla_kv=True
  • Dense Decode 模式不可用
  • 这也是为什么 DeepSeek-V3.2 在 GB200 上默认使用 sparse 模式

4. 性能基准(GB200)

根据实测数据(Batch=32, TopK=512):

指标 数值
Decode 延迟 0.93 ms
Prefill 延迟 0.04 ms
显存占用 (128K) 82 MB/卡
吞吐量 (tokens/s) ~34,000

5. 客户端调用示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import requests

# SGLang API 调用
url = "http://localhost:30000/generate"

payload = {
"text": "你好,请介绍一下 DeepSeek-V3.2",
"sampling_params": {
"temperature": 0.7,
"max_new_tokens": 2048,
"top_p": 0.95,
},
"stream": False
}

response = requests.post(url, json=payload)
result = response.json()

print(result['text'])

6. 监控与调试

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 查看服务器日志
docker logs -f <container_id>

# 检查 FlashMLA 是否生效
# 日志中应该看到:
# "Using FlashMLA FP8 KV Cache"
# "Sparse attention enabled with topk=4096"

# 检查显存占用
nvidia-smi dmon -i 0

# 检查 kernel 启动
# 使用 nsys 性能分析工具
nsys profile --stats=true \
python -m sglang.launch_server \
--model deepseek-ai/DeepSeek-V3.2-Exp \
--tp 8 --dp 8

7. 常见问题

Q1: 启动时报 “SM100 not supported”

1
2
3
4
原因:SGLang 版本过旧
解决:更新到最新版本
pip install --upgrade sglang
docker pull lmsysorg/sglang:dsv32

Q2: 显存不足

1
2
3
4
原因:batch size 或 context length 过大
解决:
--max-model-len 65532 # 减少上下文长度
--dp 16 # 增加数据并行

Q3: 生成质量下降

1
2
3
可能原因:TopK 设置过小
解决:
--sparse-topk 8192 # 增加 topk(会增加延迟)

性能优化建议

GB200 最佳实践

  1. 启用 FP8 Sparse Decode

    1
    2
    --flashmla-kv True
    --enable-sparse-attention
  2. 选择合适的 TopK

    • 延迟敏感:--sparse-topk 2048
    • 质量优先:--sparse-topk 8192
    • 平衡:--sparse-topk 4096(推荐)
  3. 数据并行配置

    • 单节点 8 卡:--tp 8 --dp 8
    • 多节点:--tp 8 --dp 16+
  4. 显存优化

    1
    2
    --gpu-memory-utilization 0.9
    --max-num-batched-tokens 16384

参考资料

核心论文

  1. DeepSeek-V2 Technical Report (arXiv:2405.04434)

  2. DeepSeek-V3.2-Exp Technical Report (2025)

代码仓库

  1. FlashMLA (GitHub)

  2. DeepGEMM (GitHub)

  3. TileLang (GitHub)

技术博客

  1. FlashMLA Deep-Dive Blog (DeepSeek 官方)

  2. LMCache Documentation

推理框架

  1. vLLM DeepSeek-V3.2 支持

  2. SGLang DeepSeek-V3.2 支持

示例代码

  1. FlashMLA Kernel Benchmark

总结

FlashMLA 的 FP8 KV Cache 和 DSA 稀疏注意力代表了 LLM 推理优化的两个重要方向:

  1. 量化压缩:FP8 格式将 KV Cache 减少 68%+,直接缓解 decode 阶段的带宽瓶颈
  2. 稀疏计算:DSA 通过 indices tensor 实现 token-level 稀疏,将注意力复杂度从 O(n²) 降到 O(n × topk)

理解这些机制对于:

  • 正确配置推理服务(如 flashmla_kv 参数)
  • 集成 KV Cache 管理系统
  • 性能调优和故障排查

都至关重要。

随着更多模型采用类似技术(GLM-5、Qwen-3 等),掌握 FlashMLA 的原理将成为 LLM 部署工程师的必备技能。


参考资料

最后更新:2026-04-10

作者注:本文基于 FlashMLA 开源代码和 DeepSeek 技术报告整理,部分实现细节可能随版本更新而变化。