ggaaooppeenngg

为什么计算机科学是无限的但生命是有限的

参考 Benchmarking Text Generation Inference

参考 SGLang issue 364

参考 LLM inference server performances comparison llama.cpp / TGI / vLLM

相关代码:

sglang bench

vLLM bech prefix cache

vLLM bench serving

TokenAttention和PagedAttention,感觉TokenAttention是个很离谱的设计,而Radix的话和PagedAttention的颗粒度不是完全对应的。

vLLM 的默认block size最多是32,虽然这个32对应的字符串长度不是固定的,一般一个Token平均对应4个字母,所以有效前缀大概120比较合适。

前缀重复度

为了能够测试不同数据集的前缀重复度,需要一种方法衡量对话的前缀重复度,如果前缀的重复度不高,可能测试结果不太能体现前缀缓存的优势。

对于所有的对话构造一个Radix树,每个树节点保存一个计数器记录经过该节点的字符串的数量。

计数重复前缀的数量,比如W这个前缀是比较多的因为很多英文问句都是Wh-开头的,而中文的话是比较随机的。

对于每个节点,在进行计数器过滤的时候,要一直遍历到某个节点的子节点都小于计数器N才结束,这样防止过滤出多个公共前缀的前缀,
因为较短的前缀肯定是被较长的前缀包含的。相当于对这棵树做剪枝,删除所有计数器小于过滤值的节点。

再从满足要求的所有被剪枝完的叶子结点中选择长度大于L的前缀。

对话数据集的前缀重复度 = 基于N剪枝的所有长度大于L的叶子前缀节点数 / 所有对话数量

压力测试数据集

  • databricks-dolly-15k 这个数据集的前缀重复度不高。
    只有两个前缀长度超过00,重复次数大于1,因为里面都是单轮的对话。
    (‘Extract all of the dates mentioned in this paragraph and list them using bullets in the format {Date} - {Description}’, 11) (‘Extract all of the names of people mentioned in this paragraph and list them using bullets in the format {Name}’, 15)
  • LMSYS-CHAT-1M
    一个parquet有16W个对话。前缀重复比较高的是30~40次。这样的对话有9483条,也就占总数的5%,重复前缀的平均长度只有300左右。
  • ShareGPT这是vLLM官方使用的一个压测数据集。压测脚本在。这个的比重也只有2%,重复前缀的平均长度是4K。

以上数据集可能对于前缀缓存的优势体现不太明显。

  • 测试工具

    • sglang inference benchmark
  • 测试参数

    • batch_size: 30
    • max_length: 4096
    • num_samples: 1000
  • 测试结果

    • TTFT
    • TBT
    • Throughput

构造数据集

用实际的数据集结果不是特别好,差异度不是很高,因为这些数据集的前缀重复度比重都不是很高。
没有特别好的现成的数据集,需要使用人工构造的方式去构造数据集。

sglang 的benchmark提供了 generated-shared-prefix dataset arguments相关的参数。
他是通过随机生成一个系统提示词再组合问题,但是Prompt是随机的。语言不是很明朗。但可能并不
影响测试效果。

比较理想的应该是认为构造一些长度的系统提示词加一些问题进行组合,这个可读性会更高一点,但是没那么灵活
不太好按要求生成指定上下文长度的提示词。

测试结果

结果来看,在batch size更大的情况下,TTFT会变得特别长,而TBT也会相应的增加一些但没有TTFT恐怖。
batch size变大以后,TTFT从300s变成了900s,而ITL则从0.2s变成了0.3s。
这和MoonCacke的论文是一致的。

测试一下PD分离的效果,使用vLLM的1P1D。
PD分离以后TTFT可以降低一个数量级,这个效果还是很明显的,直接降了一个数量级。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
============ Serving Benchmark Result ============
Backend: vllm
Traffic request rate: inf
Max reqeuest concurrency: not set
Successful requests: 47
Benchmark duration (s): 127.03
Total input tokens: 14545
Total generated tokens: 2993
Total generated tokens (retokenized): 2992
Request throughput (req/s): 0.37
Input token throughput (tok/s): 114.50
Output token throughput (tok/s): 23.56
Total token throughput (tok/s): 138.06
Concurrency: 24.49
----------------End-to-End Latency----------------
Mean E2E Latency (ms): 66177.90
Median E2E Latency (ms): 61336.75
---------------Time to First Token----------------
Mean TTFT (ms): 39888.70
Median TTFT (ms): 22421.85
P99 TTFT (ms): 116090.20
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 491.86
Median TPOT (ms): 394.97
P99 TPOT (ms): 1917.39
---------------Inter-token Latency----------------
Mean ITL (ms): 419.69
Median ITL (ms): 275.52
P99 ITL (ms): 1766.40
==================================================

