FlashInfer MXFP4 MoE 后端深度分析

背景

DeepSeek-V4 系列模型(V4 / V4-Pro)在 MoE(Mixture of Experts)层大量使用 MXFP4 量化权重,配合 FP8/BF16 激活 实现高效推理。在 SGLang / vLLM 中,flashinfer_mxfp4 是一个专为这种量化格式设计的 MoE Runner Backend。

本文将基于 H20 (SM90 Hopper) 实测数据和源码分析,完整解析 FlashInfer MXFP4 的实现原理,以及与 Marlin MoE 的架构差异。


MXFP4 量化格式

格式定义

MXFP4 是 OCP(Open Compute Project)标准的 4-bit 浮点格式:

1
2
每个权重元素: [1 sign | 2 exponent | 1 mantissa] → 4-bit 浮点数
每 32 个元素: 共享一个 8-bit 指数缩放因子 (E8M0 block scale)

与 INT4 的对比

维度 MXFP4 INT4 (Marlin)
表示方式 浮点 (E2M1) 定点整数
动态范围 大(指数缩放) 小(线性,需精确校准)
缩放粒度 每 32 元素(block scale) 每 128 元素(group quantization)
解量化 value = fp4 × 2^(scale - 127) value = int4 × scale[channel]

优势:浮点表示让 MXFP4 在训练和推理中都有更好的数值稳定性。


FlashInfer MXFP4 架构

调用链(H20 / SM90)

1
2
3
4
5
6
7
8
9
10
11
SGLang --moe-runner-backend flashinfer_mxfp4

Python: trtllm_fp4_block_scale_moe()

C++: get_cutlass_fused_moe_module() → 加载预编译 .so

CUTLASS 3.x: cute::gemm::kernel::DefaultGemmUniversal

├─ TMA (Tensor Memory Accelerator) → 异步搬权重 tile 到 SMEM
├─ WGMMA (Warpgroup MMA) → SM90 用 FP16 TC 模拟 FP4
└─ Warp Specialization → Producer warpgroup 专搬数据,Consumer warpgroup 专计算

核心特性

  1. Expert-First 遍历:按 expert 分组 token,而非按 token 遍历 expert
  2. 全 Fuse:W1 + W3 + SwiGLU + W2 四个步骤 fuse 成一个 kernel
  3. TMA + Warp Specialization:数据搬运和计算 overlap
  4. Grouped GEMM:一次 launch 处理所有 256 个 expert

完整 Pipeline(DeepSeek-V4 为例)

模型配置

1
2
3
4
5
DeepSeek-V4-Flash:
- 256 experts, topk=6
- hidden_dim=4096, moe_intermediate=2048
- 权重: MXFP4 (4-bit), 激活: BF16
- Router: Sigmoid + Group TopK (非 Softmax)

第一阶段:Routing(路由选择)

1
2
3
4
5
6
7
8
9
10
11
12
# 文件: trtllm_fused_moe_routing_deepseek.cu

# 输入: router_logits [num_tokens, 256]
scores = torch.sigmoid(router_logits + bias) # DeepSeek-V4 用 Sigmoid, 非 Softmax

# Group TopK: 256 experts → 8 组 × 32
group_scores = scores.view(num_tokens, 8, 32).max(dim=-1) # 每组最高分
top_groups = group_scores.topk(4).indices # 选 4 组 (128 experts)

# 从 128 个 expert 中选 top-6
topk_indices = ... # [num_tokens, 6]
topk_weights = ... # 归一化后的权重

为什么用 Sigmoid?

  • 独立评分(非 Softmax 的零和游戏)
  • 多 expert 可同时激活
  • 训练更稳定

第二阶段:Gather(Token 重排)

1
2
3
4
# 文件: cutlass_fused_moe_kernels.cuh → expandInputRowsKernel

# 输入: hidden_states [num_tokens, 4096] (按 token 顺序)
# 输出: expanded_input [num_tokens × 6, 4096] (按 expert 连续排列)

目的:把路由到同一个 expert 的 token 收集到一起,形成连续内存,供后续 GEMM 使用。

1
2
3
4
5
6
7
8
9
10
原始 (token 顺序):
[t0 | t1 | t2 | t3]

路由结果:
expert_3: [t0, t1, t3]
expert_7: [t0, t2]

Gather 后 (expert 顺序):
[e3: t0, t1, t3 | e7: t0, t2 | ...]
↑ 连续 ↑ 连续

TMA 优势:连续输入让 TMA 一次加载整个 tile,达到满带宽。


第三阶段:Compute(Fused GEMM)

1
# 文件: cutlass_fused_moe_kernels.cuh → CUTLASS Grouped GEMM

这是最核心的部分,W1 + W3 + SwiGLU + W2 全部 fuse

Step 3a: GEMM1 (Gate + Up Projection)

