ggaaooppeenngg

为什么计算机科学是无限的但生命是有限的

Dynamo 发布以后,我大概速览了一些设计文档,并且提取了一些关键点,并对比一些其他方案的异同点。

Smart Router

worker 上有 KVPublisher 负责发送 kvcache 的创建和删除事件,同时 KvMetricsPublisher 用于发送监控指标(如队列排队长度等)。
router 上则包含 KVIndexer,用于收集 kvcache 的事件并建立前缀树,同时 KvMetricsAggregator 用于收集监控指标。

路由策略基于 KV match rate - Load 的最大值,旨在平衡负载与 kvcache 的匹配度。

KVPublisher 应该是侵入式实现,需要给vLLM打这个patch才能实现,需要修改代码才能捕获这些事件。所以光从他的依赖来看,应该是只支持了vLLM,其他的支持估计还没开源出来。

sgl-router 完全不依赖 worker 的信息,仅通过路由自身的请求实现可过期的前缀匹配。虽然这种方式的匹配精度不如直接获取信息,但实现上更为解耦。

vllm-router 则基于 vLLM 的 Prometheus 接口,通过 /metrics 获取监控指标,其前缀匹配是通过 block hash 的近似度实现的。

llumnix 支持请求的重调度功能,可以将排队中的请求重新分配。

aibrix gateway 同时支持基于树和哈希的匹配方式,并且支持用tokenizer使用 token 进行前缀匹配,而不像 sgl-router 基于字符的匹配。

从 Dynamo 的 Indexer 实现来看,其基于 block 级别的 radix tree,事件通过 Component 的 publish 机制进行分发然后触发radix tree的更新。

条件 PD 分离

并非所有请求的 prefill 阶段都需要在 prefill instance 中计算。如果 prefill 很短,或者 decode instance 的 KV 缓存命中率较高,通常在 decode instance 中直接完成 prefill 更为高效。Dynamo 的分解设计充分考虑了这些场景,并提供了一个灵活的框架,能够在多种条件下实现卓越性能。

在 Decode Instance(在 Dynamo 中称为普通的 worker)上,需要决定是否执行分离操作。如果需要PD分离,则将 prefill 请求交给 prefill worker,通过 prefill queue 进行处理。当 prefill queue 完成后,再通过 prefill queue 将结果传回 worker,开始 decode 阶段。

具体而言,只有在满足以下两个条件时,才会向远程 prefill instance 发送请求:

  1. 没有前缀缓存命中的 prompt 长度超过设定阈值。
  2. Prefill queue 的长度小于设定阈值。

这种条件化的 PD 分离设计,使得 Dynamo 能够在动态工作负载下实现高性能。

Prefill Queue

Prefill Queue 是一个基于 NATS Stream 的全局消息队列。

在这一部分中,最具挑战性的是 KV Cache 的传输。Mooncake 开源了其 TransferEngine,而 vLLM 提供了一些 KV Connector 和 KVStore 的抽象。可以推测 Dynamo 也在 vLLM 的基础上实现了相关功能,可以看到在这个patch中,给vLLM的kv connector实现了一个DynamoNixlConnector。

The key to high-performance disaggregation is efficient KV transfer. Dynamo leverages NIXL to transfer KV cache directly from the VRAM of the prefill engine to the VRAM of the decode engine.

Dynamo 的 KV Cache 传输是通过直接 RDMA(远程直接内存访问)实现的。

为了减少 Memory Descriptors(RDMA 的描述对象)的大小,Dynamo 采用了以下两种优化:

  1. Memory Descriptors 缓存
    每个 Worker(对应传统的 Decode Instance,但在 Prefill 较短时也会执行 Prefill)在初始化并分配所有 KV 缓存池后,会将所有块的 Memory Descriptors(也称为 NIXL 元数据)存储在分布式键值存储 ETCD 中。当 Prefill Worker 第一次服务来自 Worker 的远程预填充请求时,会从 ETCD 加载这些 Memory Descriptors 并缓存到该 Worker 中。因此,在发出 Prefill 请求时,只需要传递 KV 块 ID,而无需传递完整的 Memory Descriptors。这一优化的具体作用可能需要进一步分析 NIXL 的传输过程才能完全理解。

  2. 显存分配优化
    Dynamo 在 Prefill 过程中提升了显存分配能力,通过分配连续的内存块并将其合并为更大的块,从而减少 KV 块的总数。这种合并的具体效果需要结合实现NIXL细节进一步评估。

此外,对于不同 KV 布局(例如由于不同的 TP 导致的 Decode 和 Prefill 布局差异),Dynamo 使用了一个高性能内核。在 NIXL 读取之后和写入之前,该内核会将 KV 块转置为 KV Receiver 中的匹配布局。这可能是为了将 KV Cache 分块传输到不同的 TP 上。

由于引入了 ETCD,Dynamo 支持动态调整 Worker 和 Prefill Worker 的数量。

和其他方案对比

Mooncake 的设计在架构上更加分离,主要通过一个调度器(scheduler)来负责 kvcache 的传输调度,并直接决定 P 和 D 之间的 P2P 传输,基于其 TransferEngine 实现了以下功能:

  1. 基于 kvcache 的前缀匹配分配 prefill 请求
    如果 prefill 节点上缓存了足够的前缀(由 kvcache_balancing_threshold 控制),则选择预估 TTFT(Time to First Token)最小的实例:
    TTFT = min(T_queue + T_prefill)
    如果 prefill 节点上缓存不足,则选择:
    TTFT = min(T_queue + T_prefill + T_transfer)
    其中 T_transfer 指的是将最长匹配的 KVCache 从其他实例拷贝到当前实例的预估时间。

  2. 高频使用的 kvcache P2P 传输
    Scheduler 负责 kvcache 的传输调度,例如从一个 prefill 节点传输到另一个 prefill 节点(适合Prefix Cache),或者从 prefill 节点传输到 decode 节点,decode到其他decode节点(适合多轮对话)。

  3. 基于负载均衡的 decode 请求分配
    通过负载均衡的方式预估 TBT(Time to Best Throughput),从而优化 decode instance 的请求分配。

Mooncake 的设计在模块划分上更加清晰,调度器(scheduler)与各个组件的职责分离明确。

相比之下,Dynamo 的入口在 worker(相当于 Mooncake 中的 decode instance),由 worker 决定是否将 prefill 请求交给 prefill instance。Dynamo 的特点包括:

  • Worker 也可以执行 prefill 操作(即 decode instance 有时也会承担 prefill 的职责)。
  • 引入了全局队列(queue)来处理 kvcache 的计算和计算就绪信息。
  • 提供了 NIXL 传输引擎,但仅支持 P 到 D 的 kvcache 传输,相对实现更为直白。

AIBrix 的现状

AIBrix 目前尚未实现 PD 分离功能,相关文档和白皮书中未提及此功能。

依赖与工程复杂度

从 Dynamo 的依赖项来看,其使用了 ai-dynamo-vllm v0.7.2,这是对 vLLM v0.7.2 的定制化补丁版本,需修改 vLLM 以支持 Publisher 功能。

Dynamo 的工程栈相对复杂,依赖消息队列和 ETCD,但其 PD 分离设计较为直白,例如仅支持 P 到 D 的传输。相比之下,Mooncake 的设计更注重架构分离,尽管目前未实现 offload 功能,但其 P2P kvcache pool 的设计为未来扩展提供了可能性。

关键问题

俗话说得好,关键问题是问题的关键。无论是 Mooncake 还是 Dynamo,其核心目标都是提高传输效率和 kvcache 的利用率。Dynamo 的实现更简化,而 Mooncake 则在架构设计上更具层次感。

KVCache 管理

KVCache Offload

当显存不足时,可以将 KVCache 卸载到更低级别的存储中,例如内存、磁盘,甚至对象存储。
管理器的核心在于结合驱逐策略,在以下两种情况之间取得平衡:

  • 过度缓存:可能引入查找延迟。
  • 缓存不足:导致查找失败和 KV 缓存的重新计算。

V1 单机版本

V1 版本支持将 KVCache 卸载到磁盘,同时使用 CPU 的内存作为缓存。在需要加载时,从磁盘读取数据回显存。

V2 分布式版本

V2 版本将扩展为分布式架构,形成一个全局的 KVCache 池。

Mooncake 的实现

Mooncake 的 KVCache Pool 完全基于显存的 P2P 传输,不涉及 offload 操作。它通过开源的 TransferEngine,将缓存节点上的 KVCache 调度到需要缓存的节点上。

AIBrix 的实现

AIBrix 提供了一个分布式 KVCache Pool,基于 Vineyard 的分布式内存存储。通过 Vineyard 实现 KVCache 的共享,但与专门的传输引擎相比,其传输效率可能稍逊一筹。

NIXL

NIXL 通过简化的同步和批处理以及简化的源和目标抽象简化了数据搬迁。
NIXL 能够在不同类型的内存和快速存储中抽象数据搬迁,而其他数据搬迁库通常只支持一层内存。
这些增强带来了显着的性能提升,加速了第一个词元的时间(TTFT)和整体吞吐量。

NIXL的地位应该是和Mooncake的TransferEngine相当的,至于两者谁的效果更好可能要具体看一下。

总结

看设计的话,感觉还是Mooncake更漂亮一点,层次分得较清楚,不额外依赖什么中间件,kvcache pool的这个设计虽然是纯P2P的,应该后面也可以去做offload之类的。
dynamo就显得更具有工程具体性,并且实现相对来说是要更简单一些,毕竟依赖了message queue又依赖了etcd,把一些复杂度转移给了中间件,入口从worker(or decode instance)可以自己直接prefill短prompt肯定也是做了很多tradeoff才给出了一个不完全分离的条件PD分离的实现。

参考 OpenAI Triton 主页

参考 Triton论文

参考 GPU MODE Lecture 14: Practitioners Guide to Triton

从Trinton主页引用的话

现代 GPU 的架构大致可以分为三个主要组件 ——DRAM、SRAM 和 ALU—— 在优化 CUDA 代码时必须考虑每个组件:

  1. 来自 DRAM 的内存传输必须合并为更大的事务,以利用现代内存接口的大型总线宽度。
  2. 数据在被再次使用之前,必须手动存储到SRAM中,并且要对数据进行管理,以便在检索数据时尽量减少共享内存存储体冲突的情况。
  3. 计算必须在流式多处理器(SM)之间和内部仔细分区和调度,以促进指令 / 线程级并行性并利用专用 ALU(例如Tensor Core)。

这几句话可能比较抽象,下面给一下这几个组件的指标可能感受更直观,参考Which GPU(s) to Get for Deep Learning: My Experience and Advice for Using GPUs in Deep Learning
其中指出:

  • 全局内存访问(高达 80GB):~380 个周期
  • L2 缓存:~200 个周期
  • L1 缓存或共享内存访问(每个流式多处理器高达 128 kb):~34 个周期
  • 融合乘法和加法,ab+c(FFMA): 4 个周期
  • Tensor Core 矩阵乘法:1 个周期

每个操作总是由 32 个线程组成的Warp执行,Warp中的线程必须相互等待。GPU 上的所有内存操作都针对warp进行了优化。

根据Simplifying CUDA kernels with Triton: A Pythonic Approach to GPU Programming的说法,GPU中的HBM(High Bandwidth Memory)等价于我们讲的Global Memory,SRAM对应的是L1和L2 Cache对应的是Shared Memory,这几个词在一些文档中可能会有不同的叫法,但是意思是一样的。

A100中的内存带宽约为 2TB/s,L1 缓存带宽:~100-200 TB/s 理论带宽,L2 缓存带宽:~4-7 TB/s 理论带宽。

再看OpenAI的三条说明的意思就是:

  1. 因为DRAM很大,比较容易占满总线带宽,所以尽量合并传输的事务可以减少传输的时间,让高速公路跑满。
  2. 如果数据要重复利用,反复参与计算,尽量让他们在SRAM当中能够缓存住,比如L1的读取只要34个cyle,能比从L2中快6到7倍。
  3. 尽量跑满并行度,并且利用更高效的计算单元,比如Tensor Core。

这个是OpenAI给出的GPU架构的简图,我们需要明确不同内存,缓存,和执行单元的周期之间的关系就比较好理解GPU计算当中的性能瓶颈。

Triton的目标其实就是优化 HBM -> SRAM -> 寄存器 的带宽,这在Torch里面直接实现不了,通过一些融合算子是可以减少写回到HBM的。

Triton的文档给出的很多实现的代码,可能都不太奏效了,笔者自己测试下来并没有超过torch本身的实现,
可能torch本身也再不断改进吧,这些差别很快就超越了,但是在一些写自定义融合算子方面应该还是比较有优势的。

Triton的使用

和CUDA对应的关系:

  • 程序(Program):处理数据块的kernel实例。
  • PID(程序 ID):等同于 CUDA 中的块 ID。
  • 向量化操作:在多个数据点上同时操作(triton不需要用户关心向量操作的并行化)。

