Marlin MoE Kernel 深度分析

引言

DeepSeek-V4 引入 MXFP4 量化后,MoE 层的计算效率成为推理性能的关键瓶颈。SGLang 的 Marlin Runner Backend 专门针对 INT4/MXFP4 量化权重优化 MoE 的 GEMM 计算。本文深入分析其实现原理、数据流以及设计权衡。

MoE 层的计算流程

一个标准的 MoE 层包含四个阶段:

1
2
3
4
5
6
7
8
9
10
11
12
13
MoE Layer

├── ① Router → topk_ids, topk_weights

├── ② Token Dispatch (All-to-All) ← A2A backend

├── ③ Expert Compute ← Runner Backend (本文主角)
│ ├── W1 GEMM (gate + up 融合)
│ ├── SwiGLU 激活
│ ├── W2 GEMM (down projection)
│ └── Weighted sum reduce

└── ④ Token Combine ← A2A backend

Runner Backend 只负责第 ③ 步,不同 backend 优化的是同一组计算在不同量化格式下的执行效率:

Runner Backend 权重格式 核心 kernel 适用场景
triton FP8/BF16 Triton fused MoE 通用 FP8
deep_gemm FP8 block DeepGemm DeepSeek-V3 FP8
marlin INT4/MXFP4 Marlin WNA16 GPTQ/AWQ/MXFP4 量化
flashinfer_mxfp4 MXFP4 FlashInfer MXFP4 DeepSeek-V4 MXFP4

Marlin MoE 完整数据流

以 DeepSeek-V4 配置为例:hidden_size=7168, intermediate_size=3072, num_experts=384, topk=6, 权重为 MXFP4 格式。

Step 1: Block Size 启发式选择

1
2
3
4
5
6
7
8
9
M, K = hidden_states.shape  # M=tokens, K=7168
E = w1.shape[0] # num_experts=384
N = w2.shape[1] * 16 # 3072 (Marlin 打包因子 16)
topk = topk_ids.shape[1] # 6

# 启发式选择 M 方向分块大小
for block_size_m in [8, 16, 32, 48, 64]:
if M * topk / E / block_size_m < 0.9:
break

设计意图

  • Context 阶段(M 大):用大 block(64),让每个 block 尽量填满
  • Generation 阶段(M 小):用小 block(8),避免最后一个 block 填充率过低

潜在问题:假设 token 在专家间均匀分布,但 Router 可能有热点。热点专家需要多个 block,最后一个 block 可能只填了 10%,浪费计算。

Step 2: Token-Expert 对齐

1
2
3
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, block_size_m, global_num_experts
)

作用

  1. expert_id 对 token 排序
  2. block_size_m 对齐(padding),让每个专家的 token 数是 block 的倍数
  3. 返回排序后的 token 索引和每个 block 对应的 expert_id

目的:让后续 Marlin GEMM kernel 能用 block-sparse 方式高效执行。

Step 3: W1 GEMM (gate + up 融合)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
intermediate_cache1 = moe_wna16_marlin_gemm(
hidden_states, # [M, 7168]
intermediate_cache1, # output buffer
w1, # [E, 7168/pack, 2*3072*pack]
w1_scale, # 量化 scale
sorted_token_ids, # token → expert 映射
expert_ids, # 每个 block 的 expert_id
num_tokens_post_padded, # padding 后的总 token 数
topk_weights, # 路由权重
moe_block_size=block_size_m,
top_k=topk,
mul_topk_weights=False, # ← W1 阶段不乘路由权重
size_m=M, size_n=2*N, size_k=K,
)

关键点

  • mul_topk_weights=False:W1 输出是纯 GEMM 结果,不乘权重
  • W1 权重是 gate 和 up 融合 的:[E, K/pack, 2*N*pack]
  • 输出 [M*topk, 2*N]:前半 N 是 gate,后半 N 是 up

为什么 W1 不乘权重?

假设 token 的 top-2 是 Expert 3 (weight=0.6) 和 Expert 7 (weight=0.4):