双v100 LLAMA3.2:11b

python -m sglang_router.launch_router --worker-urls http://127.0.0.1:8081 http://127.0.0.1:8082

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
============ Serving Benchmark Result ============
Backend: vllm
Traffic request rate: inf
Max reqeuest concurrency: not set
Successful requests: 1000
Benchmark duration (s): 1247.16
Total input tokens: 289255
Total generated tokens: 184429
Total generated tokens (retokenized): 184388
Request throughput (req/s): 0.80
Input token throughput (tok/s): 231.93
Output token throughput (tok/s): 147.88
Total token throughput (tok/s): 379.81
Concurrency: 470.04
----------------End-to-End Latency----------------
Mean E2E Latency (ms): 586218.50
Median E2E Latency (ms): 596155.97
---------------Time to First Token----------------
Mean TTFT (ms): 520113.99
Median TTFT (ms): 526194.47
P99 TTFT (ms): 1067230.41
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 363.05
Median TPOT (ms): 356.14
P99 TPOT (ms): 736.93
---------------Inter-token Latency----------------
Mean ITL (ms): 360.61
Median ITL (ms): 273.54
P99 ITL (ms): 1525.31
==================================================

双卡的并发的情况下,吞吐可以线性增长,但是相较于1P1D来说,prefill的时间没有改善。

笔者参考dynamo尝试实现了一个基于NCCL版本的P2P的xPyD的PD分离

基于8卡的L40进行了并发100个prompts的测试。每两卡之间是有一个NVLINK其他的卡之间全部是PCIe。

笔者对于kv 传输的group切分如下。

如果是2P4D的话就是这么划分:

如果是4P4D的话。

从测试结果可以看出来单机多卡的PD分离能够降低TBT(TPOT),一个TP的decode就已经超过8TP的decode了,这里主要是因为没有了prefill的干扰。
但是TTFT相对变大了,这个可能是TTFT多了一次传输的时间,具体原因不知道是不是我的实现方式不对,还是因为4TP的prefill就是要慢一些。
这个可能需要一个更综合性的tuning。

4P4D(4TP+4TP) V0调度器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
============ Serving Benchmark Result ============    
Backend: vllm
Traffic request rate: inf
Max reqeuest concurrency: not set
Successful requests: 100
Benchmark duration (s): 37.95
Total input tokens: 34965
Total generated tokens: 20654
Total generated tokens (retokenized): 20654
Request throughput (req/s): 2.64
Input token throughput (tok/s): 921.42
Output token throughput (tok/s): 544.29
Total token throughput (tok/s): 1465.71
Concurrency: 46.49
----------------End-to-End Latency----------------
Mean E2E Latency (ms): 17642.84
Median E2E Latency (ms): 14298.02
---------------Time to First Token----------------
Mean TTFT (ms): 8147.33
Median TTFT (ms): 8393.70
P99 TTFT (ms): 8834.48
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 58.38
Median TPOT (ms): 50.79
P99 TPOT (ms): 162.16
---------------Inter-token Latency----------------
Mean ITL (ms): 46.27
Median ITL (ms): 42.21
P99 ITL (ms): 56.21
==================================================

8TP V0调度器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
============ Serving Benchmark Result ============    
Backend: vllm
Traffic request rate: inf
Max reqeuest concurrency: not set
Successful requests: 100
Benchmark duration (s): 27.86
Total input tokens: 29552
Total generated tokens: 24879
Total generated tokens (retokenized): 24875
Request throughput (req/s): 3.59
Input token throughput (tok/s): 1060.84
Output token throughput (tok/s): 893.10
Total token throughput (tok/s): 1953.94
Concurrency: 54.24
----------------End-to-End Latency----------------
Mean E2E Latency (ms): 15109.72
Median E2E Latency (ms): 15397.55
---------------Time to First Token----------------
Mean TTFT (ms): 5398.04
Median TTFT (ms): 6015.93
P99 TTFT (ms): 7251.63
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 71.76
Median TPOT (ms): 42.68
P99 TPOT (ms): 307.25
---------------Inter-token Latency----------------
Mean ITL (ms): 39.23
Median ITL (ms): 32.54
P99 ITL (ms): 43.65
==================================================

