MegaMoE:把通信和计算融合到一个 Mega Kernel

一句话概括

MegaMoE = 把 EP(Expert Parallelism)中的 All-to-All 通信和 MoE 计算融合到一个 kernel 里,让 NVLink 传输和 Tensor Core 计算时间上完全重叠。

传统 EP 里通信和计算串行,GPU 利用率只有 50-60%;MegaMoE 通过 Symmetric Memory + 细粒度 Scheduler + 单 Kernel 状态机,把利用率推到接近 100%。


传统 EP MoE vs MegaMoE

传统 EP(如 DeepEP)

1
2
3
4
时间 →

[dispatch all-to-all] → 等完 → [GEMM1] → [SwiGLU] → [GEMM2] → 等完 → [combine all-to-all]
NVLink 通信 空闲 计算 计算 计算 空闲 NVLink 通信

问题:通信和计算串行,GPU 要么在算要么在传,利用率约 50-60%。

MegaMoE

1
2
3
4
5
6
时间 →

┌───────────────────── 一个 Mega Kernel ─────────────────────┐
│ NVLink dispatch ←→ GEMM1 ←→ SwiGLU ←→ GEMM2 ←→ NVLink combine │
│ (通信和计算同时进行,流水线式重叠) │
└──────────────────────────────────────────────────────────────┘

通信隐藏在计算背后,GPU 利用率接近 100%。


核心实现机制

1. Symmetric Memory(对称内存)

传统 all-to-all:

1
GPU_0 → cudaMemcpyAsync → GPU_1(需要显式同步,退出 kernel)

Symmetric Memory:

1
2
3
所有 GPU 共享一片 symmetric buffer
GPU_0 直接写入 GPU_1 的 symmetric buffer(RDMA over NVLink)
GPU_1 看到数据就开始算,不需要全局 barrier

关键:不需要等所有 token 传完再开始算。传一批就算一批(streaming)。

