Marlin MoE Kernel 深度分析
引言
DeepSeek-V4 引入 MXFP4 量化后,MoE 层的计算效率成为推理性能的关键瓶颈。SGLang 的 Marlin Runner Backend 专门针对 INT4/MXFP4 量化权重优化 MoE 的 GEMM 计算。本文深入分析其实现原理、数据流以及设计权衡。
MoE 层的计算流程
一个标准的 MoE 层包含四个阶段:
1 | MoE Layer |
Runner Backend 只负责第 ③ 步,不同 backend 优化的是同一组计算在不同量化格式下的执行效率:
| Runner Backend | 权重格式 | 核心 kernel | 适用场景 |
|---|---|---|---|
| triton | FP8/BF16 | Triton fused MoE | 通用 FP8 |
| deep_gemm | FP8 block | DeepGemm | DeepSeek-V3 FP8 |
| marlin | INT4/MXFP4 | Marlin WNA16 | GPTQ/AWQ/MXFP4 量化 |
| flashinfer_mxfp4 | MXFP4 | FlashInfer MXFP4 | DeepSeek-V4 MXFP4 |
Marlin MoE 完整数据流
以 DeepSeek-V4 配置为例:hidden_size=7168, intermediate_size=3072, num_experts=384, topk=6, 权重为 MXFP4 格式。
Step 1: Block Size 启发式选择
1 | M, K = hidden_states.shape # M=tokens, K=7168 |
设计意图:
- Context 阶段(M 大):用大 block(64),让每个 block 尽量填满
- Generation 阶段(M 小):用小 block(8),避免最后一个 block 填充率过低
潜在问题:假设 token 在专家间均匀分布,但 Router 可能有热点。热点专家需要多个 block,最后一个 block 可能只填了 10%,浪费计算。
Step 2: Token-Expert 对齐
1 | sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( |
作用:
- 按
expert_id对 token 排序 - 按
block_size_m对齐(padding),让每个专家的 token 数是 block 的倍数 - 返回排序后的 token 索引和每个 block 对应的 expert_id
目的:让后续 Marlin GEMM kernel 能用 block-sparse 方式高效执行。
Step 3: W1 GEMM (gate + up 融合)
1 | intermediate_cache1 = moe_wna16_marlin_gemm( |
关键点:
mul_topk_weights=False:W1 输出是纯 GEMM 结果,不乘权重- W1 权重是 gate 和 up 融合 的:
[E, K/pack, 2*N*pack] - 输出
[M*topk, 2*N]:前半 N 是 gate,后半 N 是 up
为什么 W1 不乘权重?
假设 token 的 top-2 是 Expert 3 (weight=0.6) 和 Expert 7 (weight=0.4):
1 | W1 输出(不乘权重): |
如果 W1 就乘了权重,SwiGLU 激活函数作用在”已经被压缩的信号”上,精度会下降。
Step 4: SwiGLU + Clamp
1 | if clamp_limit is not None: |
输入 [M*topk, 2*N] 拆成 gate 和 up:
1 | gate = input[:, :N] # SwiGLU gate branch |
Clamp 的作用:防止激活值爆炸,避免量化误差被放大。DeepSeek-V4 的 swiglu_limit=10.0 是经验值。
Step 5: W2 GEMM (down projection)
1 | intermediate_cache3 = moe_wna16_marlin_gemm( |
关键变化:
top_k=1:W2 的输入已经是展开的 M×topk 个 token,每个只对应 1 个专家mul_topk_weights=True:在 GEMM 内部就乘上路由权重,融合减少一次 memory pass
Step 6: 跨专家归约
1 | if is_mxfp4_marlin: |
routed_scaling_factor=2.5(DeepSeek-V4)在这里乘进去。
MXFP4 特殊处理
DeepSeek-V4 的专家权重是 MXFP4 格式(4-bit 分块量化),Marlin kernel 对此有特殊路径:
1 | is_mxfp4_marlin = ( |
MXFP4 + E8M0 scale 要求激活必须是 BF16(不支持 FP16):
1 | if is_mxfp4_marlin and hidden_states.dtype == torch.float16: |
MXFP4 特殊处理:
- 不用
atomic_add(因为 BF16 的 atomic_add 在 SM < 90 上不支持) - 改用
torch.sum替代归约 - Scale 是 E8M0 格式(8-bit exponent-only),和普通 INT4 的 per-tensor/per-channel scale 不同
内存布局优化
1 | # 两个中间 buffer 共享底层存储 |
节省显存:W1 的输出 [M*topk, 2*N] 在被 SwiGLU 消费后就不需要了,W2 可以写入同一区域。
但有个细节:W2 GEMM 的输入是 intermediate_cache2(SwiGLU 输出,维度 N),不是 intermediate_cache1(维度 2N)。所以实际复用链是:
1 | intermediate_cache1 [M*topk, 2*N] → (SwiGLU) → intermediate_cache2 [M*topk, N] |
intermediate_cache2 如果能复用 intermediate_cache1 的后半部分(up 分支在 SwiGLU 后就没用了),可以进一步节省显存。
Marlin 在 SGLang 中的注册
1 |
|
“none” 是 A2A backend(无 EP,纯 TP),**”marlin”** 是 runner backend。
这意味着 Marlin MoE 目前只支持纯 TP 模式(每张卡有所有专家的 1/N 权重),不支持 EP(Expert Parallelism)。
为什么 Marlin 没做 EP?
MXFP4 权重大小不是瓶颈
- 384 专家 × 33MB/8 = 1.6GB/卡(TP=8)
- 如果 EP=8,每卡 48 专家 × 33MB = 1.6GB/卡
- 权重大小一样,EP 的优势(减少单卡显存)消失
通信量对比
- TP 的 AllReduce:
2 × B × hidden(W1 + W2 各一次) - EP 的 All-to-All:
topk × B × hidden=6 × B × hidden(topk=6) - EP 通信量是 TP 的 3 倍
- TP 的 AllReduce:
Marlin 是量化 kernel,EP 是通信问题
- 理论上可以接
deepep,但 MXFP4 + 高 topk 场景下收益不大 - SGLang 团队可能评估后觉得优先级不高
- 理论上可以接
完整数据流图
1 | hidden_states [M, 7168] (BF16) |
总结
| 方面 | 评价 | 建议 |
|---|---|---|
| Block size 选择 | 简单启发式,可能不够鲁棒 | 考虑按专家负载动态调整 |
| MXFP4 处理 | 强制转 BF16 有额外开销 | 统一模型 dtype,避免 forward 时转换 |
| SwiGLU clamp | 好的实践,防止激活爆炸 | 验证量化范围是否覆盖 [-10, 10] |
| 内存复用 | 做得好 | intermediate_cache2 也可复用 |
mul_topk_weights |
设计合理,W2 融合权重乘法 | 无 |
| EP 支持 | 仅 TP,未实现 EP | MXFP4 + 高 topk 下收益不大,暂不急需 |
Marlin MoE 是 MXFP4 量化 MoE 的高效实现,通过 block-sparse GEMM、SwiGLU 融合、内存复用等手段优化计算。在 DeepSeek-V4 的 MXFP4 + 高 topk 场景下,选择纯 TP 而非 EP 是理性的设计决策。
参考资料
- SGLang Marlin Runner 源码
- SGLang MoE 文档
- DeepSeek-V4 技术报告