8TP V1 调度器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
============ Serving Benchmark Result ============    
Backend: vllm
Traffic request rate: inf
Max reqeuest concurrency: not set
Successful requests: 100
Benchmark duration (s): 24.07
Total input tokens: 21404
Total generated tokens: 20379
Total generated tokens (retokenized): 20377
Request throughput (req/s): 4.15
Input token throughput (tok/s): 889.06
Output token throughput (tok/s): 846.49
Total token throughput (tok/s): 1735.55
Concurrency: 41.30
----------------End-to-End Latency----------------
Mean E2E Latency (ms): 9943.29
Median E2E Latency (ms): 9369.62
---------------Time to First Token----------------
Mean TTFT (ms): 2798.48
Median TTFT (ms): 2731.82
P99 TTFT (ms): 4323.56
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 55.11
Median TPOT (ms): 39.79
P99 TPOT (ms): 307.50
---------------Inter-token Latency----------------
Mean ITL (ms): 36.45
Median ITL (ms): 31.66
P99 ITL (ms): 341.00
==================================================

多机器配置

DeepSeek R1 8xH20 x2 台机器,每台机器RDMA配置16个 MT2910 Family [ConnectX-7] 做8个bond。

8TP x 2PP 的部署方案,如果后面EP支持的话可能会有更好的效果。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
============ Serving Benchmark Result ============
Backend: vllm
Traffic request rate: inf
Max reqeuest concurrency: not set
Successful requests: 1000
Benchmark duration (s): 234.47
Total input tokens: 303481
Total generated tokens: 187870
Total generated tokens (retokenized): 186116
Request throughput (req/s): 4.26
Input token throughput (tok/s): 1294.33
Output token throughput (tok/s): 801.26
Total token throughput (tok/s): 2095.59
Concurrency: 363.04
----------------End-to-End Latency----------------
Mean E2E Latency (ms): 85122.29
Median E2E Latency (ms): 82826.18
---------------Time to First Token----------------
Mean TTFT (ms): 31789.26
Median TTFT (ms): 17669.77
P99 TTFT (ms): 100110.92
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 770.73
Median TPOT (ms): 341.77
P99 TPOT (ms): 9445.55
---------------Inter-token Latency----------------
Mean ITL (ms): 284.74
Median ITL (ms): 214.68
P99 ITL (ms): 745.14
==================================================

sglang tp 16的配置,sglang不支持pp,sglang明显要快一些,主要原因应该是sglang支持了MTP,vLLM目前还没有。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
============ Serving Benchmark Result ============
Backend: sglang
Traffic request rate: inf
Max reqeuest concurrency: not set
Successful requests: 1000
Benchmark duration (s): 190.92
Total input tokens: 306113
Total generated tokens: 197108
Total generated tokens (retokenized): 195033
Request throughput (req/s): 5.24
Input token throughput (tok/s): 1603.38
Output token throughput (tok/s): 1032.43
Total token throughput (tok/s): 2635.81
Concurrency: 488.50
----------------End-to-End Latency----------------
Mean E2E Latency (ms): 93263.23
Median E2E Latency (ms): 86230.17
---------------Time to First Token----------------
Mean TTFT (ms): 39722.57
Median TTFT (ms): 43590.80
P99 TTFT (ms): 60010.86
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 1529.43
Median TPOT (ms): 270.69
P99 TPOT (ms): 37619.47
---------------Inter-token Latency----------------
Mean ITL (ms): 276.88
Median ITL (ms): 158.45
P99 ITL (ms): 945.60
==================================================

dynamo对于vLLM的给懂是一个patch文件,vLLM给了一个PR方便在线对比。

PD分离的背景介绍

在大模型推理中,Prefill 和 Decode 的计算特性存在显著差异:

  • Prefill:计算密集型,计算比例更高。
  • Decode:访存密集型,访存比例更高。

Decode 的计算依赖于 Prefill 生成的 KV Cache。在没有 PD 分离的情况下,较长的 Prefill 通常会优先占用计算资源,导致长 Prompt 的 Prefill 时间过长,从而增加 Decode 的延迟。

Chunked Prefill 的局限性