1
2
3
4
5
6
7
8
W1 输出(不乘权重):
[Expert 3 的 gate/up, Expert 7 的 gate/up]

经过 SwiGLU:
[Expert 3 的 activated, Expert 7 的 activated]

W2 输出(乘权重):
Expert 3 输出 × 0.6 + Expert 7 输出 × 0.4

如果 W1 就乘了权重,SwiGLU 激活函数作用在”已经被压缩的信号”上,精度会下降。

Step 4: SwiGLU + Clamp

1
2
3
4
5
if clamp_limit is not None:
# DeepSeek-V4: swiglu_limit=10.0
swiglu_limit_func(intermediate_cache2, intermediate_cache1, clamp_limit)
else:
silu_and_mul(intermediate_cache1, intermediate_cache2)

输入 [M*topk, 2*N] 拆成 gate 和 up:

1
2
3
4
5
gate = input[:, :N]      # SwiGLU gate branch
up = input[:, N:] # SwiGLU up branch

# SiLU(clamp(gate)) * clamp(up)
output = F.silu(torch.clamp(gate, max=10)) * torch.clamp(up, -10, 10)

Clamp 的作用:防止激活值爆炸,避免量化误差被放大。DeepSeek-V4 的 swiglu_limit=10.0 是经验值。

Step 5: W2 GEMM (down projection)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
intermediate_cache3 = moe_wna16_marlin_gemm(
intermediate_cache2, # [M*topk, 3072]
intermediate_cache3, # output buffer
w2, # [E, 3072/pack, 7168*pack]
w2_scale,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
topk_weights,
moe_block_size=block_size_m,
top_k=1, # ← 注意这里是 1
mul_topk_weights=True, # ← W2 阶段乘路由权重
size_m=M*topk, size_n=K, size_k=N,
).view(-1, topk, K) # reshape 为 [M, topk, K]

关键变化

  • top_k=1:W2 的输入已经是展开的 M×topk 个 token,每个只对应 1 个专家
  • mul_topk_weights=True:在 GEMM 内部就乘上路由权重,融合减少一次 memory pass

Step 6: 跨专家归约

1
2
3
4
5
6
if is_mxfp4_marlin:
# MXFP4:用 torch.sum(atomic_add 不支持 BF16 on SM < 90)
output = torch.sum(intermediate_cache3, dim=1) # [M, topk, K] → [M, K]
else:
# 普通 INT4:用 CUDA kernel 做 sum + scale
moe_sum_reduce(intermediate_cache3, output, routed_scaling_factor)

routed_scaling_factor=2.5(DeepSeek-V4)在这里乘进去。

MXFP4 特殊处理

DeepSeek-V4 的专家权重是 MXFP4 格式(4-bit 分块量化),Marlin kernel 对此有特殊路径:

1
2
3
4
5
6
7
is_mxfp4_marlin = (
num_bits == 4
and w1_zeros is None # 无 zero point
and w2_zeros is None
and w1_scale.dtype == torch.float8_e8m0 # scale 是 E8M0 格式
and w2_scale.dtype == torch.float8_e8m0
)

MXFP4 + E8M0 scale 要求激活必须是 BF16(不支持 FP16):

1
2
if is_mxfp4_marlin and hidden_states.dtype == torch.float16:
marlin_hidden_states = hidden_states.to(torch.bfloat16) # 强制转 BF16

MXFP4 特殊处理

  • 不用 atomic_add(因为 BF16 的 atomic_add 在 SM < 90 上不支持)
  • 改用 torch.sum 替代归约
  • Scale 是 E8M0 格式(8-bit exponent-only),和普通 INT4 的 per-tensor/per-channel scale 不同

内存布局优化

1
2
3
4
5
6
7
# 两个中间 buffer 共享底层存储
intermediate_cache13 = torch.empty(
(M * topk * max(2*N, K),), # 按 W1 和 W2 输出中较大的分配
device=device, dtype=dtype,
)
intermediate_cache1 = intermediate_cache13[:M*topk*2*N].view(-1, 2*N)
intermediate_cache3 = intermediate_cache13[:M*topk*K].view(-1, K)