1
2
3
4
5
6
对每个 expert_i:
M = expert_i 分到的 token 数 (例如 3)
input_i = expanded_input[offset : offset+M] # [M, 4096]

# W13 = [W1; W3] 拼接,一次 GEMM 算出
gate_up = input_i @ W13_expert_i # [M, 4096] → split → [M, 2048] + [M, 2048]

MXFP4 权重存储

1
2
3
// W1_expert_i: stored as [2048, 2048] int8
// 实际逻辑形状: [2048, 4096] FP4 (每 byte 存 2 个 FP4 元素)
// Scale: [2048, 64] float8_e8m0 (每 32 个元素一个 scale)

Step 3b: SwiGLU Activation

1
2
3
# 在 GEMM1 输出还在 SMEM/寄存器时立即执行(fuse 关键)
intermediate = SiLU(gate) * up # [M, 2048]
# DeepSeek-V4 还有 swiglu_limit=10.0 的 clamp

Step 3c: GEMM2 (Down Projection)

1
2
output_i = intermediate @ W2_expert_i  # [M, 4096]
# W2_expert_i: stored as [4096, 1024] int8 → 逻辑 [4096, 2048] FP4

CUTLASS Grouped GEMM 调度

1
2
3
4
5
6
7
8
9
10
11
12
// 一次 kernel launch 处理所有 256 个 expert
cutlass::gemm::kernel::GroupedGemmKernel<...>::run();

// 内部调度:
// SM 空闲 → 从 problem list 取下一个 expert
// expert_i 的 M=0 (没 token) → 跳过
// expert_j 的 M=12 → 分配 SM 算 GEMM

// TMA + Warp Specialization:
// Warpgroup 0 (Producer): TMA_LOAD(weight_tile, desc)
// Warpgroup 1 (Consumer): WGMMA(tile_C, tile_A, tile_B)
// Producer 搬下一个 expert 的权重时,Consumer 正在算当前 expert

Pipeline 示意(H20)

1
2
3
4
Cycle 0-100:  Producer 搬 expert_3 权重,Consumer 算 expert_1 (上一轮)
Cycle 100-200: Consumer 算 expert_3 (权重已到),Producer 搬 expert_7
Cycle 200+: Consumer 算 expert_7,Producer 搬 expert_5
→ 数据搬运和计算完全 overlap!

第四阶段:Scatter(结果写回)

1
2
3
4
# 文件: cutlass_fused_moe_kernels.cuh → finalizeMoeRoutingKernel

# 输入: expert_outputs [num_tokens × 6, 4096] (按 expert 排列)
# 输出: final_output [num_tokens, 4096] (恢复 token 原始顺序)

加权求和

1
2
3
for token_j:
output[j] = Σ (weight_k × expert_output[permute_map[j][k]])
k=1..6

示例

1
2
token_0 → expert_3 (0.6), expert_7 (0.4)
output[0] = 0.6 × out_0_e3 + 0.4 × out_0_e7

为什么叫 Scatter?

  • expert_3 的输出: [out_0_e3, out_1_e3, out_3_e3]
  • 需要写回: output[0], output[1], output[3](不连续!)
  • 这就是 scatter(分散写入),与 gather(聚集读取)相对

性能实测(H20, DeepSeek-V4-Flash, TP=4)

Benchmark 结果

Concurrency FlashInfer MXFP4 Marlin 差距
conc=1 20.2 tok/s/GPU 20.5 tok/s/GPU ≈持平
conc=2 42.1 tok/s/GPU 28.1 tok/s/GPU +50%
conc=4 32.5 tok/s/GPU 31.8 tok/s/GPU ≈持平

阶段耗时分解(估)

阶段 占比 说明
Routing ~2% 纯 element-wise,快
Gather ~3% 内存拷贝
GEMM1 (W13) ~42% 计算密集
Activation ~1% fused 在 GEMM1 后
GEMM2 (W2) ~48% 计算密集
Scatter ~4% 加权求和 + 写回

**GEMM 占 90%**,所以 TMA + Grouped GEMM 对总体性能影响最大。

为什么 conc=2 差距最大?

  1. Expert-First + Grouped GEMM:一次 launch 处理所有 expert,TMA 预取下一个 tile
  2. 中等 batch:每个 expert 分到 2-3 个 token → GEMM tile 够大,WGMMA 饱和
  3. Warp Specialization:Producer 搬数据,Consumer 计算,完美 overlap

Marlin 的劣势

  • Token-First:每个 token 的 topk expert 不同 → 权重反复换入换出
  • 无 TMA:用 cp.async 加载,线程要参与地址计算 → SM 利用率 < 50%
  • 无法全 fuse:中间结果 (gate, up, mid) 要写回 GMEM

为什么 conc=4 差距消失?