先给出变量和修饰器的解释,大部分文档都混在注释里面不是很好阅读,我觉得先介绍一些简单概念再看代码会比较好一点。

@triton.jit 装饰器表示这个函数会被编译。
tl.constexpr 代表常量表达式,可以让编译器在编译期间直接求值,可以当作常量使用了。
BLOCK_SIZE对于GPU来说是比较固定的,因为一个block是有threads数上限的。
通过执行cuda-samples中的deviceQuery
可以发现L40的显卡BLOCK_SIZE最大是1024,大部分显卡应该都是这个固定大小。

1
2
3
4
5
6
7
8
9
10
Total number of registers available per block: 65536  
Warp size: 32
L2 Cache Size: 100663296 bytes(96MB)
Maximum number of threads per multiprocessor: 1536
Maximum number of threads per block: 1024
Max dimension size of a thread block (x,y,z): (1024, 1024, 64)
Max dimension size of a grid size (x,y,z): (2147483647, 65535, 65535)
Total amount of shared memory per block: 49152 bytes(48KB)
Total shared memory per multiprocessor: 102400 bytes (100KB)
Total number of registers available per block: 65536

pid = tl.program_id(axis=0) 应该是对应的CUDA中的threadIdx.x的作用,对应block的一维下标,
pid = tl.program_id(axis=1) 应该是对应的CUDA中的threadIdx.y的作用,对应block的二维下标。

autotune是一个黑盒优化,通过内部的小benchmark的方式去基于key的变量,优化configs里面的参数。
下面是一个只有两个配置的搜索空间,当n_elements的值发生变化的时候,会自动选择最优的配置。

1
2
3
4
5
@triton.autotune(configs = [
triton.Config({'BLOCK_SIZE': 128}, num_warps = 4, pre_hook = clone_inplace_updated_params),
triton.Config({'BLOCK_SIZE': 1024}, num_warps = 8, pre_hook = clone_inplace_updated_params),
], key = ['n_elements'])
@triton.jit

BLOCK_SIZE表示的是一个program负责的BLOCK大小,放在triton的语境下更像是L2 Cache的大小
但是cuda当中的block是包含n个thread的,表示的是并行线程的大小
笔者的这个说法可能不太精准,但是这两种风格导致Cuda写一些element wise的操作比较合适.
每个element wise的操作都是一个thread,这样可以充分利用GPU的并行性。
而triton比较适合一些Reduce操作,例如对数据(也就是矩阵)切BLOCK,然后每个kernel去负责一个block,
他的好处就是比如softmax这样的在行上做reduce操作会比较直观,而矩阵乘法也可以沿着MxK,KxN的维度,沿着不同的维度切块。
Triton能够帮你把矩阵乘法优化得很不错,虽然可能还比不上精准手写的Cuda算子。

Triton的范式和CUDA的Single Instruction, Multiple Thread (SIMT)不一样,官网给出了一个简化的例子。

这是CUDA like的写法,每个threadId.x代表的线程只算一个element

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
BLOCK = 512

# This is a GPU kernel in Numba.
# Different instances of this
# function may run in parallel.
@jit
def add(X, Y, Z, N):
# In Numba/CUDA, each kernel
# instance itself uses an SIMT execution
# model, where instructions are executed in
# parallel for different values of threadIdx
tid = threadIdx.x
bid = blockIdx.x
# scalar index
idx = bid * BLOCK + tid
if id < N:
# There is no pointer in Numba.
# Z,X,Y are dense tensors
Z[idx] = X[idx] + Y[idx]


...
grid = (ceil_div(N, BLOCK),)
block = (BLOCK,)
add[grid, block](x, y, z, x.shape[0])

Triton文档中的Matrix乘法简化来说就是并行计算M*N个block(沿着K所代表的维度)。
这是Triton的写法,每个Program负责了一个block,他就少了一个block的切分维度:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
BLOCK = 512

# This is a GPU kernel in Triton.
# Different instances of this
# function may run in parallel.
@jit
def add(X, Y, Z, N):
# In Triton, each kernel instance
# executes block operations on a
# single thread: there is no construct
# analogous to threadIdx
pid = program_id(0)
# block of indices
idx = pid * BLOCK + arange(BLOCK)
mask = idx < N
# Triton uses pointer arithmetics
# rather than indexing operators
x = load(X + idx, mask=mask)
y = load(Y + idx, mask=mask)
store(Z + idx, x + y, mask=mask)


...
grid = (ceil_div(N, BLOCK),)
# no thread-block
add[grid](x, y, z, x.shape[0])

笔者比较疑惑,单就这两个代码他们的并行度貌似是不一样的,难道是把block那一层隐式的放在了loadstore当中,他的loadstore其实是隐含了并行能力的。
援引知乎的文章Triton中一直是以Block为中心来计算,直到Lowering到LLVM和PTX才会转为Thread为中心的计算,而这些对于使用Block抽象进行编程的用户来说都是无感的。是符合笔者预期的,triton简化了CUDA的写法,block具体的线程数的,每个线程处理多少元素,triton自己会去帮你处理。

当使用 triton 的时候,x = tl.load(x_ptr + offsets, mask=mask)时,我们正在加载到 L2 缓存 或者叫 SRAM 中。

根据Torch的blog,以及参考 OpenAI/Triton MLIR 迁移工作简介,Triton编译的过程是@triton.jit装饰器通过遍历提供的 Python 函数的抽象语法树(AST)来工作,以便使用通用的 SSA 构造算法即时生成 Triton-IR。
然后,生成的 IR 代码被我们的编译器后端简化、优化和自动并行化,最后被转换成高质量的 LLVM IR,最终是 PTX(Nvidia GPU的汇编),可以在最近的Nvidia GPU上执行。

矩阵乘法

这段代码是基于K切BLOCK,比上面的代码要好理解一点。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
@triton.jit
def simple_mm(a, b, o, k, n,
K_BLOCK_SIZE: tl.constexpr = 64,
) -> None:
# a -> Matrix of size M x K and b -> Matrix of size K x N
# K is the common inner dimension
num_blocks = k//K_BLOCK_SIZE + 1
row_id = tl.program_id(0)
col_id = tl.program_id(1)

# Lets pick one column and one row and do a dot product
# Like the 1-D example we dont want to look at the entire row/column
# We are making use of the fact that each row/column will be of the size
# 'k' which is the inner common dimension of these matrices
# But this will only be a part of the dot product so we have to keep track of many to cover the entire column or row.

# What we are going to do is to access block size elements from the column
# and the row and compute the dot product and keep adding to a value till
# we run out of numbers
value = 0.
for k_id in range(num_blocks):
row_start = row_id * k + k_id * K_BLOCK_SIZE
row_offsets = tl.arange(0, K_BLOCK_SIZE) + row_start
# The masks are a little more trickier as we cant just see if its
# less than 'k'. We need to account for the row we are in
row_masks = row_offsets < (row_id + 1) * k
row = tl.load(a + row_offsets, mask=row_masks) # Load this into the GPU SRAM

col_start = (K_BLOCK_SIZE * k_id)
col_offsets = n * (tl.arange(0, K_BLOCK_SIZE) + col_start) + col_id # 0, n, 2n || 3n, 4n, 5n for a block size of 3 for eg
col_masks = col_offsets/n < k
col = tl.load(b + col_offsets, mask=col_masks)
value += tl.sum(row * col)

output_offset = row_id * n + col_id
tl.store(o + output_offset, value)

BLOCK_SIZE 和 GROUP_SIZE 的优化。

一次计算的时候尽量用满L2 Cache,所以可以把多个BLOCK放到一个GROUP里面,这个GROUP变成了grid的切分,但在GROUP里面我们
再去做BLOCK级别的计算,要计算好对应的线性空间中的stride。

Softmax

下面的代码只能在n_cols小于BLOCK_SIZE的数据上运行。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import triton
import triton.language as tl

@triton.jit
def softmax(Y, stride_ym, stride_yn, X, stride_xm, stride_xn, M, N):
# row index
m = tl.program_id(0)
# col indices
# this specific kernel only works for matrices that
# have less than BLOCK_SIZE columns
BLOCK_SIZE = 1024
n = tl.arange(0, BLOCK_SIZE)
# the memory address of all the elements
# that we want to load can be computed as follows
X = X + m * stride_xm + n * stride_xn
# load input data; pad out-of-bounds elements with 0
x = tl.load(X, mask=n < N, other=-float('inf'))
# compute numerically-stable softmax
z = x - tl.max(x, axis=0)
num = tl.exp(z)
denom = tl.sum(num, axis=0)
y = num / denom
# write back to Y
Y = Y + m * stride_ym + n * stride_yn
tl.store(Y, y, mask=n < N)

import torch
# Allocate input/output tensors
X = torch.normal(0, 1, size=(583, 931), device='cuda')
Y = torch.empty_like(X)
# SPMD launch grid
grid = (X.shape[0], )
# enqueue GPU kernel
softmax[grid](Y, Y.stride(0), Y.stride(1),
X, X.stride(0), X.stride(1),
X.shape[0] , X.shape[1])

从一些大模型的训练技术报告来看有一些比较有代表性的挑战,比如 Meta 的 Research Super Compute (RSC) 和 X 的 Grok Infra。这些技术报告中提到了一些关键的技术挑战和解决方案,包括 GPU 架构与互联、存储系统、训练的稳定性等。

X Grok Infra

Grok-1.5 Infra 的技术报告中可以窥见,Grok-1.5 在基础设施方面具有以下核心优势:

  1. 先进的分布式训练框架:基于 JAX、Rust 和 Kubernetes 的技术栈,不仅确保了高性能,还能快速适配和训练新的模型架构。
  2. 卓越的可靠性和可用性:通过自研的训练协调器,系统能够智能地检测并隔离故障节点,大幅降低训练任务中断的风险。
  3. 高效的存储与数据处理:在检查点存储、数据加载和训练作业重启等环节都进行了深度优化,将训练过程中的停机时间降至最低。

Meta Reasearch Super Compute

另一个典型案例是 Meta 的 Research Super Compute (RSC) 超算集群,在这上面训练了Llama3.2,有一份92页的技术报告,RSC的相关Talk,以及里面用到的MAST论文调度器:

算力规模

已升级至 16,000 张 H100 GPU,算力获得质的飞跃。每个服务器配备了 8 块 GPU 和 2 块 CPU。在服务器内部,八块 GPU 通过 NVLink 连接。

网络互联

采用双网络方案:

  • NVIDIA Quantum InfiniBand,带宽高达 1600 Gb/s,RoCE(RDMA over Converged Ethernet)作为补充互联方案。

网络拓扑

  • 底层网络(第一个层):每个机架(rack)包含 16 块 GPU,分散在两个服务器上,并通过一个 Minipack2 顶层网络(ToR)交换机连接。
  • 中间网络(第二层):192 个这样的机架通过 Cluster Switches 连接,形成一个包含 3,072 块 GPU 的 Pod。这种设计确保了从任何两个 GPU 之间的通信都有满速带宽,没有过度订阅。
  • 顶层网络(第三层):八个这样的 Pod 通过 Aggregation Switches 连接,形成一个包含 24,000 块 GPU 的集群。然而,顶层网络的连接没有保持满速带宽,而是存在过度订阅比例为 1:7。

负载均衡

  • Collective library 将 16 个网络流中的两个 GPU 之间的数据传输从一个流变为 16 个流。
  • Enhanced-ECMP(E-ECMP)协议 通过在 RoCE(Rdma over Converged Ethernet)报头中添加额外的字段,进行 hash 计算,从而有效地在不同网络路径上平衡 16 个流。

拥塞控制

使用深度缓冲区(deep-buffer switches)来解决在 Spine(Gangidi et al., 2024)中由于集体通信模式引起的暂时拥堵和缓冲问题。

存储系统

采用自研的 Tectonic 文件系统,通过 FUSE 提供标准的 Linux 文件系统接口,确保高效的数据访问。

  • 存储容量:240 PB,基于 7,500 台 SSD servers
  • 支持的最大吞吐量:7 TB/s
  • 支持的可持续吞吐量:2 TB/s
  • 检查点写入:非常时断时续,导致存储网络饱和
  • 检查点的目标:因为 checkpoint 非常大,最小化 GPU 停顿时间,加快检查点频率也变得非常重要

总结

从这些实践可以看出,现代 AI 基础设施主要围绕三大核心要素展开:

  • 计算能力(以 GPU 为核心)
  • 网络互联(RoCE 或 InfiniBand)
  • 存储系统

而在上层的编排调度领域,系统的容错能力和可靠性则成为关键考量因素。

GPU 架构与互联

在当前AI训练领域,主流的GPU型号主要是NVIDIA的A100、H100和H200系列,它们按照发布时间依次提供了更强大的算力和更优化的架构设计。关于GPU的详细架构,特别是其拓扑结构,可以参考这篇深度解析文章

GPU互联拓扑

GPU之间的互联拓扑结构主要取决于不同总线间的传输特性,GPU之间可以通过NVIDIA专有的NVLink高速互联技术直接通信。在现代GPU集群中,主要有以下几种互联方式:

  1. NVSwitch架构:通过NVIDIA的交换架构实现所有GPU之间的全互联
  2. 走网卡,如果卡之间没有NVSwitch的话,可以绕过CPU走网卡:
    1
    GPU0 -> PCIe -> IB(InfiniBand) -> PCIe -> GPU1
    这种通信模式由NCCL(NVIDIA Collective Communications Library)负责协调和优化。

GPU分配策略

NVIDIA开源的go-gpuallocator库提供了一系列基于拓扑关系的GPU分配策略。例如,其中的NewStaticDGX1Policy专门针对DGX-1标准配置优化。考虑到单机环境下GPU组合的可能性有限,这种基于静态规则的分配策略已经能够很好地满足需求。

这些分配策略的核心目标是最小化跨总线和跨NUMA节点的通信开销,确保GPU间通信尽可能利用最高带宽的数据通路,从而提供最优的训练性能。

跨节点的通信

在分布式训练场景下,跨节点通信需要经过更长的数据传输路径:

1
GPU -> NIC -> 叶层交换机 -> 核心交换机 -> NIC -> GPU

这种通信模式面临两个主要的优化方向:

  1. 本地化优化:尽可能将相关联的GPU任务分配在物理位置相近的节点上,以减少网络延迟。

  2. 负载均衡:避免将所有任务集中在同一交换机下,防止出现网络拥塞。过度集中可能导致局部带宽饱和,反而降低整体训练效率。

这种权衡本质上是一个网络流优化问题。通过图论中的网络流算法,可以在通信延迟和带宽利用率之间找到最优平衡点,从而实现更高效的跨节点通信。

一个分布式训练的带宽瓶颈来源于带宽最低的那条路径。

利用 Kubernetes Pod 亲和性优化网络拓扑

在 Kubernetes 环境下,我们可以通过 Pod 亲和性(Affinity)和规则来优化 GPU 任务的分配。主要可以从以下几个方面入手:

拓扑感知调度:使用 topologyKey 确保相关联的 Pod 被调度到网络拓扑上接近的节点:
例如同一个分布式训练任务(training-group = group1)尽让分配在一个机架上,同交换机,同核心交换机也是类似的。

1
2
3
4
5
6
7
8
9
10
11
12
affinity:
podAffinity:
preferredDuringSchedulingIgnoredDuringExecution:
- weight: 50
podAffinityTerm:
labelSelector:
matchExpressions:
- key: training-group
operator: In
values:
- group1
topologyKey: topology.kubernetes.io/rack # 同机架优先

这种方案的优势在于:

  • 配置简单,易于理解和维护
  • 充分利用 Kubernetes 原生能力,无需额外组件
  • 可以根据实际需求灵活调整权重和策略

存储系统

AI训练中的存储系统面临着两个主要挑战:

1. 海量小文件问题

AI训练数据集通常包含大量的小文件,这对传统文件系统的性能和管理造成了巨大压力。一些现代分布式文件系统提供了很好的解决方案,例如 Meta 的 Tectonic 和与其架构类似的 JuiceFS,它们采用了以下优化方案:

元数据管理优化

  • 使用元数据库管理文件结构,将 ls 命令转化为简单的字符串前缀匹配操作
  • 避免了传统 Linux 文件系统依赖 inode 管理的方式
  • 解决了 inode 臃肿问题(在传统系统中,一个 inode 的大小可能与文件本身相当)

2. Checkpoint 存储挑战

分布式训练中的 checkpoint 文件体积巨大,这在大语言模型训练中尤为明显:

  • 以 LLaMA-2-70B 为例,单个完整的 checkpoint 就需要 140GB 存储空间(FP16格式)
  • 训练过程中需要定期保存 checkpoint,累积存储需求可能达到 TB 甚至 PB 级别
  • 需要存储系统能够提供高带宽和低延迟的读写性能,同时保证数据的可靠性

这些挑战要求存储系统具备:

  • 强大的扩展性
  • 高效的数据压缩能力
  • 智能的数据分层存储机制
  • 可靠的数据备份和恢复能力

训练的稳定性

在大规模 AI 训练中,硬件故障是一个常见问题。特别是新型号显卡往往会有较高的故障率,再加上传统的硬件错误,这些都可能导致训练中断。因此,快速识别错误并恢复训练成为了一个关键挑战。目前主流的解决方案主要有以下两种:

基于 torchrun 的弹性训练

torchrun 提供了两种容错机制:简单重试和弹性训练。

  1. 简单重试模式
    通过 --max-restarts 参数配置重试次数:

    1
    2
    3
    4
    5
    6
    7
    8
    torchrun \
    --nnodes=$NUM_NODES \
    --nproc-per-node=$NUM_TRAINERS \
    --max-restarts=3 \
    --rdzv-id=$JOB_ID \
    --rdzv-backend=c10d \
    --rdzv-endpoint=$HOST_NODE_ADDR \
    YOUR_TRAINING_SCRIPT.py [script args...]
  2. 弹性训练模式
    通过设置 nnodes 的范围来支持动态节点数:

    1
    2
    3
    4
    5
    6
    7
    8
    torchrun \
    --nnodes=1:4 \ # 支持1-4个节点的动态伸缩
    --nproc-per-node=$NUM_TRAINERS \
    --max-restarts=3 \
    --rdzv-id=$JOB_ID \
    --rdzv-backend=c10d \
    --rdzv-endpoint=$HOST_NODE_ADDR \
    YOUR_TRAINING_SCRIPT.py [script args...]

弹性训练模式需要配置服务发现机制,默认使用 c10d 作为内置的节点发现服务,也支持使用 etcd 等外部服务。

当节点发生变化时,系统会自动处理以下场景:

  • 节点离开(缩容):系统通知 agent,停止现有 workers,重新组建 WorkerGroup,使用新的 RANK 和 WORLD_SIZE 启动所有 workers
  • 节点加入(扩容):接纳新节点,按照相同流程重组 WorkerGroup

基于 DeepSpeed 的弹性训练

DeepSpeed 提供了更细粒度的弹性训练配置:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
{
"elasticity": {
"enabled": true,
"max_train_batch_size": "seqlen",
"micro_batch_sizes": 8,
"min_gpus": 1024,
"max_gpus": "fixed_linear",
"min_time": "seqlen",
"version": 8,
"ignore_non_elastic_batch_info": 1024,
"num_gpus_per_node": "fixed_linear",
"model_parallel_size": MODEL_PARALLEL_SIZE
}
}

DeepSpeed 的特点是:

  • 支持动态调整 batch size
  • 以 GPU 为粒度进行弹性伸缩(而不是节点级别)
  • 提供更丰富的训练参数配置

弹性训练控制器

要实现完整的弹性训练支持,控制器需要:

  1. 依赖服务发现机制进行节点注册和健康检查
  2. 动态调整弹性策略(如 min_nodes、max_nodes 等参数)

对于简单的降级场景,通过静态配置即可实现:

  • 将 max_nodes 设置为总资源规格
  • 将 min_nodes 设置为最小运行要求(如设置为 1:4 表示支持 1-4 张显卡的动态伸缩)

节点的问题发现

在大规模语言模型(LLM)预训练过程中,常见的硬件异常包括:

  1. GPU ECC 错误:当 GPU 发生不可纠正的显存 ECC(Error Correcting Code)错误时,通常需要重置 GPU 或重启节点来清除这个错误。

  2. Infiniband(IB)/NCCL 问题:这类问题通常源于硬件故障,如网卡损坏或网络抖动,可能导致训练速度下降或任务异常中断。

  3. 任务挂起(Hang):通常与 IB/NCCL 问题相关,需要人工检测和处理。

  4. GPU 掉卡:此时一般会触发 CUDA 错误或程序异常退出,可能需要重置 GPU 或重启节点来解决。

  5. 机器异常:包括 GPU 之外的硬件异常,如硬盘、CPU 等,甚至整机故障,可能需要更换硬件或进行系统维护。

  6. 机器配置异常:例如,某台机器意外启用了 MIG(多实例 GPU),可能影响训练任务的正常运行。

  7. 集群维护:集群中的其他任务或系统维护、升级,可能需要暂停当前训练任务。

可以使用node-promblem-detector
node-problem-detector 是一个用于在集群管理栈的上游层次中使各个节点问题可见的守护进程。它在每个节点上运行,检测节点问题并将其报告给 apiserver。

监控和容错是一个比较难的问题,需要结合硬件和软件的特性,以及业务需求,进行综合考量。
特别是万卡集群,MFU 只有 50%左右。

在训练 OPT-175B 模型的过程中,Meta团队使用了 992 个 80GB 的 A100 GPU,每个 GPU 实现了约 147 TFLOP/s 的性能,对应的机器浮点利用率(MFU)约为 47%(147/312)。

为了应对可能的硬件故障,团队额外准备了 12 台备用机器,以便在出现问题时进行替换。在训练期间,平均每天约有 2 台机器发生故障,即每台机器每天发生故障的概率约为 1.61%。

整个训练过程持续了约 2 个多月,包括从 2021 年 10 月 20 日到 2021 年 11 月 11 日的测试阶段,以及从 2021 年 11 月 11 日到 2022 年 1 月 6 日的正式训练阶段,正式训练约 57 天。

根据预估,实际训练时间应为约 25 天,但由于各种问题,实际有效训练时间仅占总时间的约 44%。在前期,由于各种问题,团队至少手动重启了 35 次任务。为减少人工干预,后续引入了自动重启机制,但由于硬件故障,仍触发了超过 70 次重启,平均每天需要重启一次任务。

这些经验表明,在大规模模型训练中,硬件故障和其他问题会显著影响训练效率。为此,团队采取了多种措施来应对这些挑战,包括准备备用硬件、引入自动重启机制等,以确保训练过程的顺利进行。

这个问题在用新的卡的时候会有更多问题。

总结

To train our largest Llama 3 models, we combined three types of parallelization: data parallelization, model parallelization, and pipeline parallelization. Our most efficient implementation achieves a compute utilization of over 400 TFLOPS per GPU when trained on 16K GPUs simultaneously. We performed training runs on two custom-built 24K GPU clusters. To maximize GPU uptime, we developed an advanced new training stack that automates error detection, handling, and maintenance. We also greatly improved our hardware reliability and detection mechanisms for silent data corruption, and we developed new scalable storage systems that reduce overheads of checkpointing and rollback. Those improvements resulted in an overall effective training time of more than 95%. Combined, these improvements increased the efficiency of Llama 3 training by ~three times compared to Llama 2.

LLAMA3 的技术博客揭示了许多令人振奋的优化成果,这些优化背后蕴含着大量值得深入研究的技术细节。虽然我们可能难以直接接触如此大规模的训练集群及其面临的挑战,但这些技术进展仍然为整个 AI 基础设施领域提供了宝贵的参考和启发。

CacheBlend的主要目标是在一些RAG场景下,多个文档Chunk之间不能像多轮对话那样构成Prefix Cache。

对于位置编码来说,RoPE得到的注意力是绝对位置无关的,所以两个下三角放到对应的位置就可以。但是,两个Chunk之间的交叉注意力机制实际上是空的,如果单纯这样使用会丢失交叉注意力的信息。

论文中提到了一个例子。

在比较两个球员进球数的场景中,就损失了球员之间的交叉信息。

为了弥补下面那个空的矩形的注意力,同时尽量节省计算量,根据一些insights提出了一种选择性计算的方法。

注意力矩阵是稀疏的,其中差异较大(颜色较深)的部分对最终结果贡献较大。另外,对于多层的Transformer,前几层的注意力差异往往会一直保持下去。因此,CacheBlend会完全重算第一层的注意力,然后标记差值较大的token,用它们的注意力来代表两个Chunk之间的完全注意力矩阵。当然,这还是有损的。

其中的权衡在于选择多少比例的token来代表两个Chunk之间的注意力矩阵,这个比例是可以调整的。

一个平衡点在于从异构存储中加载KVCache和重计算KVCache的时间。如果选择性计算token的时间大于加载的时间,则可以将这个过程流水线化。在计算当前层时,加载下一层的KVCache。

比率r%刚好满足计算时间和加载时间相等。

尽管第一层高差异化的token的注意力比较重要,但只看第一层似乎不太合理。CacheBlend会用一个比r%更大的范围选择token,然后每一层逐渐递减r%,以容许更多的可能性。只从第一层选择r%可能会丢失一些重要信息。

总体来说,利用交叉注意力的稀疏性选择性重计算Chunk之间的交叉注意力,并平衡加载和计算,使得计算过程和加载是并行的,时间损耗无影响。在精度上,选择一个较大的r%再每层缩小到理想r%来满足准确性。最终实验效果显示,与完全重算的交叉注意力相比,结果还是很接近的。

vLLM的PD分离

vLLM的PD分离是指vLLM的Prefill和Decode分离到不同的实例中执行。

新增配置

新增 KVTransferConfig 配置,决定了实例的类型。如果是 prefill 则 role 为 producer,如果是 decode 则 role 为 consumer,并且要设定传输的方法。

  • is_kv_transfer_instance 判断是否是 PD 分离的实例。

代码实现

./vllm/worker/model_runner.py 中:

  1. 在计算之前执行 need_rev_kv,检查是否是 consumer,且当前 run 是不是 prefill。然后调用 get_kv_transfer_group().recv_kv_caches_and_hidden_states
  2. 在计算之后执行 need_send_kv,检查是否开启配置 producer,并且当前 run 是不是 prefill(对于以前未分离的结构来说,decode实例要经历prefill阶段,
    但是prefill已经被prefill实例做掉了,所以要等着接受prefill的KVCache,不需要重复计算prefill了)。然后调用 get_kv_transfer_group().send_kv_caches_and_hidden_states

KVTransfer 实例

get_kv_transfer_group 会返回一个 KVTransfer 的实例,是一个全局实例,初始化方式如下。其中的 rank 0 代表 prefill,rank 1 代表 decode。

1
2
3
4
_KV_TRANSFER = kv_transfer.KVTransferAgent(
rank=get_world_group().rank,
local_rank=get_world_group().local_rank,
config=vllm_config)

Transfer 实现

Transfer 的实现在 vllm/distributed/kv_transfer,这种解耦的设计是为了对接多种实现,比如 Mooncake 的开源 TransferEngine。Transfer 内部会调用 connector 的 send 和 recv 方法,这个方法是一个抽象方法,需要子类实现。目前有两种实现:Mooncake 的 transfer 和 PyNccl 的 transfer。

1
2
3
4
5
6
7
8
9
10
11
12
# Register various connectors here.
# The registration should not be done in each individual file, as we want to
# only load the files corresponding to the current connector.
KVConnectorFactory.register_connector(
"PyNcclConnector",
"vllm.distributed.kv_transfer.kv_connector.simple_connector",
"SimpleConnector")

KVConnectorFactory.register_connector(
"MooncakeConnector",
"vllm.distributed.kv_transfer.kv_connector.simple_connector",
"SimpleConnector")

Connector 依赖

Connector 依赖 kv_pipe 的实现。

  • from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator 用来实现 PyNccl 的 kv_pipe。其中的 Send 和 Recv 会依赖 NCCL 的集合通信实现。
  • 如果是 Mooncake pipe,import mooncake_vllm_adaptor as mva 这个模块,基于 ZeroMQ 的通信,通过 pickle 去序列化 tensor。
1
2
3
4
def _send_impl(self, tensor: torch.Tensor) -> None:
"""Implement the tensor sending logic."""
value_bytes = pickle.dumps(tensor)
self.transfer_engine.send_bytes(value_bytes)

KV Lookup Buffer 实现

另外还有一种 kv_lookup_buffer 的实现,抽象的接口是非阻塞 insert 和阻塞的 drop_select

  • Producer 调用 insert,consumer 调用 drop_select。目前 SimpleBuffer 也是基于 Pipe 去实现的,insert 变 send,drop_select 变 recv。
  • 如果有一些中心化的 KVCacheBuffer 的话可能可以不用基于 Pipe 的实现。比如可以基于分布式的 LMCache?

Prefill 启动

vLLM 目前的实现是基于 connector 的。Prefill 的启动时通过设置 max_token 为 1 来执行,当生成了 bonus token 以后转而去调用 decode 的实例。

DeepSeek V3分析

During the pre-training stage, training DeepSeek-V3 on each trillion tokens requires only 180K
H800 GPU hours, i.e., 3.7 days on our cluster with 2048 H800 GPUs.

DeepSeek实现了非常便宜的训练成本,是一个700B的MoE模型。

基础设施

  • 计算集群:在配备 2048 个 NVIDIA H800 GPU 的集群上训练,节点内通过 NVLink 和 NVSwitch 连接,节点间使用 InfiniBand 互连。
  • 训练框架:基于 HAI - LLM 框架,采用 16 路管道并行(PP)、64 路专家并行(EP)和 ZeRO - 1 数据并行(DP)。设计 DualPipe 算法减少管道气泡并重叠计算与通信,开发高效的跨节点全对全通信内核,优化内存占用,无需使用昂贵的张量并行(TP)。
  • FP8 训练:提出 FP8 混合精度训练框架,对多数计算密集型操作采用 FP8 精度,对部分关键操作保留原始精度,引入细粒度量化策略、提高累积精度、采用 E4M3 格式及在线量化,还降低了内存和通信开销。
  • 推理与部署:部署在 H800 集群上,通过分离预填充和解码阶段确保服务水平目标(SLO)和高吞吐量。预填充阶段最小部署单元为 4 节点 32 个 GPU,采用特定并行策略和冗余专家策略确保负载均衡;解码阶段最小部署单元为 40 节点 320 个 GPU,采用相应并行策略和冗余专家策略,并探索动态冗余策略。
  • 硬件设计建议:针对通信硬件,期望未来硬件能卸载通信任务,统一网络接口;针对计算硬件,建议提高 FP8 GEMM 累积精度、支持细粒度量化、在线量化和转置 GEMM 操作。

并行度配置

prefill阶段,attention模块采用4路张量并行+8路数据并行,moe模块采用32路专家并行。这样并行的目的是在满足首token时延的要求下,最大化系统吞吐(和训练任务类似)。

decode阶段,DeepSeek-V3采取320路专家并行(256个小专家+64个热点专家),有效降低解码时延,并缓解负载不均衡的问题。

DeepSeek-V3 采用了多种并行策略,包括 16 路流水线并行(PP),这一策略有助于提高训练效率,加快模型的处理速度。同时,还应用了 64 路专家并行(EP),且在 8 个节点上进行,能够充分发挥多节点的计算优势。此外,ZeRO-1 数据并行(DP)也被运用到训练中,进一步提升了模型的训练效果。

ZeRO-1 优化器被切分到不同的GPU上。 《大模型动力引擎——PyTorch性能与显存优化手册》有提到这个优化,总结的很好。

假设我们有N=64块GPU进行数据并行训练,在ZeRO-1阶段,优化器的状态量首先被分散存储到所有GPU中,此时单张GPU上的内存使用量骤降到(4+4+8/64)*7.5=60.9GB。ZeRO-2阶段进一步地将模型的梯度也分散存储,此时单张GPU上的内存使用量便是(4+(4+8)/64)7.5=31.4GB。而ZeRO-3阶段将模型的参数也分散存储到N个节点,此时每张GPU的内存消耗只有(4+4+8)/647.5=1.875GB。从单卡需要120GB到仅需不到2GB内存,这个优化效果是不是有点惊艳?不过需要再次强调的是,这样巨大的显存优化是有代价的,显存切分的程度越高,相应的通信开销也会增加。因此,根据实际需求合理地进行显存切分是非常重要的。

MLA

采用类似 LoRA 的架构,借助一个低秩矩阵 “compressed laten vector”,kvcache 仅需对低秩的 key-value 对以及附带旋转位置编码(RoPE)的 key 进行缓存。

MoE

除了针对 Top k、routed experts 运用添加了激活函数的加权求和方式外,还额外引入了 shared experts。在 gate 的激活函数里增添一个 bias,以此来化解 balance 失衡的难题,在训练阶段,通过调节这个 bias 对 balance 状况予以奖惩,这一调节过程被称作 bias update speed。
就一个 batch、一个序列而言,每个 token 倘若倾向于特定的一些 expert,那么未被选中的 expert 实际上仅相当于训练了极小的 batch size,或者极短的序列,正因如此,才有了这样一种策略,用以平衡 expert 的 batch size 以及序列当中的 token 数量,毕竟序列通常都很长。
DeepSeek-V3 着重凭借辅助损失策略达成负载均衡,与此同时,引入互为补充的序列平衡损失,以防单个序列内部出现极度不平衡的现象。

MTP

类似于 speculative decoding,它同样会计算多个 token,不过具体方式存在一定差异。其 embedding 与 output head 是共用的,这一点和 sd 里的 Medusa 有所不同,Medusa 是由多个头来推测不同位置,而 MTP 则是依靠多个相同的头(只是 attention 有别)去推断不同位置。

MTP 的核心目的在于提升主模型的性能表现,在推理阶段能够直接将 MTP 模块舍去,主模型依旧可以独自正常运作。不仅如此,MTP 模块还能够应用于推测解码环节,以此进一步优化生成延迟问题,让整个流程更加高效流畅。

DualPipe

双流水线pipeline的优化。它实现了前向和后向过程中计算与通信阶段的重叠,有效解决了跨节点专家并行带来的通信负载问题。

FP8

能够不依赖硬件能力做FP8精度的训练,这个点是非常厉害的。

首先,为提高模型训练速度,大部分核心计算操作(尤其是 GEMM 运算),均采用 FP8 精度实现。这些 GEMM 运算接收 FP8 格式的张量输入,输出 BF16 或 FP32 格式的结果。如图6所示,线性运算相关的三个 GEMM 操作,包括 Fprop(前向传播)、Dgrad(激活值反向传播)和 Wgrad(权重反向传播),均采用 FP8 执行。这种设计策略理论上将计算速度提升至原有 BF16 方法的两倍。同时,FP8 格式的 Wgrad GEMM 使得激活值能够以 FP8 格式存储用于反向传播,显著降低了内存使用量。

LoRA一般的设定是认为微调任务应该只需要在一个更小的子空间去训练即可不需要复用基座模型的大空间,从而实现低成本的微调。
LoRA的前提是问题是不是在子空间能得到最优解。在线性回归 y=W x 中,如果最优 W * 是高秩的,那么对 W 施加低秩假设永远不会导致最优解,无论使用什么优化器。

Gradient Low-Rank Projection (GaLore) 允许全参数学习,但比 LoRA 等常见的低秩自适应方法更具内存效率。

使用单个批处理大小从头开始预训练 LLaMA7B 模型至少需要 58 GB 内存(14GB 用于可训练参数,42GB 用于 Adam 优化器状态和权重梯度,2GB 用于激活函数)。这使得训练在消费级 GPU 上不可行,例如具有 24GB 内存的 NVIDIA RTX 4090。

他证明梯度可能具有低秩结构,如果我们能够在优化器状态中保留梯度的一个小 “核心” 的梯度统计信息,而不是完整的梯度本身,那么内存消耗就可以大幅降低。这就引出了 GaLore 策略。
他的关键思想是利用权重矩阵 W 的梯度 G 上做LoRA,而不是试图将权重矩阵本身近似为低秩。他的核心逻辑用Torch写出来如下:

1
2
3
4
5
6
7
8
9
for weight in model.parameters():
grad = weight.grad
# original space -> compact space
lor_grad = project(grad)
# update by Adam, Adafactor, etc.
lor_update = update(lor_grad)
# compact space -> original space
update = project_back(lor_update)
weight.data += update

Sequence Parallelism

假设有4个chunk,切四份。

初始化状态,每个GPU都有自己的 Qn Kn,可以计算出对应的注意力矩阵,然后类似AllReduce的方式传递切分的K。

第一步环形传递K,然后再算一次注意力矩阵。

第二步环形传递K,然后再算一次注意力矩阵。

第三步全部传完,得到完整的Sn。

然后 Sn 和 Vn 的计算也是类似的,经过三次环形传递Vn,然后每一份可以单独和小s的那一份做乘法。

所以K和V的传播都要经历 3 次(N-1)的集合通信。

LLM推理的核心在于KVCache的调度。

  1. 尽可能多次重用KV缓存,以减少所需的计算资源;
  2. 每批次最大化token数量,从而改善Model FLOPs Utilization (MFU)。

如果从远程内存获取KVCache,会增加数据传输时间,从而延长TTFT(Time To First Token)。因此,当本地KVCache的增量计算时间少于传输时间时,可以复用本地的KVCache,即使它不是最匹配的。而增大batch意味着系统处理的大批量数据,导致TBT(Token Between Token)延长,可以将负载均衡到低负载的Decode Instance。

架构

Mooncake的架构图主要分为三个部分:Prefill Instance,Decode Instance,Conductor。

  1. Cache-aware Prefill Scheduler:负责调度Request到Prefill Instance,主要考虑load和KVCache的复用率。
  2. KVCache Balance Scheduler:负责从匹配最多前缀的P2P传输KVCache到Instance(Decode和Prefill)。
  3. Load-balance Decoding Scheduler:负责负载均衡调度Request到Decode Instance。

Prefill Instance要满足TTFT SLO,最小化MFU,保证KVCache < DRAM。
Decode Instance要满足TBT SLO,保证KVCache < VRAM。
Inter-Node Transfer基于RDMA的P2P,这也是一个较大的开销。

Mooncake的方法总结如下:

  1. 转移可重用的KVCache,将尽可能多的可重用KVCache转移至Prefill Instance,减少增量计算的时间。
  2. Prefill Instance Pool分层并分块处理,并持续输出给对应的Decode Instance。分层指的是Layer-wise KVCache的异步保存,分块指的是Chunked Pipeline Parallelism。
  3. 独立的Decode Instance Pool加载KVCache,通过连续批处理解码tokens。

Mooncake的主要特点是将prefill和decode拆开,并调度KVCache块。

Reject Policy:如果一个请求不能在服务水平内完成其完整的执行,那么就应该尽早拒绝这个请求,基于这个理念需要设计一些拒绝策略,被称作Overloaded-Scheduling。

KVCache的复制

KVCache的调度主要是利用KV Cache(VRAM,DRAM),利用RDMA带宽。

下图是一个Prefill和Decode分离的计算过程。

如果了解vLLM中的prefill和decode以及管理block的方法,这个图其实很简单。

首先通过Hash判断block是否相同,例如很多系统提示词都是一样的,这部分的复用率很高。

Prefill Instance已经有了ABCDE(这里是一个P2P的过程,但我看开源的版本有个KVCache Store的WIP,不知道后面会不会有一个中心化的KVCache Store的组件)。然后计算了FGHI,存入了KV Cache(在CPU mem上),论文里面提到这个prefill在超过prefill_chunk tokens数量会做chunked prefill。

接着通过Messenger以RDMA的方式发给Decode Instance。Decode Instance基于ABCDEFGHI的prompt对应的KV Cache开始decode的过程。

根据请求模式,它可以使用缓存淘汰算法,如LRU(最近最少使用),LFU(最不常用的),或基于请求特征的算法。这些KVCache块在CPU和GPU之间的传输由一个独立的(GPUDirect)RDMA组件Messenger处理。这种架构还使我们能够为外部用户提供KVCache缓存API,从而实现更高的缓存重用性。

Mooncake已经开源了他的代码,目前只有Transfer Engine。

基于这个架构,Conductor的主要功能是:

  1. 根据当前的KVCache分布和工作负载,分发请求。
  2. 复制或交换某些KVCache块,以便于未来推理。如果某些块的数据在未来被频繁访问,Conductor可能会将其复制到其他节点上,从而提高推理效率。

Mooncake的一个争论点是,是否需要在存在chunked prefill的情况下采用这种分离架构。毕竟,chunked prefill可以填补许多pipeline中的气泡,并且能让prefill和decode节点相对统一,只需要关心一种instance,对于scheduler比较友好。

  1. 不分离的优点:

    • 所有节点被视为平等,使调度更简单;
    • 将chunked prefill内联到解码批处理中可以提高解码批次的计算强度,从而提高MFU。
  2. 分离的优点:

    • 长文本的跨节点并行和VRAM的节省。长文本输入是输出的10倍甚至100倍,对于相同的模型来说,prefill需要多节点配置才能满足显存需求。prefill阶段可以进行layer-wise prefill,每次保存大量KVCache,而decode阶段每次只需保存一个KVCache。因此,prefill阶段可以通过layer-wise prefill来减少VRAM占用。

是这么理解么?异步的Store KVCache可以节省保存的时间,但这是Prefill和Decode分离的理由么?Decode阶段应该是不保存KVCache?

然而,经过仔细考虑,论文决定保持Mooncake的分离架构。只有在请求的prefill可以不进行chunking且不影响TBT SLO的情况下,才会将其内联到解码批次中。我们这样决定的主要原因有两个:

  1. Prefill节点需要不同的跨节点并行设置来处理长上下文 (§5.1)。

  2. 这为节省VRAM提供了独特的机会 (§5.2)。

  3. 大模型需要部署在多机上,进行TP后,每一层都需要进行一次基于RDMA的reduce,这个过程开销巨大。虽然有一些Sequence Parallelism的方法,但效果并不理想,且无法避免跨节点通信。而Mooncake采用的是CPP(Chunked Parallelism Pipeline),将序列按prefill_chunk大小切分,交给prefill pool的不同节点,这些节点被切分成更小的节点池(pipelined prefill node group)。

疑问:他们是pipe的不同部分?还是完全对等的?目前感觉是PP是分layer做Pipe,而CPP是sequence分chunked做pipe。24引用的论文中提到的Sequence Pipeline可以再看一下,应该对理解这个有帮助。

  1. Layer-wise prefill,这有点像airllm项目,在计算过程中动态加载KVCache。在每次注意力计算时,KVCache是异步加载的,计算当前层时可以异步加载下一层,并且当前层结束后可以异步保存当前层。论文中认为KVCache的保存时间可以被完全省略(相较于加载计算保存的线性循环)。这样也可以降低VRAM的占用。

调度算法

  1. 选择Prefill实例

    • 如果Prefill节点上缓存了足够的前缀(由kvcache_balancing_threshold控制),则选择预估TTFT最小的实例:TTFT = min(T_queue + T_prefill)
    • 如果Prefill节点上缓存不足,则选择TTFT = min(T_queue + T_prefill + T_transfer)最小的实例,其中T_transfer指的是有最长匹配的KVCache的实例拷贝到当前实例的预估时间。
  2. 选择Decode实例

    • 通过负载均衡的方式预估TBT。
    • 如果TBT和TTFT不满足SLO,则拒绝请求,并触发KVCache的传输。
  3. 预测模型

    • 预估模型用于预测传输时间和决策传输。
    • 数据传输时间难以预测,因为它不仅取决于数据大小,还依赖于当前网络状态,特别是当发送节点处于拥塞状态时。
  4. KVCache复制

    • 热门的KVCache块需要被复制以确保高可用性。
  5. 调度器目标

    • 保证低Cache负载和高Cache命中率。
  6. 高负载情况下的策略

    • 请求可能不会被直接发送给缓存最长前缀的实例,而是转发给备选实例。备选实例会主动从缓存持有者处检索KV缓存并存储本地。
    • 当最佳的远程前缀匹配长度不超过当前本地可重用前缀的阈值时,系统优先使用本地缓存,而不是从远程实例获取令牌。

这些策略不仅减少了请求的Prefill时间,还自动复制热点缓存,使其在多台机器上更广泛地分布。

拒绝策略

论文提到了一种基于预测的拒绝策略。Prefill和Decode的负载节奏是相反的,可能在Decode负载高时,Prefill负载较低。此时如果拒绝请求,会导致Decode负载下降,而Prefill完成后Decode负载又会升高,进而再次拒绝请求。引入预测拒绝策略后,可以使Prefill过程更加平滑,减少频繁拒绝请求的情况,从而减小负载节奏的波动。

参考Pytorch版本的llama-from-scratch。原文中的RmsNorm的平均值多算了一个维度,这里改成了正确的版本。

首先需要下载TinyShakespeare数据集,这是一个莎士比亚文字数据集。本文档将大致遵循论文的布局,并跳过一些明显的步骤,比如设置虚拟环境和安装依赖项。

我们最终将实现的内容预览:

1
2
3
4
5
6
7
8
9
10
11
12
13
14

println!(generate(llama, MASTER_CONFIG, 500, device)[0])

ZELBETH:
Sey solmenter! 'tis tonguerered if berryishdd, and What his stabe, you, and, but all I pilJefals, mode with,
Vurint as steolated have loven OlD the queen'd refore
Are been, good plmp:

Proforne, wift'es swleen, was no bunderes'd a a quain beath!
Tybell is my gateer stalk smen'd as be matious dazest brink thou
lord
Enves were cIUll, afe and whwas seath This a is, an tale hoice his his onety Meall-tearn not murkawn, fase bettizen'd her,
To belacquesterer? baxewed wupl usweggs yet tall
An

实现过程中可能涉及一些 Rust 的使用方法,与 Python 有所不同,这里不做过多说明,具体的 Rust 语法可以参考其他文档。

迭代工作:从小模块开始,保持确定性,然后逐步构建

  1. 创建所有需要的辅助函数,以便定量测试模型(数据拆分、训练、绘制损失)。
  2. 从论文中挑选出不同的组件,然后逐一实现,边训练边评估。

确保你的层按预期工作

  1. 经常使用 .shape()assert 是你的朋友。
  2. 先在不进行矩阵乘法的情况下计算结果,然后使用 candle 函数使其高效。
  3. 有一个测试来确保你的层是正确的。例如,RoPE 嵌入有一个特定的属性,你可以测试它。对于 Transformer,你可以通过查看注意力矩阵来测试注意力是否正常工作。
  4. 在各种批次、序列和嵌入大小上测试你的层。即使它适用于一种大小,它可能不适用于其他大小,这将在推理时导致问题。

关于 Llama

Llama 是一种基于 Transformer 的语言建模模型。它是一个自回归模型,也称为 CausalModel,模型会将输出中的 token 加入到输入中,不断迭代推理,直到超过上下文长度或遇到停止符。Meta AI 开源了 Llama,并明确表示他们的目标是使模型在推理时更高效,而不是优化训练成本。

接下来,我们将加载库并开始实现。

1
2
3
4
5
6
7
8
9
10
11
12
use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::ops::softmax;
use candle_nn::{
embedding, linear, loss, AdamW, Embedding, Init, Linear, Module, Optimizer, ParamsAdamW,
VarBuilder,
};

use core::f32;
use rand::Rng;
use std::collections::HashMap;
use std::fs;
use std::time;

设置数据集

虽然 Llama 在 1.4T 个标记上进行训练,但我们的数据集 TinyShakespeare,即莎士比亚所有作品的集合,大约只有 100 万个字符。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
use std::collections::HashMap;
use std::fs;

fn main() {
// Read the entire content of the file
let lines = fs::read_to_string("./input.txt")
.expect("Failed to read the file");

// Create a sorted set of unique characters
let mut vocab: Vec<char> = lines.chars().collect();
vocab.sort_unstable();
vocab.dedup();

// Create itos and stoi mappings
let itos: HashMap<usize, char> = vocab.iter().enumerate().map(|(i, &ch)| (i, ch)).collect();
let stoi: HashMap<char, usize> = vocab.iter().enumerate().map(|(i, &ch)| (ch, i)).collect();

// Print the first 30 characters of the file
println!("{}", &lines[..30.min(lines.len())]);
}
First Citizen:
Before we proce

他们使用了SentencePiece字节对编码分词器,但我们将只使用一个简单的字符级分词器。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
use std::collections::HashMap;
use std::fs;
struct Vocab {
itos: HashMap<u32, char>,
stoi: HashMap<char, u32>,
}

impl Vocab {
fn new(itos: HashMap<u32, char>, stoi: HashMap<char, u32>) -> Vocab {
Vocab { itos, stoi }
}
fn decode(&self, ids: &[u32]) -> String {
ids.iter().map(|&id| self.itos[&id]).collect()
}
fn encode(&self, text: &str) -> Vec<u32> {
text.chars().map(|ch| self.stoi[&ch]).collect()
}
fn len(&self) -> usize {
self.itos.len()
}
fn build(lines: &str) -> Self {
// Create a sorted set of unique characters
let mut vocab: Vec<char> = lines.chars().collect();
vocab.sort();
vocab.dedup();

// Create itos and stoi mappings
let itos: HashMap<u32, char> = vocab
.iter()
.enumerate()
.map(|(i, &ch)| (i as u32, ch))
.collect();
let stoi: HashMap<char, u32> = vocab
.iter()
.enumerate()
.map(|(i, &ch)| (ch, i as u32))
.collect();
Self { itos, stoi }
}
}

fn main() {
// Read the entire content of the file
let lines = fs::read_to_string("./input.txt").expect("Failed to read the file");

let vocab = Vocab::build(&lines);
println!("vocab size = {}", vocab.len());
println!("{}", vocab.decode(&vocab.encode("hello")));
}
vocab size = 65
hello

由于数据集较小,我们无需担心内存存储问题。

我们创建了一个 config 对象来存储基本的模型参数。这样可以提高代码的可读性,并且便于修改配置。Rust 是强类型语言,因此所有变量都有明确的类型。

1
2
3
let mut modeConfig = ModelConfig {
vocab_size: vocab.len(),
}
1
2
let dataset = Tensor::from_slice(&vocab.encode(&lines), (lines.len(),), &Device::Cpu).unwrap();
println!("{:?}", dataset.shape());
[1115394]

让我们创建一个方法 get_batches 来生成训练数据和目标的批次。我们将使用相同的方法来生成验证和测试数据,通过 split 参数来控制。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
fn get_batches(
dataset: &Tensor,
split: &str,
batch_size: usize,
context_length: usize,
) -> Result<(Tensor, Tensor)> {
let len_of_dataset = dataset.shape().dim(0).unwrap() as f32;
// 按照 0.8 0.1 0.1 的比例切分训练集, 验证集和测试集
let batch_data = match split {
"val" => &dataset.i((0.8 * len_of_dataset) as usize..(0.9 * len_of_dataset) as usize)?,
"test" => &dataset.i((0.9 * len_of_dataset) as usize..)?,
_ => &dataset.i(..(0.8 * len_of_dataset) as usize)?,
};
// 生成随机index
let mut rng = rand::thread_rng();
let data_len = batch_data.shape().dim(0)?;
let indices: Vec<usize> = (0..batch_size)
.map(|_| rng.gen_range(0..data_len - context_length - 1))
.collect();
let mut x_batches = Vec::with_capacity(batch_size);
let mut y_batches = Vec::with_capacity(batch_size);

for idx in indices {
let x = batch_data.i(idx..(idx + context_length))?;
// y 是 x 后面的一个字符
let y = batch_data.i((idx + 1)..(idx + context_length + 1))?;
x_batches.push(x);
y_batches.push(y);
}
// stack 和 cat 的区别是, stack 是在新的维度上堆叠, cat 是在已有的维度上堆叠
let x_tensor = Tensor::stack(&x_batches, 0)?;
let y_tensor = Tensor::stack(&y_batches, 0)?;
Ok((x_tensor, y_tensor))
}
// in fn main
modeConfig.context_length = 16;
modeConfig.batch_size = 8;
let batch = get_batches(
&dataset,
"train",
modeConfig.batch_size,
modeConfig.context_length,
)?;
println!(
"batch size {}, context_length {}",
batch.0.shape().dim(0)?,
batch.0.shape().dim(1)?
);
for i in 0..modeConfig.batch_size {
println!(
"{:?}, {:?}",
vocab.decode(&batch.0.i(i)?.to_vec1()?),
vocab.decode(&batch.1.i(i)?.to_vec1()?),
);
}

":\nBut, that I'll", "\nBut, that I'll "
"ng?\nWhy, then th", "g?\nWhy, then the"
"s so blind, but ", " so blind, but s"
"thy offices,\nSo ", "hy offices,\nSo r"
"ords, how plainl", "rds, how plainly"
"IET:\nHere's such", "ET:\nHere's such "
"wer\nTo take off ", "er\nTo take off s"
" hurry from the ", "hurry from the f"

实现论文有趣的一点在于,模型“工作”有两个方面:编译(你的张量是否在各层之间匹配)和训练(损失是否下降)。
我们还要定义评估模型的方法。我们希望在定义模型之前就这样做,因为我们希望在训练模型时使用它来评估模型的性能。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
fn evaluate_loss(
model: &SimpleBrokenModel,
dataset: &Tensor,
vocab: &Vocab,
config: ModelConfig,
) -> Result<HashMap<String, f32>> {
let mut out = HashMap::new();
for split in ["train", "val"] {
let mut losses = Vec::new();
for _ in 0..10 {
let (xs, ys) = get_batches(&dataset, split, config.batch_size, config.context_length)?;
let (_, loss) = model.forward(&xs, Some(&ys))?;
let loss = loss.unwrap();
losses.push(loss.to_scalar::<f32>()?);
}
let avg_loss = losses.iter().sum::<f32>() / losses.len() as f32;
out.insert(split.to_owned(), avg_loss);
}
Ok(out)
}

设置一个可以工作的简单模型

这是一个带有嵌入的基本前馈神经网络。它是我们将要开始的基础模型,然后我们将逐步替换其部分内容,直到最终得到 Llama 论文中描述的模型。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
struct SimpleBrokenModel {
embedding: Embedding,
mlp: Sequential,
config: ModelConfig,
}

impl SimpleBrokenModel {
fn forward(&self, x: &Tensor, targets: Option<&Tensor>) -> Result<(Tensor, Option<Tensor>)> {
// 潜入层
let embeds = self.embedding.forward(x)?;

// 线性和激活层
let logits = self.mlp.forward(&embeds)?;

// 如果提供了targets就计算loss,不然视为推理,计算logits就可以。
if let Some(targets) = targets {
// 负的似然函数
// -log(x) 越大,loss 越小
// y = [0, 0 , 0, 0, 1, ...,0,0]
// y' = [4, 5, 6, 7, 8, ...,11,12 ]
// 这个 cross_entropy 帮我们做了一个 log softmax
// y' = [0.1, 0.12, 0.13, 0.64, ..., 0,0]
// loss = -log(0.64)
// 当 -log(q) = 4.17 q = 0.015 大概 1/64,vocab_size = 65,所以基本是在瞎猜。
let loss = loss::cross_entropy(
&logits.reshape(((), self.config.vocab_size))?,
&targets.reshape(((),))?,
)?;
Ok((logits, Some(loss)))
} else {
Ok((logits, None))
}
}
// VarBuilder是用来构建参数的,我们目前不加载和保存模型参数,但是candle的用法必须基于这个。
// vb.pp 会在参数树中加入参数的前缀,这样可以方便的查看参数的结构。
fn load(vb: VarBuilder, config: ModelConfig) -> Result<(Self)> {
let embedding = embedding(
config.vocab_size,
config.d_model,
vb.pp("model.embed_tokens"),
)?;
let mlp = sequential::seq()
.add(linear(
config.d_model,
config.d_model,
vb.push_prefix("model.fc1"),
)?)
.add(Activation::Relu)
.add(linear(
config.d_model,
config.vocab_size,
vb.push_prefix("model.fc2"),
)?);
Ok(Self {
embedding,
mlp,
config,
})
}
}
// in fn main
modeConfig.d_model = 128;
modeConfig.batch_size = 32;
let (xs, ys) = get_batches(
&dataset,
"train",
modeConfig.batch_size,
modeConfig.context_length,
)?;
let varmap = candle_nn::VarMap::new();
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &Device::Cpu);
let model = SimpleBrokenModel::load(vb, modeConfig)?;
let (logits, loss) = model.forward(&xs, Some(&ys))?;
println!("{:?} {:?}", logits, loss);
let mut params_count: usize = 0;
for (name, var) in varmap.data().lock().unwrap().iter() {
println!("{}: {:?}", name, var.elem_count());
params_count += var.elem_count();
}
println!("params count: {}", params_count);
Tensor[dims 32, 16, 65; f32] Some(Tensor[5.266067; f32])
model.fc2.weight: 8320
model.embed_tokens.weight: 8320
model.fc1.bias: 128
model.fc2.bias: 65
model.fc1.weight: 16384
params count: 33217

在这一点上,我们必须开始关注张量的形状,并让矩阵的维度匹配。查看我们模型定义中的这一行:

1
2
3
4
let loss = loss::cross_entropy(
&logits.reshape(((), self.config.vocab_size))?,
&targets.reshape(((),))?,
)?;

我们必须调整 logitstargets 张量的形状,以便在比较时它们的维度匹配。我们使用 reshape 方法来实现这一点。
() 参数的意思是“从其他维度推断这个维度”。所以,在这种情况下,我们是在说“将 logitstargets 重新调整为具有相同行数的形状,并使用所需的列数来实现这一点”。这是处理批量数据时的常见模式。

让我们训练我们的 SimpleBrokenModel 以确保梯度流动。在确认这一点之后,我们可以替换它的部分内容以匹配 Llama,再次训练并跟踪我们的进展。在这一点上,我开始记录我的训练运行日志,这样如果我搞砸了,我可以轻松地回到之前的运行。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
fn train(
config: ModelConfig,
model: &SimpleBrokenModel,
opt: &mut AdamW,
dataset: &Tensor,
vocab: &Vocab,
) -> Result<()> {
let mut start_time = std::time::Instant::now();
for epoch in 0..config.epochs {
let (xs, ys) = get_batches(&dataset, "train", config.batch_size, config.context_length)?;
let (_, loss) = model.forward(&xs, Some(&ys))?;
opt.backward_step(&loss.unwrap())?;
if epoch % config.log_interval == 0 {
let batch_duration = start_time.elapsed().as_secs_f32();
let loss = evaluate_loss(&model, dataset, vocab, config)?;
let val_loss = loss.get("val").unwrap();
let eta = batch_duration * (config.epochs - epoch) as f32;
let eta = eta.round();
println!(
"Epoch: {epoch} | Loss: {val_loss} | Time: {batch_duration} | ETA in seconds {eta}"
);
start_time = time::Instant::now();
}
}
Ok(())
}
// in fn main
modeConfig.log_interval = 10;
modeConfig.epochs = 100;
let mut opt = candle_nn::AdamW::new(varmap.all_vars(), ParamsAdamW::default())?;
train(modeConfig, &model, &mut opt, &dataset, &vocab)?;
let out = evaluate_loss(&model, &dataset, &vocab, modeConfig);
println!("{:?}", out);
Epoch: 10 | Loss: 3.9159875 | Time: 6.5813394 | ETA in seconds 599
Epoch: 20 | Loss: 3.26492 | Time: 6.3639965 | ETA in seconds 515
Epoch: 30 | Loss: 2.9944448 | Time: 6.3596206 | ETA in seconds 452
Epoch: 40 | Loss: 2.8793342 | Time: 6.357106 | ETA in seconds 388
Epoch: 50 | Loss: 2.7827232 | Time: 6.3562865 | ETA in seconds 324
Epoch: 60 | Loss: 2.764416 | Time: 6.352279 | ETA in seconds 260
Epoch: 70 | Loss: 2.7196321 | Time: 6.356127 | ETA in seconds 197
Epoch: 80 | Loss: 2.7631993 | Time: 6.357493 | ETA in seconds 134
Epoch: 90 | Loss: 2.696882 | Time: 6.358631 | ETA in seconds 70
Epoch: 100 | Loss: 2.670012 | Time: 6.3603354 | ETA in seconds 6
Ok({"train": 2.591057, "val": 2.6625311})
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
fn generate(model: &SimpleBrokenModel, vocab: &Vocab, max_tokens: usize) -> Result<()> {
// batch size 5, initial token = 0
let mut token_ids = Tensor::zeros((5, 1), DType::U32, &Device::Cpu).unwrap();
for _ in 0..max_tokens {
let (logits, _) = model.forward(&token_ids, None)?;
assert!(logits.shape().dims() == [token_ids.dim(0)?, token_ids.dim(1)?, 65]);
let last_step_logits = logits.i((.., logits.dim(1)? - 1))?;
assert!(last_step_logits.shape().dims() == [token_ids.dim(0)?, 65]);
let probs = softmax(&last_step_logits, last_step_logits.dims().len() - 1)?;
assert!(probs.shape().dims() == [token_ids.dim(0)?, 65]);
let next_token = probs.argmax(probs.dims().len() - 1)?;
assert!(next_token.shape().dims() == [token_ids.dim(0)?]);
token_ids = Tensor::cat(&[token_ids, next_token.reshape(((), 1))?], 1)?;
}
let lines = fs::read_to_string("./input.txt").expect("Failed to read the file");
for v in &token_ids.to_vec2()? {
let text = vocab.decode(v);
println!("{}", text);
}
Ok(())
}
// fn in main
generate(&model, &vocab, 10, device)?;
['\nFind!\nD:\nAr t,\nLis sthte o t l',
 '\nAnd ronnot ar\nBE:\nKINRDYOrspr;',
 '\nI t athe momyengthend thanswal',
 '\nFis t bp he\nLacarn.\nA:\nYOMI wi',
 '\nWh ly sck\nB-de pll t\nHERIns ou']

这还算不错,但也不算太好。不过现在我们有了一个可以训练到验证损失的工作模型。因此,我们将在此基础上迭代我们的模型,使其更接近 Llama。

Llama 具体细节

Llama 对原始 Transformer 进行了三项架构修改:

  1. 用于预归一化的 RMSNorm
  2. 旋转嵌入 RoPE
  3. SwiGLU 激活函数

我们将逐一添加每个修改到我们的基础模型,并进行迭代。

RMSNorm

在 Vaswani 2017 中,原始的 Transformer 使用了 BatchNormalization。在 Llama 中,作者使用了 RMSNorm,这是一种在不进行中心化的情况下通过方差来缩放向量的方法。此外,虽然 Vaswani 将归一化应用于注意力层的输出(后归一化),但 Llama 将其应用于输入之前(前归一化)。

这篇文章对于RMSNorm有一个很好的解释。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
pub struct RmsNorm {
scale: Tensor,
eps: f64,
}
impl RmsNorm {
fn new(size: usize, vb: VarBuilder) -> Result<Self> {
Ok(RmsNorm {
scale: vb.get_with_hints(size, "weight", Init::Const(1.))?,
eps: 1e-6,
})
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x_sqr = x.sqr()?;
assert!(x_sqr.shape().dims() == x.shape().dims());
let norm_x = (x.mean(D::Minus1)? + self.eps)?.sqrt()?;
assert!(norm_x.shape().dims() == [x.shape().dim(0)?, x.shape().dim(1)?]);
let x_normed = x.broadcast_div(&norm_x.reshape((
norm_x.shape().dim(0)?,
norm_x.shape().dim(1)?,
(),
))?)?;
assert!(x_normed.shape().dims() == x.shape().dims());
let x = (x_normed.broadcast_mul(&self.scale))?;
Ok(x)
}
}

let varmap = candle_nn::VarMap::new();
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &Device::Cpu);

let rms_norms = RmsNorm::new(2, vb)?;
// (2,3,2)
let batch = Tensor::new(
vec![
vec![vec![1f32, 1f32], vec![1.2f32, 2f32], vec![3f32, 3f32]],
vec![vec![4f32, 43f32], vec![5f32, 5f32], vec![61f32, 6f32]],
],
&Device::Cpu,
)?;
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &Device::Cpu);
let out = rms_norms.forward(&batch)?;
Tensor[dims 2, 3, 2; f32]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
struct SimpleBrokenModel {
embedding: Embedding,
mlp: Sequential,
rms_norm: RmsNorm,
config: ModelConfig,
}

impl SimpleBrokenModel {
fn forward(&self, x: &Tensor, targets: Option<&Tensor>) -> Result<(Tensor, Option<Tensor>)> {
// Embedding
let embeds = self.embedding.forward(x)?;
// RMSNorm
let normed_embeds = self.rms_norm.forward(&embeds)?;
// Linear layers and activation
let logits = self.mlp.forward(&normed_embeds)?;

// Calculate loss if targets are provided
if let Some(targets) = targets {
// 负的似然函数
// log(x) 越大,loss 越小
// y = [0, 0 , 0, 0, 1, ...,0,0]
// y' = [4, 5, 6, 7, 8, ...,11,12 ]
// 这个 cross_entropy 帮我们做了一个 log softmax
// y' = [0.1, 0.12, 0.13, 0.64, ..., 0,0]
// loss = -log(0.64)
// -log(q) = 4.17 q = 0.015 大概 1/64,vocab_size = 65,所以基本是在瞎猜。
// println!("{:?}", targets.shape());
// println!("{:?}", logits.shape());
let loss = loss::cross_entropy(
&logits.reshape(((), self.config.vocab_size))?,
&targets.reshape(((),))?,
)?;
Ok((logits, Some(loss)))
} else {
Ok((logits, None))
}
}
// VarBuilder是用来构建参数的。
fn load(vb: VarBuilder, config: ModelConfig) -> Result<(Self)> {
let embedding = embedding(
config.vocab_size,
config.d_model,
vb.pp("model.embed_tokens"),
)?;
let rms_norm = RmsNorm::new(config.d_model, vb.pp("model.rms_norm"))?;
let mlp = sequential::seq()
.add(linear(
config.d_model,
config.d_model,
vb.push_prefix("model.fc1"),
)?)
.add(Activation::Relu)
.add(linear(
config.d_model,
config.vocab_size,
vb.push_prefix("model.fc2"),
)?);
Ok(Self {
embedding,
mlp,
config,
rms_norm,
})
}
}

Epoch: 10 | Loss: 4.1559505 | Time: 6.779387 | ETA in seconds 617
Epoch: 20 | Loss: 4.14648 | Time: 6.7727704 | ETA in seconds 549
Epoch: 30 | Loss: 4.1364665 | Time: 6.776428 | ETA in seconds 481
Epoch: 40 | Loss: 4.125594 | Time: 6.772582 | ETA in seconds 413
Epoch: 50 | Loss: 4.120083 | Time: 6.7661977 | ETA in seconds 345
Epoch: 60 | Loss: 4.1099877 | Time: 6.760399 | ETA in seconds 277
Epoch: 70 | Loss: 4.0996284 | Time: 6.7623253 | ETA in seconds 210
Epoch: 80 | Loss: 4.0902996 | Time: 6.761824 | ETA in seconds 142
Epoch: 90 | Loss: 4.0833025 | Time: 6.76845 | ETA in seconds 74
Epoch: 100 | Loss: 4.070025 | Time: 6.7624454 | ETA in seconds 7
Ok({"train": 4.072861, "val": 4.0711236})

从这里得到的结果来看,范化以后,模型的表现并没有提升,所以我们需要继续迭代,只是梯度的下降变得比较平滑了。

Rotary Embeddings

RoPE 是一种用于 Transformer 的位置编码方法。在《Attention is All You Need》中,作者提出了两种位置编码方法:学习的和固定的。在 RoPE 中,作者通过旋转嵌入来表示序列中标记的位置,并在每个位置使用不同的旋转角度。

其中的 cos 和 sin 值可以预先计算并缓存,避免重复计算,后续会统一存放在一个缓存结构中。

RoPE 将 hidden_state 中每两个 x 组成的向量与旋转矩阵相乘来实现位置编码。

1
2
3
4
[x0, x1, .... ,xn]
y0 = x0 * cos(theta) - x1 * sin(theta)
y1 = x0 * sin(theta) + x1 * cos(theta)
[y0, y1, ...., yn]

nd_model 的一半。
theta 是一个根据位置得到的固定值,计算公式为:
theta = m / 10000^(2i/n)
其中,m 是在序列中的位置,i 是在 d_model 中的位置。

这个公式的含义是将特征向量中的 x0x1 进行一个固定的旋转,这个旋转不是通过学习得到的,而是预先计算的。它可以用于表示相对位置信息。

1
2
3
```Rust
pos_index = 0, 中的 x0,x1
pos_index = 1, 中的 x0,x1

隔了一个恒定的调度旋转。

freq_cis缓存住提前算好的cos和sin的值,这部分不用重复计算。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
struct Cache {
cos: Tensor,
sin: Tensor,
}

impl Cache {
fn new(context_length: usize, n_elem: usize, vb: VarBuilder) -> Result<Cache> {
let theta: Vec<_> = (0..n_elem)
.step_by(2)
.map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32))
.collect();
let theta = Tensor::new(theta.as_slice(), vb.device())?;
let idx_theta = Tensor::arange(0, context_length as u32, vb.device())?
.to_dtype(DType::F32)?
.reshape((context_length, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
let freq_cis_real = idx_theta.cos()?;
let freq_cis_imag = idx_theta.sin()?;

let cos = freq_cis_real.reshape((context_length, n_elem / 2, 1))?;
let sin = freq_cis_imag.reshape((context_length, n_elem / 2, 1))?;
Ok(Cache { cos, sin})
}
}

rope计算的时候

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
fn apply_rotary_emb(&self, x: &Tensor, cache: &Cache) -> Result<Tensor> {
let (b_sz, seq_len, n_embd) = x.dims3()?;
// println!("shape of cache.cos {:?}", cache.cos.shape());
let cos = cache.cos.i(..seq_len)?;
let sin = cache.sin.i(..seq_len)?;
// println!("cos shape {:?}", cos.shape());
let cos = cos.broadcast_as((b_sz, seq_len, n_embd / 2, 1))?;
let sin = sin.broadcast_as((b_sz, seq_len, n_embd / 2, 1))?;
// println!("broadcast cos shape {:?}", cos.shape());
let x = x.reshape((b_sz, seq_len, n_embd / 2, 2))?;
let x0 = x.narrow(D::Minus1, 0, 1)?;
let x1 = x.narrow(D::Minus1, 1, 1)?;
let dst0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
let dst1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
let rope = Tensor::cat(&[&dst0, &dst1], D::Minus1)?.reshape((b_sz, seq_len, n_embd))?;
Ok(rope)
}

Self Attention

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
struct AttentionModel {
embedding: Embedding,
mlp: Sequential,
rms_norm: RmsNorm,
config: ModelConfig,
self_attention: SelfAttention,
cache: Cache,
}

impl AttentionModel {
fn forward(&self, x: &Tensor, targets: Option<&Tensor>) -> Result<(Tensor, Option<Tensor>)> {
// 嵌入层
let embeds = self.embedding.forward(x)?;
// 范化层
let normed_embeds = self.rms_norm.forward(&embeds)?;
// 自注意力层
let y = self.self_attention.forward(&normed_embeds, &self.cache)?;
// 线性和激活层
let logits = self.mlp.forward(&normed_embeds)?;

if let Some(targets) = targets {
// 负的似然函数
// -log(x) 越大,loss 越小
// y = [0, 0 , 0, 0, 1, ...,0,0]
// y' = [4, 5, 6, 7, 8, ...,11,12 ]
// 这个 cross_entropy 帮我们做了一个 log softmax
// y' = [0.1, 0.12, 0.13, 0.64, ..., 0,0]
// loss = -log(0.64)
// 例如 -log(q) = 4.17 q = 0.015 大概 1/64,vocab_size = 65,所以基本是在瞎猜。
let loss = loss::cross_entropy(
&logits.reshape(((), self.config.vocab_size))?,
&targets.reshape(((),))?,
)?;
Ok((logits, Some(loss)))
} else {
Ok((logits, None))
}
}

fn load(vb: VarBuilder, config: ModelConfig) -> Result<(Self)> {
let embedding = embedding(
config.vocab_size,
config.d_model,
vb.pp("model.embed_tokens"),
)?;
let rms_norm = RmsNorm::new(config.d_model, vb.pp("model.rms_norm"))?;
let self_attention = SelfAttention::load(vb.pp("model.self_attention"), config)?;
let mlp = sequential::seq()
.add(linear(
config.d_model,
config.d_model,
vb.push_prefix("model.fc1"),
)?)
.add(Activation::Relu)
.add(linear(
config.d_model,
config.vocab_size,
vb.push_prefix("model.fc2"),
)?);
Ok(Self {
embedding,
mlp,
config,
rms_norm,
self_attention,
cache: Cache::new(config.context_length, config.d_model, vb)?,
})
}
}

提示:了解训练时张量维度与推理时张量维度的区别。

虽然在训练时,你可以期望张量维度与模型参数紧密匹配,例如 batch.shape = (config['batch_size'], config['context_window'], config['d_model']),但在推理时,你可能需要处理单个示例,例如 batch.shape = (1, 1, config['d_model'])。因此,你需要确保在 forward 传递中进行索引时,使用从输入派生的形状,而不一定是模型参数。

MultiHeadRopeAttention

让我们为这个单一的注意力头设置一个多头注意力层,看看训练时会发生什么。

这里实现的是GQA的注意力头。n_kv_head=1时就是MQA,n_kv_head>1n_kv_head<n_head时就是GQA,n_kv_head=n_head时就是原本的MHA。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
struct MultiHeadAttention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
n_head: usize,
n_kv_head: usize,
head_dim: usize,
}

impl MultiHeadAttention {
fn load(vb: VarBuilder, config: ModelConfig) -> Result<Self> {
let q_proj = linear(config.d_model, config.d_model, vb.pp("model.q_proj"))?;
let k_proj = linear(
config.d_model,
(config.d_model / config.n_head) * config.n_kv_head,
vb.pp("model.k_proj"),
)?;
let v_proj = linear(
config.d_model,
(config.d_model / config.n_head) * config.n_kv_head,
vb.pp("model.v_proj"),
)?;
let o_proj = linear(config.d_model, config.d_model, vb.pp("model.o_proj"))?;
println!(
"MHA config n_head {} n_kv_head {}",
config.n_head, config.n_kv_head
);
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
n_head: config.n_head,
n_kv_head: config.n_kv_head,
head_dim: config.d_model / config.n_head,
})
}
fn apply_rotary_emb(&self, x: &Tensor, cache: &Cache) -> Result<Tensor> {
let (b_sz, seq_len, h, n_embd) = x.dims4()?;
let cos = cache.cos.i(..seq_len)?;
let sin = cache.sin.i(..seq_len)?;
let cos = cos.unsqueeze(1)?;
let sin = sin.unsqueeze(1)?;
let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
let sin = sin.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?;
let x0 = x.narrow(D::Minus1, 0, 1)?;
let x1 = x.narrow(D::Minus1, 1, 1)?;
let dst0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
let dst1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
let rope = Tensor::cat(&[&dst0, &dst1], D::Minus1)?.reshape((b_sz, seq_len, h, n_embd))?;
Ok(rope)
}
fn forward(&self, x: &Tensor, cache: &Cache) -> Result<Tensor> {
let (b_sz, seq_len, n_embd) = x.dims3()?;

// 计算 q k v
let q = self.q_proj.forward(x)?;
let k = self.k_proj.forward(x)?;
let v = self.v_proj.forward(x)?;

assert!(n_embd == self.n_head * self.head_dim);

let q = q.reshape((b_sz, seq_len, self.n_head, self.head_dim))?;
let k = k.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?;
let v = v.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?;

// 对 q 和 k 做位置编码
let q = self.apply_rotary_emb(&q, cache)?;
let k = self.apply_rotary_emb(&k, cache)?;
// 复制成 n_head / n_kv_head 份
let k = self.repeat_kv(k)?;
let v = self.repeat_kv(v)?;

// 把 seq_len 和 n_head 交换
// 这转换一下是为了做一个 cat single head 的简单操作
// 相当于 n_head 个的seq_len*seq_len的注意力。
let q = q.transpose(1, 2)?.contiguous()?;
let k = k.transpose(1, 2)?.contiguous()?;
let v = v.transpose(1, 2)?.contiguous()?;

// q*k^T / sqrt(d_k) d_k = d_model
// 这里是 (bs,n_head) 个 (seq_len, seq_len) 的注意力
let attn = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;

// 这里是头内的softmax (seq_len,seq_len)的行总和为1
let attn = softmax(&attn, D::Minus1)?;

// 再乘 * (bs, n_head, seq_len, head_dim) 得到 (bs, n_head)个注意力头对应的加权的v
let y = attn.matmul(&v)?;
// 把 n_head 和 seq_len 交换回来,得到 (bs, seq_len, n_head, head_dim) 然后reshape以后
// 得到 (bs, seq_len, n_head * head_dim) 把头给cat到一起。
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
let y = self.o_proj.forward(&y)?;
Ok(y)
}

fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
let n_rep = self.n_head / self.n_kv_head;
if n_rep == 1 {
Ok(x)
} else {
let (b_sz, seq_len, n_kv_head, head_dim) = x.dims4()?;
let x = x
.unsqueeze(3)?
.expand((b_sz, seq_len, n_kv_head, n_rep, head_dim))?
.reshape((b_sz, seq_len, n_kv_head * n_rep, head_dim))?;
Ok(x)
}
}
}

完整模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
struct AttentionModel {
embedding: Embedding,
mlp: Sequential,
rms_norm: RmsNorm,
config: ModelConfig,
self_attention: MultiHeadAttention,
cache: Cache,
}

impl AttentionModel {
fn forward(&self, x: &Tensor, targets: Option<&Tensor>) -> Result<(Tensor, Option<Tensor>)> {

let embeds = self.embedding.forward(x)?;

let normed_embeds = self.rms_norm.forward(&embeds)?;
let y = self.self_attention.forward(&normed_embeds, &self.cache)?;

let logits = self.mlp.forward(&y)?;

if let Some(targets) = targets {
let loss = loss::cross_entropy(
&logits.reshape(((), self.config.vocab_size))?,
&targets.reshape(((),))?,
)?;
Ok((logits, Some(loss)))
} else {
Ok((logits, None))
}
}

fn load(vb: VarBuilder, config: ModelConfig) -> Result<(Self)> {
let embedding = embedding(
config.vocab_size,
config.d_model,
vb.pp("model.embed_tokens"),
)?;
let rms_norm = RmsNorm::new(config.d_model, vb.pp("model.rms_norm"))?;
let self_attention = MultiHeadAttention::load(vb.pp("model.multi_head_attention"), config)?;
let mlp = sequential::seq()
.add(linear(
config.d_model,
config.d_model,
vb.push_prefix("model.fc1"),
)?)
.add(Activation::Relu)
.add(linear(
config.d_model,
config.vocab_size,
vb.push_prefix("model.fc2"),
)?);
Ok(Self {
embedding,
mlp,
config,
rms_norm,
self_attention,
cache: Cache::new(config.context_length, config.d_model/config.n_head, vb)?,
})
}
}

1
generate(&model, &vocab, 10, device)?;
['\n\n\n\n\n\n\n\nI\n\nOOOOOOOOOFOOtOOOOOOO',
 '\nIIIIII IIIIIIIIIIIIIIIIIIIIIII',
 '\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n',
 '\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\naaame',
 '\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n']

所以看起来很糟糕。这里发生了什么?让我们通过查看注意力来开始调试。

目前的注意力是没有masked的,任何位置的字符都在关注任何其他位置的字符。
这有什么不好呢?我们试图仅基于之前的标记来预测下一个标记,但这里我们看到模型正在关注之后的标记。
换句话说,模型在作弊,或者从未来泄露信息。这是一个问题,这就是为什么我们需要使用因果掩码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132

pub struct Cache {
// ...
// mask 也可以 cached
mask: Tensor,
}
impl Cache {
fn new(context_length: usize, n_elem: usize, vb: VarBuilder) -> Result<Cache> {
// ...
// _ 表示类型由编译器推断,
// 默认的collect 是 [_],但是这个大小是不可变的要编译期间决定
// 所以这里还是要提示编译器要 collect 成 vec.
let mask: Vec<_> = (0..context_length)
.flat_map(|i| (0..context_length).map(move |j| u8::from(j > i)))
.collect();
let mask = Tensor::from_slice(&mask, (context_length, context_length), vb.device())?;
Ok(Cache { cos, sin, mask })
}
}

struct MultiHeadAttention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
n_head: usize,
n_kv_head: usize,
head_dim: usize,
}

impl MultiHeadAttention {
fn load(vb: VarBuilder, config: ModelConfig) -> Result<Self> {
let q_proj = linear(config.d_model, config.d_model, vb.pp("model.q_proj"))?;
let k_proj = linear(
config.d_model,
(config.d_model / config.n_head) * config.n_kv_head,
vb.pp("model.k_proj"),
)?;
let v_proj = linear(
config.d_model,
(config.d_model / config.n_head) * config.n_kv_head,
vb.pp("model.v_proj"),
)?;
let o_proj = linear(config.d_model, config.d_model, vb.pp("model.o_proj"))?;
println!(
"MHA config n_head {} n_kv_head {}",
config.n_head, config.n_kv_head
);
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
n_head: config.n_head,
n_kv_head: config.n_kv_head,
head_dim: config.d_model / config.n_head,
})
}
// ...
fn forward(&self, x: &Tensor, cache: &Cache) -> Result<Tensor> {
let (b_sz, seq_len, n_embd) = x.dims3()?;

// 计算 q k v
let q = self.q_proj.forward(x)?;
let k = self.k_proj.forward(x)?;
let v = self.v_proj.forward(x)?;

assert!(n_embd == self.n_head * self.head_dim);

let q = q.reshape((b_sz, seq_len, self.n_head, self.head_dim))?;
let k = k.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?;
let v = v.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?;

// 对 q 和 k 做位置编码
let q = self.apply_rotary_emb(&q, cache)?;
let k = self.apply_rotary_emb(&k, cache)?;
// 复制成 n_head / n_kv_head 份
let k = self.repeat_kv(k)?;
let v = self.repeat_kv(v)?;

// println!("q.shape {:?}", q.shape());
// 把 seq_len 和 n_head 交换
// 这转换一下是为了做一个 cat single head 的简单操作
// 相当于 n_head 个的seq_len*seq_len的注意力。
let q = q.transpose(1, 2)?.contiguous()?;
let k = k.transpose(1, 2)?.contiguous()?;
let v = v.transpose(1, 2)?.contiguous()?;

// let tmp = q.matmul(&k.t()?)?;
// 这个结果是有负数的,但是注意力层不会有负数。
// 计算结果出了很多 NaN,感觉应该是范化没做好。
// 后面发现是没有用 sqrt 用了开方,导致负数变成了NaN。
// q*k^T / sqrt(d_k) d_k = d_model
// 这里是 (bs,n_head) 个 (seq_len, seq_len) 的注意力
let attn = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
// 在 softmax 之前,把未来的token位置变为负无穷,这样在softmax之后,这些位置的概率就会变为0
let mask = cache
.mask
.i((..seq_len, ..seq_len))?
.unsqueeze(0)?
.unsqueeze(0)?
.broadcast_as(attn.shape())?;
let on_true =
Tensor::new(f32::NEG_INFINITY, attn.device())?.broadcast_as(mask.shape().dims())?;
let attn = mask.where_cond(&on_true, &attn)?;
// 取一个例子
// 这里是头内的softmax (seq_len,seq_len)的每行 (seq_len) 总和为1
let attn = softmax(&attn, D::Minus1)?;

// 再乘 * (bs, n_head, seq_len, head_dim) 得到 (bs, n_head)个注意力头对应的加权的v
let y = attn.matmul(&v)?;
// 把 n_head 和 seq_len 交换回来,得到 (bs, seq_len, n_head, head_dim) 然后reshape以后
// 得到 (bs, seq_len, n_head * head_dim) 把头给cat到一起。
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
let y = self.o_proj.forward(&y)?;
Ok(y)
}

fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
let n_rep = self.n_head / self.n_kv_head;
if n_rep == 1 {
Ok(x)
} else {
let (b_sz, seq_len, n_kv_head, head_dim) = x.dims4()?;
let x = x
.unsqueeze(3)?
.expand((b_sz, seq_len, n_kv_head, n_rep, head_dim))?
.reshape((b_sz, seq_len, n_kv_head * n_rep, head_dim))?;
Ok(x)
}
}
}

现在,我们可以让注意力激活的上三角部分(对应未来的部分)几乎被归零了。让我们看看训练时会发生什么。

SwiGLU

正如论文中所述,“我们用SwiGLU激活函数替换了ReLU非线性函数……我们使用$\frac{2}{3} 4d$的维度,而不是PaLM中的$4d$。” SwiGLU定义为:
$$
\text{SwiGLU}(x) = \text{Swish}_\beta (xW + b) \otimes (xV + c)
$$
其中$\otimes$是逐元素乘积。Swish函数定义为:
$$
\text{Swish}_\beta(x) = x \sigma(\beta x)
$$
其中$\beta$是一个可学习的参数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
fn silu(xs: &Tensor) -> Result<Tensor> {
xs / (xs.neg()?.exp()? + 1.0)?
}

struct SwiGLU {
c_fc1: Linear,
c_fc2: Linear,
c_proj: Linear,
}
// 新的 mlp 是三层的带gate
impl SwiGLU {
// silu 的特征是允许有一点点的负数
fn forward(&self, x: &Tensor) -> Result<Tensor> {
// 这里就是 SwiGLU SiLU(W_1 * x) * (W_2 * x) 是 element wise 的,这个可以作为gate门信号
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
self.c_proj.forward(&x)
}

fn load(vb: VarBuilder, cfg: ModelConfig) -> Result<Self> {
let h_size = cfg.d_model;
let i_size = cfg.hidden_dim;
let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?;
let c_fc2 = linear(h_size, i_size, vb.pp("up_proj"))?;
let c_proj = linear(i_size, h_size, vb.pp("down_proj"))?;
Ok(Self {
c_fc1,
c_fc2,
c_proj,
})
}
}

一个llama block res 两次,一次在attention之前,一次在swiglu之前。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
struct Block {
rms_1: RmsNorm,
attn: MultiHeadAttention,
rms_2: RmsNorm,
swiglu: SwiGLU,
}

impl Block {
fn forward(&self, x: &Tensor, cache: &Cache) -> Result<Tensor> {
let residual = x;
let x = self.rms_1.forward(x)?;
let x = (self.attn.forward(&x, cache)? + residual)?;
let residual = &x;
let x = (self.swiglu.forward(&self.rms_2.forward(&x)?)? + residual)?;
Ok(x)
}

fn load(vb: VarBuilder, cfg: ModelConfig) -> Result<Self> {
let attn = MultiHeadAttention::load(vb.pp("self_attn"), cfg)?;
let swiglu = SwiGLU::load(vb.pp("mlp"), cfg)?;
let rms_1 = RmsNorm::new(cfg.d_model, vb.pp("input_layernorm"))?;
let rms_2 = RmsNorm::new(cfg.d_model, vb.pp("post_attention_layernorm"))?;
Ok(Self {
rms_1,
attn,
rms_2,
swiglu,
})
}
}

现在,让我们通过创建块来添加多个层的最后完整的模型。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
struct Llama {
embedding: Embedding,
blocks: Vec<Block>,
ln_f: RmsNorm,
lm_head: Linear,
cache: Cache,
config: ModelConfig,
}

impl Llama {
pub fn forward(
&self,
x: &Tensor,
targets: Option<&Tensor>,
) -> Result<(Tensor, Option<Tensor>)> {
let (_b_sz, _seq_len) = x.dims2()?;
let mut x = self.embedding.forward(x)?;
for block in &self.blocks {
x = block.forward(&x, &self.cache)?;
}
let x = self.ln_f.forward(&x)?;
let logits = self.lm_head.forward(&x)?;

if let Some(targets) = targets {
let loss = loss::cross_entropy(
&logits.reshape(((), self.config.vocab_size))?,
&targets.reshape(((),))?,
)?;
Ok((logits, Some(loss)))
} else {
Ok((logits, None))
}
}

pub fn load(vb: VarBuilder, config: ModelConfig) -> Result<Self> {
let embed_layer = embedding(
config.vocab_size,
config.d_model,
vb.pp("model.embed_tokens"),
)?;
let lm_head = linear(config.d_model, config.vocab_size, vb.pp("lm_head"))?;
let ln_f = RmsNorm::new(config.d_model, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..config.n_layers)
.map(|i| Block::load(vb.pp(format!("model.layers.{i}")), config).unwrap())
.collect();
Ok(Self {
embedding: embed_layer,
blocks,
ln_f,
lm_head,
config,
cache: Cache::new(config.context_length, config.d_model / config.n_head, vb)?,
})
}
}

扩展

这里主要是训练的部分,在推理的过程中还涉及到一个比较重要的kv cache。
kv cache主要是缓存自回归过程中的kv,这个kv是不变的,因为这个k 和 v 只和之前的token有关系,所以可以缓存下来,
这样在推理的时候就不用重复计算了。
围绕着prompt产生第一个Token的prefill的阶段和计算完prompt之后的decode阶段也是目前业界比较关注的推理优化的方向。
源代码在这里