节省显存:W1 的输出 [M*topk, 2*N] 在被 SwiGLU 消费后就不需要了,W2 可以写入同一区域。

但有个细节:W2 GEMM 的输入是 intermediate_cache2(SwiGLU 输出,维度 N),不是 intermediate_cache1(维度 2N)。所以实际复用链是:

1
2
3
intermediate_cache1 [M*topk, 2*N] → (SwiGLU) → intermediate_cache2 [M*topk, N]

intermediate_cache3 [M*topk, K] ← (W2 GEMM 输出)

intermediate_cache2 如果能复用 intermediate_cache1 的后半部分(up 分支在 SwiGLU 后就没用了),可以进一步节省显存。

Marlin 在 SGLang 中的注册

1
2
3
@register_fused_func("none", "marlin")
def fused_experts_none_to_marlin(dispatch_output, quant_info, runner_config):
...

“none” 是 A2A backend(无 EP,纯 TP),**”marlin”** 是 runner backend。

这意味着 Marlin MoE 目前只支持纯 TP 模式(每张卡有所有专家的 1/N 权重),不支持 EP(Expert Parallelism)。

为什么 Marlin 没做 EP?

  1. MXFP4 权重大小不是瓶颈

    • 384 专家 × 33MB/8 = 1.6GB/卡(TP=8)
    • 如果 EP=8,每卡 48 专家 × 33MB = 1.6GB/卡
    • 权重大小一样,EP 的优势(减少单卡显存)消失
  2. 通信量对比

    • TP 的 AllReduce:2 × B × hidden(W1 + W2 各一次)
    • EP 的 All-to-All:topk × B × hidden = 6 × B × hidden(topk=6)
    • EP 通信量是 TP 的 3 倍
  3. Marlin 是量化 kernel,EP 是通信问题

    • 理论上可以接 deepep,但 MXFP4 + 高 topk 场景下收益不大
    • SGLang 团队可能评估后觉得优先级不高

完整数据流图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
hidden_states [M, 7168] (BF16)

▼ moe_align_block_size(topk_ids)
│ → sorted_token_ids, expert_ids (按专家排序+对齐)

▼ moe_wna16_marlin_gemm (W1, gate+up fused)
│ 每个 block 读对应 expert 的 MXFP4 权重 → 反量化 → GEMM
│ [M, 7168] × [E, 7168/pack, 2*3072*pack] → [M*6, 6144]
│ mul_topk_weights=False (不乘权重)

▼ SwiGLU + clamp (swiglu_limit=10.0)
│ gate = clamp([:, :3072], max=10)
│ up = clamp([:, 3072:], -10, 10)
│ → SiLU(gate) * up → [M*6, 3072]

▼ moe_wna16_marlin_gemm (W2, down)
│ [M*6, 3072] × [E, 3072/pack, 7168*pack] → [M*6, 7168]
│ mul_topk_weights=True (内部乘路由权重)
│ → reshape [M, 6, 7168]

▼ torch.sum(dim=1) 或 moe_sum_reduce
│ [M, 6, 7168] → [M, 7168] (× routed_scaling_factor=2.5)

▼ output [M, 7168] (BF16)

总结

方面 评价 建议
Block size 选择 简单启发式,可能不够鲁棒 考虑按专家负载动态调整
MXFP4 处理 强制转 BF16 有额外开销 统一模型 dtype,避免 forward 时转换
SwiGLU clamp 好的实践,防止激活爆炸 验证量化范围是否覆盖 [-10, 10]
内存复用 做得好 intermediate_cache2 也可复用
mul_topk_weights 设计合理,W2 融合权重乘法
EP 支持 仅 TP,未实现 EP MXFP4 + 高 topk 下收益不大,暂不急需

Marlin MoE 是 MXFP4 量化 MoE 的高效实现,通过 block-sparse GEMM、SwiGLU 融合、内存复用等手段优化计算。在 DeepSeek-V4 的 MXFP4 + 高 topk 场景下,选择纯 TP 而非 EP 是理性的设计决策。

参考资料