1
2
3
4
5
conc=4: prefill=50k tokens
→ Prefill 时间: ~1500ms (占 85%+)
→ Decode MoE 时间: ~265ms (占 15%-)

即使 MoE 快 50%,总体提升也只有: 265ms × 50% / 1765ms ≈ 7.5%

Prefill 主导后,MoE decode 的优化被稀释。


与 Marlin MoE 的架构对比

核心差异

维度 FlashInfer MXFP4 Marlin (SGLang)
量化格式 MXFP4 (浮点4位) INT4 / MXFP4
遍历顺序 Expert-First Token-First
融合程度 W1+W3+SwiGLU+W2 全 fuse W1+W3 部分 fuse,W2 单独
数据加载 TMA (硬件 DMA) cp.async / LDG (线程参与)
调度方式 Warp Specialization (Producer/Consumer) 单 warp 既搬又算
Kernel 架构 CUTLASS 3.x Grouped GEMM 手写 CUDA,Ampere 设计
TMA 利用 ✅ 异步 tile 预取 ❌ 无 TMA (用 cp.async)
多 expert 并行 Grouped GEMM 一次 launch 逐 expert 串行或小并行

Expert-First vs Token-First

**Token-First (Marlin)**:

1
2
3
4
5
for token_0:
expert_3: W1 → W3 → SwiGLU → W2 ← 写回 GMEM
expert_7: W1 → W3 → SwiGLU → W2 ← 换权重
for token_1:
... # 重复上述,权重反复换入换出

**Expert-First (FlashInfer)**:

1
2
3
4
5
6
7
for expert_3:
# 固定权重 W1/W3/W2,常驻 SMEM
gate_up = input_e3 @ W13_e3 # 连续 token 一起算
mid = SiLU(gate) * up # 在寄存器/SMEM 直接算
output = mid @ W2_e3 # 中间结果不落地
for expert_7:
... # 换一次权重,算所有路由到它的 token

为什么 Expert-First 能全 fuse?

  • 固定权重 → W1/W3/W2 常驻 SMEM,不换出
  • 变 batch → 同一 expert 的多个 token 连续处理
  • 中间结果 (gate, up, mid) 全在寄存器/SMEM,不写 GMEM

TMA + WGMMA 的硬件优势

TMA (Tensor Memory Accelerator)

SM90 (Hopper) 引入的硬件 DMA 引擎

  • 异步搬数据从 HBM 到 SMEM,不占用 CUDA core
  • 一条 tcgen05.1d 指令触发,后台独立运行
  • 支持多维 tensor 描述符(TMA Descriptor)

vs 传统 LDG

1
2
3
4
5
6
7
8
LDG (传统):
线程发 load → 等 HBM (~400 cycles) → 数据到 SMEM → 继续算
→ 线程 stall,SM 有空闲

TMA:
线程发 TMA_LOAD → 立即返回 → 可以去 setup 下一个 tile
→ TMA 后台搬,线程继续算上一轮结果
→ SM 利用率 80%+

WGMMA (Warpgroup MMA)

SM90 引入的 warpgroup 级矩阵乘指令

  • 一个 warpgroup = 4 个 warp = 128 线程
  • 直接读写 TMEM(Tensor Memory,SM90 新增寄存器文件)
  • 比传统 mma.sync 高一个抽象层级

SM100 (Blackwell) 的进化

1
2
SM90: TMA → SMEM → WGMMA → TMEM
SM100: TMA_GATHER4 → TMEM (跳过 SMEM) → UTCMMA (原生 FP4 TC)

Blackwell 的 TMA_GATHER4 还能一次加载 4 个不连续 token(专为 MoE 的 sparse routing 优化)。


总结

FlashInfer MXFP4 的核心优势

  1. Expert-First + 全 Fuse:中间结果不落地 GMEM,省 ~2× hidden_dim × batch 带宽
  2. TMA + Warp Specialization:搬运和计算 overlap,SM 利用率最大化
  3. Grouped GEMM:一次 launch 处理 256 experts,launch overhead 最小化
  4. Blackwell 未来兼容:cute_dsl_fused_moe_nvfp4 已支持 SM100 原生 FP4 TC

适用场景

场景 推荐后端
H20/H100, conc=2~8 ✅ FlashInfer MXFP4
A100, 或 conc=1 调试 Marlin
Blackwell (B200) FlashInfer NVFP4 (cute_dsl)
追求最大吞吐 FlashInfer MXFP4

一句话

FlashInfer MXFP4 在 Hopper 上通过 Expert-First + TMA + Warp Specialization + 全 Fuse,把 MoE 推理的瓶颈从内存带宽转移到了计算,实现了中等并发下 50% 的吞吐提升


参考资料


如果你对 MoE 推理优化感兴趣,可以看看我之前写的 Marlin MoE Kernel 深度分析,对比两种实现的差异。