为了解决上述问题,Chunked Prefill 是一种常见的优化方案,但它也存在以下挑战:

  1. 适合超长 Context Length:Chunked Prefill 能有效降低中间显存的占用,但仅适用于超长上下文场景。
  2. 大 Chunked Prefill 的影响:当大 Chunk Prefill 和 Decode 同时出现在一个 Batch 中时,会显著拖慢 Decode 的速度。
  3. 小 Chunked Prefill 的退化:小 Chunk Prefill 容易退化为 Memory Bound,导致计算单元的利用率(MFU)下降。

因此,PD 分离的优势在于可以分别优化 Prefill 和 Decode 的性能。

Dynamo 的 PD 分离策略

Dynamo 使用了一种条件 PD 分离策略:

  • 仅在满足特定条件时,Prefill 才会远程计算;否则,仍然在本地进行 PD 混合计算。

PD 分离的实现主要包括两部分:

  1. 模型的切分与传输:通过合理的切分策略和高效的传输机制实现计算资源的分离。
  2. 高效的异步传输或存储引擎:这是性能优化的关键,尤其是 KV Cache 的传输或存储。

KV Cache 传输的挑战

以 A100 为例,在 8B LLAMA 模型下,Prefill 的计算速度可达 1 万 tokens/s,这会产生约 3GB 的 KV Cache 数据,给传输带宽带来极大压力。
对于计算速度更快的 H100,这种传输需求会进一步增加,对带宽提出了更高的要求。

xPyD 的主要设计概要

以下是 xPyD 的主要设计概要,不同框架可能在 PD 负载均衡方式、KV 传输和存储上有所不同。

KV Cache 的切分

TP 条件下的切分

在 Tensor Parallel (TP) 条件下,KV Cache 按照 head 进行切分。例如,对于 Qwen 小模型,其 KV head 数为 8,Q head 数为 40,hidden size 为 5120。
一个 token 的大小计算如下:
8 * 5120 / 40 = 1024

如果 P/D TP比例为 2,则 P 会沿着 head 切分为两部分。传输引擎会将切分后的 tensor 发送给每个 D。
因此,一个 token 在 D 上的大小为:
8 / 2 * 5120 / 40 = 512

需要注意,上述计算未包含数据宽度。如果数据类型为 FP8,则宽度为 1;如果为 BF16,则宽度为 2。

PP 条件下的切分

在 Pipeline Parallel (PP) 条件下,切分相对简单,直接沿着层进行切分即可。例如,如果 P/D PP比例为 2,且模型有 64 层,则:

  • 层 0-31 分配给 D0
  • 层 32-63 分配给 D1

这种切分方式无需对 tensor 进行额外处理。

DP 和 EP 条件下的切分

  • EP:EP 主要用于 FFN,与注意力机制无关,因此无需考虑 EP 条件下的切分。
  • DP:DP 条件下无需切分。如果 P/D 比例为 2,直接将 P 的副本同时发送给两个 D 即可。

传输方式

P2P 传输

P2P 传输采用点对点方式:

  • P 向 D 建立 RDMA 连接,并申请 RDMA 的 VRAM。
  • P 直接将数据发送给 D。

KV Cache Store

KV Cache Store 属于 Pooling 模式,支持以下功能:

  • 中间存储:P 将数据存储到 KV Store,D 从 KV Store 中领取数据。
  • 缓存优化:Store 也可以基于 P2P 实现,支持显存缓存、内存或 SSD 上的 KV Cache Swap。
  • Prefix Cache 共享:通过中间存储,可以实现跨请求的 Prefix Cache 共享。

PD 顺序

  • Dynamo 的策略:Dynamo 采用先 Decode 后 Prefill 的条件PD策略。请求首先发送到 Decode 实例。如果是短的 Prefill,Decode 实例会直接计算,无需触发远端 Prefill;如果需要远端 Prefill,则会触发远端计算。
  • 其他框架的策略:大多数框架采用先 Prefill 后 Decode 的策略。请求先由 Prefill 计算出第一个bonus token,然后转交给 Decode 继续计算。

Dynamo 实现分析

本文仅分析 Dynamo 对 vLLM 本身的一些改动,不涉及 Dynamo 在上层的工作,例如全局基于消息队列的 PrefillQueue 是在上层实现的。
vLLM 在被请求时会告知自己是否是 Prefill 请求,或者是否需要远程 Prefill 的 Decode 请求,这一层逻辑由 Dynamo 在上层完成。

Dynamo 基于 vLLM V0 的调度器实现。V0 的调度器主要将 Prefill 和 Decode 明确分开,而 V1 的调度器则考虑了 Chunked Prefill,不再在调度器内部区分 Prefill 和 Decode 两种 sequence。

vLLM V0 调度器回顾

在 vLLM V0 的实现中,enginestep 方法会调用调度器,调度器负责给出需要执行的 sequence group request,然后调用模型执行器(model executor)执行这些请求。执行完成后,调度器会处理结果并更新被调度请求的状态。

LLMEngine

LLMEngine 是 vLLM 的核心组件之一,负责协调调度器和模型执行器的工作。Dynamo 在此基础上进行了扩展,以支持 PD 分离的功能。

增加_finished_transfers_finished_prefills用于保存prefill的传输结束的request和decode接收传输的request,这两个变量会传给调度器。

remote_prefill_requests保存在remote prefill中的requests。

调度结束以后会拆分出running request和remote prefill requests。

对于running request逻辑不变,但是remote prefill requests会在model execution执行前先发出去。方法是给seq的remote_prefill_request_callback添加remote_prefill_request
这个callback是dynamo上层的worker设置的,他对应的是向全局的消息队列PrefillQueue发送Prefill消息。

通过比较computed_block_nums是否等于block_table标记完成并且放入到本地的_finished_prefills当中。
这是decode视角:把调度器调度的prefill的request发送出去不在本地计算。

到prefill视角来看,会在memory_transfer_reqs中加入已经计算好的computed block和requestid等信息构造的MemoryTransferRequest。

excute_model_req会增加一个需要传输的requests。

execute_model_req.memory_transfer_requests = memory_transfer_reqs

然后开始执行model_excutorexcute_model

根据执行结果返回的request_notif_counterrequest_done_counter更新对应的_finished_prefills_finished_transfers

上层需要初始化LLMEngine的NIXL Agent。

Schedule

调度器的改动其实比较简单,对于prefill角度,prefill结束的释放掉request,decode角度把remote prefill结束的变成running走原本的decode流程。

除了running之外额外增加了一个remote_prefilling的queue用于管理在远端prefill的请求,他和running queue的关系是有相似性的,
比如判断有无未完成的请求时会同时看running和remote_prefilling,但他们的区别在于是不是在本地running。
也增加了prefill_sending用户标记正在传输的prefill。

调度主体会接收finished_prefillsfinished_transfers用于D标记的远端prefill结束(已经传到了本地)和P标记已经完成传输的requests。

remote_prefilling的中的request在finished_prefills中时代表prefill结束,会把状态设置为running并且开始decode调度。
prefill_sending的中的request在finished_transfers中时代表prefill传输结束,会free掉这个request。

prefill_sendingfinished_transfers是一对,是对于prefill instance来说的。
remote_prefillingfinished_prefills是一对,是对于decode instance来说的。

seq_group中会添加is_remote_decode这个用于标记这个请求只在自己这里prefill,decode要在decode instance上做,这个标记是上层的worker设置的,不在vLLM层。

每个sequence group会添加一个标记,seq_group.remote_prefill_params.is_remote_prefill,标记了就加入到remote_prefilling队列中
不然就走老的流程。这也是上层决定。

EventManager

只和 router 负载均衡有关。

worker 上有 KVPublisher 负责发送 kvcache 的创建和删除事件,同时 KvMetricsPublisher 用于发送监控指标(如队列排队长度等)。
router 上则包含 KVIndexer,用于收集 kvcache 的事件并建立前缀树,同时 KvMetricsAggregator 用于收集监控指标。

路由策略基于 KV match rate - Load 的最大值,旨在平衡负载与 kvcache 的匹配度。

PrefixCachingBlockAllocator加入了event_managerKVCacheEventManager实现就不细说了,就是一个事件收发器。

在evict block的时候,发送删除事件event_manager.enqueue_removed_event(content_hash_to_evict)
当block被填满变成 immutable block 的时候,发送分配事件event_manager.enqueue_stored_event(block.prev_block, block)

NIXL Transfer

数据传输的切分

假设:

  • PS(prefill parallel size)=1
  • DS(decode parallel size)=1
  • PTP(prefill tensor parallel size)=2
  • DTP(decode tensor parallel size)=4