实现层级

  • 硬件层:NVSwitch 全连接,任何 GPU pair 都有直达链路,memory controller 识别远端地址自动走 NVLink
  • 驱动/Runtime 层:CUDA VMM 把远端 GPU 物理内存映射到本地虚拟地址(cuMemMap / nvshmem_ptr
  • 应用层:MegaMoE 直接用普通指针读写远端 buffer,一条 PTX load/store 指令完成

注意:只有 symmetric allocation 的那块 buffer 地址一致,不是整个 GPU 显存对称。

2. 单 Kernel 状态机

传统方式多个 kernel 之间有全局同步点,无法实现 fine-grained overlap:

1
kernel_1 (dispatch) → kernel 边界(全局 barrier)→ kernel_2 (GEMM1) → ...

MegaMoE 的 scheduler/mega_moe.cuh 用一个状态机让所有 SM 在同一个 kernel 内自主领任务:

1
2
3
4
5
6
7
8
9
10
enum class BlockPhase { None, Linear1, Linear2 };

while (true) {
auto block = scheduler.get_next_block();
switch (block.phase) {
case Linear1: /* 做 W1/W3 GEMM */ break;
case Linear2: /* 做 W2 GEMM */ break;
case None: return;
}
}

没有 kernel 边界,没有全局 barrier,通信和计算可以交错到 warp 级别。

3. Wave-Based Expert Processing

256 个 expert 太多,无法同时处理(Token Pool 显存有限),分 wave 分批:

1
2
3
4
Wave 0: expert [0..31]   → 处理完 → 释放 pool
Wave 1: expert [32..63] → 处理完 → 释放 pool
...
Wave 7: expert [224..255]

每个 wave 内:

  • 所有 SM 抢 L1 block(gate + up projection)
  • SwiGLU 在 SMEM 中直接做激活
  • 所有 SM 抢 L2 block(down projection)

kNumExpertsPerWave 由 SMEM 容量、BLOCK_M、pool_capacity 共同决定,让 wave 内 expert 数刚好喂饱所有 SM 又不撑爆 Token Pool。

4. Arrival Count 轮询(细粒度 Overlap 的关键)

1
2
3
4
// 轮询等待:不是等所有 token 到齐,而是等够一个 BLOCK_M 就开始算
while (volatile_count < kNumSMs * kNumRanks) {
// spin-wait 直到其他 SM/rank 报告完成
}
  • 不需要全局 barrier(不需要所有 token 都收完才开始)
  • 收到一个 block 的 token 就立刻开始算
  • 这就是 streaming overlap 的本质

5. L1/L2 交错

1
2
Expert_0: [L1 block 0][L1 block 1][L1 block 2] → SwiGLU → [L2 block 0][L2 block 1]
Expert_1: [L1 block 0][L1 block 1] → SwiGLU → [L2 block 0]

L1 结果留在 Token Pool 中,L2 直接从 pool 读(在 L2 cache 中热),一个 wave 结束后才释放 pool 空间。

对应关系:

MegaMoE 术语 权重矩阵 操作 含义
L1 (Linear 1) W1 + W3(拼在一起) gate + up projection hidden → 2×intermediate
SwiGLU 无权重 SiLU(gate) × up 激活函数
L2 (Linear 2) W2 down projection intermediate → hidden

W1 和 W3 拼在一起做一次 GEMM,可以复用 TMA 加载的 activation tile,减少 GMEM → SMEM 搬运。


SM/Warp 级别并行

GPU 硬件层次:

1
2
3
4
GPU (整块芯片)
└── SM (Streaming Multiprocessor) × 132 (H20) / 192 (B200)
└── Warp × 最多 64 resident / SM
└── Thread × 32 / Warp(SIMT 锁步执行)

在 MegaMoE 中,一个 Mega Kernel launch 占满所有 SM,每个 SM 内部:

1
2
Warp Group A(若干 warp):负责通信(polling + NVLink store/load)
Warp Group B(若干 warp):负责计算(Tensor Core MMA)

Warp Group A 和 B 在 SM 内部并行执行——这就是”通信和计算在 warp 级别交错”的含义。


为什么必须 GB200/GB300

MegaMoE 依赖三个 GB200/GB300 独有(或大幅增强)的硬件能力:

硬件特性 GB200/GB300 H20 影响
GPU 间拓扑 72 GPU NVSwitch 全连接 8 GPU NVLink mesh H20 无法 symmetric memory 跨机
Symmetric Memory ✅ kernel 内 load/store 远端 ❌ 需要 NCCL API 无法做 in-kernel 通信
NVSwitch Reduction ✅ 硬件完成 combine ❌ SM 做 reduce combine 占 SM 资源
FP4 Tensor Core ✅ 原生指令(SM100) ❌ 软件模拟(SM90) 计算慢 → 通信占比低 → overlap 收益有限
NVLink 带宽/GPU 900 GB/s(18 ports) 450 GB/s(NV18 mesh) H20 通信带宽是瓶颈

为什么计算足够快才值得 overlap

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
假设: 通信时间 = T_comm, 计算时间 = T_compute

情况 A — 计算快 (FP4 Tensor Core):
T_compute = 2ms, T_comm = 5ms
不 overlap: 7ms
overlap: max(2, 5) = 5ms ← 省 28%

情况 B — 计算慢 (BF16 软件模拟):
T_compute = 20ms, T_comm = 5ms
不 overlap: 25ms
overlap: max(20, 5) = 20ms ← 只省 20%

而且 overlap 不是免费的:
- 需要 SM 专门做通信 warp(减少计算并行度)
- Symmetric Memory 的 load/store 比本地慢 10-20x
- Scheduler 状态机有开销

只有当计算非常快、通信成为瓶颈时,overlap 的净收益才大于开销。

和 DeepEP / flashinfer_mxfp4 的区别

MegaMoE vs DeepEP

DeepEP MegaMoE
通信粒度 整个 batch dispatch 完再算 tile 级别通信和计算交错
overlap 方式 不同 CUDA stream 异步(kernel 间) 同一个 kernel 内 tile 级 pipeline
权重格式 任意 Runner backend 都能接 只支持 DeepGEMM 格式(权重需预转换)
内存模型 普通 GPU buffer NVSHMEM 对称内存 (SymmBuffer)
batch 上限 无(取决于显存) 有(NUM_MAX_TOKENS_PER_RANK

MegaMoE vs flashinfer_mxfp4

flashinfer_mxfp4 MegaMoE
解决的问题 单卡 MoE 计算效率 多卡 EP 通信+计算 overlap
Fuse 范围 GEMM1 + SwiGLU + GEMM2 dispatch + GEMM1 + SwiGLU + GEMM2 + combine
通信 不涉及(单卡或假设已 gather 好) 融合 All-to-All 通信
适用场景 TP 模式(experts 复制在每卡) EP 模式(experts 分布在不同卡)
硬件要求 任何 Hopper/Blackwell 需要 NVLink + Symmetric Memory(GB200+)
状态 已发布,可用 开发中(DeepGEMM PR #304)

两者互补:flashinfer_mxfp4 管单卡 MoE kernel 效率,MegaMoE 管 EP 多卡通信+计算融合。


Python API 调用流程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# 1. 多进程初始化 + 分配对称内存
buffer = deep_gemm.get_symm_buffer_for_mega_moe(
group, # 进程组 (torch.distributed)
num_experts=256,
num_max_tokens_per_rank=8192,
num_topk=6,
hidden=4096,
intermediate_hidden=2048
)

# 2. 权重预处理(FP4 packing + 布局变换)
l1_transformed, l2_transformed = deep_gemm.transform_weights_for_mega_moe(
l1_weights, # W1, W3 (gate + up)
l2_weights # W2 (down)
)

# 3. 加载输入到 symmetric buffer
buffer.x[:N].copy_(x_fp8) # FP8 激活
buffer.x_sf[:N].copy_(x_sf) # 激活的 scale factor
buffer.topk_idx[:N].copy_(topk_idx) # 路由结果
buffer.topk_weights[:N].copy_(topk_weights)

# 4. 一次调用,完成 dispatch + GEMM1 + SwiGLU + GEMM2 + combine
y = torch.empty((N, 4096), dtype=torch.bfloat16, device='cuda')
deep_gemm.fp8_fp4_mega_moe(y, l1_transformed, l2_transformed, buffer)
# y 就是最终结果,全程只有这一个 kernel launch

总结:MegaMoE 的关键创新

机制 作用 依赖硬件
Symmetric Memory kernel 内直接 load/store 远端 GPU NVSwitch 全连接
Arrival Count 轮询 收到一个 block 就开始算,不等全部 低延迟 NVLink
Wave-based Scheduler SM 自主领任务,无全局 barrier SM 数量充足(Blackwell 192 SM)
L1/L2 交错 L1 结果留在 pool,L2 直接读 大 L2 cache(Blackwell 100MB+)
FP8×FP4 GEMM 计算足够快,才值得 overlap 通信 原生 FP4 Tensor Core(SM100)

一句话:MegaMoE = Symmetric Memory(消除通信 barrier)+ 细粒度 Scheduler(一个 block 到了就算)+ 单 Kernel 状态机(避免 kernel launch 开销),把 EP MoE 的 GPU 利用率从 ~60% 推到接近 100%。


局限性与适用边界

  • 权重格式锁定:只支持 DeepGEMM 的 FP4 权重布局,需要预转换
  • 激活函数限制:依赖 SwiGLU 无状态特性,L1 结果可直接给 L2 用;若 MoE 层有跨 token 依赖(如全局 norm),pipeline 会断
  • 硬件绑定:Symmetric Memory + NVSwitch 全连接 + 原生 FP4 三个条件缺一不可,H20 及更早硬件无法使用
  • 跨机不支持:Symmetric Memory 只在单机 NVSwitch 域内有效,跨机仍需回退到 DeepEP
  • 负载不均影响fetch_expert_recv_count() 的 spin-wait 在 expert 负载不均时可能导致 SM 空转

参考