FlashInfer MXFP4 MoE 后端深度分析
背景
DeepSeek-V4 系列模型(V4 / V4-Pro)在 MoE(Mixture of Experts)层大量使用 MXFP4 量化权重,配合 FP8/BF16 激活 实现高效推理。在 SGLang / vLLM 中,flashinfer_mxfp4 是一个专为这种量化格式设计的 MoE Runner Backend。
本文将基于 H20 (SM90 Hopper) 实测数据和源码分析,完整解析 FlashInfer MXFP4 的实现原理,以及与 Marlin MoE 的架构差异。
MXFP4 量化格式
格式定义
MXFP4 是 OCP(Open Compute Project)标准的 4-bit 浮点格式:
1 | 每个权重元素: [1 sign | 2 exponent | 1 mantissa] → 4-bit 浮点数 |
与 INT4 的对比
| 维度 | MXFP4 | INT4 (Marlin) |
|---|---|---|
| 表示方式 | 浮点 (E2M1) | 定点整数 |
| 动态范围 | 大(指数缩放) | 小(线性,需精确校准) |
| 缩放粒度 | 每 32 元素(block scale) | 每 128 元素(group quantization) |
| 解量化 | value = fp4 × 2^(scale - 127) |
value = int4 × scale[channel] |
优势:浮点表示让 MXFP4 在训练和推理中都有更好的数值稳定性。
FlashInfer MXFP4 架构
调用链(H20 / SM90)
1 | SGLang --moe-runner-backend flashinfer_mxfp4 |
核心特性
- Expert-First 遍历:按 expert 分组 token,而非按 token 遍历 expert
- 全 Fuse:W1 + W3 + SwiGLU + W2 四个步骤 fuse 成一个 kernel
- TMA + Warp Specialization:数据搬运和计算 overlap
- Grouped GEMM:一次 launch 处理所有 256 个 expert
完整 Pipeline(DeepSeek-V4 为例)
模型配置
1 | DeepSeek-V4-Flash: |
第一阶段:Routing(路由选择)
1 | # 文件: trtllm_fused_moe_routing_deepseek.cu |
为什么用 Sigmoid?
- 独立评分(非 Softmax 的零和游戏)
- 多 expert 可同时激活
- 训练更稳定
第二阶段:Gather(Token 重排)
1 | # 文件: cutlass_fused_moe_kernels.cuh → expandInputRowsKernel |
目的:把路由到同一个 expert 的 token 收集到一起,形成连续内存,供后续 GEMM 使用。
1 | 原始 (token 顺序): |
TMA 优势:连续输入让 TMA 一次加载整个 tile,达到满带宽。
第三阶段:Compute(Fused GEMM)
1 | # 文件: cutlass_fused_moe_kernels.cuh → CUTLASS Grouped GEMM |
这是最核心的部分,W1 + W3 + SwiGLU + W2 全部 fuse。
Step 3a: GEMM1 (Gate + Up Projection)
1 | 对每个 expert_i: |
MXFP4 权重存储:
1 | // W1_expert_i: stored as [2048, 2048] int8 |
Step 3b: SwiGLU Activation
1 | # 在 GEMM1 输出还在 SMEM/寄存器时立即执行(fuse 关键) |
Step 3c: GEMM2 (Down Projection)
1 | output_i = intermediate @ W2_expert_i # [M, 4096] |
CUTLASS Grouped GEMM 调度
1 | // 一次 kernel launch 处理所有 256 个 expert |
Pipeline 示意(H20):
1 | Cycle 0-100: Producer 搬 expert_3 权重,Consumer 算 expert_1 (上一轮) |
第四阶段:Scatter(结果写回)
1 | # 文件: cutlass_fused_moe_kernels.cuh → finalizeMoeRoutingKernel |
加权求和:
1 | for token_j: |
示例:
1 | token_0 → expert_3 (0.6), expert_7 (0.4) |
为什么叫 Scatter?
- expert_3 的输出: [out_0_e3, out_1_e3, out_3_e3]
- 需要写回: output[0], output[1], output[3](不连续!)
- 这就是 scatter(分散写入),与 gather(聚集读取)相对
性能实测(H20, DeepSeek-V4-Flash, TP=4)
Benchmark 结果
| Concurrency | FlashInfer MXFP4 | Marlin | 差距 |
|---|---|---|---|
| conc=1 | 20.2 tok/s/GPU | 20.5 tok/s/GPU | ≈持平 |
| conc=2 | 42.1 tok/s/GPU | 28.1 tok/s/GPU | +50% |
| conc=4 | 32.5 tok/s/GPU | 31.8 tok/s/GPU | ≈持平 |
阶段耗时分解(估)
| 阶段 | 占比 | 说明 |
|---|---|---|
| Routing | ~2% | 纯 element-wise,快 |
| Gather | ~3% | 内存拷贝 |
| GEMM1 (W13) | ~42% | 计算密集 |
| Activation | ~1% | fused 在 GEMM1 后 |
| GEMM2 (W2) | ~48% | 计算密集 |
| Scatter | ~4% | 加权求和 + 写回 |
**GEMM 占 90%**,所以 TMA + Grouped GEMM 对总体性能影响最大。
为什么 conc=2 差距最大?
- Expert-First + Grouped GEMM:一次 launch 处理所有 expert,TMA 预取下一个 tile
- 中等 batch:每个 expert 分到 2-3 个 token → GEMM tile 够大,WGMMA 饱和
- Warp Specialization:Producer 搬数据,Consumer 计算,完美 overlap
Marlin 的劣势:
- Token-First:每个 token 的 topk expert 不同 → 权重反复换入换出
- 无 TMA:用
cp.async加载,线程要参与地址计算 → SM 利用率 < 50% - 无法全 fuse:中间结果 (gate, up, mid) 要写回 GMEM
为什么 conc=4 差距消失?
1 | conc=4: prefill=50k tokens |
Prefill 主导后,MoE decode 的优化被稀释。
与 Marlin MoE 的架构对比
核心差异
| 维度 | FlashInfer MXFP4 | Marlin (SGLang) |
|---|---|---|
| 量化格式 | MXFP4 (浮点4位) | INT4 / MXFP4 |
| 遍历顺序 | Expert-First | Token-First |
| 融合程度 | W1+W3+SwiGLU+W2 全 fuse | W1+W3 部分 fuse,W2 单独 |
| 数据加载 | TMA (硬件 DMA) | cp.async / LDG (线程参与) |
| 调度方式 | Warp Specialization (Producer/Consumer) | 单 warp 既搬又算 |
| Kernel 架构 | CUTLASS 3.x Grouped GEMM | 手写 CUDA,Ampere 设计 |
| TMA 利用 | ✅ 异步 tile 预取 | ❌ 无 TMA (用 cp.async) |
| 多 expert 并行 | Grouped GEMM 一次 launch | 逐 expert 串行或小并行 |
Expert-First vs Token-First
**Token-First (Marlin)**:
1 | for token_0: |
**Expert-First (FlashInfer)**:
1 | for expert_3: |
为什么 Expert-First 能全 fuse?
- 固定权重 → W1/W3/W2 常驻 SMEM,不换出
- 变 batch → 同一 expert 的多个 token 连续处理
- 中间结果 (gate, up, mid) 全在寄存器/SMEM,不写 GMEM
TMA + WGMMA 的硬件优势
TMA (Tensor Memory Accelerator)
SM90 (Hopper) 引入的硬件 DMA 引擎:
- 异步搬数据从 HBM 到 SMEM,不占用 CUDA core
- 一条
tcgen05.1d指令触发,后台独立运行 - 支持多维 tensor 描述符(TMA Descriptor)
vs 传统 LDG:
1 | LDG (传统): |
WGMMA (Warpgroup MMA)
SM90 引入的 warpgroup 级矩阵乘指令:
- 一个 warpgroup = 4 个 warp = 128 线程
- 直接读写 TMEM(Tensor Memory,SM90 新增寄存器文件)
- 比传统
mma.sync高一个抽象层级
SM100 (Blackwell) 的进化:
1 | SM90: TMA → SMEM → WGMMA → TMEM |
Blackwell 的 TMA_GATHER4 还能一次加载 4 个不连续 token(专为 MoE 的 sparse routing 优化)。
总结
FlashInfer MXFP4 的核心优势
- Expert-First + 全 Fuse:中间结果不落地 GMEM,省 ~2× hidden_dim × batch 带宽
- TMA + Warp Specialization:搬运和计算 overlap,SM 利用率最大化
- Grouped GEMM:一次 launch 处理 256 experts,launch overhead 最小化
- Blackwell 未来兼容:cute_dsl_fused_moe_nvfp4 已支持 SM100 原生 FP4 TC
适用场景
| 场景 | 推荐后端 |
|---|---|
| H20/H100, conc=2~8 | ✅ FlashInfer MXFP4 |
| A100, 或 conc=1 调试 | Marlin |
| Blackwell (B200) | FlashInfer NVFP4 (cute_dsl) |
| 追求最大吞吐 | FlashInfer MXFP4 |
一句话
FlashInfer MXFP4 在 Hopper 上通过 Expert-First + TMA + Warp Specialization + 全 Fuse,把 MoE 推理的瓶颈从内存带宽转移到了计算,实现了中等并发下 50% 的吞吐提升。
参考资料
- FlashInfer GitHub
- CUTLASS 3.x Documentation
- Introducing Machete - Red Hat Developer
- DeepSeek-V4 源码分析(本文基于 DeepSeek-V4-Flash 实测)
- NVIDIA Hopper Architecture White Paper
如果你对 MoE 推理优化感兴趣,可以看看我之前写的 Marlin MoE Kernel 深度分析,对比两种实现的差异。