存在两个 kvgroupP0,D0,D1P1,D2,D3

计算可得:
TPM(tensor parallel multiplier) = DTP/PTP = 2

由此可以推导出以下关系:

  • kv_rank:表示每个 P 或者 D 实例的序号,在本示例中分别为 01
  • p_kv_group_rank = kv_rank
  • d_kv_group_rank = PS + (kv_rank - PS) * TPM + rank % TPM
  • kv_world_size = PS + DS * TPM
  • p_kv_global_rank = kv_rank * PTP + rank
  • d_kv_global_rank = PS * PTP + (kv_rank - PS) * DTP + rank

关于集合通信的端口号的 base

  • p_port_offset_base = 2 * rank + 1
  • d_port_offset_base = 2 * rank//TPM + 1

以下是具体的参数值表格:

role rank kv_rank kv_group_rank kv_world_size kv_global_rank port_offset_base
P 0 0 0 3 0 1
P 1 0 0 3 1 3
D 0 1 1 3 2 1
D 1 1 2 3 3 1
D 2 1 1 3 4 3
D 3 1 2 3 5 3

初始化 rank 0 收集all_gather所有的parallel_config

每有一个kv_role = kv_producerkv_producers_parallel_size就+1。

第一步当然要支持xPyD的配置,比如下面的配置。

P的并行规模

1
kv_producers_parallel_size: Optional[int] = None

P的TP

1
kv_producers_tensor_parallel_size: Optional[int] = None

P的PP

1
kv_producers_pipeline_parallel_size: Optional[int] = None

D的TP

1
kv_consumers_tensor_parallel_size: Optional[int] = None

D的PP

1
kv_consumers_pipeline_parallel_size: Optional[int] = None

属性函数
其中tensor_parallel_multiplierp的tp // d的tp

然后总结一下各种pp,tp,rank的关系。

  1. TP的倍率由D的TP地板除P的TP:tensor_parallel_multiplier = self.kv_consumers_tensor_parallel_size // self.kv_producers_tensor_parallel_size
  2. D的并行规模是整体并行规模减去P的并行规模: kv_consumers_parallel_size = self.kv_parallel_size - self.kv_producers_parallel_size
  3. kv_world_size = self.kv_producers_parallel_size + self.kv_consumers_parallel_size * self.tensor_parallel_multiplier

发送的入口是send_kv_caches_and_hidden_states

这个函数的主要工作是根据TP的大小,从Prefill worker上切出Decode worker上需要的tensor
并且发送给Decode worker。

在这个函数中根据自己的rank计算对应的D的rank。
PP的话比较容易,直接按层的range就行了,TP的话需要在给定层的range下做TP的切分。从代码来看dynamo只支持PP=1。

笔者写了一个简易版的示例带代码方便查看对应的shape

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
# kv_cache 是一个包含张量的列表
# head 数量是 8, hidden_state per head 也是8
# key_cache 是一个连续的空间,有slot_mapping做索引
key_cache = torch.randn(10, 8, 8)
print("原始 key_cache shape")
print(key_cache.shape)
# output: torch.Size([10, 8, 8])
# 这次request使用的是第0和5条的key cache
current_slot_mapping = [0,5]
# prefill tp=2
p_tp = 2
# decode tp=4
d_tp = 4
tp_multiplier = d_tp // p_tp
# 考虑decode worker 的 rank 0 的情况
target_rank = 0
# num_heads_per_rank = 1 也就是每个decode rank分到一个head
num_heads_per_rank = 8 // p_tp // d_tp
# 计算head的range
head_start = target_rank * num_heads_per_rank
head_end = head_start + num_heads_per_rank
# 按 p_tp reshape
key_cache = key_cache.reshape(-1,8//p_tp,8)
print("按 prefill tp 切分")
print(key_cache.shape)
# output: torch.Size([20, 4, 8])
print("decode 选择器(在第0维度)", current_slot_mapping, ",head range(在第1维度)", str(head_start)+":"+str(head_end))
# output: decode 选择器(在第0维度) [0, 5] ,head range(在第1维度) 0:1
d_key_cache = key_cache[current_slot_mapping, head_start:head_end]
print("获取 d key cache")
print(d_key_cache.shape)
# output: torch.Size([2, 1, 8])

相对应的接收函数的入口recv_kv_caches_and_hidden_states 没啥特殊处理,直接已经切好,收到以后直接cache住就行。