SGLang Pipeline Parallelism 深度分析
Pipeline Parallelism(PP)是多卡推理中的核心技术之一。SGLang 的 PP 实现有一套独立的事件循环和调度机制,和普通的 single-batch 路径完全不同。本文将从代码出发,深入分析 SGLang 的 PP 实现。
为什么需要 Pipeline Parallelism
单机推理的瓶颈
单卡推理时,GPU 利用率受限于模型大小和显存。一个 70B 模型单卡放不下,必须拆分到多卡。
模型并行方案对比
| 方案 | 原理 | 优点 | 缺点 |
|---|---|---|---|
| TP (Tensor Parallelism) | 按层切分权重 | 通信少,延迟低 | 受限于单节点卡数 |
| PP (Pipeline Parallelism) | 按层切分模型 | 支持任意层数 | 存在 pipeline bubble |
| DP (Data Parallelism) | 每卡完整模型 | 简单,吞吐高 | 显存需求不变 |
SGLang 通常 TP + PP 组合使用:节点内 TP,跨节点 PP。
SGLang PP 架构概览
PP 路径的独立事件循环
SGLang 的调度器有三条事件循环路径:
event_loop_normal— 普通单 batch 调度event_loop_overlap— 计算和通信 overlap 的调度event_loop_pp— PP 路径的独立事件循环
PP 路径有自己独立的事件循环,定义在 scheduler_pp_mixin.py 的 SchedulerPPMixin 类中。它和 normal/overlap 路径完全隔离,因为 PP 的通信模式(send/recv proxy tensors 跨 stage)和单机调度逻辑差异太大,硬塞进同一个 event loop 会很脏。
event_loop_pp 核心流程
1 | # scheduler_pp_mixin.py 第 72-145 行 |
关键设计:
- 每个 microbatch 槽位独立管理状态
- async send + recv 实现计算和通信的 overlap
- 主循环轮转处理每个 microbatch,填充 pipeline bubble
Microbatch:填充 Pipeline Bubble
问题:Pipeline Bubble
PP2(2 个 stage)如果只用一个 batch,会出现大量空闲:
1 | 时间 → |
Stage 0 做完 batch 0 后必须等 Stage 1 处理完才能做 batch 1(因为只有一个 batch 在 pipeline 里流转),一半的时间都在空闲。这就是 pipeline bubble。
解法:Microbatch
把多个 batch 同时塞进 pipeline,这些 batch 就叫 microbatch:
1 | 时间 → |
Stage 0 做完 mb0 不用等,立刻做 mb1;Stage 1 也紧接着处理,bubble 大幅减少。
Microbatch 数量的确定
1 | # scheduler_pp_mixin.py 第 514 行 |
pp_size= 2(PP2)→ 至少 2 个 microbatch 槽位pp_async_batch_depth→ 额外的 buffer 深度,进一步隐藏延迟
每个 microbatch 都有独立的状态:
1 | self.mbs = [None] * self.pp_loop_size # 当前 batch |
Microbatch vs 正常 Batch
| 维度 | 正常 batch | Microbatch(PP) |
|---|---|---|
| 数量 | 同时 1 个 | 同时 pp_size + async_depth 个 |
| 目的 | 普通调度 | 填充 pipeline bubble,提升利用率 |
| 状态管理 | 单个 running_batch | 数组 running_mbs[mb_id] 各自独立 |
| 本质 | 就是一个 batch | 也是普通 batch,只是多个在 pipeline 中交替执行 |
Microbatch 和 batch 在数据结构上没有区别(都是 ScheduleBatch),区别只在于 PP 需要多个 batch 同时在不同 stage 中流转,所以给它们编号叫 microbatch。
Microbatch 数量不是越多越好
1 | pp_loop_size = pp_size + pp_async_batch_depth |
- 太少(= pp_size):pipeline 容易断流,一个 stage 稍慢就导致下游空闲
- 太多:显存占用线性增长(每个 microbatch 都要有自己的 KV cache + activation),而且可能因为显存压力导致 batch size 被迫缩小,反而吞吐下降
调优经验:
pp_async_batch_depth = 1~2通常足够,再深收益递减- 显存充裕时可以试更大的 depth,但要注意 activation memory 的开销
- 如果 single batch 的 token 数很大(长上下文),microbatch 多了容易 OOM
PP Rank 的请求接收路径
recv_requests() 根据 PP rank 的不同,有两条完全不同的接收路径:
PP Rank 0(第一个 stage)
从上游组件直接接收用户请求,走两条 zmq 通道:
1 | # 1. 从 tokenizer 接收(zmq 通道) |
接收的是 tokenize 后的请求对象,类型为:
TokenizedGenerateReqInput— 生成请求TokenizedEmbeddingReqInput— embedding 请求- 其他控制类消息(flush、abort 等)
PP Rank > 0(后续 stage)
不直接从 tokenizer 接收,而是从前一个 PP rank P2P 转发过来:
1 | recv_reqs = point_to_point_pyobj( |
对应 event_loop_pp 中的流转
1 | # event_loop_pp 第 80-88 行 |
流转全图
1 | Tokenizer / RPC |
关键设计:所有 PP rank 收到的内容是一样的(同样的请求列表),这样每个 stage 都能为同一批请求做 get_next_batch_to_run() 调度,保证各 stage 的 batch 一致。
process_input_requests:请求路由器
recv_requests() 收到请求后,process_input_requests() 负责按类型分发到对应的 handler:
1 | # scheduler_pp_mixin.py 第 1545-1566 行 |
_request_dispatcher 是一个 TypeBasedDispatcher,根据请求类型路由到不同 handler:
核心推理请求:
| 请求类型 | Handler | 说明 |
|---|---|---|
TokenizedGenerateReqInput |
handle_generate_request |
生成请求 → 加入 waiting_queue |
TokenizedEmbeddingReqInput |
handle_embedding_request |
Embedding 请求 |
BatchTokenized* |
handle_batch_* |
批量请求 |
控制类请求:
| 请求类型 | Handler | 说明 |
|---|---|---|
AbortReq |
abort_request |
中止请求 |
FlushCacheReqInput |
flush_cache_wrapped |
清空 KV cache |
OpenSessionReqInput |
open_session |
开启会话 |
CloseSessionReqInput |
close_session |
关闭会话 |
ProfileReq |
profile |
性能分析 |
SlowDownReqInput |
slow_down |
降速 |
PauseGenerationReqInput |
pause_generation |
暂停生成 |
权重管理请求:
| 请求类型 | Handler | 说明 |
|---|---|---|
UpdateWeightsFrom* |
update_weights_from_* |
更新模型权重(disk/distributed/tensor/IPC) |
LoadLoRAAdapterReqInput |
load_lora_adapter |
加载 LoRA |
UnloadLoRAAdapterReqInput |
unload_lora_adapter |
卸载 LoRA |
一句话总结:process_input_requests 就是一个请求路由器——生成请求进入 waiting_queue 等待后续调度,控制请求立即处理并回送结果,所有 PP rank 都执行同样的分发逻辑保证状态一致。
_pp_recv_proxy_tensors:接收上游 Stage 的中间结果
_pp_recv_proxy_tensors(scheduler_pp_mixin.py 第 993-1004 行)负责从上一个 PP stage 接收 forward 的中间激活值:
1 | def _pp_recv_proxy_tensors(self: Scheduler) -> Optional[PPProxyTensors]: |
行为:
- PP Rank 0(first rank):返回
None,因为它是第一个 stage,没有上一级给它传 hidden states,直接用 embedding 层的输出开始 forward - PP Rank > 0:从前一个 PP rank 接收 tensor dict(
msg_type="proxy"),包装成PPProxyTensors
Proxy tensors 是什么:
看 profile_and_init_predictor 里的构造(第 609-622 行)就清楚了:
1 | proxy_tensors = { |
就是上一个 PP stage forward 输出的中间激活值:
hidden_states— 当前层的隐藏状态residual— 残差连接
在 event_loop_pp 中的位置:
1 | if self.cur_batch: |
对应的发送端:在 event_loop_pp 末尾(第 129-139 行),非 last rank 发送 proxy:
1 | if not self.pp_group.is_last_rank: |
一句话总结:_pp_recv_proxy_tensors 就是非首 rank 从前一个 PP stage 接收 forward 中间结果(hidden_states + residual),这样本 stage 的模型层才能接着算。这是 PP 的核心数据流——每个 stage 只有部分层,需要前一个 stage 的输出作为输入。
_pp_commit_send_output_work_and_preprocess_output_tensors:Output 的收发与预处理
这个函数(scheduler_pp_mixin.py 第 854-873 行)负责处理 forward 结果的 output tensors 的发送、接收和预处理:
1 | def _pp_commit_send_output_work_and_preprocess_output_tensors( |
核心逻辑在 _pp_send_recv_and_preprocess_output_tensors(第 1081-1116 行),做了三件事:
1. 发送 output(last rank → rank 0)
1 | # Last rank: 从 last_rank_comm_queue 取出 forward 结果,发给 rank 0 |
2. 接收 output(rank 0 从 last rank 收)
1 | if mbs[next_mb_id] is not None: |
3. 预处理:GPU→CPU 拷贝(copy stream 上)
1 | with self.copy_stream_ctx: |
Output tensors 里有什么:
看 _pp_prepare_tensor_dict(第 920-933 行):
1 | tensor_dict = { |
这是 last rank 采样后的最终结果,不是中间的 hidden states。
在 event_loop_pp 中的调用位置:
1 | for mb_id in range(self.pp_loop_size): |
pp_async_batch_depth 的影响:
pp_async_batch_depth |
output 处理时机 | 效果 |
|---|---|---|
> 0 |
forward 之前处理上一个 mb 的 output | output 处理和当前 forward 重叠,隐藏延迟 |
== 0 |
forward 之后再处理 | 串行执行,不重叠 |
一句话总结:_pp_commit_send_output_work_and_preprocess_output_tensors 负责把 last rank 的采样结果(next_token_ids、logprobs)从 GPU 传回 rank 0,并在 rank 0 上做 GPU→CPU 拷贝和结果解包。这是 pipeline 的”最后一公里”——把最终输出送回给 tokenizer/rpc 层。
数据流全图:
1 | Last Rank 中间 Rank Rank 0 |
PP 路径中的 HiCache(分层 KV Cache)
HiCache(Hierarchical Cache)是 SGLang 的多级 KV Cache 管理机制,把 KV cache 存在 L3(CPU 内存或磁盘),需要时再加载到 GPU。在 PP scheduler 中,HiCache 不在 event_loop_pp 本身,而是在 event loop 调用的子流程中发挥作用。
HiCache 在 PP 调度中的四个阶段
1. 请求入队时 — 预取(Prefetch)
1 | # process_input_requests → handle_generate_request → _add_request_to_queue → _prefetch_kvcache |
时机:请求刚进入 waiting_queue 时,立即发起异步预取,利用等待调度的时间把 KV cache 从外部存储搬到 GPU。
2. 调度组 batch 时 — 检查预取进度 + 加载
get_next_batch_to_run → _get_new_batch_prefill_raw 中有三处:
1 | # 2a. 检查异步事件完成(第 2282-2283 行) |
3. 请求被 abort 时 — 释放预取资源
1 | # 在多个 abort 路径中(第 1965、1997、3146 行) |
4. flush cache 时 — 等待异步操作完成
1 | # 第 2877 行 |
check_hicache_events():异步事件推进器
check_hicache_events() 是 HiCache 的核心——每轮调度前轮询一次,把已完成的 GPU↔️Host↔️L3 异步拷贝操作确认掉(解锁节点、释放资源、推进状态)。
1 | # tree_cache.py 第 1136-1144 行 |
① writing_check()(第 676-717 行)— GPU→Host 写入
KV cache 从 GPU 写到 Host 内存(备份),是异步的。这个函数检查哪些写入完成了:
1 | # 遍历 ack_write_queue 里的 CUDA event |
作用:GPU 上的 KV cache 备份到 Host 完成后,解锁节点(可以被驱逐释放 GPU 显存),如果开了 L3 存储还会继续往 L3 写。
② loading_check()(第 719-732 行)— Host→GPU 加载
KV cache 从 Host 内存加载回 GPU,也是异步的:
1 | for _, finish_event, ack_list in self.cache_controller.ack_load_queue: |
作用:之前从 Host 加载回 GPU 的 KV cache 完成后,解锁节点,这些 KV cache 就可以被 forward 使用了。
③ drain_storage_control_queues()(第 1146-1172 行)— L3 存储控制
处理三个队列:
1 | cc = self.cache_controller |
| 队列 | 作用 |
|---|---|
prefetch_revoke_queue |
取消不再需要的 L3→Host 预取 |
ack_backup_queue |
确认 Host→L3 备份写入完成 |
host_mem_release_queue |
释放不再需要的 Host 内存 |
HiCache 三级缓存的数据流
1 | check_hicache_events() 推进的异步操作 |
HiCache 数据流
1 | 请求进入 waiting_queue |
一句话总结:HiCache 在 PP 调度中贯穿请求的完整生命周期——入队时异步预取、调度时检查进度、abort 时释放资源、flush 时等待完成。核心思想是利用请求在 waiting_queue 中的等待时间,提前把 KV cache 从 L3 搬到 GPU,隐藏加载延迟。
Host 内存驱逐:LRU 最小堆
当 host 内存不足时,evict_host() 使用 LRU 最小堆来决定驱逐哪些节点。
1 | # 1. 收集候选节点 |
三元组比较:
| 位置 | 内容 | 含义 |
|---|---|---|
| 第 1 个 | get_priority(node) |
LRU 策略下就是 last_access_time,值越小 = 越久没访问 = 越优先驱逐 |
| 第 2 个 | _evict_tie_breaker(node) |
当 priority 相同时的打破平局:PP>1 用内容 hash(跨 rank 一致),PP=1 用 node.id |
| 第 3 个 | node |
实际要驱逐的节点对象 |
PP 场景下的 tie-breaker 设计:
- PP=1:用
node.id作为 tie-breaker,简单高效 - PP>1:用内容 hash(
node.hash_value)作为 tie-breaker,保证不同 rank 对相同节点计算出相同的优先级顺序,确保驱逐一致性
一句话总结:把所有可驱逐的 host 叶子节点放进 LRU 最小堆,循环弹出堆顶(最久没访问的)逐个驱逐,直到释放够 num_tokens 个 token 的 host 内存。tie-breaker 保证 PP 多 rank 场景下驱逐顺序一致。
PP + HiCache 的 Tree Cache 同步问题(PR #22878)
PP + HiCache 组合使用时存在一个经典问题:各 PP rank 的 writing_check() 独立消费 GPU→Host 的 write-ack,由于各 rank 的 batch 内容和异步进度不同,消费数量可能不一致,导致各 rank 的 radix tree 状态出现分歧(哪些节点已备份、哪些可以驱逐不一致)。
解决方案:上游 rank 的消费数作为下游的预算上限,通过 piggyback 在已有的 PP 请求转发中传递。
1. writing_check() 末尾,PP rank 0 记录本轮消费数
1 | # hiradix_cache.py |
2. 通过 recv_reqs 捎带传给下游(scheduler_pp_mixin.py)
1 | # 非 last rank 发送请求时,把 ack_count 捎带上 |
3. 下游 rank 接收并设置预算(scheduler.py)
1 | # recv_requests() 中,PP rank > 0 解包捎带的 ack_count |
4. 下游 rank 的 writing_check 受限于上游预算
1 | # writing_check() 中,下游 rank 不能消费超过上游的数量 |
数据流:
1 | PP Rank 0 PP Rank 1 |
设计亮点:不新增通信通道,而是 piggyback 在已有的 _pp_send_pyobj_to_next_stage 上,仅 +61 行代码解决问题。
一句话总结:把上游 PP rank 的 write-ack 消费数量捎带在 recv_reqs 的 pyobj 通信上传给下游,下游以此为上限限制自己的消费,确保所有 PP rank 的 radix tree 状态变更保持一致,防止缓存驱逐决策不同步。
PP + HiCache 一致性修复进度(Issue #22607):
| 任务 | PR 状态 |
|---|---|
| CP 同步 | PR #20460 |
| 逻辑时钟替代 time.monotonic() | PR #22759(解决出错 3) |
| PD 模块 | PR #22771 |
| Mooncake CP 写控制 + PP storage key 隔离 | 待做 |
| Channel A: Host tree event replay | 待做 |
| Channel B: Write-back count sync | PR #22878 ✅ |
| Channel C: L3 hit delegation | 待做 |
| Zero-hit deferred revoke | 待做 |
| PP prefill 诊断工具 | 待做 |
Channel A:Host Tree Event Replay
Channel B(write-back count sync)只限制 PP1 不超前于 PP0,但不能保证两个 rank 在每一轮都消费相同数量。如果 PP1 的 CUDA event 本地没完成,min(local_ready=0, budget=3) = 0,PP1 这轮仍然消费 0。这意味着 host tree 的状态变更时机在两个 rank 之间可以有偏差,一旦偏差积累,就会导致 host tree 结构不同。
Channel A 的思路:不指望两个 rank 独立做出相同决策,而是让 PP0 做决策,PP1 回放 PP0 的结果。
具体场景:PP2 下的 host tree 分歧
假设 PP2 配置:PP0 管 layer 0-29,PP1 管 layer 30-59,两个 rank 各自独立维护一棵 radix tree。
第 1 步:正常运行,两个 rank 一致
请求 A、B、C 进来,两个 rank 收到相同请求、做相同调度。节点 X(系统提示词的 KV cache)在两个 rank 的 GPU 上都存在:
1 | PP0 radix tree: root → X(value=GPU, host_value=None, hit_count=5) |
第 2 步:hit_count 达到阈值,触发 write_backup
1 | PP0: write_backup(X) → CUDA 异步拷贝启动 → X 加入 ongoing_write_through |
第 3 步:writing_check 完成时间分歧(关键!)
1 | PP0: CUDA event 已完成 → finish_count=1 → X.lock_ref=0 |
Channel B 此时 PP0 把 consumed_count=1 传给 PP1,PP1 的 budget 变成 1。但 PP1 本地 finish_count=0,min(0, 1)=0,这轮 PP1 仍然消费 0。Channel B 只保证 PP1 不会超过 PP0,不能强迫 PP1 追上 PP0。
第 4 步:GPU 内存压力 → PP0 驱逐 X
1 | PP0: X 在 evictable_leaves 中(lock_ref=0, backuped=True) |
第 5 步:PP0 的 Host 内存也满了 → evict_host 删除 X
1 | PP0: write_backup(Z) → host 内存不足 → evict_host() |
第 6 步:新请求匹配前缀 → crash
1 | PP0: _match_prefix_helper → 找不到 X → device_indices=[], host_hit_length=0 |
Channel A 如何解决
1 | PP0 (权威源) PP1 (回放) |
三个 Channel 的分工
1 | 问题链条: |
一句话总结:Channel B + 逻辑时钟是”努力让两个 rank 独立做出相同决策”,但由于异步操作的固有非确定性,无法 100% 保证。Channel A 是”放弃独立决策,直接同步结果”,从根本上消除分歧。
PP 路径的调度细节
get_next_batch_to_run
这个调度器在 PP 路径下要考虑的东西比 normal 多得多:
- 当前 stage 的显存 — 不能 OOM
- Microbatch 的依赖关系 — 上下游 stage 的进度
- Pipeline 的吞吐 — 尽量让每个 stage 都有活干
Async Send/Recv 的 Overlap
第 2 步 send reqs 和 recv proxy tensors 是 async 的,理想情况下计算和通信能 overlap。但如果 recv 阻塞了,整个 loop 就卡住了。
1 | Stage 0: [compute mb0] → [send hidden to S1] → [recv hidden from S1] → [compute mb1] |
通信和计算的重叠程度决定了 pipeline 的效率。
PP 路径 vs 非 PP 路径对比
非 PP 路径(event_loop_normal / event_loop_overlap)
1 | 时间 → |
调度一个 batch,forward 一个 batch,处理结果,再调度下一个。同一时刻只有一个 batch。
PP 路径(event_loop_pp)
1 | 时间 → |
多个 microbatch 同时在不同 stage 中流转,主循环轮转处理每个 microbatch。
总结
SGLang 的 PP 实现有几个关键设计:
- 独立事件循环 — PP 路径和 normal/overlap 路径完全隔离,避免复杂度交叉
- Microbatch 填充 bubble — 通过
pp_loop_size = pp_size + async_depth控制 pipeline 中的并发度 - Async 通信 — send/recv 异步化,尽量 overlap 计算和通信
- 独立状态管理 — 每个 microbatch 槽位有独立的 running_batch、last_batch 等状态
PP 的核心挑战始终是 如何减少 bubble、提升 GPU 利用率,同时控制显存开销。Microbatch 数量的调优是 PP 性能调优的关键。
参考
- SGLang 源码:
scheduler_pp_mixin.py - Pipeline Parallelism 论文:GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism